bithal26 commited on
Commit
6793ef0
·
verified ·
1 Parent(s): 55c0597

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +374 -75
app.py CHANGED
@@ -1,95 +1,394 @@
1
- WEIGHT_FILE = "final_111_DeepFakeClassifier_tf_efficientnet_b7_ns_0_36"
 
 
 
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import torch
4
- from torch import nn
5
  from torch.nn.modules.dropout import Dropout
6
  from torch.nn.modules.linear import Linear
7
  from torch.nn.modules.pooling import AdaptiveAvgPool2d
8
- from timm.models.efficientnet import tf_efficientnet_b7_ns
9
  from functools import partial
10
- import re
11
  import gradio as gr
12
- import numpy as np
13
- from torchvision.transforms import Normalize
14
 
15
- # --- 1. MODEL ARCHITECTURE ---
16
- encoder_params = {
17
- "tf_efficientnet_b7.ns_jft_in1k": {
18
- "features": 2560,
19
- "init_op": partial(tf_efficientnet_b7_ns, pretrained=False, drop_path_rate=0.2)
20
- }
21
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  class DeepFakeClassifier(nn.Module):
24
- def __init__(self, encoder="tf_efficientnet_b7.ns_jft_in1k", dropout_rate=0.0) -> None:
25
  super().__init__()
26
- self.encoder = encoder_params[encoder]["init_op"]()
27
  self.avg_pool = AdaptiveAvgPool2d((1, 1))
28
- self.dropout = Dropout(dropout_rate)
29
- self.fc = Linear(encoder_params[encoder]["features"], 1)
30
 
31
  def forward(self, x):
32
  x = self.encoder.forward_features(x)
33
  x = self.avg_pool(x).flatten(1)
34
  x = self.dropout(x)
35
- x = self.fc(x)
36
- return x
37
-
38
- print(f"Booting API Worker: Loading {WEIGHT_FILE}...")
39
- device = torch.device('cpu')
40
- model = DeepFakeClassifier(encoder="tf_efficientnet_b7.ns_jft_in1k").to(device)
41
-
42
- checkpoint = torch.load(WEIGHT_FILE, map_location="cpu", weights_only=False)
43
- state_dict = checkpoint.get("state_dict", checkpoint)
44
- model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=True)
45
- model.eval()
46
-
47
- # --- 3. PREPROCESSING UTILS ---
48
- mean = [0.485, 0.456, 0.406]
49
- std = [0.229, 0.224, 0.225]
50
- normalize_transform = Normalize(mean, std)
51
-
52
- # --- 4. API ENDPOINT ---
53
- def predict_tensor(tensor_file):
54
- if tensor_file is None:
55
- return {"error": "No file received"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  try:
57
- # Handle Gradio 4 filepath string formatting safely
58
- file_path = tensor_file if isinstance(tensor_file, str) else tensor_file.name
59
-
60
- # 1. Load the raw numpy array (Bypasses PyTorch security & saves bandwidth)
61
- x_np = np.load(file_path)
62
- x = torch.tensor(x_np).float()
63
-
64
- # 2. Permute NHWC to NCHW
65
- x = x.permute((0, 3, 1, 2))
66
-
67
- # 3. Normalize locally on the worker
68
- for i in range(len(x)):
69
- x[i] = normalize_transform(x[i] / 255.)
70
-
71
- # 4. MINI-BATCHING: Process 8 frames at a time to prevent 16GB RAM Crash
72
- mini_batch_size = 8
73
- all_preds = []
74
-
75
- with torch.no_grad():
76
- for i in range(0, len(x), mini_batch_size):
77
- chunk = x[i:i+mini_batch_size]
78
- # Cast to float16 ONLY if running on local GPU, else stay float32 for HF CPU
79
- chunk = chunk.half() if torch.cuda.is_available() else chunk
80
- y_chunk = model(chunk)
81
- y_chunk = torch.sigmoid(y_chunk.squeeze())
82
-
83
- if y_chunk.dim() == 0:
84
- all_preds.append(float(y_chunk.cpu().numpy()))
85
- else:
86
- all_preds.extend(y_chunk.cpu().numpy().tolist())
87
-
88
- return {"predictions": all_preds}
89
- except Exception as e:
90
- return {"error": str(e)}
91
-
92
- interface = gr.Interface(fn=predict_tensor, inputs=gr.File(label="Input Tensor (.npy)"), outputs=gr.JSON())
93
 
94
  if __name__ == "__main__":
95
- interface.launch()
 
1
+ WEIGHT_FILE = os.environ.get(
2
+ "WEIGHT_FILE",
3
+ "final_111_DeepFakeClassifier_tf_efficientnet_b7_ns_0_36", # safe default
4
+ )
5
 
6
+ """
7
+ ================================================================================
8
+ VERIDEX — DeepFake Worker Space (Generic Template)
9
+ ─────────────────────────────────────────────────────
10
+ DEPLOY INSTRUCTIONS — zero code changes between workers
11
+ ──────────────────────────────────────────────────────────
12
+ 1. Commit this IDENTICAL app.py to all 7 Worker Spaces.
13
+ 2. Upload each worker's .pt weight file to its Space's files tab.
14
+ 3. In each Space → Settings → Variables, set:
15
+
16
+ WEIGHT_FILE = final_111_DeepFakeClassifier_tf_efficientnet_b7_ns_0_36
17
+ MODEL_CLASS = base # or srm / gwap (optional, default: base)
18
+
19
+ 4. That's it. No code edits required.
20
+
21
+ API CONTRACT (called by the Master UI)
22
+ ───────────────────────────────────────
23
+ Input : a .npy file (uint8, shape [N, H, W, 3], HWC, 380×380)
24
+ Output : JSON { "predictions": [float, ...], "n_frames": int }
25
+ OR { "error": "...", "predictions": null }
26
+
27
+ GRADIO VERSION NOTE
28
+ ────────────────────
29
+ HF Spaces force-installs gradio==6.x at build time regardless of what
30
+ requirements.txt pins. This file targets Gradio 6:
31
+ • gr.File input passes a tempfile.SpooledTemporaryFile-backed object
32
+ with a .name attribute in Gradio 6 (not a plain string or dict).
33
+ • allow_flagging is removed (deprecated in Gradio 6; raises a warning
34
+ that can abort startup on strict HF runtime configs).
35
+ ================================================================================
36
+ """
37
+
38
+ import os
39
+ import io
40
+ import re
41
+ import traceback
42
+ import logging
43
+
44
+ import numpy as np
45
  import torch
46
+ import torch.nn as nn
47
  from torch.nn.modules.dropout import Dropout
48
  from torch.nn.modules.linear import Linear
49
  from torch.nn.modules.pooling import AdaptiveAvgPool2d
50
+ from torchvision.transforms import Normalize
51
  from functools import partial
52
+
53
  import gradio as gr
 
 
54
 
55
+ # ── timm / efficientnet ───────────────────────────────────────────────────────
56
+ try:
57
+ from timm.models.efficientnet import tf_efficientnet_b7_ns
58
+ except ImportError:
59
+ # timm 0.9 moved the alias; fall back gracefully
60
+ import timm
61
+ tf_efficientnet_b7_ns = partial(timm.create_model, "tf_efficientnet_b7.ns_jft_in1k")
62
+
63
+
64
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [WORKER] %(levelname)s %(message)s")
65
+ logger = logging.getLogger(__name__)
66
+
67
+ # ══════════════════════════════════════════════════════════════════════════════
68
+ # ❶ ALL CONFIG IS VIA ENV VARS — set these in each Space's Settings → Variables
69
+ # WEIGHT_FILE : filename of the .pt checkpoint (no extension required)
70
+ # MODEL_CLASS : "base" | "srm" | "gwap" (default: base)
71
+ # MINI_BATCH : frames per forward pass (default: 8)
72
+ # WEIGHTS_DIR : directory containing the .pt file (default: repo root ".")
73
+ # ══════════════════════════════════════════════════════════════════════════════
74
+
75
+ MODEL_CLASS = os.environ.get("MODEL_CLASS", "base") # "base" | "srm" | "gwap"
76
+ MINI_BATCH = int(os.environ.get("MINI_BATCH", "8")) # frames per forward pass
77
+ WEIGHTS_DIR = os.environ.get("WEIGHTS_DIR", ".") # dir that contains the .pt
78
+ # ══════════════════════════════════════════════════════════════════════════════
79
+
80
+ # ── ImageNet normalisation ────────────────────────────────────────────────────
81
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
82
+ IMAGENET_STD = [0.229, 0.224, 0.225]
83
+ normalize_fn = Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
84
+
85
+ # ── EfficientNet-B7 feature size ──────────────────────────────────────────────
86
+ ENCODER_FEATURES = 2560
87
+
88
+ # ─────────────────────────────────────────────────────────────────────────────
89
+ # Model definitions (identical to deepfake_det.py so checkpoints load clean)
90
+ # ─────────────────────────────────────────────────────────────────────────────
91
+
92
+ def _make_encoder():
93
+ return tf_efficientnet_b7_ns(pretrained=False, drop_path_rate=0.2)
94
+
95
+
96
+ def _setup_srm_weights(input_channels: int = 3) -> torch.Tensor:
97
+ srm_kernel = torch.from_numpy(np.array([
98
+ [[0.,0.,0.,0.,0.],[0.,0.,0.,0.,0.],[0.,1.,-2.,1.,0.],[0.,0.,0.,0.,0.],[0.,0.,0.,0.,0.]],
99
+ [[0.,0.,0.,0.,0.],[0.,-1.,2.,-1.,0.],[0.,2.,-4.,2.,0.],[0.,-1.,2.,-1.,0.],[0.,0.,0.,0.,0.]],
100
+ [[-1.,2.,-2.,2.,-1.],[2.,-6.,8.,-6.,2.],[-2.,8.,-12.,8.,-2.],[2.,-6.,8.,-6.,2.],[-1.,2.,-2.,2.,-1.]],
101
+ ])).float()
102
+ srm_kernel[0] /= 2
103
+ srm_kernel[1] /= 4
104
+ srm_kernel[2] /= 12
105
+ return srm_kernel.view(3, 1, 5, 5).repeat(1, input_channels, 1, 1)
106
+
107
+
108
+ def _setup_srm_layer(input_channels: int = 3) -> nn.Module:
109
+ weights = _setup_srm_weights(input_channels)
110
+ conv = nn.Conv2d(input_channels, 3, kernel_size=5, stride=1, padding=2, bias=False)
111
+ with torch.no_grad():
112
+ conv.weight = nn.Parameter(weights, requires_grad=False)
113
+ return conv
114
+
115
 
116
  class DeepFakeClassifier(nn.Module):
117
+ def __init__(self, dropout_rate=0.0):
118
  super().__init__()
119
+ self.encoder = _make_encoder()
120
  self.avg_pool = AdaptiveAvgPool2d((1, 1))
121
+ self.dropout = Dropout(dropout_rate)
122
+ self.fc = Linear(ENCODER_FEATURES, 1)
123
 
124
  def forward(self, x):
125
  x = self.encoder.forward_features(x)
126
  x = self.avg_pool(x).flatten(1)
127
  x = self.dropout(x)
128
+ return self.fc(x)
129
+
130
+
131
+ class DeepFakeClassifierSRM(nn.Module):
132
+ def __init__(self, dropout_rate=0.5):
133
+ super().__init__()
134
+ self.encoder = _make_encoder()
135
+ self.avg_pool = AdaptiveAvgPool2d((1, 1))
136
+ self.srm_conv = _setup_srm_layer(3)
137
+ self.dropout = Dropout(dropout_rate)
138
+ self.fc = Linear(ENCODER_FEATURES, 1)
139
+
140
+ def forward(self, x):
141
+ noise = self.srm_conv(x)
142
+ x = self.encoder.forward_features(noise)
143
+ x = self.avg_pool(x).flatten(1)
144
+ x = self.dropout(x)
145
+ return self.fc(x)
146
+
147
+
148
+ class _GWAP(nn.Module):
149
+ def __init__(self, features: int):
150
+ super().__init__()
151
+ self.conv = nn.Conv2d(features, 1, kernel_size=1, bias=True)
152
+
153
+ def forward(self, x):
154
+ w = self.conv(x).sigmoid().exp()
155
+ w = w / w.sum(dim=[2, 3], keepdim=True)
156
+ return (w * x).sum(dim=[2, 3], keepdim=False)
157
+
158
+
159
+ class DeepFakeClassifierGWAP(nn.Module):
160
+ def __init__(self, dropout_rate=0.5):
161
+ super().__init__()
162
+ self.encoder = _make_encoder()
163
+ self.avg_pool = _GWAP(ENCODER_FEATURES)
164
+ self.dropout = Dropout(dropout_rate)
165
+ self.fc = Linear(ENCODER_FEATURES, 1)
166
+
167
+ def forward(self, x):
168
+ x = self.encoder.forward_features(x)
169
+ x = self.avg_pool(x)
170
+ x = self.dropout(x)
171
+ return self.fc(x)
172
+
173
+
174
+ _MODEL_MAP = {
175
+ "base": DeepFakeClassifier,
176
+ "srm": DeepFakeClassifierSRM,
177
+ "gwap": DeepFakeClassifierGWAP,
178
+ }
179
+
180
+
181
+ # ─────────────────────────────────────────────────────────────────────────────
182
+ # Model loading (runs once at startup)
183
+ # ────────────────────────────────��────────────────────────────────────────────
184
+
185
+ def load_model() -> nn.Module:
186
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
187
+ cls = _MODEL_MAP.get(MODEL_CLASS, DeepFakeClassifier)
188
+ model = cls().to(device)
189
+
190
+ weight_path = os.path.join(WEIGHTS_DIR, WEIGHT_FILE)
191
+ # Allow common extensions in case the file was renamed
192
+ if not os.path.exists(weight_path):
193
+ for ext in (".pt", ".pth", ".bin"):
194
+ if os.path.exists(weight_path + ext):
195
+ weight_path = weight_path + ext
196
+ break
197
+
198
+ if not os.path.exists(weight_path):
199
+ raise FileNotFoundError(
200
+ f"Weight file not found: {weight_path}\n"
201
+ f"Files present in '{WEIGHTS_DIR}': {os.listdir(WEIGHTS_DIR)}"
202
+ )
203
+
204
+ logger.info(f"Loading weights from: {weight_path}")
205
+
206
+ # PyTorch 2.6+ requires weights_only=False for pickled checkpoints; also
207
+ # use map_location='cpu' so the model loads on any machine regardless of
208
+ # how it was saved.
209
+ checkpoint = torch.load(weight_path, map_location="cpu", weights_only=False)
210
+ state_dict = checkpoint.get("state_dict", checkpoint)
211
+
212
+ # Strip "module." prefix added by DataParallel / DistributedDataParallel
213
+ cleaned = {re.sub(r"^module\.", "", k): v for k, v in state_dict.items()}
214
+ model.load_state_dict(cleaned, strict=True)
215
+ model.eval()
216
+
217
+ # FP16 halves VRAM; safe on both CUDA and CPU
218
+ model = model.half()
219
+
220
+ logger.info(f"Model ready — class={cls.__name__}, device={device}, fp16=True")
221
+ return model, device
222
+
223
+
224
+ try:
225
+ MODEL, DEVICE = load_model()
226
+ LOAD_ERROR = None
227
+ except Exception as exc:
228
+ MODEL = None
229
+ DEVICE = None
230
+ LOAD_ERROR = traceback.format_exc()
231
+ logger.error(f"MODEL LOAD FAILED:\n{LOAD_ERROR}")
232
+
233
+
234
+ # ─────────────────────────────────────────────────────────────────────────────
235
+ # Inference helper
236
+ # ─────────────────────────────────────────────────────────────────────────────
237
+
238
+ def _preprocess_npy(npy_input) -> torch.Tensor:
239
+ """
240
+ Load a uint8 HWC .npy face-batch, convert to normalised float CHW tensor.
241
+
242
+ Gradio version compatibility matrix
243
+ ─────────────────────────────────────
244
+ Gradio 4 : passes a plain string filepath "/tmp/gradio/.../faces.npy"
245
+ Gradio 4 : may wrap in dict {"path": "...", "orig_name": "..."}
246
+ Gradio 6 : passes a tempfile.SpooledTemporaryFile (file-like with .name)
247
+ OR a gradio.FileData dataclass with a .path attribute
248
+
249
+ We resolve all four forms to a final file path or file-like object
250
+ that np.load() can consume.
251
+ """
252
+ npy_path = None # will hold a string path if resolvable
253
+ file_obj = None # will hold a file-like if path is unavailable
254
+
255
+ # ── Form 1: plain string ──────────────────────────────────────────────────
256
+ if isinstance(npy_input, str):
257
+ npy_path = npy_input
258
+
259
+ # ── Form 2: Gradio 4 dict {"path": ..., "orig_name": ...} ────────────────
260
+ elif isinstance(npy_input, dict):
261
+ npy_path = (
262
+ npy_input.get("path")
263
+ or npy_input.get("name")
264
+ or next(iter(npy_input.values()), None)
265
+ )
266
+
267
+ # ── Form 3: Gradio 6 dataclass (has .path attribute) ─────────────────────
268
+ elif hasattr(npy_input, "path"):
269
+ npy_path = npy_input.path
270
+
271
+ # ── Form 4: file-like object (SpooledTemporaryFile, BytesIO, etc.) ────────
272
+ elif hasattr(npy_input, "read"):
273
+ # Try to get the backing file path first (avoids reading into RAM twice)
274
+ backing = getattr(npy_input, "name", None)
275
+ if backing and isinstance(backing, str) and os.path.exists(backing):
276
+ npy_path = backing
277
+ else:
278
+ file_obj = npy_input
279
+
280
+ else:
281
+ raise TypeError(
282
+ f"Cannot resolve npy input of type {type(npy_input)}: {npy_input!r}"
283
+ )
284
+
285
+ # ── Load the array ─────────────────────────────────────────────────────────
286
+ def _load(src):
287
+ try:
288
+ return np.load(src, allow_pickle=False)
289
+ except ValueError:
290
+ # Legacy pickled .npy — seek back to start if file-like
291
+ if hasattr(src, "seek"):
292
+ src.seek(0)
293
+ return np.load(src, allow_pickle=True)
294
+
295
+ if npy_path is not None:
296
+ if not os.path.exists(npy_path):
297
+ raise FileNotFoundError(f"NPY payload not found at: {npy_path}")
298
+ faces_uint8 = _load(npy_path)
299
+ else:
300
+ faces_uint8 = _load(file_obj)
301
+
302
+ # ── Validate shape ─────────────────────────────────────────────────────────
303
+ if faces_uint8.ndim != 4 or faces_uint8.shape[3] != 3:
304
+ raise ValueError(
305
+ f"Expected uint8 array shape (N, H, W, 3), got {faces_uint8.shape}"
306
+ )
307
+
308
+ # Convert: uint8 HWC → float32 CHW → normalised
309
+ tensor = torch.from_numpy(faces_uint8).float() # [N, H, W, 3]
310
+ tensor = tensor.permute(0, 3, 1, 2) # [N, 3, H, W]
311
+ # Normalise each frame in-place
312
+ for i in range(tensor.shape[0]):
313
+ tensor[i] = normalize_fn(tensor[i] / 255.0)
314
+
315
+ return tensor # float32, shape [N, 3, H, W]
316
+
317
+
318
+ def run_inference(tensor: torch.Tensor) -> list:
319
+ """
320
+ Forward-pass the pre-processed face tensor through the model in
321
+ mini-batches of size MINI_BATCH to avoid OOM on 16 GB RAM spaces.
322
+ Returns a flat Python list of per-frame fake-probabilities [0, 1].
323
+ """
324
+ predictions = []
325
+ n = tensor.shape[0]
326
+
327
+ with torch.no_grad():
328
+ for start in range(0, n, MINI_BATCH):
329
+ batch = tensor[start : start + MINI_BATCH]
330
+ batch = batch.to(DEVICE).half() # fp16 matches model dtype
331
+
332
+ logits = MODEL(batch) # [B, 1]
333
+ probs = torch.sigmoid(logits.squeeze(-1)) # [B]
334
+ predictions.extend(probs.cpu().float().tolist())
335
+
336
+ return predictions
337
+
338
+
339
+ # ─────────────────────────────────────────────────────────────────────────────
340
+ # Gradio endpoint (headless — no UI blocks, purely an API)
341
+ # ─────────────────────────────────────────────────────────────────────────────
342
+
343
+ def predict(npy_file) -> dict:
344
+ """
345
+ Gradio API endpoint.
346
+
347
+ Parameters
348
+ ----------
349
+ npy_file : str | dict
350
+ Filepath (or Gradio file dict) pointing to the .npy face batch.
351
+
352
+ Returns
353
+ -------
354
+ dict with keys:
355
+ predictions : list[float] | None
356
+ n_frames : int
357
+ error : str | None
358
+ """
359
+ if MODEL is None:
360
+ msg = f"Model failed to load at startup:\n{LOAD_ERROR}"
361
+ logger.error(msg)
362
+ return {"predictions": None, "n_frames": 0, "error": msg}
363
+
364
  try:
365
+ tensor = _preprocess_npy(npy_file)
366
+ n_frames = tensor.shape[0]
367
+ predictions = run_inference(tensor)
368
+ logger.info(f"Inference OK frames={n_frames}, mean_pred={np.mean(predictions):.4f}")
369
+ return {"predictions": predictions, "n_frames": n_frames, "error": None}
370
+
371
+ except Exception:
372
+ err = traceback.format_exc()
373
+ logger.error(f"Inference failed:\n{err}")
374
+ return {"predictions": None, "n_frames": 0, "error": err}
375
+
376
+
377
+ # ─────────────────────────────────────────────────────────────────────��───────
378
+ # Launch
379
+ # ─────────────────────────────────────────────────────────────────────────────
380
+
381
+ demo = gr.Interface(
382
+ fn=predict,
383
+ inputs=gr.File(label="Face batch (.npy)", file_types=[".npy"]),
384
+ outputs=gr.JSON(label="Worker prediction"),
385
+ title=f"VERIDEX Worker — {WEIGHT_FILE}",
386
+ description=(
387
+ "Headless inference worker. "
388
+ "POST a uint8 .npy face-batch; receive per-frame fake probabilities."
389
+ ),
390
+ # allow_flagging removed: deprecated in Gradio 5, gone in Gradio 6
391
+ )
 
 
 
 
 
 
 
 
 
392
 
393
  if __name__ == "__main__":
394
+ demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)