NeelakshSaxena commited on
Commit
b3b4bd5
Β·
1 Parent(s): b9d0270

Replace tiny LLaVA with stable CPU fallback mode

Browse files
Files changed (1) hide show
  1. app.py +79 -99
app.py CHANGED
@@ -10,7 +10,7 @@ import gradio as gr
10
  import torch
11
  from dotenv import load_dotenv
12
  from PIL import Image, ImageOps
13
- from transformers import AutoProcessor, LlavaForConditionalGeneration
14
 
15
  ROOT_DIR = Path(__file__).resolve().parent
16
  SCRIPTS_DIR = ROOT_DIR / "scripts"
@@ -25,25 +25,23 @@ BASE_MODEL_ID = "llava-hf/llava-1.5-7b-hf"
25
  ADAPTER_PATH = ROOT_DIR / "final-production-weights" / "best_model"
26
  ADAPTER_REPO_ID = os.getenv("ADAPTER_REPO_ID", "Werrewulf/TMOS-DD")
27
  ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER", "")
28
- MODEL_MODE = os.getenv("MODEL_MODE", "tiny").strip().lower()
29
- TINY_BASE_MODEL_ID = os.getenv("TINY_BASE_MODEL_ID", "bczhou/tiny-llava-v1-hf")
 
 
30
  TMOS_PROMPT = "USER: <image>\nIs this video real or produced by AI?\nASSISTANT:"
31
- TINY_PROMPT = "Answer with one word only: Real or Fake."
32
  TARGET_IMAGE_SIZE = 336
33
  THRESHOLD = 0.5
34
 
35
  model = None
36
  processor = None
37
  inference_device = None
38
- inference_mode = MODEL_MODE
39
 
40
 
41
  def resolve_inference_device(model_obj) -> torch.device:
42
  if torch.cuda.is_available():
43
  return torch.device("cuda")
44
 
45
- # With device_map='auto', some parameters can live on 'meta' while offloaded.
46
- # For CPU Spaces, inputs must stay on CPU.
47
  device_map = getattr(model_obj, "hf_device_map", None)
48
  if isinstance(device_map, dict):
49
  for mapped in device_map.values():
@@ -117,36 +115,16 @@ def load_remote_adapter_config(repo_id: str, subfolder: str) -> dict | None:
117
  def select_torch_dtype() -> torch.dtype:
118
  if torch.cuda.is_available():
119
  return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
120
- # float16 on CPU is numerically unstable for this model and can produce NaNs.
121
  return torch.float32
122
 
123
 
124
- def load_model_and_processor():
125
  global model, processor, inference_device
126
 
127
- if model is not None and processor is not None and inference_device is not None:
128
- return model, processor, inference_device
129
-
130
- if MODEL_MODE == "tiny":
131
- dtype = select_torch_dtype()
132
- print(f"Loading low-memory tiny model from {TINY_BASE_MODEL_ID} with dtype={dtype}...")
133
-
134
- model = LlavaForConditionalGeneration.from_pretrained(
135
- TINY_BASE_MODEL_ID,
136
- torch_dtype=dtype,
137
- low_cpu_mem_usage=True,
138
- device_map="auto",
139
- token=HF_TOKEN,
140
  )
141
- model.eval()
142
-
143
- processor = AutoProcessor.from_pretrained(TINY_BASE_MODEL_ID, token=HF_TOKEN)
144
- processor.patch_size = 14
145
- processor.vision_feature_select_strategy = "default"
146
-
147
- inference_device = resolve_inference_device(model)
148
- print(f"TMOS-DD tiny fallback ready on {inference_device}.")
149
- return model, processor, inference_device
150
 
151
  from peft import PeftModel
152
  from tmos_classifier import TMOSClassifier
@@ -225,7 +203,6 @@ def load_model_and_processor():
225
  "Classifier weights did not change after loading adapter; adapter likely incompatible."
226
  )
227
 
228
- # Merge LoRA into the base network so inference always uses adapted weights.
229
  model = loaded_model.merge_and_unload()
