File size: 3,155 Bytes
3bff21b
30a7879
 
 
 
 
3bff21b
30a7879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import spaces 
import torch
import clip  
from PIL import Image
from torch.cuda.amp import autocast as autocast
from huggingface_hub import hf_hub_download 


from model import flow_model 
from augmentations_clip import DataAugmentationCLIP as DataAugmentationCLIP_test

MODEL_REPO_ID = "davjoython/flow_fake"
FLOW_MODEL_FILENAME = "flow_fake_detector_centercrop_v4.pth" 
CLIP_MODEL_FILENAME = "my_clip_ViT-L-14.pt" 
class FakeImageDetector:

    def __init__(self):
        
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"检测器初始化在 CPU 上,运行时将使用 {self.device}")

        print(f"正在从 {MODEL_REPO_ID} 下载 CLIP 模型...")
        clip_model_path = hf_hub_download(
            repo_id=MODEL_REPO_ID,
            filename=CLIP_MODEL_FILENAME
        )
        print("CLIP 模型已下载。")
        self.clip_model, _ = clip.load(clip_model_path, device="cpu") 
        self.clip_model.eval()
        print("CLIP 模型已加载到 CPU。")

        print(f"正在从 {MODEL_REPO_ID} 下载 Flow 模型...")
        flow_model_path = hf_hub_download(
            repo_id=MODEL_REPO_ID,
            filename=FLOW_MODEL_FILENAME
        )
        print("Flow 模型已下载。")
        self.flow = flow_model()
        self.flow.load_state_dict(torch.load(flow_model_path, map_location="cpu"))
        self.flow = self.flow.to("cpu")
        self.flow.eval()
        print("Flow 模型已加载到 CPU。")
        
        print("模型加载完成。")

        self.transform = DataAugmentationCLIP_test(
            (0.9, 1.0), (0.05, 0.4), 1,
            global_crops_size=224, local_crops_size=96,
        )

    @spaces.GPU(duration=10) 
    def detect(self, image_pil, threshold=0.5):
        
        if not isinstance(image_pil, Image.Image):
            raise TypeError("输入必须是 PIL Image 对象")

        img_rgb = image_pil.convert("RGB")
        
        current_device = "cuda" if torch.cuda.is_available() else "cpu"
        
        flow_model_gpu = self.flow.to(current_device)
        clip_model_gpu = self.clip_model.to(current_device)
        
        transformed_img_dict = self.transform(img_rgb)
        img_tensor = transformed_img_dict["source"][0].unsqueeze(0).to(current_device)

        with torch.no_grad():
            if current_device == "cuda":
                with autocast():
                    embedding = clip_model_gpu.visual(img_tensor.half())
                    z = flow_model_gpu(embedding)
                    score = 1 - torch.sigmoid(torch.mean(z.float()**2 / 10000, dim=1)).item()
            else:
                embedding = clip_model_gpu.visual(img_tensor)
                z = flow_model_gpu(embedding.float()) 
                score = 1 - torch.sigmoid(torch.mean(z.float()**2 / 10000, dim=1)).item()

        if current_device == "cuda":
            torch.cuda.empty_cache()

        if score > threshold:
            result_text = f"结论: 伪造的 (Fake)\n分数: {score:.10f}"
        else:
            result_text = f"结论: 真实的 (Real)\n分数: {score:.10f}"
            
        return result_text, score