import json import importlib.util import sys from pathlib import Path import torch import torch.nn as nn import gradio as gr from PIL import Image from torchvision import transforms from huggingface_hub import hf_hub_download MODEL_REPO = "OhMyYuwan/face-forgery-detection" with open(hf_hub_download(MODEL_REPO, "registry.json")) as f: REGISTRY = json.load(f)["models"] with open(hf_hub_download(MODEL_REPO, "optimal_thresholds.json")) as f: THRESHOLDS = json.load(f) transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") _model_cache = {} def load_model_class(model_name): """Dynamically load model class from HuggingFace""" model_file = hf_hub_download(MODEL_REPO, f"{model_name}/model.py") spec = importlib.util.spec_from_file_location(f"{model_name}_model", model_file) module = importlib.util.module_from_spec(spec) sys.modules[f"{model_name}_model"] = module spec.loader.exec_module(module) return module.OurNet def load_model(model_name): if model_name in _model_cache: return _model_cache[model_name] config_path = hf_hub_download(MODEL_REPO, REGISTRY[model_name]["config_file"]) model_path = hf_hub_download(MODEL_REPO, REGISTRY[model_name]["model_file"]) with open(config_path) as f: config = json.load(f) # Dynamically load model class OurNet = load_model_class(model_name) model = OurNet(config) state = torch.load(model_path, map_location=DEVICE, weights_only=False) model.load_state_dict(state, strict=False) model.to(DEVICE).eval() _model_cache[model_name] = model return model def predict(image, model_name): if image is None: return "", "" try: model = load_model(model_name) threshold = THRESHOLDS[model_name]["threshold"] x = transform(image.convert("RGB")).unsqueeze(0).to(DEVICE) with torch.no_grad(): _, det = model.forward_det(x) score = torch.sigmoid(det).item() label = "🔴 Forged" if score > threshold else "✅ Real" confidence = score if score > threshold else 1 - score result_text = f"**Prediction:** {label}\n**Confidence:** {confidence*100:.2f}%" details = f"Raw Score: {score:.4f}\nThreshold: {threshold:.4f}\nModel: {model_name}" return result_text, details except Exception as e: return f"❌ Error: {str(e)}", "" # Only support models that don't require local files SUPPORTED_MODELS = [ "convnext_base", "inceptionnext_base", "maxvit_base", ] MODEL_NAMES = [m for m in REGISTRY.keys() if m in SUPPORTED_MODELS] with gr.Blocks(title="Face Forgery Detection") as demo: gr.Markdown("# 🔍 Face Forgery Detection\nDetect whether a face image is real or forged using state-of-the-art models.") model_selector = gr.Dropdown(choices=MODEL_NAMES, value=MODEL_NAMES[0], label="Select Model") with gr.Row(): with gr.Column(): img_input = gr.Image(type="pil", label="Upload Image") detect_btn = gr.Button("🔍 Detect", variant="primary", size="lg") with gr.Column(): result_output = gr.Markdown(label="Result") details_output = gr.Textbox(label="Details", lines=3, interactive=False) detect_btn.click( predict, inputs=[img_input, model_selector], outputs=[result_output, details_output], api_name=False, ) demo.launch()