import os import gradio as gr import torch import numpy as np import cv2 from PIL import Image from transformers import ViTImageProcessor, ViTForImageClassification # from transformers import AutoModelForImageClassification, AutoImageProcessor # ----------------------------- # CONFIGURATION # ----------------------------- MODEL_REPO = "SARVM/ViT_Deepfake" HF_TOKEN = os.getenv("HF_TOKEN") # Set in Space secrets or local env # HF_TOKEN = "hf_xxxxxxxxxxxxxxxxxxxxxxxx" # 🔐 Replace with your actual Hugging Face token print(f"Loading model from {MODEL_REPO}...") processor = ViTImageProcessor.from_pretrained( MODEL_REPO, token=HF_TOKEN ) model = ViTForImageClassification.from_pretrained( MODEL_REPO, token=HF_TOKEN, output_attentions=True ) # processor = AutoImageProcessor.from_pretrained( # MODEL_REPO, # token=HF_TOKEN # ) # model = AutoModelForImageClassification.from_pretrained( # MODEL_REPO, # token=HF_TOKEN # ) model.eval() # Override labels to REAL / FAKE model.config.id2label = { 1: "REAL", 0: "FAKE" } model.config.label2id = { "REAL": 1, "FAKE": 0 } # ----------------------------- # ATTENTION ROLLOUT # ----------------------------- def compute_attention_rollout(attentions): att_mat = torch.stack(attentions).squeeze(1) att_mat = att_mat.mean(dim=1) residual_att = torch.eye(att_mat.size(-1)) aug_att_mat = att_mat + residual_att aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1) joint_attentions = torch.zeros_like(aug_att_mat) joint_attentions[0] = aug_att_mat[0] for n in range(1, aug_att_mat.size(0)): joint_attentions[n] = aug_att_mat[n] @ joint_attentions[n - 1] return joint_attentions[-1] # ----------------------------- # PREDICTION FUNCTION # ----------------------------- def predict(image): if image is None: return None, None, None inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs, output_attentions=True) logits = outputs.logits attentions = outputs.attentions probs = torch.nn.functional.softmax(logits, dim=-1) confidence, predicted_class_idx = torch.max(probs, dim=-1) prediction = model.config.id2label[predicted_class_idx.item()] confidence_pct = round(confidence.item() * 100, 2) # Attention rollout rollout = compute_attention_rollout(attentions) mask = rollout[0, 1:] size = int(mask.shape[0] ** 0.5) mask = mask.reshape(size, size).cpu().numpy() mask = cv2.resize(mask, image.size) mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8) heatmap = cv2.applyColorMap( np.uint8(255 * mask), cv2.COLORMAP_JET ) overlay = cv2.addWeighted( np.array(image), 0.6, heatmap, 0.4, 0 ) return prediction, f"{confidence_pct}%", overlay # ----------------------------- # UI DESIGN # ----------------------------- custom_css = """ /* Professional Adaptive Theme */ :root { --primary-blue: #2563eb; --hero-text: #0f172a; /* Dark slate for light mode */ } .dark { --hero-text: #f8fafc; /* White for dark mode */ } /* Background refinement */ body { background-color: var(--background-fill-primary); } /* Adaptive Typography */ .hero { text-align: center; font-family: 'Inter', sans-serif; font-size: 48px; font-weight: 800; letter-spacing: -0.04em; margin-top: 50px; /* This variable handles the visibility toggle */ color: var(--hero-text) !important; } .sub { text-align: center; opacity: 0.7; font-size: 14px; font-weight: 600; letter-spacing: 0.1em; text-transform: uppercase; margin-bottom: 40px; color: var(--body-text-color); } /* Professional Container Styling */ .glass { background: var(--block-background-fill) !important; border: 1px solid var(--border-color-primary) !important; border-radius: 12px !important; padding: 24px !important; box-shadow: var(--block-shadow); transition: all 0.2s ease; } .glass:hover { border-color: var(--primary-blue) !important; box-shadow: 0 4px 20px rgba(37, 99, 235, 0.1); } /* Enterprise Button */ button.primary { background: var(--primary-blue) !important; color: white !important; border: none !important; font-weight: 600 !important; padding: 12px 24px !important; border-radius: 8px !important; box-shadow: 0 4px 12px rgba(37, 99, 235, 0.2) !important; } button.primary:hover { background: #1d4ed8 !important; transform: translateY(-1px); box-shadow: 0 6px 16px rgba(37, 99, 235, 0.3) !important; } /* Label & Input tweaks for clarity */ .gr-label { font-weight: 600 !important; font-size: 12px !important; text-transform: uppercase; color: var(--primary-blue) !important; } """ with gr.Blocks( css=custom_css, theme=gr.themes.Soft( primary_hue="blue", font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui"] ) ) as demo: gr.Markdown(f"
FORESIGHT.
") gr.Markdown("
Deep Intelligence Neural Analysis
") with gr.Row(): with gr.Column(): with gr.Group(elem_classes="glass"): input_image = gr.Image(type="pil", label="Source Input") run_btn = gr.Button("RUN DIAGNOSTIC", variant="primary") with gr.Column(): with gr.Group(elem_classes="glass"): output_label = gr.Label(label="Classification Verdict") output_conf = gr.Textbox(label="Confidence Rating", interactive=False) heatmap_output = gr.Image(label="Vulnerability Visualization") run_btn.click( fn=predict, inputs=input_image, outputs=[output_label, output_conf, heatmap_output] ) if __name__ == "__main__": demo.launch()