import torch import clip from PIL import Image from torchvision import transforms import torch.nn.functional as F import torch.nn as nn import argparse from .blipmodels import blip_decoder class NeuralNet(nn.Module): def __init__(self, input_size, hidden_size_list, num_classes): super(NeuralNet, self).__init__() self.dropout2 = nn.Dropout(0.5) self.fc1 = nn.Linear(input_size, hidden_size_list[0]) self.fc2 = nn.Linear(hidden_size_list[0], hidden_size_list[1]) self.fc3 = nn.Linear(hidden_size_list[1], num_classes) def forward(self, x): out = self.fc1(x) out = F.relu(out) out = self.dropout2(out) out = self.fc2(out) out = F.relu(out) out = self.fc3(out) return out def load_models(device=None): """ 加载 CLIP、BLIP 和线性分类器,只加载一次。 """ import os print("Current working folder:", os.getcwd()) # ← 加这行 if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) image_size = 224 blip_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth' blip = blip_decoder(pretrained=blip_url, image_size=image_size, vit='base') blip.eval() blip = blip.to(device) # 加载 finetuned CLIP clip_finetuned = torch.load("finetune_clip.pt", map_location=device, weights_only=False).to(device) # 加载线性分类器 linear = NeuralNet(1024, [512, 256], 2).to(device) linear = torch.load("clip_linear.pt", map_location=device, weights_only=False).to(device) linear.eval() return { "device": device, "clip_model": clip_model, "clip_preprocess": clip_preprocess, "blip": blip, "linear": linear, } def predict_image(image_path, models=None): """ 传入图片路径,返回预测结果和概率。 """ if models is None: models = load_models() device = models["device"] clip_model = models["clip_model"] clip_preprocess = models["clip_preprocess"] blip = models["blip"] linear = models["linear"] # 1. 用 BLIP 生成 caption img = Image.open(image_path).convert('RGB') tform = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), ]) img_tensor = tform(img).unsqueeze(0).to(device) with torch.no_grad(): caption = blip.generate(img_tensor, sample=False, num_beams=3, max_length=60, min_length=5) text = clip.tokenize(list(caption)).to(device) # 2. 用 CLIP preprocess 处理图像 image = clip_preprocess(Image.open(image_path)).unsqueeze(0).to(device) # 3. 提取特征并分类 with torch.no_grad(): image_features = clip_model.encode_image(image) text_features = clip_model.encode_text(text) emb = torch.cat((image_features, text_features), 1) output = linear(emb.float()) probs = torch.softmax(output, dim=1) pred = probs.argmax(1).item() probs_list = probs[0].cpu().numpy().tolist() return pred, probs_list def main(): parser = argparse.ArgumentParser(description='De-Fake single image test') parser.add_argument('--image_path', default='CLIP.png', type=str) args = parser.parse_args() models = load_models() pred, probs = predict_image(args.image_path, models) print("Prediction:", pred) print("Probabilities:", probs) if __name__ == "__main__": main()