import spaces import torch import clip from PIL import Image from torch.cuda.amp import autocast as autocast from huggingface_hub import hf_hub_download from model import flow_model from augmentations_clip import DataAugmentationCLIP as DataAugmentationCLIP_test MODEL_REPO_ID = "davjoython/flow_fake" FLOW_MODEL_FILENAME = "flow_fake_detector_centercrop_v4.pth" CLIP_MODEL_FILENAME = "my_clip_ViT-L-14.pt" class FakeImageDetector: def __init__(self): self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"检测器初始化在 CPU 上,运行时将使用 {self.device}") print(f"正在从 {MODEL_REPO_ID} 下载 CLIP 模型...") clip_model_path = hf_hub_download( repo_id=MODEL_REPO_ID, filename=CLIP_MODEL_FILENAME ) print("CLIP 模型已下载。") self.clip_model, _ = clip.load(clip_model_path, device="cpu") self.clip_model.eval() print("CLIP 模型已加载到 CPU。") print(f"正在从 {MODEL_REPO_ID} 下载 Flow 模型...") flow_model_path = hf_hub_download( repo_id=MODEL_REPO_ID, filename=FLOW_MODEL_FILENAME ) print("Flow 模型已下载。") self.flow = flow_model() self.flow.load_state_dict(torch.load(flow_model_path, map_location="cpu")) self.flow = self.flow.to("cpu") self.flow.eval() print("Flow 模型已加载到 CPU。") print("模型加载完成。") self.transform = DataAugmentationCLIP_test( (0.9, 1.0), (0.05, 0.4), 1, global_crops_size=224, local_crops_size=96, ) @spaces.GPU(duration=10) def detect(self, image_pil, threshold=0.5): if not isinstance(image_pil, Image.Image): raise TypeError("输入必须是 PIL Image 对象") img_rgb = image_pil.convert("RGB") current_device = "cuda" if torch.cuda.is_available() else "cpu" flow_model_gpu = self.flow.to(current_device) clip_model_gpu = self.clip_model.to(current_device) transformed_img_dict = self.transform(img_rgb) img_tensor = transformed_img_dict["source"][0].unsqueeze(0).to(current_device) with torch.no_grad(): if current_device == "cuda": with autocast(): embedding = clip_model_gpu.visual(img_tensor.half()) z = flow_model_gpu(embedding) score = 1 - torch.sigmoid(torch.mean(z.float()**2 / 10000, dim=1)).item() else: embedding = clip_model_gpu.visual(img_tensor) z = flow_model_gpu(embedding.float()) score = 1 - torch.sigmoid(torch.mean(z.float()**2 / 10000, dim=1)).item() if current_device == "cuda": torch.cuda.empty_cache() if score > threshold: result_text = f"结论: 伪造的 (Fake)\n分数: {score:.10f}" else: result_text = f"结论: 真实的 (Real)\n分数: {score:.10f}" return result_text, score