Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| from transformers import ViTImageProcessor, ViTForImageClassification | |
| # from transformers import AutoModelForImageClassification, AutoImageProcessor | |
| # ----------------------------- | |
| # CONFIGURATION | |
| # ----------------------------- | |
| MODEL_REPO = "SARVM/ViT_Deepfake" | |
| HF_TOKEN = os.getenv("HF_TOKEN") # Set in Space secrets or local env | |
| # HF_TOKEN = "hf_xxxxxxxxxxxxxxxxxxxxxxxx" # 🔐 Replace with your actual Hugging Face token | |
| print(f"Loading model from {MODEL_REPO}...") | |
| processor = ViTImageProcessor.from_pretrained( | |
| MODEL_REPO, | |
| token=HF_TOKEN | |
| ) | |
| model = ViTForImageClassification.from_pretrained( | |
| MODEL_REPO, | |
| token=HF_TOKEN, | |
| output_attentions=True | |
| ) | |
| # processor = AutoImageProcessor.from_pretrained( | |
| # MODEL_REPO, | |
| # token=HF_TOKEN | |
| # ) | |
| # model = AutoModelForImageClassification.from_pretrained( | |
| # MODEL_REPO, | |
| # token=HF_TOKEN | |
| # ) | |
| model.eval() | |
| # Override labels to REAL / FAKE | |
| model.config.id2label = { | |
| 1: "REAL", | |
| 0: "FAKE" | |
| } | |
| model.config.label2id = { | |
| "REAL": 1, | |
| "FAKE": 0 | |
| } | |
| # ----------------------------- | |
| # ATTENTION ROLLOUT | |
| # ----------------------------- | |
| def compute_attention_rollout(attentions): | |
| att_mat = torch.stack(attentions).squeeze(1) | |
| att_mat = att_mat.mean(dim=1) | |
| residual_att = torch.eye(att_mat.size(-1)) | |
| aug_att_mat = att_mat + residual_att | |
| aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1) | |
| joint_attentions = torch.zeros_like(aug_att_mat) | |
| joint_attentions[0] = aug_att_mat[0] | |
| for n in range(1, aug_att_mat.size(0)): | |
| joint_attentions[n] = aug_att_mat[n] @ joint_attentions[n - 1] | |
| return joint_attentions[-1] | |
| # ----------------------------- | |
| # PREDICTION FUNCTION | |
| # ----------------------------- | |
| def predict(image): | |
| if image is None: | |
| return None, None, None | |
| inputs = processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs, output_attentions=True) | |
| logits = outputs.logits | |
| attentions = outputs.attentions | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| confidence, predicted_class_idx = torch.max(probs, dim=-1) | |
| prediction = model.config.id2label[predicted_class_idx.item()] | |
| confidence_pct = round(confidence.item() * 100, 2) | |
| # Attention rollout | |
| rollout = compute_attention_rollout(attentions) | |
| mask = rollout[0, 1:] | |
| size = int(mask.shape[0] ** 0.5) | |
| mask = mask.reshape(size, size).cpu().numpy() | |
| mask = cv2.resize(mask, image.size) | |
| mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8) | |
| heatmap = cv2.applyColorMap( | |
| np.uint8(255 * mask), | |
| cv2.COLORMAP_JET | |
| ) | |
| overlay = cv2.addWeighted( | |
| np.array(image), | |
| 0.6, | |
| heatmap, | |
| 0.4, | |
| 0 | |
| ) | |
| return prediction, f"{confidence_pct}%", overlay | |
| # ----------------------------- | |
| # UI DESIGN | |
| # ----------------------------- | |
| custom_css = """ | |
| /* Professional Adaptive Theme */ | |
| :root { | |
| --primary-blue: #2563eb; | |
| --hero-text: #0f172a; /* Dark slate for light mode */ | |
| } | |
| .dark { | |
| --hero-text: #f8fafc; /* White for dark mode */ | |
| } | |
| /* Background refinement */ | |
| body { | |
| background-color: var(--background-fill-primary); | |
| } | |
| /* Adaptive Typography */ | |
| .hero { | |
| text-align: center; | |
| font-family: 'Inter', sans-serif; | |
| font-size: 48px; | |
| font-weight: 800; | |
| letter-spacing: -0.04em; | |
| margin-top: 50px; | |
| /* This variable handles the visibility toggle */ | |
| color: var(--hero-text) !important; | |
| } | |
| .sub { | |
| text-align: center; | |
| opacity: 0.7; | |
| font-size: 14px; | |
| font-weight: 600; | |
| letter-spacing: 0.1em; | |
| text-transform: uppercase; | |
| margin-bottom: 40px; | |
| color: var(--body-text-color); | |
| } | |
| /* Professional Container Styling */ | |
| .glass { | |
| background: var(--block-background-fill) !important; | |
| border: 1px solid var(--border-color-primary) !important; | |
| border-radius: 12px !important; | |
| padding: 24px !important; | |
| box-shadow: var(--block-shadow); | |
| transition: all 0.2s ease; | |
| } | |
| .glass:hover { | |
| border-color: var(--primary-blue) !important; | |
| box-shadow: 0 4px 20px rgba(37, 99, 235, 0.1); | |
| } | |
| /* Enterprise Button */ | |
| button.primary { | |
| background: var(--primary-blue) !important; | |
| color: white !important; | |
| border: none !important; | |
| font-weight: 600 !important; | |
| padding: 12px 24px !important; | |
| border-radius: 8px !important; | |
| box-shadow: 0 4px 12px rgba(37, 99, 235, 0.2) !important; | |
| } | |
| button.primary:hover { | |
| background: #1d4ed8 !important; | |
| transform: translateY(-1px); | |
| box-shadow: 0 6px 16px rgba(37, 99, 235, 0.3) !important; | |
| } | |
| /* Label & Input tweaks for clarity */ | |
| .gr-label { | |
| font-weight: 600 !important; | |
| font-size: 12px !important; | |
| text-transform: uppercase; | |
| color: var(--primary-blue) !important; | |
| } | |
| """ | |
| with gr.Blocks( | |
| css=custom_css, | |
| theme=gr.themes.Soft( | |
| primary_hue="blue", | |
| font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui"] | |
| ) | |
| ) as demo: | |
| gr.Markdown(f"<div class='hero'>FORESIGHT<span style='color:#3b82f6'>.</span></div>") | |
| gr.Markdown("<div class='sub'>Deep Intelligence Neural Analysis</div>") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Group(elem_classes="glass"): | |
| input_image = gr.Image(type="pil", label="Source Input") | |
| run_btn = gr.Button("RUN DIAGNOSTIC", variant="primary") | |
| with gr.Column(): | |
| with gr.Group(elem_classes="glass"): | |
| output_label = gr.Label(label="Classification Verdict") | |
| output_conf = gr.Textbox(label="Confidence Rating", interactive=False) | |
| heatmap_output = gr.Image(label="Vulnerability Visualization") | |
| run_btn.click( | |
| fn=predict, | |
| inputs=input_image, | |
| outputs=[output_label, output_conf, heatmap_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |