Spaces:
Sleeping
Sleeping
| import torch | |
| import clip | |
| from PIL import Image | |
| from torchvision import transforms | |
| import torch.nn.functional as F | |
| import torch.nn as nn | |
| import argparse | |
| from .blipmodels import blip_decoder | |
| class NeuralNet(nn.Module): | |
| def __init__(self, input_size, hidden_size_list, num_classes): | |
| super(NeuralNet, self).__init__() | |
| self.dropout2 = nn.Dropout(0.5) | |
| self.fc1 = nn.Linear(input_size, hidden_size_list[0]) | |
| self.fc2 = nn.Linear(hidden_size_list[0], hidden_size_list[1]) | |
| self.fc3 = nn.Linear(hidden_size_list[1], num_classes) | |
| def forward(self, x): | |
| out = self.fc1(x) | |
| out = F.relu(out) | |
| out = self.dropout2(out) | |
| out = self.fc2(out) | |
| out = F.relu(out) | |
| out = self.fc3(out) | |
| return out | |
| def load_models(device=None): | |
| """ | |
| 加载 CLIP、BLIP 和线性分类器,只加载一次。 | |
| """ | |
| import os | |
| print("Current working folder:", os.getcwd()) # ← 加这行 | |
| if device is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) | |
| image_size = 224 | |
| blip_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth' | |
| blip = blip_decoder(pretrained=blip_url, image_size=image_size, vit='base') | |
| blip.eval() | |
| blip = blip.to(device) | |
| # 加载 finetuned CLIP | |
| clip_finetuned = torch.load("finetune_clip.pt", map_location=device, weights_only=False).to(device) | |
| # 加载线性分类器 | |
| linear = NeuralNet(1024, [512, 256], 2).to(device) | |
| linear = torch.load("clip_linear.pt", map_location=device, weights_only=False).to(device) | |
| linear.eval() | |
| return { | |
| "device": device, | |
| "clip_model": clip_model, | |
| "clip_preprocess": clip_preprocess, | |
| "blip": blip, | |
| "linear": linear, | |
| } | |
| def predict_image(image_path, models=None): | |
| """ | |
| 传入图片路径,返回预测结果和概率。 | |
| """ | |
| if models is None: | |
| models = load_models() | |
| device = models["device"] | |
| clip_model = models["clip_model"] | |
| clip_preprocess = models["clip_preprocess"] | |
| blip = models["blip"] | |
| linear = models["linear"] | |
| # 1. 用 BLIP 生成 caption | |
| img = Image.open(image_path).convert('RGB') | |
| tform = transforms.Compose([ | |
| transforms.Resize(224), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| ]) | |
| img_tensor = tform(img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| caption = blip.generate(img_tensor, sample=False, num_beams=3, max_length=60, min_length=5) | |
| text = clip.tokenize(list(caption)).to(device) | |
| # 2. 用 CLIP preprocess 处理图像 | |
| image = clip_preprocess(Image.open(image_path)).unsqueeze(0).to(device) | |
| # 3. 提取特征并分类 | |
| with torch.no_grad(): | |
| image_features = clip_model.encode_image(image) | |
| text_features = clip_model.encode_text(text) | |
| emb = torch.cat((image_features, text_features), 1) | |
| output = linear(emb.float()) | |
| probs = torch.softmax(output, dim=1) | |
| pred = probs.argmax(1).item() | |
| probs_list = probs[0].cpu().numpy().tolist() | |
| return pred, probs_list | |
| def main(): | |
| parser = argparse.ArgumentParser(description='De-Fake single image test') | |
| parser.add_argument('--image_path', default='CLIP.png', type=str) | |
| args = parser.parse_args() | |
| models = load_models() | |
| pred, probs = predict_image(args.image_path, models) | |
| print("Prediction:", pred) | |
| print("Probabilities:", probs) | |
| if __name__ == "__main__": | |
| main() | |