S-Mithilesh commited on
Commit
ca71be4
·
verified ·
1 Parent(s): 3a82a4c

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +237 -0
app.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import gradio as gr
4
+ import torch
5
+ import numpy as np
6
+ import cv2
7
+ from PIL import Image
8
+ from transformers import ViTImageProcessor, ViTForImageClassification
9
+ # from transformers import AutoModelForImageClassification, AutoImageProcessor
10
+
11
+ # -----------------------------
12
+ # CONFIGURATION
13
+ # -----------------------------
14
+
15
+ MODEL_REPO = "SARVM/ViT_Deepfake"
16
+ # HF_TOKEN = os.getenv("HF_TOKEN") # Set in Space secrets or local env
17
+ HF_TOKEN = "hf_xxxxxxxxxxxxxxxxxxxxxxxx" # 🔐 Replace with your actual Hugging Face token
18
+
19
+ print(f"Loading model from {MODEL_REPO}...")
20
+
21
+ processor = ViTImageProcessor.from_pretrained(
22
+ MODEL_REPO,
23
+ token=HF_TOKEN
24
+ )
25
+
26
+ model = ViTForImageClassification.from_pretrained(
27
+ MODEL_REPO,
28
+ token=HF_TOKEN,
29
+ output_attentions=True
30
+ )
31
+
32
+ # processor = AutoImageProcessor.from_pretrained(
33
+ # MODEL_REPO,
34
+ # token=HF_TOKEN
35
+ # )
36
+
37
+ # model = AutoModelForImageClassification.from_pretrained(
38
+ # MODEL_REPO,
39
+ # token=HF_TOKEN
40
+ # )
41
+
42
+ model.eval()
43
+
44
+ # Override labels to REAL / FAKE
45
+ model.config.id2label = {
46
+ 1: "REAL",
47
+ 0: "FAKE"
48
+ }
49
+
50
+ model.config.label2id = {
51
+ "REAL": 1,
52
+ "FAKE": 0
53
+ }
54
+
55
+ # -----------------------------
56
+ # ATTENTION ROLLOUT
57
+ # -----------------------------
58
+
59
+ def compute_attention_rollout(attentions):
60
+ att_mat = torch.stack(attentions).squeeze(1)
61
+ att_mat = att_mat.mean(dim=1)
62
+
63
+ residual_att = torch.eye(att_mat.size(-1))
64
+ aug_att_mat = att_mat + residual_att
65
+ aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)
66
+
67
+ joint_attentions = torch.zeros_like(aug_att_mat)
68
+ joint_attentions[0] = aug_att_mat[0]
69
+
70
+ for n in range(1, aug_att_mat.size(0)):
71
+ joint_attentions[n] = aug_att_mat[n] @ joint_attentions[n - 1]
72
+
73
+ return joint_attentions[-1]
74
+
75
+
76
+ # -----------------------------
77
+ # PREDICTION FUNCTION
78
+ # -----------------------------
79
+
80
+ def predict(image):
81
+ if image is None:
82
+ return None, None, None
83
+
84
+ inputs = processor(images=image, return_tensors="pt")
85
+
86
+ with torch.no_grad():
87
+ outputs = model(**inputs, output_attentions=True)
88
+ logits = outputs.logits
89
+ attentions = outputs.attentions
90
+
91
+ probs = torch.nn.functional.softmax(logits, dim=-1)
92
+ confidence, predicted_class_idx = torch.max(probs, dim=-1)
93
+
94
+ prediction = model.config.id2label[predicted_class_idx.item()]
95
+ confidence_pct = round(confidence.item() * 100, 2)
96
+
97
+ # Attention rollout
98
+ rollout = compute_attention_rollout(attentions)
99
+
100
+ mask = rollout[0, 1:]
101
+ size = int(mask.shape[0] ** 0.5)
102
+ mask = mask.reshape(size, size).cpu().numpy()
103
+
104
+ mask = cv2.resize(mask, image.size)
105
+ mask = (mask - mask.min()) / (mask.max() - mask.min() + 1e-8)
106
+
107
+ heatmap = cv2.applyColorMap(
108
+ np.uint8(255 * mask),
109
+ cv2.COLORMAP_JET
110
+ )
111
+
112
+ overlay = cv2.addWeighted(
113
+ np.array(image),
114
+ 0.6,
115
+ heatmap,
116
+ 0.4,
117
+ 0
118
+ )
119
+
120
+ return prediction, f"{confidence_pct}%", overlay
121
+
122
+
123
+ # -----------------------------
124
+ # UI DESIGN
125
+ # -----------------------------
126
+
127
+ custom_css = """
128
+ /* Professional Adaptive Theme */
129
+ :root {
130
+ --primary-blue: #2563eb;
131
+ --hero-text: #0f172a; /* Dark slate for light mode */
132
+ }
133
+
134
+ .dark {
135
+ --hero-text: #f8fafc; /* White for dark mode */
136
+ }
137
+
138
+ /* Background refinement */
139
+ body {
140
+ background-color: var(--background-fill-primary);
141
+ }
142
+
143
+ /* Adaptive Typography */
144
+ .hero {
145
+ text-align: center;
146
+ font-family: 'Inter', sans-serif;
147
+ font-size: 48px;
148
+ font-weight: 800;
149
+ letter-spacing: -0.04em;
150
+ margin-top: 50px;
151
+ /* This variable handles the visibility toggle */
152
+ color: var(--hero-text) !important;
153
+ }
154
+
155
+ .sub {
156
+ text-align: center;
157
+ opacity: 0.7;
158
+ font-size: 14px;
159
+ font-weight: 600;
160
+ letter-spacing: 0.1em;
161
+ text-transform: uppercase;
162
+ margin-bottom: 40px;
163
+ color: var(--body-text-color);
164
+ }
165
+
166
+ /* Professional Container Styling */
167
+ .glass {
168
+ background: var(--block-background-fill) !important;
169
+ border: 1px solid var(--border-color-primary) !important;
170
+ border-radius: 12px !important;
171
+ padding: 24px !important;
172
+ box-shadow: var(--block-shadow);
173
+ transition: all 0.2s ease;
174
+ }
175
+
176
+ .glass:hover {
177
+ border-color: var(--primary-blue) !important;
178
+ box-shadow: 0 4px 20px rgba(37, 99, 235, 0.1);
179
+ }
180
+
181
+ /* Enterprise Button */
182
+ button.primary {
183
+ background: var(--primary-blue) !important;
184
+ color: white !important;
185
+ border: none !important;
186
+ font-weight: 600 !important;
187
+ padding: 12px 24px !important;
188
+ border-radius: 8px !important;
189
+ box-shadow: 0 4px 12px rgba(37, 99, 235, 0.2) !important;
190
+ }
191
+
192
+ button.primary:hover {
193
+ background: #1d4ed8 !important;
194
+ transform: translateY(-1px);
195
+ box-shadow: 0 6px 16px rgba(37, 99, 235, 0.3) !important;
196
+ }
197
+
198
+ /* Label & Input tweaks for clarity */
199
+ .gr-label {
200
+ font-weight: 600 !important;
201
+ font-size: 12px !important;
202
+ text-transform: uppercase;
203
+ color: var(--primary-blue) !important;
204
+ }
205
+ """
206
+
207
+ with gr.Blocks(
208
+ css=custom_css,
209
+ theme=gr.themes.Soft(
210
+ primary_hue="blue",
211
+ font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui"]
212
+ )
213
+ ) as demo:
214
+
215
+ gr.Markdown(f"<div class='hero'>FORESIGHT<span style='color:#3b82f6'>.</span></div>")
216
+ gr.Markdown("<div class='sub'>Deep Intelligence Neural Analysis</div>")
217
+
218
+ with gr.Row():
219
+ with gr.Column():
220
+ with gr.Group(elem_classes="glass"):
221
+ input_image = gr.Image(type="pil", label="Source Input")
222
+ run_btn = gr.Button("RUN DIAGNOSTIC", variant="primary")
223
+
224
+ with gr.Column():
225
+ with gr.Group(elem_classes="glass"):
226
+ output_label = gr.Label(label="Classification Verdict")
227
+ output_conf = gr.Textbox(label="Confidence Rating", interactive=False)
228
+ heatmap_output = gr.Image(label="Vulnerability Visualization")
229
+
230
+ run_btn.click(
231
+ fn=predict,
232
+ inputs=input_image,
233
+ outputs=[output_label, output_conf, heatmap_output]
234
+ )
235
+
236
+ if __name__ == "__main__":
237
+ demo.launch()