yuwan0's picture
Fix huggingface_hub compatibility issue
91117ec
import json
import importlib.util
import sys
from pathlib import Path
import torch
import torch.nn as nn
import gradio as gr
from PIL import Image
from torchvision import transforms
from huggingface_hub import hf_hub_download
MODEL_REPO = "OhMyYuwan/face-forgery-detection"
with open(hf_hub_download(MODEL_REPO, "registry.json")) as f:
REGISTRY = json.load(f)["models"]
with open(hf_hub_download(MODEL_REPO, "optimal_thresholds.json")) as f:
THRESHOLDS = json.load(f)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_model_cache = {}
def load_model_class(model_name):
"""Dynamically load model class from HuggingFace"""
model_file = hf_hub_download(MODEL_REPO, f"{model_name}/model.py")
spec = importlib.util.spec_from_file_location(f"{model_name}_model", model_file)
module = importlib.util.module_from_spec(spec)
sys.modules[f"{model_name}_model"] = module
spec.loader.exec_module(module)
return module.OurNet
def load_model(model_name):
if model_name in _model_cache:
return _model_cache[model_name]
config_path = hf_hub_download(MODEL_REPO, REGISTRY[model_name]["config_file"])
model_path = hf_hub_download(MODEL_REPO, REGISTRY[model_name]["model_file"])
with open(config_path) as f:
config = json.load(f)
# Dynamically load model class
OurNet = load_model_class(model_name)
model = OurNet(config)
state = torch.load(model_path, map_location=DEVICE, weights_only=False)
model.load_state_dict(state, strict=False)
model.to(DEVICE).eval()
_model_cache[model_name] = model
return model
def predict(image, model_name):
if image is None:
return "", ""
try:
model = load_model(model_name)
threshold = THRESHOLDS[model_name]["threshold"]
x = transform(image.convert("RGB")).unsqueeze(0).to(DEVICE)
with torch.no_grad():
_, det = model.forward_det(x)
score = torch.sigmoid(det).item()
label = "πŸ”΄ Forged" if score > threshold else "βœ… Real"
confidence = score if score > threshold else 1 - score
result_text = f"**Prediction:** {label}\n**Confidence:** {confidence*100:.2f}%"
details = f"Raw Score: {score:.4f}\nThreshold: {threshold:.4f}\nModel: {model_name}"
return result_text, details
except Exception as e:
return f"❌ Error: {str(e)}", ""
# Only support models that don't require local files
SUPPORTED_MODELS = [
"convnext_base",
"inceptionnext_base",
"maxvit_base",
]
MODEL_NAMES = [m for m in REGISTRY.keys() if m in SUPPORTED_MODELS]
with gr.Blocks(title="Face Forgery Detection") as demo:
gr.Markdown("# πŸ” Face Forgery Detection\nDetect whether a face image is real or forged using state-of-the-art models.")
model_selector = gr.Dropdown(choices=MODEL_NAMES, value=MODEL_NAMES[0], label="Select Model")
with gr.Row():
with gr.Column():
img_input = gr.Image(type="pil", label="Upload Image")
detect_btn = gr.Button("πŸ” Detect", variant="primary", size="lg")
with gr.Column():
result_output = gr.Markdown(label="Result")
details_output = gr.Textbox(label="Details", lines=3, interactive=False)
detect_btn.click(
predict,
inputs=[img_input, model_selector],
outputs=[result_output, details_output],
api_name=False,
)
demo.launch()