home
add curl examples
36f9974
import torch
import clip
from PIL import Image
from torchvision import transforms
import torch.nn.functional as F
import torch.nn as nn
import argparse
from .blipmodels import blip_decoder
class NeuralNet(nn.Module):
def __init__(self, input_size, hidden_size_list, num_classes):
super(NeuralNet, self).__init__()
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(input_size, hidden_size_list[0])
self.fc2 = nn.Linear(hidden_size_list[0], hidden_size_list[1])
self.fc3 = nn.Linear(hidden_size_list[1], num_classes)
def forward(self, x):
out = self.fc1(x)
out = F.relu(out)
out = self.dropout2(out)
out = self.fc2(out)
out = F.relu(out)
out = self.fc3(out)
return out
def load_models(device=None):
"""
加载 CLIP、BLIP 和线性分类器,只加载一次。
"""
import os
print("Current working folder:", os.getcwd()) # ← 加这行
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
image_size = 224
blip_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth'
blip = blip_decoder(pretrained=blip_url, image_size=image_size, vit='base')
blip.eval()
blip = blip.to(device)
# 加载 finetuned CLIP
clip_finetuned = torch.load("finetune_clip.pt", map_location=device, weights_only=False).to(device)
# 加载线性分类器
linear = NeuralNet(1024, [512, 256], 2).to(device)
linear = torch.load("clip_linear.pt", map_location=device, weights_only=False).to(device)
linear.eval()
return {
"device": device,
"clip_model": clip_model,
"clip_preprocess": clip_preprocess,
"blip": blip,
"linear": linear,
}
def predict_image(image_path, models=None):
"""
传入图片路径,返回预测结果和概率。
"""
if models is None:
models = load_models()
device = models["device"]
clip_model = models["clip_model"]
clip_preprocess = models["clip_preprocess"]
blip = models["blip"]
linear = models["linear"]
# 1. 用 BLIP 生成 caption
img = Image.open(image_path).convert('RGB')
tform = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
])
img_tensor = tform(img).unsqueeze(0).to(device)
with torch.no_grad():
caption = blip.generate(img_tensor, sample=False, num_beams=3, max_length=60, min_length=5)
text = clip.tokenize(list(caption)).to(device)
# 2. 用 CLIP preprocess 处理图像
image = clip_preprocess(Image.open(image_path)).unsqueeze(0).to(device)
# 3. 提取特征并分类
with torch.no_grad():
image_features = clip_model.encode_image(image)
text_features = clip_model.encode_text(text)
emb = torch.cat((image_features, text_features), 1)
output = linear(emb.float())
probs = torch.softmax(output, dim=1)
pred = probs.argmax(1).item()
probs_list = probs[0].cpu().numpy().tolist()
return pred, probs_list
def main():
parser = argparse.ArgumentParser(description='De-Fake single image test')
parser.add_argument('--image_path', default='CLIP.png', type=str)
args = parser.parse_args()
models = load_models()
pred, probs = predict_image(args.image_path, models)
print("Prediction:", pred)
print("Probabilities:", probs)
if __name__ == "__main__":
main()