MITTop's picture
Upload app (1).py
ccd4124 verified
"""
DeepFake Detector β€” HuggingFace Gradio Space
=============================================
Architecture : ConvNeXt-Base + custom classifier head
Weights : ARPAN2026/dfake-hcnext (auto-downloaded on first run)
Classes : Real (0) | Fake (1)
"""
import os
import urllib.request
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import timm
import gradio as gr
from PIL import Image
# ============================================================
# CONFIG
# ============================================================
MODEL_URL = "https://huggingface.co/ARPAN2026/dfake-hcnext/resolve/main/best_model_New.pth"
MODEL_PATH = "best_model_New.pth"
IMG_SIZE = 224
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Image normalisation (same as training pipeline)
TRANSFORM = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize([0.5] * 3, [0.5] * 3),
])
# ============================================================
# MODEL DEFINITION
# ============================================================
class DeepfakeModel(nn.Module):
"""ConvNeXt-Base backbone with a lightweight two-class classifier head."""
def __init__(self) -> None:
super().__init__()
self.backbone = timm.create_model(
"convnext_base", pretrained=False, num_classes=0
)
dim = self.backbone.num_features
self.classifier = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, 256),
nn.GELU(),
nn.Dropout(0.4),
nn.Linear(256, 2),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
features = self.backbone.forward_features(x)
if features.ndim == 4: # (B, C, H, W) β†’ (B, C)
features = features.flatten(2).mean(-1)
return self.classifier(features)
# ============================================================
# UTILITIES
# ============================================================
def download_weights() -> None:
"""Download model weights from HuggingFace Hub if not already present."""
if not os.path.exists(MODEL_PATH):
print(f"[INFO] Downloading weights from:\n {MODEL_URL}")
urllib.request.urlretrieve(MODEL_URL, MODEL_PATH)
print("[INFO] Download complete.")
else:
print(f"[INFO] Weights already found at '{MODEL_PATH}' β€” skipping download.")
def load_model() -> DeepfakeModel:
"""Instantiate DeepfakeModel, load saved weights, and set to eval mode."""
net = DeepfakeModel().to(DEVICE)
state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
net.load_state_dict(state_dict)
net.eval()
print(f"[INFO] Model ready on {DEVICE}.")
return net
# ============================================================
# INFERENCE
# ============================================================
def predict(model: DeepfakeModel, image: Image.Image):
"""
Run inference on a single PIL image.
Parameters
----------
model : DeepfakeModel
Loaded, eval-mode model.
image : PIL.Image.Image | None
Image uploaded by the user.
Returns
-------
label_dict : dict[str, float]
Mapping of class name β†’ probability (consumed by gr.Label).
verdict_md : str
Markdown-formatted verdict string.
"""
if image is None:
return {"Error": 1.0}, "⚠️ Please upload an image first."
tensor = TRANSFORM(image.convert("RGB")).unsqueeze(0).to(DEVICE)
with torch.no_grad():
logits = model(tensor)
probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
real_prob = float(probs[0])
fake_prob = float(probs[1])
confidence = max(real_prob, fake_prob) * 100
if fake_prob > real_prob:
verdict_md = f"## πŸ”΄ DEEPFAKE DETECTED\n**Confidence:** {confidence:.1f}%"
else:
verdict_md = f"## 🟒 LIKELY REAL\n**Confidence:** {confidence:.1f}%"
label_dict = {
"Real": round(real_prob, 4),
"Fake": round(fake_prob, 4),
}
return label_dict, verdict_md
# ============================================================
# CUSTOM CSS (dark forensic theme)
# ============================================================
CSS = """
@import url('https://fonts.googleapis.com/css2?family=Share+Tech+Mono&family=Syne:wght@400;700;800&display=swap');
:root {
--bg: #0a0c10;
--surface: #111318;
--border: #1e2330;
--accent: #00e5ff;
--danger: #ff3b5c;
--safe: #00e676;
--text: #d0d8f0;
--muted: #5a6480;
--radius: 8px;
}
body, .gradio-container {
background: var(--bg) !important;
font-family: 'Syne', sans-serif !important;
color: var(--text) !important;
}
h1.title-heading {
font-family: 'Syne', sans-serif;
font-weight: 800;
font-size: 2.4rem;
letter-spacing: -0.02em;
background: linear-gradient(90deg, var(--accent), #7b61ff);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
margin: 0;
}
p.subtitle {
color: var(--muted);
font-family: 'Share Tech Mono', monospace;
font-size: 0.85rem;
margin-top: 4px;
letter-spacing: 0.08em;
}
.gr-box, .gr-panel, .gr-form {
background: var(--surface) !important;
border: 1px solid var(--border) !important;
border-radius: var(--radius) !important;
}
.gr-image, .svelte-1n8nu59 {
border: 2px dashed var(--border) !important;
border-radius: var(--radius) !important;
background: #0d0f14 !important;
}
button.primary {
background: var(--accent) !important;
color: #000 !important;
font-family: 'Syne', sans-serif !important;
font-weight: 700 !important;
border: none !important;
border-radius: var(--radius) !important;
letter-spacing: 0.05em;
}
button.secondary {
background: transparent !important;
border: 1px solid var(--border) !important;
color: var(--muted) !important;
font-family: 'Syne', sans-serif !important;
border-radius: var(--radius) !important;
}
.gr-markdown h2 {
font-family: 'Syne', sans-serif;
font-size: 1.4rem;
font-weight: 700;
margin: 0 0 4px;
}
.gr-label .wrap {
background: var(--surface) !important;
border: 1px solid var(--border) !important;
border-radius: var(--radius) !important;
}
.gr-label .label-wrap span {
font-family: 'Share Tech Mono', monospace !important;
color: var(--text) !important;
}
.gr-label .bar {
background: linear-gradient(90deg, var(--accent), #7b61ff) !important;
}
footer { display: none !important; }
"""
# ============================================================
# GRADIO UI BUILDER
# ============================================================
def build_ui(model: DeepfakeModel) -> gr.Blocks:
"""
Construct and return the Gradio Blocks interface.
Parameters
----------
model : DeepfakeModel
Pre-loaded, eval-mode model passed in via closure.
Returns
-------
gr.Blocks
The assembled Gradio app (not yet launched).
"""
def _predict_wrapper(image: Image.Image):
"""Closure wrapper β€” captures `model` from the outer scope."""
return predict(model, image)
with gr.Blocks(css=CSS, title="DeepFake Detector") as demo:
# ── Header ──────────────────────────────────────────────────
gr.HTML("""
<div style="text-align:center; padding:32px 0 16px;">
<h1 class='title-heading'>DEEPFAKE DETECTOR</h1>
<p class='subtitle'>
ConvNeXt-Base &nbsp;Β·&nbsp; Trained on RVF Faces &nbsp;Β·&nbsp; Hackathon Edition
</p>
</div>
""")
# ── Main two-column layout ───────────────────────────────────
with gr.Row():
# Left column β€” upload + controls + model info
with gr.Column(scale=1):
image_input = gr.Image(
type="pil",
label="Upload Face Image",
height=320,
)
with gr.Row():
submit_btn = gr.Button("πŸ” Analyze", variant="primary")
clear_btn = gr.ClearButton(
components=[image_input],
value="βœ• Clear",
)
gr.HTML("""
<div style="margin-top:12px; padding:12px 16px;
background:#0d0f14; border:1px solid #1e2330;
border-radius:8px; font-family:'Share Tech Mono',monospace;
font-size:0.78rem; color:#5a6480; line-height:1.9;">
<b style="color:#00e5ff;">MODEL</b> &nbsp;&nbsp; ConvNeXt-Base + custom head<br>
<b style="color:#00e5ff;">TRAINED</b>&nbsp; Real vs Fake Faces (80/20 split)<br>
<b style="color:#00e5ff;">INPUT</b> &nbsp;&nbsp; 224 Γ— 224 Β· RGB Β· normalised<br>
<b style="color:#00e5ff;">CLASSES</b>&nbsp; Real &nbsp;|&nbsp; Fake
</div>
""")
# Right column β€” verdict + probability bars
with gr.Column(scale=1):
verdict_output = gr.Markdown(
value="*Upload an image and click **Analyze** to begin.*",
label="Verdict",
)
label_output = gr.Label(
num_top_classes=2,
label="Class Probabilities",
)
# ── Example images (add files to repo root to enable) ────────
gr.Examples(
examples=[], # e.g. [["examples/real1.jpg"], ["examples/fake1.jpg"]]
inputs=image_input,
label="Example Images",
)
# ── Event wiring ─────────────────────────────────────────────
submit_btn.click(
fn=_predict_wrapper,
inputs=image_input,
outputs=[label_output, verdict_output],
)
# ── Footer ───────────────────────────────────────────────────
gr.HTML("""
<div style="text-align:center; padding:24px 0 8px;
font-family:'Share Tech Mono',monospace;
font-size:0.75rem; color:#2a3050;">
Built with ❀ &nbsp;·&nbsp; Gradio &nbsp;·&nbsp; HuggingFace Spaces &nbsp;·&nbsp; PyTorch
</div>
""")
return demo
# ============================================================
# MAIN
# ============================================================
def main() -> None:
"""
Application entry point β€” runs the full pipeline:
1. Download model weights from HuggingFace Hub (if not cached).
2. Instantiate and load the DeepfakeModel.
3. Build the Gradio UI.
4. Launch the Space server.
"""
# ── Step 1: Weights ──────────────────────────────────────────
download_weights()
# ── Step 2: Model ────────────────────────────────────────────
model = load_model()
# ── Step 3: UI ───────────────────────────────────────────────
demo = build_ui(model)
# ── Step 4: Launch ───────────────────────────────────────────
demo.launch()
if __name__ == "__main__":
main()