Spaces:
Sleeping
Sleeping
| import json | |
| import importlib.util | |
| import sys | |
| from pathlib import Path | |
| import torch | |
| import torch.nn as nn | |
| import gradio as gr | |
| from PIL import Image | |
| from torchvision import transforms | |
| from huggingface_hub import hf_hub_download | |
| MODEL_REPO = "OhMyYuwan/face-forgery-detection" | |
| with open(hf_hub_download(MODEL_REPO, "registry.json")) as f: | |
| REGISTRY = json.load(f)["models"] | |
| with open(hf_hub_download(MODEL_REPO, "optimal_thresholds.json")) as f: | |
| THRESHOLDS = json.load(f) | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| _model_cache = {} | |
| def load_model_class(model_name): | |
| """Dynamically load model class from HuggingFace""" | |
| model_file = hf_hub_download(MODEL_REPO, f"{model_name}/model.py") | |
| spec = importlib.util.spec_from_file_location(f"{model_name}_model", model_file) | |
| module = importlib.util.module_from_spec(spec) | |
| sys.modules[f"{model_name}_model"] = module | |
| spec.loader.exec_module(module) | |
| return module.OurNet | |
| def load_model(model_name): | |
| if model_name in _model_cache: | |
| return _model_cache[model_name] | |
| config_path = hf_hub_download(MODEL_REPO, REGISTRY[model_name]["config_file"]) | |
| model_path = hf_hub_download(MODEL_REPO, REGISTRY[model_name]["model_file"]) | |
| with open(config_path) as f: | |
| config = json.load(f) | |
| # Dynamically load model class | |
| OurNet = load_model_class(model_name) | |
| model = OurNet(config) | |
| state = torch.load(model_path, map_location=DEVICE, weights_only=False) | |
| model.load_state_dict(state, strict=False) | |
| model.to(DEVICE).eval() | |
| _model_cache[model_name] = model | |
| return model | |
| def predict(image, model_name): | |
| if image is None: | |
| return "", "" | |
| try: | |
| model = load_model(model_name) | |
| threshold = THRESHOLDS[model_name]["threshold"] | |
| x = transform(image.convert("RGB")).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| _, det = model.forward_det(x) | |
| score = torch.sigmoid(det).item() | |
| label = "π΄ Forged" if score > threshold else "β Real" | |
| confidence = score if score > threshold else 1 - score | |
| result_text = f"**Prediction:** {label}\n**Confidence:** {confidence*100:.2f}%" | |
| details = f"Raw Score: {score:.4f}\nThreshold: {threshold:.4f}\nModel: {model_name}" | |
| return result_text, details | |
| except Exception as e: | |
| return f"β Error: {str(e)}", "" | |
| # Only support models that don't require local files | |
| SUPPORTED_MODELS = [ | |
| "convnext_base", | |
| "inceptionnext_base", | |
| "maxvit_base", | |
| ] | |
| MODEL_NAMES = [m for m in REGISTRY.keys() if m in SUPPORTED_MODELS] | |
| with gr.Blocks(title="Face Forgery Detection") as demo: | |
| gr.Markdown("# π Face Forgery Detection\nDetect whether a face image is real or forged using state-of-the-art models.") | |
| model_selector = gr.Dropdown(choices=MODEL_NAMES, value=MODEL_NAMES[0], label="Select Model") | |
| with gr.Row(): | |
| with gr.Column(): | |
| img_input = gr.Image(type="pil", label="Upload Image") | |
| detect_btn = gr.Button("π Detect", variant="primary", size="lg") | |
| with gr.Column(): | |
| result_output = gr.Markdown(label="Result") | |
| details_output = gr.Textbox(label="Details", lines=3, interactive=False) | |
| detect_btn.click( | |
| predict, | |
| inputs=[img_input, model_selector], | |
| outputs=[result_output, details_output], | |
| api_name=False, | |
| ) | |
| demo.launch() | |