File size: 3,621 Bytes
776deff
 
 
 
 
 
 
 
36f9974
776deff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03240ea
776deff
 
 
03240ea
776deff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
import clip
from PIL import Image
from torchvision import transforms
import torch.nn.functional as F
import torch.nn as nn
import argparse

from .blipmodels import blip_decoder

class NeuralNet(nn.Module):
    def __init__(self, input_size, hidden_size_list, num_classes):
        super(NeuralNet, self).__init__()
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(input_size, hidden_size_list[0])
        self.fc2 = nn.Linear(hidden_size_list[0], hidden_size_list[1])
        self.fc3 = nn.Linear(hidden_size_list[1], num_classes)

    def forward(self, x):
        out = self.fc1(x)
        out = F.relu(out)
        out = self.dropout2(out)
        out = self.fc2(out)
        out = F.relu(out)
        out = self.fc3(out)
        return out

def load_models(device=None):
    """
    加载 CLIP、BLIP 和线性分类器,只加载一次。
    """
    import os
    print("Current working folder:", os.getcwd())  # ← 加这行
    
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)

    image_size = 224
    blip_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth'
    blip = blip_decoder(pretrained=blip_url, image_size=image_size, vit='base')
    blip.eval()
    blip = blip.to(device)

    # 加载 finetuned CLIP
    clip_finetuned = torch.load("finetune_clip.pt", map_location=device, weights_only=False).to(device)

    # 加载线性分类器
    linear = NeuralNet(1024, [512, 256], 2).to(device)
    linear = torch.load("clip_linear.pt", map_location=device, weights_only=False).to(device)
    linear.eval()

    return {
        "device": device,
        "clip_model": clip_model,
        "clip_preprocess": clip_preprocess,
        "blip": blip,
        "linear": linear,
    }

def predict_image(image_path, models=None):
    """
    传入图片路径,返回预测结果和概率。
    """
    if models is None:
        models = load_models()

    device = models["device"]
    clip_model = models["clip_model"]
    clip_preprocess = models["clip_preprocess"]
    blip = models["blip"]
    linear = models["linear"]

    # 1. 用 BLIP 生成 caption
    img = Image.open(image_path).convert('RGB')
    tform = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ])
    img_tensor = tform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        caption = blip.generate(img_tensor, sample=False, num_beams=3, max_length=60, min_length=5)
    text = clip.tokenize(list(caption)).to(device)

    # 2. 用 CLIP preprocess 处理图像
    image = clip_preprocess(Image.open(image_path)).unsqueeze(0).to(device)

    # 3. 提取特征并分类
    with torch.no_grad():
        image_features = clip_model.encode_image(image)
        text_features = clip_model.encode_text(text)
        emb = torch.cat((image_features, text_features), 1)
        output = linear(emb.float())
        probs = torch.softmax(output, dim=1)
        pred = probs.argmax(1).item()
        probs_list = probs[0].cpu().numpy().tolist()

    return pred, probs_list

def main():
    parser = argparse.ArgumentParser(description='De-Fake single image test')
    parser.add_argument('--image_path', default='CLIP.png', type=str)
    args = parser.parse_args()

    models = load_models()
    pred, probs = predict_image(args.image_path, models)
    print("Prediction:", pred)
    print("Probabilities:", probs)

if __name__ == "__main__":
    main()