Flow_Fake_Demo / detector.py
davjoython's picture
Upload detector.py
3bff21b verified
import spaces
import torch
import clip
from PIL import Image
from torch.cuda.amp import autocast as autocast
from huggingface_hub import hf_hub_download
from model import flow_model
from augmentations_clip import DataAugmentationCLIP as DataAugmentationCLIP_test
MODEL_REPO_ID = "davjoython/flow_fake"
FLOW_MODEL_FILENAME = "flow_fake_detector_centercrop_v4.pth"
CLIP_MODEL_FILENAME = "my_clip_ViT-L-14.pt"
class FakeImageDetector:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"检测器初始化在 CPU 上,运行时将使用 {self.device}")
print(f"正在从 {MODEL_REPO_ID} 下载 CLIP 模型...")
clip_model_path = hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=CLIP_MODEL_FILENAME
)
print("CLIP 模型已下载。")
self.clip_model, _ = clip.load(clip_model_path, device="cpu")
self.clip_model.eval()
print("CLIP 模型已加载到 CPU。")
print(f"正在从 {MODEL_REPO_ID} 下载 Flow 模型...")
flow_model_path = hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=FLOW_MODEL_FILENAME
)
print("Flow 模型已下载。")
self.flow = flow_model()
self.flow.load_state_dict(torch.load(flow_model_path, map_location="cpu"))
self.flow = self.flow.to("cpu")
self.flow.eval()
print("Flow 模型已加载到 CPU。")
print("模型加载完成。")
self.transform = DataAugmentationCLIP_test(
(0.9, 1.0), (0.05, 0.4), 1,
global_crops_size=224, local_crops_size=96,
)
@spaces.GPU(duration=10)
def detect(self, image_pil, threshold=0.5):
if not isinstance(image_pil, Image.Image):
raise TypeError("输入必须是 PIL Image 对象")
img_rgb = image_pil.convert("RGB")
current_device = "cuda" if torch.cuda.is_available() else "cpu"
flow_model_gpu = self.flow.to(current_device)
clip_model_gpu = self.clip_model.to(current_device)
transformed_img_dict = self.transform(img_rgb)
img_tensor = transformed_img_dict["source"][0].unsqueeze(0).to(current_device)
with torch.no_grad():
if current_device == "cuda":
with autocast():
embedding = clip_model_gpu.visual(img_tensor.half())
z = flow_model_gpu(embedding)
score = 1 - torch.sigmoid(torch.mean(z.float()**2 / 10000, dim=1)).item()
else:
embedding = clip_model_gpu.visual(img_tensor)
z = flow_model_gpu(embedding.float())
score = 1 - torch.sigmoid(torch.mean(z.float()**2 / 10000, dim=1)).item()
if current_device == "cuda":
torch.cuda.empty_cache()
if score > threshold:
result_text = f"结论: 伪造的 (Fake)\n分数: {score:.10f}"
else:
result_text = f"结论: 真实的 (Real)\n分数: {score:.10f}"
return result_text, score