bithal26 commited on
Commit
2e2c690
Β·
verified Β·
1 Parent(s): 86e6cf7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +374 -75
app.py CHANGED
@@ -1,94 +1,393 @@
1
- WEIGHT_FILE = "final_888_DeepFakeClassifier_tf_efficientnet_b7_ns_0_40"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
- from torch import nn
4
  from torch.nn.modules.dropout import Dropout
5
  from torch.nn.modules.linear import Linear
6
  from torch.nn.modules.pooling import AdaptiveAvgPool2d
7
- from timm.models.efficientnet import tf_efficientnet_b7_ns
8
  from functools import partial
9
- import re
10
  import gradio as gr
11
- import numpy as np
12
- from torchvision.transforms import Normalize
13
 
14
- # --- 1. MODEL ARCHITECTURE ---
15
- encoder_params = {
16
- "tf_efficientnet_b7.ns_jft_in1k": {
17
- "features": 2560,
18
- "init_op": partial(tf_efficientnet_b7_ns, pretrained=False, drop_path_rate=0.2)
19
- }
20
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  class DeepFakeClassifier(nn.Module):
23
- def __init__(self, encoder="tf_efficientnet_b7.ns_jft_in1k", dropout_rate=0.0) -> None:
24
  super().__init__()
25
- self.encoder = encoder_params[encoder]["init_op"]()
26
  self.avg_pool = AdaptiveAvgPool2d((1, 1))
27
- self.dropout = Dropout(dropout_rate)
28
- self.fc = Linear(encoder_params[encoder]["features"], 1)
29
 
30
  def forward(self, x):
31
  x = self.encoder.forward_features(x)
32
  x = self.avg_pool(x).flatten(1)
33
  x = self.dropout(x)
34
- x = self.fc(x)
35
- return x
36
-
37
- print(f"Booting API Worker: Loading {WEIGHT_FILE}...")
38
- device = torch.device('cpu')
39
- model = DeepFakeClassifier(encoder="tf_efficientnet_b7.ns_jft_in1k").to(device)
40
-
41
- checkpoint = torch.load(WEIGHT_FILE, map_location="cpu", weights_only=False)
42
- state_dict = checkpoint.get("state_dict", checkpoint)
43
- model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=True)
44
- model.eval()
45
-
46
- # --- 3. PREPROCESSING UTILS ---
47
- mean = [0.485, 0.456, 0.406]
48
- std = [0.229, 0.224, 0.225]
49
- normalize_transform = Normalize(mean, std)
50
-
51
- # --- 4. API ENDPOINT ---
52
- def predict_tensor(tensor_file):
53
- if tensor_file is None:
54
- return {"error": "No file received"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  try:
56
- # Handle Gradio 4 filepath string formatting safely
57
- file_path = tensor_file if isinstance(tensor_file, str) else tensor_file.name
58
-
59
- # 1. Load the raw numpy array (Bypasses PyTorch security & saves bandwidth)
60
- x_np = np.load(file_path)
61
- x = torch.tensor(x_np).float()
62
-
63
- # 2. Permute NHWC to NCHW
64
- x = x.permute((0, 3, 1, 2))
65
-
66
- # 3. Normalize locally on the worker
67
- for i in range(len(x)):
68
- x[i] = normalize_transform(x[i] / 255.)
69
-
70
- # 4. MINI-BATCHING: Process 8 frames at a time to prevent 16GB RAM Crash
71
- mini_batch_size = 8
72
- all_preds = []
73
-
74
- with torch.no_grad():
75
- for i in range(0, len(x), mini_batch_size):
76
- chunk = x[i:i+mini_batch_size]
77
- # Cast to float16 ONLY if running on local GPU, else stay float32 for HF CPU
78
- chunk = chunk.half() if torch.cuda.is_available() else chunk
79
- y_chunk = model(chunk)
80
- y_chunk = torch.sigmoid(y_chunk.squeeze())
81
-
82
- if y_chunk.dim() == 0:
83
- all_preds.append(float(y_chunk.cpu().numpy()))
84
- else:
85
- all_preds.extend(y_chunk.cpu().numpy().tolist())
86
-
87
- return {"predictions": all_preds}
88
- except Exception as e:
89
- return {"error": str(e)}
90
-
91
- interface = gr.Interface(fn=predict_tensor, inputs=gr.File(label="Input Tensor (.npy)"), outputs=gr.JSON())
92
 
93
  if __name__ == "__main__":
94
- interface.launch()
 
1
+
2
+ """
3
+ ================================================================================
4
+ VERIDEX β€” DeepFake Worker Space (Generic Template)
5
+ ─────────────────────────────────────────────────────
6
+ DEPLOY INSTRUCTIONS β€” zero code changes between workers
7
+ ──────────────────────────────────────────────────────────
8
+ 1. Commit this IDENTICAL app.py to all 7 Worker Spaces.
9
+ 2. Upload each worker's .pt weight file to its Space's files tab.
10
+ 3. In each Space β†’ Settings β†’ Variables, set:
11
+
12
+ WEIGHT_FILE = final_111_DeepFakeClassifier_tf_efficientnet_b7_ns_0_36
13
+ MODEL_CLASS = base # or srm / gwap (optional, default: base)
14
+
15
+ 4. That's it. No code edits required.
16
+
17
+ API CONTRACT (called by the Master UI)
18
+ ───────────────────────────────────────
19
+ Input : a .npy file (uint8, shape [N, H, W, 3], HWC, 380Γ—380)
20
+ Output : JSON { "predictions": [float, ...], "n_frames": int }
21
+ OR { "error": "...", "predictions": null }
22
+
23
+ GRADIO VERSION NOTE
24
+ ────────────────────
25
+ HF Spaces force-installs gradio==6.x at build time regardless of what
26
+ requirements.txt pins. This file targets Gradio 6:
27
+ β€’ gr.File input passes a tempfile.SpooledTemporaryFile-backed object
28
+ with a .name attribute in Gradio 6 (not a plain string or dict).
29
+ β€’ allow_flagging is removed (deprecated in Gradio 6; raises a warning
30
+ that can abort startup on strict HF runtime configs).
31
+ ================================================================================
32
+ """
33
+
34
+ import os
35
+ import io
36
+ import re
37
+ import traceback
38
+ import logging
39
+
40
+ import numpy as np
41
  import torch
42
+ import torch.nn as nn
43
  from torch.nn.modules.dropout import Dropout
44
  from torch.nn.modules.linear import Linear
45
  from torch.nn.modules.pooling import AdaptiveAvgPool2d
46
+ from torchvision.transforms import Normalize
47
  from functools import partial
48
+
49
  import gradio as gr
 
 
50
 
51
+ # ── timm / efficientnet ───────────────────────────────────────────────────────
52
+ try:
53
+ from timm.models.efficientnet import tf_efficientnet_b7_ns
54
+ except ImportError:
55
+ # timm β‰₯ 0.9 moved the alias; fall back gracefully
56
+ import timm
57
+ tf_efficientnet_b7_ns = partial(timm.create_model, "tf_efficientnet_b7.ns_jft_in1k")
58
+
59
+
60
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [WORKER] %(levelname)s %(message)s")
61
+ logger = logging.getLogger(__name__)
62
+
63
+ # ══════════════════════════════════════════════════════════════════════════════
64
+ # ❢ ALL CONFIG IS VIA ENV VARS β€” set these in each Space's Settings β†’ Variables
65
+ # WEIGHT_FILE : filename of the .pt checkpoint (no extension required)
66
+ # MODEL_CLASS : "base" | "srm" | "gwap" (default: base)
67
+ # MINI_BATCH : frames per forward pass (default: 8)
68
+ # WEIGHTS_DIR : directory containing the .pt file (default: repo root ".")
69
+ # ══════════════════════════════════════════════════════════════════════════════
70
+ WEIGHT_FILE = os.environ.get(
71
+ "WEIGHT_FILE",
72
+ "final_888_DeepFakeClassifier_tf_efficientnet_b7_ns_0_40", # safe default
73
+ )
74
+ MODEL_CLASS = os.environ.get("MODEL_CLASS", "base") # "base" | "srm" | "gwap"
75
+ MINI_BATCH = int(os.environ.get("MINI_BATCH", "8")) # frames per forward pass
76
+ WEIGHTS_DIR = os.environ.get("WEIGHTS_DIR", ".") # dir that contains the .pt
77
+ # ══════════════════════════════════════════════════════════════════════════════
78
+
79
+ # ── ImageNet normalisation ────────────────────────────────────────────────────
80
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
81
+ IMAGENET_STD = [0.229, 0.224, 0.225]
82
+ normalize_fn = Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
83
+
84
+ # ── EfficientNet-B7 feature size ──────────────────────────────────────────────
85
+ ENCODER_FEATURES = 2560
86
+
87
+ # ─────────────────────────────────────────────────────────────────────────────
88
+ # Model definitions (identical to deepfake_det.py so checkpoints load clean)
89
+ # ─────────────────────────────────────────────────────────────────────────────
90
+
91
+ def _make_encoder():
92
+ return tf_efficientnet_b7_ns(pretrained=False, drop_path_rate=0.2)
93
+
94
+
95
+ def _setup_srm_weights(input_channels: int = 3) -> torch.Tensor:
96
+ srm_kernel = torch.from_numpy(np.array([
97
+ [[0.,0.,0.,0.,0.],[0.,0.,0.,0.,0.],[0.,1.,-2.,1.,0.],[0.,0.,0.,0.,0.],[0.,0.,0.,0.,0.]],
98
+ [[0.,0.,0.,0.,0.],[0.,-1.,2.,-1.,0.],[0.,2.,-4.,2.,0.],[0.,-1.,2.,-1.,0.],[0.,0.,0.,0.,0.]],
99
+ [[-1.,2.,-2.,2.,-1.],[2.,-6.,8.,-6.,2.],[-2.,8.,-12.,8.,-2.],[2.,-6.,8.,-6.,2.],[-1.,2.,-2.,2.,-1.]],
100
+ ])).float()
101
+ srm_kernel[0] /= 2
102
+ srm_kernel[1] /= 4
103
+ srm_kernel[2] /= 12
104
+ return srm_kernel.view(3, 1, 5, 5).repeat(1, input_channels, 1, 1)
105
+
106
+
107
+ def _setup_srm_layer(input_channels: int = 3) -> nn.Module:
108
+ weights = _setup_srm_weights(input_channels)
109
+ conv = nn.Conv2d(input_channels, 3, kernel_size=5, stride=1, padding=2, bias=False)
110
+ with torch.no_grad():
111
+ conv.weight = nn.Parameter(weights, requires_grad=False)
112
+ return conv
113
+
114
 
115
  class DeepFakeClassifier(nn.Module):
116
+ def __init__(self, dropout_rate=0.0):
117
  super().__init__()
118
+ self.encoder = _make_encoder()
119
  self.avg_pool = AdaptiveAvgPool2d((1, 1))
120
+ self.dropout = Dropout(dropout_rate)
121
+ self.fc = Linear(ENCODER_FEATURES, 1)
122
 
123
  def forward(self, x):
124
  x = self.encoder.forward_features(x)
125
  x = self.avg_pool(x).flatten(1)
126
  x = self.dropout(x)
127
+ return self.fc(x)
128
+
129
+
130
+ class DeepFakeClassifierSRM(nn.Module):
131
+ def __init__(self, dropout_rate=0.5):
132
+ super().__init__()
133
+ self.encoder = _make_encoder()
134
+ self.avg_pool = AdaptiveAvgPool2d((1, 1))
135
+ self.srm_conv = _setup_srm_layer(3)
136
+ self.dropout = Dropout(dropout_rate)
137
+ self.fc = Linear(ENCODER_FEATURES, 1)
138
+
139
+ def forward(self, x):
140
+ noise = self.srm_conv(x)
141
+ x = self.encoder.forward_features(noise)
142
+ x = self.avg_pool(x).flatten(1)
143
+ x = self.dropout(x)
144
+ return self.fc(x)
145
+
146
+
147
+ class _GWAP(nn.Module):
148
+ def __init__(self, features: int):
149
+ super().__init__()
150
+ self.conv = nn.Conv2d(features, 1, kernel_size=1, bias=True)
151
+
152
+ def forward(self, x):
153
+ w = self.conv(x).sigmoid().exp()
154
+ w = w / w.sum(dim=[2, 3], keepdim=True)
155
+ return (w * x).sum(dim=[2, 3], keepdim=False)
156
+
157
+
158
+ class DeepFakeClassifierGWAP(nn.Module):
159
+ def __init__(self, dropout_rate=0.5):
160
+ super().__init__()
161
+ self.encoder = _make_encoder()
162
+ self.avg_pool = _GWAP(ENCODER_FEATURES)
163
+ self.dropout = Dropout(dropout_rate)
164
+ self.fc = Linear(ENCODER_FEATURES, 1)
165
+
166
+ def forward(self, x):
167
+ x = self.encoder.forward_features(x)
168
+ x = self.avg_pool(x)
169
+ x = self.dropout(x)
170
+ return self.fc(x)
171
+
172
+
173
+ _MODEL_MAP = {
174
+ "base": DeepFakeClassifier,
175
+ "srm": DeepFakeClassifierSRM,
176
+ "gwap": DeepFakeClassifierGWAP,
177
+ }
178
+
179
+
180
+ # ─────────────────────────────────────────────────────────────────────────────
181
+ # Model loading (runs once at startup)
182
+ # ─────────────────────────────────���───────────────────────────────────────────
183
+
184
+ def load_model() -> nn.Module:
185
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
186
+ cls = _MODEL_MAP.get(MODEL_CLASS, DeepFakeClassifier)
187
+ model = cls().to(device)
188
+
189
+ weight_path = os.path.join(WEIGHTS_DIR, WEIGHT_FILE)
190
+ # Allow common extensions in case the file was renamed
191
+ if not os.path.exists(weight_path):
192
+ for ext in (".pt", ".pth", ".bin"):
193
+ if os.path.exists(weight_path + ext):
194
+ weight_path = weight_path + ext
195
+ break
196
+
197
+ if not os.path.exists(weight_path):
198
+ raise FileNotFoundError(
199
+ f"Weight file not found: {weight_path}\n"
200
+ f"Files present in '{WEIGHTS_DIR}': {os.listdir(WEIGHTS_DIR)}"
201
+ )
202
+
203
+ logger.info(f"Loading weights from: {weight_path}")
204
+
205
+ # PyTorch 2.6+ requires weights_only=False for pickled checkpoints; also
206
+ # use map_location='cpu' so the model loads on any machine regardless of
207
+ # how it was saved.
208
+ checkpoint = torch.load(weight_path, map_location="cpu", weights_only=False)
209
+ state_dict = checkpoint.get("state_dict", checkpoint)
210
+
211
+ # Strip "module." prefix added by DataParallel / DistributedDataParallel
212
+ cleaned = {re.sub(r"^module\.", "", k): v for k, v in state_dict.items()}
213
+ model.load_state_dict(cleaned, strict=True)
214
+ model.eval()
215
+
216
+ # FP16 halves VRAM; safe on both CUDA and CPU
217
+ model = model.half()
218
+
219
+ logger.info(f"Model ready β€” class={cls.__name__}, device={device}, fp16=True")
220
+ return model, device
221
+
222
+
223
+ try:
224
+ MODEL, DEVICE = load_model()
225
+ LOAD_ERROR = None
226
+ except Exception as exc:
227
+ MODEL = None
228
+ DEVICE = None
229
+ LOAD_ERROR = traceback.format_exc()
230
+ logger.error(f"MODEL LOAD FAILED:\n{LOAD_ERROR}")
231
+
232
+
233
+ # ─────────────────────────────────────────────────────────────────────────────
234
+ # Inference helper
235
+ # ─────────────────────────────────────────────────────────────────────────────
236
+
237
+ def _preprocess_npy(npy_input) -> torch.Tensor:
238
+ """
239
+ Load a uint8 HWC .npy face-batch, convert to normalised float CHW tensor.
240
+
241
+ Gradio version compatibility matrix
242
+ ─────────────────────────────────────
243
+ Gradio 4 : passes a plain string filepath "/tmp/gradio/.../faces.npy"
244
+ Gradio 4 : may wrap in dict {"path": "...", "orig_name": "..."}
245
+ Gradio 6 : passes a tempfile.SpooledTemporaryFile (file-like with .name)
246
+ OR a gradio.FileData dataclass with a .path attribute
247
+
248
+ We resolve all four forms to a final file path or file-like object
249
+ that np.load() can consume.
250
+ """
251
+ npy_path = None # will hold a string path if resolvable
252
+ file_obj = None # will hold a file-like if path is unavailable
253
+
254
+ # ── Form 1: plain string ──────────────────────────────────────────────────
255
+ if isinstance(npy_input, str):
256
+ npy_path = npy_input
257
+
258
+ # ── Form 2: Gradio 4 dict {"path": ..., "orig_name": ...} ────────────────
259
+ elif isinstance(npy_input, dict):
260
+ npy_path = (
261
+ npy_input.get("path")
262
+ or npy_input.get("name")
263
+ or next(iter(npy_input.values()), None)
264
+ )
265
+
266
+ # ── Form 3: Gradio 6 dataclass (has .path attribute) ─────────────────────
267
+ elif hasattr(npy_input, "path"):
268
+ npy_path = npy_input.path
269
+
270
+ # ── Form 4: file-like object (SpooledTemporaryFile, BytesIO, etc.) ────────
271
+ elif hasattr(npy_input, "read"):
272
+ # Try to get the backing file path first (avoids reading into RAM twice)
273
+ backing = getattr(npy_input, "name", None)
274
+ if backing and isinstance(backing, str) and os.path.exists(backing):
275
+ npy_path = backing
276
+ else:
277
+ file_obj = npy_input
278
+
279
+ else:
280
+ raise TypeError(
281
+ f"Cannot resolve npy input of type {type(npy_input)}: {npy_input!r}"
282
+ )
283
+
284
+ # ── Load the array ─────────────────────────────────────────────────────────
285
+ def _load(src):
286
+ try:
287
+ return np.load(src, allow_pickle=False)
288
+ except ValueError:
289
+ # Legacy pickled .npy β€” seek back to start if file-like
290
+ if hasattr(src, "seek"):
291
+ src.seek(0)
292
+ return np.load(src, allow_pickle=True)
293
+
294
+ if npy_path is not None:
295
+ if not os.path.exists(npy_path):
296
+ raise FileNotFoundError(f"NPY payload not found at: {npy_path}")
297
+ faces_uint8 = _load(npy_path)
298
+ else:
299
+ faces_uint8 = _load(file_obj)
300
+
301
+ # ── Validate shape ─────────────────────────────────────────────────────────
302
+ if faces_uint8.ndim != 4 or faces_uint8.shape[3] != 3:
303
+ raise ValueError(
304
+ f"Expected uint8 array shape (N, H, W, 3), got {faces_uint8.shape}"
305
+ )
306
+
307
+ # Convert: uint8 HWC β†’ float32 CHW β†’ normalised
308
+ tensor = torch.from_numpy(faces_uint8).float() # [N, H, W, 3]
309
+ tensor = tensor.permute(0, 3, 1, 2) # [N, 3, H, W]
310
+ # Normalise each frame in-place
311
+ for i in range(tensor.shape[0]):
312
+ tensor[i] = normalize_fn(tensor[i] / 255.0)
313
+
314
+ return tensor # float32, shape [N, 3, H, W]
315
+
316
+
317
+ def run_inference(tensor: torch.Tensor) -> list:
318
+ """
319
+ Forward-pass the pre-processed face tensor through the model in
320
+ mini-batches of size MINI_BATCH to avoid OOM on 16 GB RAM spaces.
321
+ Returns a flat Python list of per-frame fake-probabilities [0, 1].
322
+ """
323
+ predictions = []
324
+ n = tensor.shape[0]
325
+
326
+ with torch.no_grad():
327
+ for start in range(0, n, MINI_BATCH):
328
+ batch = tensor[start : start + MINI_BATCH]
329
+ batch = batch.to(DEVICE).half() # fp16 matches model dtype
330
+
331
+ logits = MODEL(batch) # [B, 1]
332
+ probs = torch.sigmoid(logits.squeeze(-1)) # [B]
333
+ predictions.extend(probs.cpu().float().tolist())
334
+
335
+ return predictions
336
+
337
+
338
+ # ─────────────────────────────────────────────────────────────────────────────
339
+ # Gradio endpoint (headless β€” no UI blocks, purely an API)
340
+ # ─────────────────────────────────────────────────────────────────────────────
341
+
342
+ def predict(npy_file) -> dict:
343
+ """
344
+ Gradio API endpoint.
345
+
346
+ Parameters
347
+ ----------
348
+ npy_file : str | dict
349
+ Filepath (or Gradio file dict) pointing to the .npy face batch.
350
+
351
+ Returns
352
+ -------
353
+ dict with keys:
354
+ predictions : list[float] | None
355
+ n_frames : int
356
+ error : str | None
357
+ """
358
+ if MODEL is None:
359
+ msg = f"Model failed to load at startup:\n{LOAD_ERROR}"
360
+ logger.error(msg)
361
+ return {"predictions": None, "n_frames": 0, "error": msg}
362
+
363
  try:
364
+ tensor = _preprocess_npy(npy_file)
365
+ n_frames = tensor.shape[0]
366
+ predictions = run_inference(tensor)
367
+ logger.info(f"Inference OK β€” frames={n_frames}, mean_pred={np.mean(predictions):.4f}")
368
+ return {"predictions": predictions, "n_frames": n_frames, "error": None}
369
+
370
+ except Exception:
371
+ err = traceback.format_exc()
372
+ logger.error(f"Inference failed:\n{err}")
373
+ return {"predictions": None, "n_frames": 0, "error": err}
374
+
375
+
376
+ # ──────────────────────────────────────────────────────────────────────���──────
377
+ # Launch
378
+ # ─────────────────────────────────────────────────────────────────────────────
379
+
380
+ demo = gr.Interface(
381
+ fn=predict,
382
+ inputs=gr.File(label="Face batch (.npy)", file_types=[".npy"]),
383
+ outputs=gr.JSON(label="Worker prediction"),
384
+ title=f"VERIDEX Worker β€” {WEIGHT_FILE}",
385
+ description=(
386
+ "Headless inference worker. "
387
+ "POST a uint8 .npy face-batch; receive per-frame fake probabilities."
388
+ ),
389
+ # allow_flagging removed: deprecated in Gradio 5, gone in Gradio 6
390
+ )
 
 
 
 
 
 
 
 
 
391
 
392
  if __name__ == "__main__":
393
+ demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)