Spaces:
Runtime error
Runtime error
| import spaces | |
| import torch | |
| import clip | |
| from PIL import Image | |
| from torch.cuda.amp import autocast as autocast | |
| from huggingface_hub import hf_hub_download | |
| from model import flow_model | |
| from augmentations_clip import DataAugmentationCLIP as DataAugmentationCLIP_test | |
| MODEL_REPO_ID = "davjoython/flow_fake" | |
| FLOW_MODEL_FILENAME = "flow_fake_detector_centercrop_v4.pth" | |
| CLIP_MODEL_FILENAME = "my_clip_ViT-L-14.pt" | |
| class FakeImageDetector: | |
| def __init__(self): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"检测器初始化在 CPU 上,运行时将使用 {self.device}") | |
| print(f"正在从 {MODEL_REPO_ID} 下载 CLIP 模型...") | |
| clip_model_path = hf_hub_download( | |
| repo_id=MODEL_REPO_ID, | |
| filename=CLIP_MODEL_FILENAME | |
| ) | |
| print("CLIP 模型已下载。") | |
| self.clip_model, _ = clip.load(clip_model_path, device="cpu") | |
| self.clip_model.eval() | |
| print("CLIP 模型已加载到 CPU。") | |
| print(f"正在从 {MODEL_REPO_ID} 下载 Flow 模型...") | |
| flow_model_path = hf_hub_download( | |
| repo_id=MODEL_REPO_ID, | |
| filename=FLOW_MODEL_FILENAME | |
| ) | |
| print("Flow 模型已下载。") | |
| self.flow = flow_model() | |
| self.flow.load_state_dict(torch.load(flow_model_path, map_location="cpu")) | |
| self.flow = self.flow.to("cpu") | |
| self.flow.eval() | |
| print("Flow 模型已加载到 CPU。") | |
| print("模型加载完成。") | |
| self.transform = DataAugmentationCLIP_test( | |
| (0.9, 1.0), (0.05, 0.4), 1, | |
| global_crops_size=224, local_crops_size=96, | |
| ) | |
| def detect(self, image_pil, threshold=0.5): | |
| if not isinstance(image_pil, Image.Image): | |
| raise TypeError("输入必须是 PIL Image 对象") | |
| img_rgb = image_pil.convert("RGB") | |
| current_device = "cuda" if torch.cuda.is_available() else "cpu" | |
| flow_model_gpu = self.flow.to(current_device) | |
| clip_model_gpu = self.clip_model.to(current_device) | |
| transformed_img_dict = self.transform(img_rgb) | |
| img_tensor = transformed_img_dict["source"][0].unsqueeze(0).to(current_device) | |
| with torch.no_grad(): | |
| if current_device == "cuda": | |
| with autocast(): | |
| embedding = clip_model_gpu.visual(img_tensor.half()) | |
| z = flow_model_gpu(embedding) | |
| score = 1 - torch.sigmoid(torch.mean(z.float()**2 / 10000, dim=1)).item() | |
| else: | |
| embedding = clip_model_gpu.visual(img_tensor) | |
| z = flow_model_gpu(embedding.float()) | |
| score = 1 - torch.sigmoid(torch.mean(z.float()**2 / 10000, dim=1)).item() | |
| if current_device == "cuda": | |
| torch.cuda.empty_cache() | |
| if score > threshold: | |
| result_text = f"结论: 伪造的 (Fake)\n分数: {score:.10f}" | |
| else: | |
| result_text = f"结论: 真实的 (Real)\n分数: {score:.10f}" | |
| return result_text, score |