230
  selected_subfolder = subfolder
231
  print(
@@ -250,16 +227,43 @@ def load_model_and_processor():
250
  print(f"Loaded TMOS local adapter (lora_layers={lora_layer_count})")
251
 
252
  model.eval()
253
-
254
  processor = AutoProcessor.from_pretrained(BASE_MODEL_ID, token=HF_TOKEN)
255
  processor.patch_size = 14
256
  processor.vision_feature_select_strategy = "default"
257
-
258
  inference_device = resolve_inference_device(model)
259
  if adapter_source == ADAPTER_REPO_ID:
260
  print(f"TMOS-DD ready on {inference_device} using remote subfolder '{selected_subfolder or '.'}'.")
261
  else:
262
  print(f"TMOS-DD ready on {inference_device} using local production adapter.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  return model, processor, inference_device
264
 
265
 
@@ -293,28 +297,33 @@ def confidence_card(prob_fake: float, label: str) -> str:
293
  """
294
 
295
 
296
- def score_binary_logits(logits: torch.Tensor, tokenizer) -> tuple[float, str]:
297
- fake_ids = tokenizer(" Fake", add_special_tokens=False).input_ids or tokenizer("Fake", add_special_tokens=False).input_ids
298
- real_ids = tokenizer(" Real", add_special_tokens=False).input_ids or tokenizer("Real", add_special_tokens=False).input_ids
299
 
300
- if not fake_ids or not real_ids:
301
- return 0.5, "Real"
 
 
 
 
 
 
302
 
303
- fake_id = fake_ids[0]
304
- real_id = real_ids[0]
 
305
 
306
- candidate_logits = torch.tensor([logits[real_id].item(), logits[fake_id].item()], dtype=torch.float32)
307
- probs = torch.softmax(candidate_logits, dim=0)
308
- prob_fake = float(probs[1].item())
309
- label = "Fake" if prob_fake >= THRESHOLD else "Real"
310
- return prob_fake, label
311
 
 
 
 
 
 
312
 
313
- def build_prompt(processor_obj, user_text: str) -> str:
314
- if MODEL_MODE == "tiny":
315
- # Tiny LLaVA forward expects an explicit image placeholder token.
316
- return f"<image>\n{user_text}"
317
- return user_text
318
 
319
 
320
  def infer_image(image: Image.Image):
@@ -323,27 +332,7 @@ def infer_image(image: Image.Image):
323
  return None, "Error: please upload an image.", None, None, None, "<div style='color:#f87171;'>Please upload an image before running detection.</div>"
324
 
325
  model_obj, processor_obj, device = load_model_and_processor()
326
-
327
  prepared_image = preprocess_image(image)
328
- prompt_text = build_prompt(processor_obj, TINY_PROMPT if MODEL_MODE == "tiny" else TMOS_PROMPT)
329
- inputs = processor_obj(text=prompt_text, images=prepared_image, return_tensors="pt", padding=True)
330
- inputs = {name: tensor.to(device) for name, tensor in inputs.items()}
331
-
332
- if MODEL_MODE == "tiny":
333
- image_token_id = getattr(model_obj.config, "image_token_index", None)
334
- if image_token_id is not None and "input_ids" in inputs:
335
- image_token_count = int((inputs["input_ids"] == image_token_id).sum().item())
336
- if image_token_count == 0:
337
- # Defensive recovery if tokenizer template/path strips the image placeholder.
338
- fallback_prompt = f"<image>\n{TINY_PROMPT}"
339
- inputs = processor_obj(text=fallback_prompt, images=prepared_image, return_tensors="pt", padding=True)
340
- inputs = {name: tensor.to(device) for name, tensor in inputs.items()}
341
-
342
- # Keep pixel dtype stable across CPU/GPU to avoid backend kernel errors.
343
- if "pixel_values" in inputs:
344
- inputs["pixel_values"] = inputs["pixel_values"].to(
345
- dtype=select_torch_dtype() if device.type == "cuda" else torch.float32
346
- )
347
 
348
  autocast_context = (
349
  torch.autocast(device_type="cuda", dtype=select_torch_dtype())
@@ -353,44 +342,35 @@ def infer_image(image: Image.Image):
353
 
354
  start_time = time.perf_counter()
355
  with torch.inference_mode(), autocast_context:
356
- if MODEL_MODE == "tiny":
357
- outputs = model_obj(
358
- input_ids=inputs["input_ids"],
359
- pixel_values=inputs["pixel_values"],
360
- attention_mask=inputs.get("attention_mask"),
361
- return_dict=True,
362
- )
363
- else:
364
  outputs = model_obj(
365
  input_ids=inputs["input_ids"],
366
  pixel_values=inputs["pixel_values"],
367
  attention_mask=inputs["attention_mask"],
368
  )
 
 
 
 
 
 
 
 
 
 
 
 
369
 
370
  if device.type == "cuda":
371
  torch.cuda.synchronize()
372
 
373
  elapsed_ms = (time.perf_counter() - start_time) * 1000.0
374
-
375
- if MODEL_MODE == "tiny":
376
- next_token_logits = outputs.logits[:, -1, :].squeeze(0).detach().float().cpu()
377
- prob_fake, label = score_binary_logits(next_token_logits, processor_obj.tokenizer)
378
- logit = float(torch.logit(torch.tensor(prob_fake, dtype=torch.float32), eps=1e-6).item())
379
- else:
380
- logit = float(outputs["logit"].squeeze().detach().float().cpu().item())
381
- if not math.isfinite(logit):
382
- raise gr.Error("Model produced a non-finite logit (NaN/Inf). Please retry.")
383
-
384
- prob_fake = float(torch.sigmoid(torch.tensor(logit)).item())
385
- if not math.isfinite(prob_fake):
386
- raise gr.Error("Model produced a non-finite probability (NaN/Inf). Please retry.")
387
-
388
- label = "Fake" if prob_fake >= THRESHOLD else "Real"
389
-
390
  if not math.isfinite(prob_fake):
391
  raise gr.Error("Model produced a non-finite probability (NaN/Inf). Please retry.")
392
- confidence = prob_fake if label == "Fake" else 1.0 - prob_fake
393
 
 
394
  return prepared_image, label, round(prob_fake, 6), round(confidence * 100.0, 2), round(elapsed_ms, 2), confidence_card(prob_fake, label)
395
  except Exception as exc:
396
  err = f"Inference failed: {type(exc).__name__}: {exc}"
@@ -402,9 +382,9 @@ load_model_and_processor()
402
 
403
  with gr.Blocks(title="TMOS Deepfake Detector", theme=gr.themes.Soft()) as demo:
404
  demo_description = (
405
- "Low-memory fallback mode using tiny LLaVA for stable Space execution.\n\n"
406
- if MODEL_MODE == "tiny"
407
- else "Research demo for image-based deepfake detection with a deterministic classification head on top of LLaVA-1.5-7B.\n\n"
408
  )
409
  gr.Markdown(
410
  "# TMOS Deepfake Detector\n"
 
10
  import torch
11
  from dotenv import load_dotenv
12
  from PIL import Image, ImageOps
13
+ from transformers import AutoProcessor, AutoImageProcessor, AutoModelForImageClassification
14
 
15
  ROOT_DIR = Path(__file__).resolve().parent
16
  SCRIPTS_DIR = ROOT_DIR / "scripts"
 
25
  ADAPTER_PATH = ROOT_DIR / "final-production-weights" / "best_model"
26
  ADAPTER_REPO_ID = os.getenv("ADAPTER_REPO_ID", "Werrewulf/TMOS-DD")
27
  ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER", "")
28
+
29
+ MODEL_MODE = os.getenv("MODEL_MODE", "cpu-fallback").strip().lower()
30
+ CPU_FALLBACK_MODEL_ID = os.getenv("CPU_FALLBACK_MODEL_ID", "DaMsTaR/Detecto-DeepFake_Image_Detector")
31
+
32
  TMOS_PROMPT = "USER: <image>\nIs this video real or produced by AI?\nASSISTANT:"
 
33
  TARGET_IMAGE_SIZE = 336
34
  THRESHOLD = 0.5
35
 
36
  model = None
37
  processor = None
38
  inference_device = None
 
39
 
40
 
41
  def resolve_inference_device(model_obj) -> torch.device:
42
  if torch.cuda.is_available():
43
  return torch.device("cuda")
44
 
 
 
45
  device_map = getattr(model_obj, "hf_device_map", None)
46
  if isinstance(device_map, dict):
47
  for mapped in device_map.values():
 
115
  def select_torch_dtype() -> torch.dtype:
116
  if torch.cuda.is_available():
117
  return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
 
118
  return torch.float32
119
 
120
 
121
+ def load_tmos_model():
122
  global model, processor, inference_device
123
 
124
+ if not torch.cuda.is_available():
125
+ raise RuntimeError(
126
+ "TMOS mode requires GPU hardware. Set MODEL_MODE=cpu-fallback for free-tier CPU execution."
 
 
 
 
 
 
 
 
 
 
127
  )
 
 
 
 
 
 
 
 
 
128
 
129
  from peft import PeftModel
130
  from tmos_classifier import TMOSClassifier
 
203
  "Classifier weights did not change after loading adapter; adapter likely incompatible."
204
  )
205
 
 
206
  model = loaded_model.merge_and_unload()
207
  selected_subfolder = subfolder
208
  print(
 
227
  print(f"Loaded TMOS local adapter (lora_layers={lora_layer_count})")
228
 
229
  model.eval()
 
230
  processor = AutoProcessor.from_pretrained(BASE_MODEL_ID, token=HF_TOKEN)
231
  processor.patch_size = 14
232
  processor.vision_feature_select_strategy = "default"
 
233
  inference_device = resolve_inference_device(model)
234
  if adapter_source == ADAPTER_REPO_ID:
235
  print(f"TMOS-DD ready on {inference_device} using remote subfolder '{selected_subfolder or '.'}'.")
236
  else:
237
  print(f"TMOS-DD ready on {inference_device} using local production adapter.")
238
+
239
+
240
+ def load_cpu_fallback_model():
241
+ global model, processor, inference_device
242
+ print(f"Loading CPU fallback model from {CPU_FALLBACK_MODEL_ID}...")
243
+
244
+ processor = AutoImageProcessor.from_pretrained(CPU_FALLBACK_MODEL_ID, token=HF_TOKEN)
245
+ model = AutoModelForImageClassification.from_pretrained(
246
+ CPU_FALLBACK_MODEL_ID,
247
+ torch_dtype=torch.float32,
248
+ low_cpu_mem_usage=True,
249
+ token=HF_TOKEN,
250
+ )
251
+ model.to("cpu").eval()
252
+ inference_device = torch.device("cpu")
253
+ print("CPU fallback classifier ready.")
254
+
255
+
256
+ def load_model_and_processor():
257
+ global model, processor, inference_device
258
+
259
+ if model is not None and processor is not None and inference_device is not None:
260
+ return model, processor, inference_device
261
+
262
+ if MODEL_MODE == "tmos":
263
+ load_tmos_model()
264
+ else:
265
+ load_cpu_fallback_model()
266
+
267
  return model, processor, inference_device
268
 
269
 
 
297
  """
298
 
299
 
300
+ def score_fallback_logits(logits: torch.Tensor, id2label: dict) -> tuple[float, str]:
301
+ probs = torch.softmax(logits.float(), dim=0)
 
302
 
303
+ fake_indices = []
304
+ real_indices = []
305
+ for idx in range(len(probs)):
306
+ label = str(id2label.get(idx, "")).lower()
307
+ if any(key in label for key in ["fake", "deepfake", "ai", "synthetic"]):
308
+ fake_indices.append(idx)
309
+ if any(key in label for key in ["real", "authentic", "genuine"]):
310
+ real_indices.append(idx)
311
 
312
+ if len(probs) == 2 and not fake_indices and not real_indices:
313
+ fake_indices = [1]
314
+ real_indices = [0]
315
 
316
+ fake_prob = float(probs[fake_indices].sum().item()) if fake_indices else 0.0
317
+ real_prob = float(probs[real_indices].sum().item()) if real_indices else 0.0
 
 
 
318
 
319
+ total = fake_prob + real_prob
320
+ if total > 0:
321
+ prob_fake = fake_prob / total
322
+ else:
323
+ prob_fake = float(probs.max().item()) if len(probs) == 1 else float(probs[1].item()) if len(probs) > 1 else 0.5
324
 
325
+ label = "Fake" if prob_fake >= THRESHOLD else "Real"
326
+ return prob_fake, label
 
 
 
327
 
328
 
329
  def infer_image(image: Image.Image):
 
332
  return None, "Error: please upload an image.", None, None, None, "<div style='color:#f87171;'>Please upload an image before running detection.</div>"
333
 
334
  model_obj, processor_obj, device = load_model_and_processor()
 
335
  prepared_image = preprocess_image(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
  autocast_context = (
338
  torch.autocast(device_type="cuda", dtype=select_torch_dtype())
 
342
 
343
  start_time = time.perf_counter()
344
  with torch.inference_mode(), autocast_context:
345
+ if MODEL_MODE == "tmos":
346
+ inputs = processor_obj(text=TMOS_PROMPT, images=prepared_image, return_tensors="pt", padding=True)
347
+ inputs = {name: tensor.to(device) for name, tensor in inputs.items()}
 
 
 
 
 
348
  outputs = model_obj(
349
  input_ids=inputs["input_ids"],
350
  pixel_values=inputs["pixel_values"],
351
  attention_mask=inputs["attention_mask"],
352
  )
353
+ logit = float(outputs["logit"].squeeze().detach().float().cpu().item())
354
+ if not math.isfinite(logit):
355
+ raise gr.Error("Model produced a non-finite logit (NaN/Inf). Please retry.")
356
+ prob_fake = float(torch.sigmoid(torch.tensor(logit)).item())
357
+ label = "Fake" if prob_fake >= THRESHOLD else "Real"
358
+ else:
359
+ inputs = processor_obj(images=prepared_image, return_tensors="pt")
360
+ inputs = {name: tensor.to(device) for name, tensor in inputs.items()}
361
+ outputs = model_obj(**inputs)
362
+ logits = outputs.logits.squeeze(0).detach().float().cpu()
363
+ id2label = getattr(model_obj.config, "id2label", {}) or {}
364
+ prob_fake, label = score_fallback_logits(logits, id2label)
365
 
366
  if device.type == "cuda":
367
  torch.cuda.synchronize()
368
 
369
  elapsed_ms = (time.perf_counter() - start_time) * 1000.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  if not math.isfinite(prob_fake):
371
  raise gr.Error("Model produced a non-finite probability (NaN/Inf). Please retry.")
 
372
 
373
+ confidence = prob_fake if label == "Fake" else 1.0 - prob_fake
374
  return prepared_image, label, round(prob_fake, 6), round(confidence * 100.0, 2), round(elapsed_ms, 2), confidence_card(prob_fake, label)
375
  except Exception as exc:
376
  err = f"Inference failed: {type(exc).__name__}: {exc}"
 
382
 
383
  with gr.Blocks(title="TMOS Deepfake Detector", theme=gr.themes.Soft()) as demo:
384
  demo_description = (
385
+ "TMOS mode (GPU required) enabled.\n\n"
386
+ if MODEL_MODE == "tmos"
387
+ else f"CPU fallback mode using {CPU_FALLBACK_MODEL_ID}.\n\n"
388
  )
389
  gr.Markdown(
390
  "# TMOS Deepfake Detector\n"