NeelakshSaxena commited on
Commit
8d017ca
Β·
verified Β·
1 Parent(s): 513d223

Deploy auto GPU fallback + FastAPI /predict

Browse files
Files changed (5) hide show
  1. .gitattributes +35 -35
  2. README.md +14 -14
  3. app.py +464 -423
  4. requirements.txt +36 -36
  5. scripts/tmos_classifier.py +216 -216
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,14 +1,14 @@
1
- ---
2
- title: TrueFrame
3
- emoji: πŸš€
4
- colorFrom: yellow
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 6.11.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- short_description: A LLaVA-based multimodal classifier with LoRA fine-tuning, d
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: TrueFrame
3
+ emoji: πŸš€
4
+ colorFrom: yellow
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 6.11.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ short_description: A LLaVA-based multimodal classifier with LoRA fine-tuning, d
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,423 +1,464 @@
1
- import os
2
- import sys
3
- import time
4
- import math
5
- import json
6
- from contextlib import nullcontext
7
- from pathlib import Path
8
-
9
- import gradio as gr
10
- import torch
11
- from dotenv import load_dotenv
12
- from PIL import Image, ImageOps
13
- from transformers import AutoProcessor, AutoImageProcessor, AutoModelForImageClassification
14
-
15
- ROOT_DIR = Path(__file__).resolve().parent
16
- SCRIPTS_DIR = ROOT_DIR / "scripts"
17
-
18
- if str(SCRIPTS_DIR) not in sys.path:
19
- sys.path.insert(0, str(SCRIPTS_DIR))
20
-
21
- load_dotenv()
22
- HF_TOKEN = os.getenv("HF_TOKEN")
23
-
24
- BASE_MODEL_ID = "llava-hf/llava-1.5-7b-hf"
25
- ADAPTER_PATH = ROOT_DIR / "final-production-weights" / "best_model"
26
- ADAPTER_REPO_ID = os.getenv("ADAPTER_REPO_ID", "Werrewulf/TMOS-DD")
27
- ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER", "")
28
-
29
- MODEL_MODE = os.getenv("MODEL_MODE", "cpu-fallback").strip().lower()
30
- CPU_FALLBACK_MODEL_ID = os.getenv("CPU_FALLBACK_MODEL_ID", "DaMsTaR/Detecto-DeepFake_Image_Detector")
31
- DEFAULT_INVERT_FALLBACK = CPU_FALLBACK_MODEL_ID.lower() == "damstar/detecto-deepfake_image_detector"
32
- INVERT_FALLBACK_OUTPUT = os.getenv("INVERT_FALLBACK_OUTPUT", str(DEFAULT_INVERT_FALLBACK)).strip().lower() == "true"
33
-
34
- TMOS_PROMPT = "USER: <image>\nIs this video real or produced by AI?\nASSISTANT:"
35
- TARGET_IMAGE_SIZE = 336
36
- THRESHOLD = 0.5
37
-
38
- model = None
39
- processor = None
40
- inference_device = None
41
-
42
-
43
- def resolve_inference_device(model_obj) -> torch.device:
44
- if torch.cuda.is_available():
45
- return torch.device("cuda")
46
-
47
- device_map = getattr(model_obj, "hf_device_map", None)
48
- if isinstance(device_map, dict):
49
- for mapped in device_map.values():
50
- if isinstance(mapped, str) and mapped.startswith("cuda"):
51
- return torch.device(mapped)
52
- return torch.device("cpu")
53
-
54
-
55
- def find_classifier_weight_tensor(model_obj):
56
- visited = set()
57
- queue = [model_obj]
58
- while queue:
59
- current = queue.pop(0)
60
- if current is None:
61
- continue
62
- obj_id = id(current)
63
- if obj_id in visited:
64
- continue
65
- visited.add(obj_id)
66
-
67
- classifier = getattr(current, "classifier", None)
68
- if classifier is not None and hasattr(classifier, "weight"):
69
- return classifier.weight
70
-
71
- for attr in ("model", "base_model", "module"):
72
- nested = getattr(current, attr, None)
73
- if nested is not None:
74
- queue.append(nested)
75
-
76
- return None
77
-
78
-
79
- def count_lora_layers(model_obj) -> int:
80
- count = 0
81
- for _, module in model_obj.named_modules():
82
- if hasattr(module, "lora_A") and hasattr(module, "lora_B"):
83
- count += 1
84
- return count
85
-
86
-
87
- def is_tmos_adapter_config(cfg: dict) -> bool:
88
- modules_to_save = cfg.get("modules_to_save") or []
89
- target_modules = set(cfg.get("target_modules") or [])
90
- required_targets = {"q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"}
91
-
92
- return (
93
- "classifier" in modules_to_save
94
- and cfg.get("r") == 64
95
- and required_targets.issubset(target_modules)
96
- )
97
-
98
-
99
- def load_local_adapter_config(adapter_dir: Path) -> dict | None:
100
- cfg_path = adapter_dir / "adapter_config.json"
101
- if not cfg_path.exists():
102
- return None
103
- with cfg_path.open("r", encoding="utf-8") as fp:
104
- return json.load(fp)
105
-
106
-
107
- def load_remote_adapter_config(repo_id: str, subfolder: str) -> dict | None:
108
- from peft import PeftConfig
109
-
110
- try:
111
- peft_cfg = PeftConfig.from_pretrained(repo_id, subfolder=subfolder, token=HF_TOKEN)
112
- return peft_cfg.to_dict()
113
- except Exception:
114
- return None
115
-
116
-
117
- def select_torch_dtype() -> torch.dtype:
118
- if torch.cuda.is_available():
119
- return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
120
- return torch.float32
121
-
122
-
123
- def load_tmos_model():
124
- global model, processor, inference_device
125
-
126
- if not torch.cuda.is_available():
127
- raise RuntimeError(
128
- "TMOS mode requires GPU hardware. Set MODEL_MODE=cpu-fallback for free-tier CPU execution."
129
- )
130
-
131
- from peft import PeftModel
132
- from tmos_classifier import TMOSClassifier
133
-
134
- adapter_source = None
135
- local_adapter_file = next(
136
- (
137
- candidate
138
- for candidate in (
139
- ADAPTER_PATH / "adapter_model.safetensors",
140
- ADAPTER_PATH / "adapter_model.bin",
141
- )
142
- if candidate.exists()
143
- ),
144
- None,
145
- )
146
-
147
- selected_subfolder = ""
148
-
149
- if local_adapter_file is not None:
150
- adapter_source = str(ADAPTER_PATH)
151
- local_cfg = load_local_adapter_config(ADAPTER_PATH)
152
- if local_cfg is None or not is_tmos_adapter_config(local_cfg):
153
- raise RuntimeError(
154
- "Local adapter exists but is not TMOS-compatible. Expected modules_to_save=['classifier'], r=64, and TMOS target modules."
155
- )
156
- else:
157
- adapter_source = ADAPTER_REPO_ID
158
-
159
- dtype = select_torch_dtype()
160
- print(f"Loading TMOS-DD model from {adapter_source} with dtype={dtype}...")
161
-
162
- base_model = TMOSClassifier(
163
- base_model_id=BASE_MODEL_ID,
164
- torch_dtype=dtype,
165
- device_map="auto",
166
- token=HF_TOKEN,
167
- )
168
- base_classifier_weight = find_classifier_weight_tensor(base_model)
169
- base_classifier_snapshot = None
170
- if base_classifier_weight is not None:
171
- base_classifier_snapshot = base_classifier_weight.detach().float().cpu().clone()
172
-
173
- peft_kwargs = {"is_trainable": False, "token": HF_TOKEN}
174
- if adapter_source == ADAPTER_REPO_ID:
175
- candidate_subfolders = [
176
- s for s in [ADAPTER_SUBFOLDER, "multimodal", "multimodal/checkpoint-5", "llava"] if s is not None
177
- ]
178
-
179
- last_error = None
180
- for subfolder in candidate_subfolders:
181
- try:
182
- remote_cfg = load_remote_adapter_config(adapter_source, subfolder)
183
- if remote_cfg is None or not is_tmos_adapter_config(remote_cfg):
184
- raise ValueError("Adapter config is not TMOS-compatible.")
185
-
186
- current_kwargs = dict(peft_kwargs)
187
- if subfolder:
188
- current_kwargs["subfolder"] = subfolder
189
- loaded_model = PeftModel.from_pretrained(base_model, adapter_source, **current_kwargs)
190
-
191
- lora_layer_count = count_lora_layers(loaded_model)
192
- if lora_layer_count == 0:
193
- raise RuntimeError("Loaded adapter has zero LoRA layers attached.")
194
-
195
- loaded_classifier_weight = find_classifier_weight_tensor(loaded_model)
196
- if loaded_classifier_weight is None:
197
- raise RuntimeError("Classifier head not found after adapter load.")
198
-
199
- if base_classifier_snapshot is not None:
200
- classifier_delta = (
201
- loaded_classifier_weight.detach().float().cpu() - base_classifier_snapshot
202
- ).abs().mean().item()
203
- if classifier_delta < 1e-8:
204
- raise RuntimeError(
205
- "Classifier weights did not change after loading adapter; adapter likely incompatible."
206
- )
207
-
208
- model = loaded_model.merge_and_unload()
209
- selected_subfolder = subfolder
210
- print(
211
- f"Loaded TMOS adapter from repo subfolder: '{subfolder or '.'}' "
212
- f"(lora_layers={lora_layer_count})"
213
- )
214
- break
215
- except Exception as exc:
216
- last_error = exc
217
- continue
218
- else:
219
- raise RuntimeError(
220
- "No TMOS-compatible adapter found in remote repo. Upload TMOS production weights with classifier head "
221
- "(modules_to_save=['classifier'], r=64, 7-target-module LoRA)."
222
- ) from last_error
223
- else:
224
- loaded_model = PeftModel.from_pretrained(base_model, adapter_source, **peft_kwargs)
225
- lora_layer_count = count_lora_layers(loaded_model)
226
- if lora_layer_count == 0:
227
- raise RuntimeError("Local adapter load produced zero LoRA layers attached.")
228
- model = loaded_model.merge_and_unload()
229
- print(f"Loaded TMOS local adapter (lora_layers={lora_layer_count})")
230
-
231
- model.eval()
232
- processor = AutoProcessor.from_pretrained(BASE_MODEL_ID, token=HF_TOKEN)
233
- processor.patch_size = 14
234
- processor.vision_feature_select_strategy = "default"
235
- inference_device = resolve_inference_device(model)
236
- if adapter_source == ADAPTER_REPO_ID:
237
- print(f"TMOS-DD ready on {inference_device} using remote subfolder '{selected_subfolder or '.'}'.")
238
- else:
239
- print(f"TMOS-DD ready on {inference_device} using local production adapter.")
240
-
241
-
242
- def load_cpu_fallback_model():
243
- global model, processor, inference_device
244
- print(f"Loading CPU fallback model from {CPU_FALLBACK_MODEL_ID}...")
245
-
246
- processor = AutoImageProcessor.from_pretrained(CPU_FALLBACK_MODEL_ID, token=HF_TOKEN)
247
- model = AutoModelForImageClassification.from_pretrained(
248
- CPU_FALLBACK_MODEL_ID,
249
- torch_dtype=torch.float32,
250
- low_cpu_mem_usage=True,
251
- token=HF_TOKEN,
252
- )
253
- model.to("cpu").eval()
254
- inference_device = torch.device("cpu")
255
- print("CPU fallback classifier ready.")
256
-
257
-
258
- def load_model_and_processor():
259
- global model, processor, inference_device
260
-
261
- if model is not None and processor is not None and inference_device is not None:
262
- return model, processor, inference_device
263
-
264
- if MODEL_MODE == "tmos":
265
- load_tmos_model()
266
- else:
267
- load_cpu_fallback_model()
268
-
269
- return model, processor, inference_device
270
-
271
-
272
- def preprocess_image(image: Image.Image) -> Image.Image:
273
- image = image.convert("RGB")
274
- return ImageOps.contain(image, (TARGET_IMAGE_SIZE, TARGET_IMAGE_SIZE), method=Image.Resampling.BICUBIC)
275
-
276
-
277
- def confidence_card(prob_fake: float, label: str) -> str:
278
- confidence = prob_fake if label == "Fake" else 1.0 - prob_fake
279
- confidence_pct = confidence * 100.0
280
- fake_pct = prob_fake * 100.0
281
- real_pct = (1.0 - prob_fake) * 100.0
282
- accent = "#ef4444" if label == "Fake" else "#10b981"
283
-
284
- return f"""
285
- <div style="border:1px solid rgba(255,255,255,0.12); border-radius:16px; padding:16px; background:linear-gradient(135deg, rgba(17,24,39,0.92), rgba(15,23,42,0.96)); color:white;">
286
- <div style="font-size:0.85rem; opacity:0.8; letter-spacing:0.04em; text-transform:uppercase; margin-bottom:8px;">Confidence</div>
287
- <div style="display:flex; align-items:baseline; gap:10px; margin-bottom:12px;">
288
- <div style="font-size:2rem; font-weight:700; color:{accent};">{confidence_pct:.2f}%</div>
289
- <div style="font-size:1rem; opacity:0.9;">for <strong>{label}</strong></div>
290
- </div>
291
- <div style="height:12px; width:100%; background:rgba(255,255,255,0.08); border-radius:999px; overflow:hidden; margin-bottom:10px;">
292
- <div style="height:100%; width:{fake_pct:.2f}%; background:linear-gradient(90deg, #f87171, #ef4444);"></div>
293
- </div>
294
- <div style="display:flex; justify-content:space-between; font-size:0.9rem; opacity:0.95;">
295
- <span>Real: {real_pct:.2f}%</span>
296
- <span>Fake: {fake_pct:.2f}%</span>
297
- </div>
298
- </div>
299
- """
300
-
301
-
302
- def score_fallback_logits(logits: torch.Tensor, id2label: dict) -> tuple[float, str]:
303
- probs = torch.softmax(logits.float(), dim=0)
304
-
305
- fake_indices = []
306
- real_indices = []
307
- for idx in range(len(probs)):
308
- label = str(id2label.get(idx, "")).lower()
309
- if any(key in label for key in ["fake", "deepfake", "ai", "synthetic"]):
310
- fake_indices.append(idx)
311
- if any(key in label for key in ["real", "authentic", "genuine"]):
312
- real_indices.append(idx)
313
-
314
- if len(probs) == 2 and not fake_indices and not real_indices:
315
- fake_indices = [1]
316
- real_indices = [0]
317
-
318
- fake_prob = float(probs[fake_indices].sum().item()) if fake_indices else 0.0
319
- real_prob = float(probs[real_indices].sum().item()) if real_indices else 0.0
320
-
321
- total = fake_prob + real_prob
322
- if total > 0:
323
- prob_fake = fake_prob / total
324
- else:
325
- prob_fake = float(probs.max().item()) if len(probs) == 1 else float(probs[1].item()) if len(probs) > 1 else 0.5
326
-
327
- if INVERT_FALLBACK_OUTPUT:
328
- prob_fake = 1.0 - prob_fake
329
-
330
- label = "Fake" if prob_fake >= THRESHOLD else "Real"
331
- return prob_fake, label
332
-
333
-
334
- def infer_image(image: Image.Image):
335
- try:
336
- if image is None:
337
- return None, "Error: please upload an image.", None, None, None, "<div style='color:#f87171;'>Please upload an image before running detection.</div>"
338
-
339
- model_obj, processor_obj, device = load_model_and_processor()
340
- prepared_image = preprocess_image(image)
341
-
342
- autocast_context = (
343
- torch.autocast(device_type="cuda", dtype=select_torch_dtype())
344
- if device.type == "cuda"
345
- else nullcontext()
346
- )
347
-
348
- start_time = time.perf_counter()
349
- with torch.inference_mode(), autocast_context:
350
- if MODEL_MODE == "tmos":
351
- inputs = processor_obj(text=TMOS_PROMPT, images=prepared_image, return_tensors="pt", padding=True)
352
- inputs = {name: tensor.to(device) for name, tensor in inputs.items()}
353
- outputs = model_obj(
354
- input_ids=inputs["input_ids"],
355
- pixel_values=inputs["pixel_values"],
356
- attention_mask=inputs["attention_mask"],
357
- )
358
- logit = float(outputs["logit"].squeeze().detach().float().cpu().item())
359
- if not math.isfinite(logit):
360
- raise gr.Error("Model produced a non-finite logit (NaN/Inf). Please retry.")
361
- prob_fake = float(torch.sigmoid(torch.tensor(logit)).item())
362
- label = "Fake" if prob_fake >= THRESHOLD else "Real"
363
- else:
364
- inputs = processor_obj(images=prepared_image, return_tensors="pt")
365
- inputs = {name: tensor.to(device) for name, tensor in inputs.items()}
366
- outputs = model_obj(**inputs)
367
- logits = outputs.logits.squeeze(0).detach().float().cpu()
368
- id2label = getattr(model_obj.config, "id2label", {}) or {}
369
- prob_fake, label = score_fallback_logits(logits, id2label)
370
-
371
- if device.type == "cuda":
372
- torch.cuda.synchronize()
373
-
374
- elapsed_ms = (time.perf_counter() - start_time) * 1000.0
375
- if not math.isfinite(prob_fake):
376
- raise gr.Error("Model produced a non-finite probability (NaN/Inf). Please retry.")
377
-
378
- confidence = prob_fake if label == "Fake" else 1.0 - prob_fake
379
- return prepared_image, label, round(prob_fake, 6), round(confidence * 100.0, 2), round(elapsed_ms, 2), confidence_card(prob_fake, label)
380
- except Exception as exc:
381
- err = f"Inference failed: {type(exc).__name__}: {exc}"
382
- err_html = f"<div style='color:#fca5a5; border:1px solid rgba(252,165,165,0.35); padding:10px; border-radius:10px;'>\n<b>Inference error</b><br>{err}</div>"
383
- return None, err, None, None, None, err_html
384
-
385
-
386
- load_model_and_processor()
387
-
388
- with gr.Blocks(title="TMOS Deepfake Detector", theme=gr.themes.Soft()) as demo:
389
- demo_description = (
390
- "TMOS mode (GPU required) enabled.\n\n"
391
- if MODEL_MODE == "tmos"
392
- else f"CPU fallback mode using {CPU_FALLBACK_MODEL_ID} (invert_output={INVERT_FALLBACK_OUTPUT}).\n\n"
393
- )
394
- gr.Markdown(
395
- "# TMOS Deepfake Detector\n"
396
- + demo_description
397
- + "> Warning: runs on free infrastructure, so startup and inference may take time."
398
- )
399
-
400
- with gr.Row():
401
- image_input = gr.Image(type="pil", label="Upload image")
402
- with gr.Column():
403
- prediction_output = gr.Textbox(label="Prediction", interactive=False)
404
- probability_output = gr.Number(label="P(fake)", interactive=False, precision=6)
405
- confidence_output = gr.Number(label="Confidence (%)", interactive=False, precision=2)
406
- latency_output = gr.Number(label="Latency (ms)", interactive=False, precision=2)
407
-
408
- preview_output = gr.Image(label="Processed image passed to the model", interactive=False)
409
- confidence_html = gr.HTML()
410
-
411
- detect_button = gr.Button("Run detection", variant="primary")
412
-
413
- detect_button.click(
414
- fn=infer_image,
415
- inputs=image_input,
416
- outputs=[preview_output, prediction_output, probability_output, confidence_output, latency_output, confidence_html],
417
- )
418
-
419
- demo.queue(default_concurrency_limit=1, max_size=8)
420
-
421
-
422
- if __name__ == "__main__":
423
- demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import math
5
+ import json
6
+ import io
7
+ from contextlib import nullcontext
8
+ from pathlib import Path
9
+
10
+ import gradio as gr
11
+ import torch
12
+ from fastapi import FastAPI, UploadFile, File
13
+ from fastapi.middleware.cors import CORSMiddleware
14
+ from fastapi.responses import JSONResponse
15
+ from dotenv import load_dotenv
16
+ from PIL import Image, ImageOps
17
+ from transformers import AutoProcessor, AutoImageProcessor, AutoModelForImageClassification
18
+
19
+ ROOT_DIR = Path(__file__).resolve().parent
20
+ SCRIPTS_DIR = ROOT_DIR / "scripts"
21
+
22
+ if str(SCRIPTS_DIR) not in sys.path:
23
+ sys.path.insert(0, str(SCRIPTS_DIR))
24
+
25
+ load_dotenv()
26
+ HF_TOKEN = os.getenv("HF_TOKEN")
27
+
28
+ BASE_MODEL_ID = "llava-hf/llava-1.5-7b-hf"
29
+ ADAPTER_PATH = ROOT_DIR / "final-production-weights" / "best_model"
30
+ ADAPTER_REPO_ID = os.getenv("ADAPTER_REPO_ID", "Werrewulf/TMOS-DD")
31
+ ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER", "")
32
+
33
+ CPU_FALLBACK_MODEL_ID = os.getenv("CPU_FALLBACK_MODEL_ID", "DaMsTaR/Detecto-DeepFake_Image_Detector")
34
+ DEFAULT_INVERT_FALLBACK = CPU_FALLBACK_MODEL_ID.lower() == "damstar/detecto-deepfake_image_detector"
35
+ INVERT_FALLBACK_OUTPUT = os.getenv("INVERT_FALLBACK_OUTPUT", str(DEFAULT_INVERT_FALLBACK)).strip().lower() == "true"
36
+
37
+ TMOS_PROMPT = "USER: <image>\nIs this video real or produced by AI?\nASSISTANT:"
38
+ TARGET_IMAGE_SIZE = 336
39
+ THRESHOLD = 0.5
40
+
41
+ model = None
42
+ processor = None
43
+ inference_device = None
44
+
45
+
46
+ def resolve_inference_device(model_obj) -> torch.device:
47
+ if torch.cuda.is_available():
48
+ return torch.device("cuda")
49
+
50
+ device_map = getattr(model_obj, "hf_device_map", None)
51
+ if isinstance(device_map, dict):
52
+ for mapped in device_map.values():
53
+ if isinstance(mapped, str) and mapped.startswith("cuda"):
54
+ return torch.device(mapped)
55
+ return torch.device("cpu")
56
+
57
+
58
+ def find_classifier_weight_tensor(model_obj):
59
+ visited = set()
60
+ queue = [model_obj]
61
+ while queue:
62
+ current = queue.pop(0)
63
+ if current is None:
64
+ continue
65
+ obj_id = id(current)
66
+ if obj_id in visited:
67
+ continue
68
+ visited.add(obj_id)
69
+
70
+ classifier = getattr(current, "classifier", None)
71
+ if classifier is not None and hasattr(classifier, "weight"):
72
+ return classifier.weight
73
+
74
+ for attr in ("model", "base_model", "module"):
75
+ nested = getattr(current, attr, None)
76
+ if nested is not None:
77
+ queue.append(nested)
78
+
79
+ return None
80
+
81
+
82
+ def count_lora_layers(model_obj) -> int:
83
+ count = 0
84
+ for _, module in model_obj.named_modules():
85
+ if hasattr(module, "lora_A") and hasattr(module, "lora_B"):
86
+ count += 1
87
+ return count
88
+
89
+
90
+ def is_tmos_adapter_config(cfg: dict) -> bool:
91
+ modules_to_save = cfg.get("modules_to_save") or []
92
+ target_modules = set(cfg.get("target_modules") or [])
93
+ required_targets = {"q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"}
94
+
95
+ return (
96
+ "classifier" in modules_to_save
97
+ and cfg.get("r") == 64
98
+ and required_targets.issubset(target_modules)
99
+ )
100
+
101
+
102
+ def load_local_adapter_config(adapter_dir: Path) -> dict | None:
103
+ cfg_path = adapter_dir / "adapter_config.json"
104
+ if not cfg_path.exists():
105
+ return None
106
+ with cfg_path.open("r", encoding="utf-8") as fp:
107
+ return json.load(fp)
108
+
109
+
110
+ def load_remote_adapter_config(repo_id: str, subfolder: str) -> dict | None:
111
+ from peft import PeftConfig
112
+
113
+ try:
114
+ peft_cfg = PeftConfig.from_pretrained(repo_id, subfolder=subfolder, token=HF_TOKEN)
115
+ return peft_cfg.to_dict()
116
+ except Exception:
117
+ return None
118
+
119
+
120
+ def select_torch_dtype() -> torch.dtype:
121
+ if torch.cuda.is_available():
122
+ return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
123
+ return torch.float32
124
+
125
+
126
+ def load_tmos_model():
127
+ global model, processor, inference_device
128
+
129
+ if not torch.cuda.is_available():
130
+ raise RuntimeError(
131
+ "TMOS mode requires GPU hardware. CPU fallback should be used on CPU-only environments."
132
+ )
133
+
134
+ from peft import PeftModel
135
+ from tmos_classifier import TMOSClassifier
136
+
137
+ adapter_source = None
138
+ local_adapter_file = next(
139
+ (
140
+ candidate
141
+ for candidate in (
142
+ ADAPTER_PATH / "adapter_model.safetensors",
143
+ ADAPTER_PATH / "adapter_model.bin",
144
+ )
145
+ if candidate.exists()
146
+ ),
147
+ None,
148
+ )
149
+
150
+ selected_subfolder = ""
151
+
152
+ if local_adapter_file is not None:
153
+ adapter_source = str(ADAPTER_PATH)
154
+ local_cfg = load_local_adapter_config(ADAPTER_PATH)
155
+ if local_cfg is None or not is_tmos_adapter_config(local_cfg):
156
+ raise RuntimeError(
157
+ "Local adapter exists but is not TMOS-compatible. Expected modules_to_save=['classifier'], r=64, and TMOS target modules."
158
+ )
159
+ else:
160
+ adapter_source = ADAPTER_REPO_ID
161
+
162
+ dtype = select_torch_dtype()
163
+ print(f"Loading TMOS-DD model from {adapter_source} with dtype={dtype}...")
164
+
165
+ base_model = TMOSClassifier(
166
+ base_model_id=BASE_MODEL_ID,
167
+ torch_dtype=dtype,
168
+ device_map="auto",
169
+ token=HF_TOKEN,
170
+ )
171
+ base_classifier_weight = find_classifier_weight_tensor(base_model)
172
+ base_classifier_snapshot = None
173
+ if base_classifier_weight is not None:
174
+ base_classifier_snapshot = base_classifier_weight.detach().float().cpu().clone()
175
+
176
+ peft_kwargs = {"is_trainable": False, "token": HF_TOKEN}
177
+ if adapter_source == ADAPTER_REPO_ID:
178
+ candidate_subfolders = [
179
+ s for s in [ADAPTER_SUBFOLDER, "multimodal", "multimodal/checkpoint-5", "llava"] if s is not None
180
+ ]
181
+
182
+ last_error = None
183
+ for subfolder in candidate_subfolders:
184
+ try:
185
+ remote_cfg = load_remote_adapter_config(adapter_source, subfolder)
186
+ if remote_cfg is None or not is_tmos_adapter_config(remote_cfg):
187
+ raise ValueError("Adapter config is not TMOS-compatible.")
188
+
189
+ current_kwargs = dict(peft_kwargs)
190
+ if subfolder:
191
+ current_kwargs["subfolder"] = subfolder
192
+ loaded_model = PeftModel.from_pretrained(base_model, adapter_source, **current_kwargs)
193
+
194
+ lora_layer_count = count_lora_layers(loaded_model)
195
+ if lora_layer_count == 0:
196
+ raise RuntimeError("Loaded adapter has zero LoRA layers attached.")
197
+
198
+ loaded_classifier_weight = find_classifier_weight_tensor(loaded_model)
199
+ if loaded_classifier_weight is None:
200
+ raise RuntimeError("Classifier head not found after adapter load.")
201
+
202
+ if base_classifier_snapshot is not None:
203
+ classifier_delta = (
204
+ loaded_classifier_weight.detach().float().cpu() - base_classifier_snapshot
205
+ ).abs().mean().item()
206
+ if classifier_delta < 1e-8:
207
+ raise RuntimeError(
208
+ "Classifier weights did not change after loading adapter; adapter likely incompatible."
209
+ )
210
+
211
+ model = loaded_model.merge_and_unload()
212
+ selected_subfolder = subfolder
213
+ print(
214
+ f"Loaded TMOS adapter from repo subfolder: '{subfolder or '.'}' "
215
+ f"(lora_layers={lora_layer_count})"
216
+ )
217
+ break
218
+ except Exception as exc:
219
+ last_error = exc
220
+ continue
221
+ else:
222
+ raise RuntimeError(
223
+ "No TMOS-compatible adapter found in remote repo. Upload TMOS production weights with classifier head "
224
+ "(modules_to_save=['classifier'], r=64, 7-target-module LoRA)."
225
+ ) from last_error
226
+ else:
227
+ loaded_model = PeftModel.from_pretrained(base_model, adapter_source, **peft_kwargs)
228
+ lora_layer_count = count_lora_layers(loaded_model)
229
+ if lora_layer_count == 0:
230
+ raise RuntimeError("Local adapter load produced zero LoRA layers attached.")
231
+ model = loaded_model.merge_and_unload()
232
+ print(f"Loaded TMOS local adapter (lora_layers={lora_layer_count})")
233
+
234
+ model.eval()
235
+ processor = AutoProcessor.from_pretrained(BASE_MODEL_ID, token=HF_TOKEN)
236
+ processor.patch_size = 14
237
+ processor.vision_feature_select_strategy = "default"
238
+ inference_device = resolve_inference_device(model)
239
+ if adapter_source == ADAPTER_REPO_ID:
240
+ print(f"TMOS-DD ready on {inference_device} using remote subfolder '{selected_subfolder or '.'}'.")
241
+ else:
242
+ print(f"TMOS-DD ready on {inference_device} using local production adapter.")
243
+
244
+
245
+ def load_cpu_fallback_model():
246
+ global model, processor, inference_device
247
+ print(f"Loading CPU fallback model from {CPU_FALLBACK_MODEL_ID}...")
248
+
249
+ processor = AutoImageProcessor.from_pretrained(CPU_FALLBACK_MODEL_ID, token=HF_TOKEN)
250
+ model = AutoModelForImageClassification.from_pretrained(
251
+ CPU_FALLBACK_MODEL_ID,
252
+ torch_dtype=torch.float32,
253
+ low_cpu_mem_usage=True,
254
+ token=HF_TOKEN,
255
+ )
256
+ model.to("cpu").eval()
257
+ inference_device = torch.device("cpu")
258
+ print("CPU fallback classifier ready.")
259
+
260
+
261
+ def load_model_and_processor():
262
+ global model, processor, inference_device
263
+
264
+ if model is not None and processor is not None and inference_device is not None:
265
+ return model, processor, inference_device
266
+
267
+ if torch.cuda.is_available():
268
+ print("GPU detected -> loading TMOS")
269
+ try:
270
+ load_tmos_model()
271
+ except Exception as exc:
272
+ print(f"TMOS failed: {exc}")
273
+ print("Falling back to CPU model...")
274
+ load_cpu_fallback_model()
275
+ else:
276
+ print("No GPU detected -> using CPU fallback")
277
+ load_cpu_fallback_model()
278
+
279
+ return model, processor, inference_device
280
+
281
+
282
+ def preprocess_image(image: Image.Image) -> Image.Image:
283
+ image = image.convert("RGB")
284
+ return ImageOps.contain(image, (TARGET_IMAGE_SIZE, TARGET_IMAGE_SIZE), method=Image.Resampling.BICUBIC)
285
+
286
+
287
+ def confidence_card(prob_fake: float, label: str) -> str:
288
+ confidence = prob_fake if label == "Fake" else 1.0 - prob_fake
289
+ confidence_pct = confidence * 100.0
290
+ fake_pct = prob_fake * 100.0
291
+ real_pct = (1.0 - prob_fake) * 100.0
292
+ accent = "#ef4444" if label == "Fake" else "#10b981"
293
+
294
+ return f"""
295
+ <div style="border:1px solid rgba(255,255,255,0.12); border-radius:16px; padding:16px; background:linear-gradient(135deg, rgba(17,24,39,0.92), rgba(15,23,42,0.96)); color:white;">
296
+ <div style="font-size:0.85rem; opacity:0.8; letter-spacing:0.04em; text-transform:uppercase; margin-bottom:8px;">Confidence</div>
297
+ <div style="display:flex; align-items:baseline; gap:10px; margin-bottom:12px;">
298
+ <div style="font-size:2rem; font-weight:700; color:{accent};">{confidence_pct:.2f}%</div>
299
+ <div style="font-size:1rem; opacity:0.9;">for <strong>{label}</strong></div>
300
+ </div>
301
+ <div style="height:12px; width:100%; background:rgba(255,255,255,0.08); border-radius:999px; overflow:hidden; margin-bottom:10px;">
302
+ <div style="height:100%; width:{fake_pct:.2f}%; background:linear-gradient(90deg, #f87171, #ef4444);"></div>
303
+ </div>
304
+ <div style="display:flex; justify-content:space-between; font-size:0.9rem; opacity:0.95;">
305
+ <span>Real: {real_pct:.2f}%</span>
306
+ <span>Fake: {fake_pct:.2f}%</span>
307
+ </div>
308
+ </div>
309
+ """
310
+
311
+
312
+ def score_fallback_logits(logits: torch.Tensor, id2label: dict) -> tuple[float, str]:
313
+ probs = torch.softmax(logits.float(), dim=0)
314
+
315
+ fake_indices = []
316
+ real_indices = []
317
+ for idx in range(len(probs)):
318
+ label = str(id2label.get(idx, "")).lower()
319
+ if any(key in label for key in ["fake", "deepfake", "ai", "synthetic"]):
320
+ fake_indices.append(idx)
321
+ if any(key in label for key in ["real", "authentic", "genuine"]):
322
+ real_indices.append(idx)
323
+
324
+ if len(probs) == 2 and not fake_indices and not real_indices:
325
+ fake_indices = [1]
326
+ real_indices = [0]
327
+
328
+ fake_prob = float(probs[fake_indices].sum().item()) if fake_indices else 0.0
329
+ real_prob = float(probs[real_indices].sum().item()) if real_indices else 0.0
330
+
331
+ total = fake_prob + real_prob
332
+ if total > 0:
333
+ prob_fake = fake_prob / total
334
+ else:
335
+ prob_fake = float(probs.max().item()) if len(probs) == 1 else float(probs[1].item()) if len(probs) > 1 else 0.5
336
+
337
+ if INVERT_FALLBACK_OUTPUT:
338
+ prob_fake = 1.0 - prob_fake
339
+
340
+ label = "Fake" if prob_fake >= THRESHOLD else "Real"
341
+ return prob_fake, label
342
+
343
+
344
+ def infer_image(image: Image.Image):
345
+ try:
346
+ if image is None:
347
+ return None, "Error: please upload an image.", None, None, None, "<div style='color:#f87171;'>Please upload an image before running detection.</div>"
348
+
349
+ model_obj, processor_obj, device = load_model_and_processor()
350
+ prepared_image = preprocess_image(image)
351
+
352
+ autocast_context = (
353
+ torch.autocast(device_type="cuda", dtype=select_torch_dtype())
354
+ if device.type == "cuda"
355
+ else nullcontext()
356
+ )
357
+
358
+ start_time = time.perf_counter()
359
+ with torch.inference_mode(), autocast_context:
360
+ if inference_device.type == "cuda":
361
+ inputs = processor_obj(text=TMOS_PROMPT, images=prepared_image, return_tensors="pt", padding=True)
362
+ inputs = {name: tensor.to(device) for name, tensor in inputs.items()}
363
+ outputs = model_obj(
364
+ input_ids=inputs["input_ids"],
365
+ pixel_values=inputs["pixel_values"],
366
+ attention_mask=inputs["attention_mask"],
367
+ )
368
+ logit = float(outputs["logit"].squeeze().detach().float().cpu().item())
369
+ if not math.isfinite(logit):
370
+ raise gr.Error("Model produced a non-finite logit (NaN/Inf). Please retry.")
371
+ prob_fake = float(torch.sigmoid(torch.tensor(logit)).item())
372
+ label = "Fake" if prob_fake >= THRESHOLD else "Real"
373
+ else:
374
+ inputs = processor_obj(images=prepared_image, return_tensors="pt")
375
+ inputs = {name: tensor.to(device) for name, tensor in inputs.items()}
376
+ outputs = model_obj(**inputs)
377
+ logits = outputs.logits.squeeze(0).detach().float().cpu()
378
+ id2label = getattr(model_obj.config, "id2label", {}) or {}
379
+ prob_fake, label = score_fallback_logits(logits, id2label)
380
+
381
+ if device.type == "cuda":
382
+ torch.cuda.synchronize()
383
+
384
+ elapsed_ms = (time.perf_counter() - start_time) * 1000.0
385
+ if not math.isfinite(prob_fake):
386
+ raise gr.Error("Model produced a non-finite probability (NaN/Inf). Please retry.")
387
+
388
+ confidence = prob_fake if label == "Fake" else 1.0 - prob_fake
389
+ return prepared_image, label, round(prob_fake, 6), round(confidence * 100.0, 2), round(elapsed_ms, 2), confidence_card(prob_fake, label)
390
+ except Exception as exc:
391
+ err = f"Inference failed: {type(exc).__name__}: {exc}"
392
+ err_html = f"<div style='color:#fca5a5; border:1px solid rgba(252,165,165,0.35); padding:10px; border-radius:10px;'>\n<b>Inference error</b><br>{err}</div>"
393
+ return None, err, None, None, None, err_html
394
+
395
+
396
+
397
+ api = FastAPI()
398
+
399
+ api.add_middleware(
400
+ CORSMiddleware,
401
+ allow_origins=["*"],
402
+ allow_credentials=True,
403
+ allow_methods=["*"],
404
+ allow_headers=["*"],
405
+ )
406
+
407
+
408
+ @api.post("/predict")
409
+ async def predict(file: UploadFile = File(...)):
410
+ try:
411
+ contents = await file.read()
412
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
413
+
414
+ _, label, prob_fake, confidence, latency, _ = infer_image(image)
415
+
416
+ return JSONResponse(
417
+ {
418
+ "verdict": label,
419
+ "confidence_percent": confidence,
420
+ "p_fake": prob_fake,
421
+ "latency_ms": latency,
422
+ }
423
+ )
424
+ except Exception as exc:
425
+ return JSONResponse({"error": str(exc)}, status_code=500)
426
+
427
+
428
+ load_model_and_processor()
429
+
430
+ with gr.Blocks(title="TMOS Deepfake Detector", theme=gr.themes.Soft()) as demo:
431
+ device_label = "GPU (TMOS Model)" if torch.cuda.is_available() else "CPU Fallback Model"
432
+ gr.Markdown(
433
+ f"# TMOS Deepfake Detector\n"
434
+ f"**Running on:** {device_label}\n\n"
435
+ f"> Warning: runs on free infrastructure, so startup and inference may take time."
436
+ )
437
+
438
+ with gr.Row():
439
+ image_input = gr.Image(type="pil", label="Upload image")
440
+ with gr.Column():
441
+ prediction_output = gr.Textbox(label="Prediction", interactive=False)
442
+ probability_output = gr.Number(label="P(fake)", interactive=False, precision=6)
443
+ confidence_output = gr.Number(label="Confidence (%)", interactive=False, precision=2)
444
+ latency_output = gr.Number(label="Latency (ms)", interactive=False, precision=2)
445
+
446
+ preview_output = gr.Image(label="Processed image passed to the model", interactive=False)
447
+ confidence_html = gr.HTML()
448
+
449
+ detect_button = gr.Button("Run detection", variant="primary")
450
+
451
+ detect_button.click(
452
+ fn=infer_image,
453
+ inputs=image_input,
454
+ outputs=[preview_output, prediction_output, probability_output, confidence_output, latency_output, confidence_html],
455
+ )
456
+
457
+ demo.queue(default_concurrency_limit=1, max_size=8)
458
+
459
+ app = gr.mount_gradio_app(api, demo, path="/")
460
+
461
+ if __name__ == "__main__":
462
+ import uvicorn
463
+
464
+ uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "7860")))
requirements.txt CHANGED
@@ -1,36 +1,36 @@
1
- # Core ML
2
- numpy>=1.24.0
3
- python-dotenv
4
- gradio
5
- torch>=2.0.0
6
- torchvision>=0.15.0
7
- torchaudio>=2.0.0
8
- torchcodec
9
- Pillow
10
-
11
- # Dependencies
12
- albumentations>=0.5.2
13
- datasets
14
- huggingface_hub
15
- scikit-learn>=1.3.0
16
- scikit-image>=0.21.0
17
- pandas>=2.0.0
18
- matplotlib>=3.7.0
19
- seaborn
20
- transformers==4.36.2
21
- peft
22
- accelerate
23
- diffusers
24
- opencv-python
25
-
26
- # Optional fallback for lower-memory GPU execution
27
- bitsandbytes
28
-
29
- # M2TR specific
30
- yacs==0.1.8
31
- nbconvert
32
- tensorboard==2.20.0
33
- tqdm==4.67.1
34
- PyYAML==6.0.3
35
- simplejson==3.20.2
36
- fvcore
 
1
+ # Core ML
2
+ numpy>=1.24.0
3
+ python-dotenv
4
+ gradio
5
+ torch>=2.0.0
6
+ torchvision>=0.15.0
7
+ torchaudio>=2.0.0
8
+ torchcodec
9
+ Pillow
10
+
11
+ # Dependencies
12
+ albumentations>=0.5.2
13
+ datasets
14
+ huggingface_hub
15
+ scikit-learn>=1.3.0
16
+ scikit-image>=0.21.0
17
+ pandas>=2.0.0
18
+ matplotlib>=3.7.0
19
+ seaborn
20
+ transformers==4.36.2
21
+ peft
22
+ accelerate
23
+ diffusers
24
+ opencv-python
25
+
26
+ # Optional fallback for lower-memory GPU execution
27
+ bitsandbytes
28
+
29
+ # M2TR specific
30
+ yacs==0.1.8
31
+ nbconvert
32
+ tensorboard==2.20.0
33
+ tqdm==4.67.1
34
+ PyYAML==6.0.3
35
+ simplejson==3.20.2
36
+ fvcore
scripts/tmos_classifier.py CHANGED
@@ -1,216 +1,216 @@
1
- """
2
- TMOS_Classifier: Binary classification head on top of LLaVA's transformer backbone.
3
-
4
- Strips the autoregressive lm_head and replaces it with a single nn.Linear(hidden_size, 1)
5
- for binary deepfake detection (0 = Real, 1 = Fake).
6
-
7
- Usage:
8
- from tmos_classifier import TMOSClassifier, TMOS_LORA_CONFIG
9
-
10
- classifier = TMOSClassifier(base_model_id="llava-hf/llava-1.5-7b-hf")
11
- classifier = get_peft_model(classifier, TMOS_LORA_CONFIG)
12
-
13
- logit = classifier(input_ids=..., pixel_values=..., attention_mask=...)
14
- loss = nn.BCEWithLogitsLoss()(logit, label)
15
- """
16
-
17
- import torch
18
- import torch.nn as nn
19
- from transformers import LlavaForConditionalGeneration
20
- from peft import LoraConfig
21
-
22
-
23
- # ─── LoRA Configuration ──────────────────────────────────────────────
24
- # Massive expansion: r=64 across ALL linear layers in the LLM backbone.
25
- # We exclude lm_head (we discard it), fc1/fc2/out_proj (CLIP vision),
26
- # and linear_1/linear_2 (multi-modal projector) from LoRA to keep
27
- # the vision encoder frozen and only adapt the language transformer.
28
-
29
- TMOS_LORA_CONFIG = LoraConfig(
30
- r=64,
31
- lora_alpha=128, # 2x rank as a common heuristic
32
- target_modules=[
33
- "q_proj", "k_proj", "v_proj", "o_proj",
34
- "gate_proj", "up_proj", "down_proj",
35
- ],
36
- lora_dropout=0.1,
37
- bias="none",
38
- task_type=None, # Custom classifier β€” not a causal LM
39
- modules_to_save=["classifier"], # Always train the classification head
40
- )
41
-
42
-
43
- class TMOSClassifier(nn.Module):
44
- """
45
- Binary classifier built on the LLaVA transformer backbone.
46
-
47
- Architecture:
48
- pixel_values ──► CLIP Vision Tower ──► Multi-Modal Projector ──┐
49
- β”œβ”€β”€β–Ί LLaMA Transformer ──► last_hidden_state[:, -1, :] ──► classifier ──► logit
50
- input_ids ──► Token Embedding β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
51
-
52
- The lm_head is never used. We extract the final token's hidden state
53
- and pass it through a learned nn.Linear(hidden_size, 1) head.
54
- """
55
-
56
- def __init__(self, base_model_id, torch_dtype=torch.float16, device_map="auto", token=None):
57
- super().__init__()
58
-
59
- # Load the full LLaVA model (we need vision tower + projector + LLM)
60
- self.base = LlavaForConditionalGeneration.from_pretrained(
61
- base_model_id,
62
- torch_dtype=torch_dtype,
63
- low_cpu_mem_usage=True,
64
- device_map=device_map,
65
- token=token,
66
- )
67
-
68
- hidden_size = self.base.config.text_config.hidden_size # 4096 for 7B
69
-
70
- # Freeze the lm_head β€” we won't use it, but freezing prevents
71
- # wasted gradient computation if PEFT accidentally wraps it.
72
- for param in self.base.lm_head.parameters():
73
- param.requires_grad = False
74
-
75
- # Keep the classifier head in fp32 for numerical stability.
76
- self.classifier = nn.Linear(hidden_size, 1, dtype=torch.float32)
77
- nn.init.xavier_uniform_(self.classifier.weight)
78
- nn.init.zeros_(self.classifier.bias)
79
-
80
- def forward(
81
- self,
82
- input_ids=None,
83
- pixel_values=None,
84
- attention_mask=None,
85
- labels=None, # float tensor of shape (B,) β€” 0.0=real, 1.0=fake
86
- **kwargs, # absorb extra keys from data collator
87
- ):
88
- """
89
- Single deterministic forward pass β†’ logit + optional BCE loss.
90
-
91
- Returns:
92
- dict with keys:
93
- "logit": (B, 1) raw logit
94
- "loss": scalar BCE loss (only if labels provided)
95
- """
96
- # ── 1. Forward through the LLaVA backbone ──
97
- # We call the internal model (vision + projector + LLM) directly,
98
- # asking for hidden states, NOT for language-model logits.
99
- outputs = self.base.model(
100
- input_ids=input_ids,
101
- pixel_values=pixel_values,
102
- attention_mask=attention_mask,
103
- return_dict=True,
104
- )
105
-
106
- # last_hidden_state: (B, seq_len, hidden_size)
107
- last_hidden_state = outputs.last_hidden_state
108
-
109
- # ── 2. Pool: extract the final non-padded token per sequence ──
110
- if attention_mask is not None:
111
- # Sum of mask gives the sequence length (excluding padding)
112
- # Index of the last real token = seq_lengths - 1
113
- seq_lengths = attention_mask.sum(dim=1).long() - 1
114
- # Clamp to valid range
115
- seq_lengths = seq_lengths.clamp(min=0, max=last_hidden_state.size(1) - 1)
116
- # Gather the hidden state at each sequence's last real token
117
- pooled = last_hidden_state[
118
- torch.arange(last_hidden_state.size(0), device=last_hidden_state.device),
119
- seq_lengths,
120
- ]
121
- else:
122
- # No mask β†’ just take the last position
123
- pooled = last_hidden_state[:, -1, :]
124
-
125
- # Replace non-finite activations defensively before the classifier.
126
- pooled = torch.nan_to_num(pooled, nan=0.0, posinf=1e4, neginf=-1e4)
127
-
128
- # Match classifier device to pooled features when model is sharded/offloaded.
129
- if self.classifier.weight.device != pooled.device:
130
- self.classifier = self.classifier.to(pooled.device)
131
-
132
- # ── 3. Classify ──
133
- logit = self.classifier(pooled.float()) # (B, 1)
134
- logit = torch.nan_to_num(logit, nan=0.0, posinf=20.0, neginf=-20.0)
135
-
136
- result = {"logit": logit}
137
-
138
- # ── 4. Loss ──
139
- if labels is not None:
140
- labels = labels.to(logit.dtype).to(logit.device)
141
- if labels.dim() == 1:
142
- labels = labels.unsqueeze(1) # (B,) β†’ (B, 1)
143
- loss_fn = nn.BCEWithLogitsLoss()
144
- result["loss"] = loss_fn(logit, labels)
145
-
146
- return result
147
-
148
- def prepare_inputs_for_generation(self, *args, **kwargs):
149
- """Stub required by PEFT β€” we never generate text."""
150
- raise NotImplementedError("TMOSClassifier does not support generation.")
151
-
152
- def gradient_checkpointing_enable(self, **kwargs):
153
- """Delegate to the base model for HF Trainer compatibility."""
154
- self.base.model.gradient_checkpointing_enable(**kwargs)
155
-
156
- @property
157
- def config(self):
158
- """Expose the base model config for PEFT."""
159
- return self.base.config
160
-
161
- @property
162
- def device(self):
163
- return next(self.parameters()).device
164
-
165
- @property
166
- def dtype(self):
167
- return next(self.parameters()).dtype
168
-
169
-
170
- # ─── Standalone Test ──────────────────────────────────────────────────
171
- if __name__ == "__main__":
172
- import os
173
- from dotenv import load_dotenv
174
- load_dotenv()
175
- HF_TOKEN = os.getenv("HF_TOKEN")
176
-
177
- print("Testing TMOSClassifier...")
178
- device = "cuda" if torch.cuda.is_available() else "cpu"
179
-
180
- clf = TMOSClassifier(
181
- base_model_id="llava-hf/llava-1.5-7b-hf",
182
- torch_dtype=torch.float16,
183
- token=HF_TOKEN,
184
- )
185
- clf.to(device)
186
-
187
- # Print parameter counts
188
- total = sum(p.numel() for p in clf.parameters())
189
- trainable = sum(p.numel() for p in clf.parameters() if p.requires_grad)
190
- print(f"Total params: {total:>12,}")
191
- print(f"Trainable params: {trainable:>12,}")
192
- print(f"Classifier head: {sum(p.numel() for p in clf.classifier.parameters()):,}")
193
-
194
- # Smoke test with dummy input
195
- from transformers import AutoProcessor
196
- processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf", token=HF_TOKEN)
197
- processor.patch_size = 14
198
- processor.vision_feature_select_strategy = "default"
199
-
200
- from PIL import Image
201
- dummy_img = Image.new("RGB", (336, 336), color=(128, 128, 128))
202
- inputs = processor(
203
- text="USER: <image>\nIs this real?\nASSISTANT:",
204
- images=dummy_img,
205
- return_tensors="pt",
206
- ).to(device)
207
-
208
- labels = torch.tensor([1.0], device=device) # fake
209
-
210
- with torch.no_grad():
211
- out = clf(**inputs, labels=labels)
212
-
213
- print(f"Logit: {out['logit'].item():.4f}")
214
- print(f"Loss: {out['loss'].item():.4f}")
215
- print(f"Prob: {torch.sigmoid(out['logit']).item():.4f}")
216
- print("Test passed.")
 
1
+ """
2
+ TMOS_Classifier: Binary classification head on top of LLaVA's transformer backbone.
3
+
4
+ Strips the autoregressive lm_head and replaces it with a single nn.Linear(hidden_size, 1)
5
+ for binary deepfake detection (0 = Real, 1 = Fake).
6
+
7
+ Usage:
8
+ from tmos_classifier import TMOSClassifier, TMOS_LORA_CONFIG
9
+
10
+ classifier = TMOSClassifier(base_model_id="llava-hf/llava-1.5-7b-hf")
11
+ classifier = get_peft_model(classifier, TMOS_LORA_CONFIG)
12
+
13
+ logit = classifier(input_ids=..., pixel_values=..., attention_mask=...)
14
+ loss = nn.BCEWithLogitsLoss()(logit, label)
15
+ """
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ from transformers import LlavaForConditionalGeneration
20
+ from peft import LoraConfig
21
+
22
+
23
+ # ─── LoRA Configuration ──────────────────────────────────────────────
24
+ # Massive expansion: r=64 across ALL linear layers in the LLM backbone.
25
+ # We exclude lm_head (we discard it), fc1/fc2/out_proj (CLIP vision),
26
+ # and linear_1/linear_2 (multi-modal projector) from LoRA to keep
27
+ # the vision encoder frozen and only adapt the language transformer.
28
+
29
+ TMOS_LORA_CONFIG = LoraConfig(
30
+ r=64,
31
+ lora_alpha=128, # 2x rank as a common heuristic
32
+ target_modules=[
33
+ "q_proj", "k_proj", "v_proj", "o_proj",
34
+ "gate_proj", "up_proj", "down_proj",
35
+ ],
36
+ lora_dropout=0.1,
37
+ bias="none",
38
+ task_type=None, # Custom classifier β€” not a causal LM
39
+ modules_to_save=["classifier"], # Always train the classification head
40
+ )
41
+
42
+
43
+ class TMOSClassifier(nn.Module):
44
+ """
45
+ Binary classifier built on the LLaVA transformer backbone.
46
+
47
+ Architecture:
48
+ pixel_values ──► CLIP Vision Tower ──► Multi-Modal Projector ──┐
49
+ β”œβ”€β”€β–Ί LLaMA Transformer ──► last_hidden_state[:, -1, :] ──► classifier ──► logit
50
+ input_ids ──► Token Embedding β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
51
+
52
+ The lm_head is never used. We extract the final token's hidden state
53
+ and pass it through a learned nn.Linear(hidden_size, 1) head.
54
+ """
55
+
56
+ def __init__(self, base_model_id, torch_dtype=torch.float16, device_map="auto", token=None):
57
+ super().__init__()
58
+
59
+ # Load the full LLaVA model (we need vision tower + projector + LLM)
60
+ self.base = LlavaForConditionalGeneration.from_pretrained(
61
+ base_model_id,
62
+ torch_dtype=torch_dtype,
63
+ low_cpu_mem_usage=True,
64
+ device_map=device_map,
65
+ token=token,
66
+ )
67
+
68
+ hidden_size = self.base.config.text_config.hidden_size # 4096 for 7B
69
+
70
+ # Freeze the lm_head β€” we won't use it, but freezing prevents
71
+ # wasted gradient computation if PEFT accidentally wraps it.
72
+ for param in self.base.lm_head.parameters():
73
+ param.requires_grad = False
74
+
75
+ # Keep the classifier head in fp32 for numerical stability.
76
+ self.classifier = nn.Linear(hidden_size, 1, dtype=torch.float32)
77
+ nn.init.xavier_uniform_(self.classifier.weight)
78
+ nn.init.zeros_(self.classifier.bias)
79
+
80
+ def forward(
81
+ self,
82
+ input_ids=None,
83
+ pixel_values=None,
84
+ attention_mask=None,
85
+ labels=None, # float tensor of shape (B,) β€” 0.0=real, 1.0=fake
86
+ **kwargs, # absorb extra keys from data collator
87
+ ):
88
+ """
89
+ Single deterministic forward pass β†’ logit + optional BCE loss.
90
+
91
+ Returns:
92
+ dict with keys:
93
+ "logit": (B, 1) raw logit
94
+ "loss": scalar BCE loss (only if labels provided)
95
+ """
96
+ # ── 1. Forward through the LLaVA backbone ──
97
+ # We call the internal model (vision + projector + LLM) directly,
98
+ # asking for hidden states, NOT for language-model logits.
99
+ outputs = self.base.model(
100
+ input_ids=input_ids,
101
+ pixel_values=pixel_values,
102
+ attention_mask=attention_mask,
103
+ return_dict=True,
104
+ )
105
+
106
+ # last_hidden_state: (B, seq_len, hidden_size)
107
+ last_hidden_state = outputs.last_hidden_state
108
+
109
+ # ── 2. Pool: extract the final non-padded token per sequence ──
110
+ if attention_mask is not None:
111
+ # Sum of mask gives the sequence length (excluding padding)
112
+ # Index of the last real token = seq_lengths - 1
113
+ seq_lengths = attention_mask.sum(dim=1).long() - 1
114
+ # Clamp to valid range
115
+ seq_lengths = seq_lengths.clamp(min=0, max=last_hidden_state.size(1) - 1)
116
+ # Gather the hidden state at each sequence's last real token
117
+ pooled = last_hidden_state[
118
+ torch.arange(last_hidden_state.size(0), device=last_hidden_state.device),
119
+ seq_lengths,
120
+ ]
121
+ else:
122
+ # No mask β†’ just take the last position
123
+ pooled = last_hidden_state[:, -1, :]
124
+
125
+ # Replace non-finite activations defensively before the classifier.
126
+ pooled = torch.nan_to_num(pooled, nan=0.0, posinf=1e4, neginf=-1e4)
127
+
128
+ # Match classifier device to pooled features when model is sharded/offloaded.
129
+ if self.classifier.weight.device != pooled.device:
130
+ self.classifier = self.classifier.to(pooled.device)
131
+
132
+ # ── 3. Classify ──
133
+ logit = self.classifier(pooled.float()) # (B, 1)
134
+ logit = torch.nan_to_num(logit, nan=0.0, posinf=20.0, neginf=-20.0)
135
+
136
+ result = {"logit": logit}
137
+
138
+ # ── 4. Loss ──
139
+ if labels is not None:
140
+ labels = labels.to(logit.dtype).to(logit.device)
141
+ if labels.dim() == 1:
142
+ labels = labels.unsqueeze(1) # (B,) β†’ (B, 1)
143
+ loss_fn = nn.BCEWithLogitsLoss()
144
+ result["loss"] = loss_fn(logit, labels)
145
+
146
+ return result
147
+
148
+ def prepare_inputs_for_generation(self, *args, **kwargs):
149
+ """Stub required by PEFT β€” we never generate text."""
150
+ raise NotImplementedError("TMOSClassifier does not support generation.")
151
+
152
+ def gradient_checkpointing_enable(self, **kwargs):
153
+ """Delegate to the base model for HF Trainer compatibility."""
154
+ self.base.model.gradient_checkpointing_enable(**kwargs)
155
+
156
+ @property
157
+ def config(self):
158
+ """Expose the base model config for PEFT."""
159
+ return self.base.config
160
+
161
+ @property
162
+ def device(self):
163
+ return next(self.parameters()).device
164
+
165
+ @property
166
+ def dtype(self):
167
+ return next(self.parameters()).dtype
168
+
169
+
170
+ # ─── Standalone Test ──────────────────────────────────────────────────
171
+ if __name__ == "__main__":
172
+ import os
173
+ from dotenv import load_dotenv
174
+ load_dotenv()
175
+ HF_TOKEN = os.getenv("HF_TOKEN")
176
+
177
+ print("Testing TMOSClassifier...")
178
+ device = "cuda" if torch.cuda.is_available() else "cpu"
179
+
180
+ clf = TMOSClassifier(
181
+ base_model_id="llava-hf/llava-1.5-7b-hf",
182
+ torch_dtype=torch.float16,
183
+ token=HF_TOKEN,
184
+ )
185
+ clf.to(device)
186
+
187
+ # Print parameter counts
188
+ total = sum(p.numel() for p in clf.parameters())
189
+ trainable = sum(p.numel() for p in clf.parameters() if p.requires_grad)
190
+ print(f"Total params: {total:>12,}")
191
+ print(f"Trainable params: {trainable:>12,}")
192
+ print(f"Classifier head: {sum(p.numel() for p in clf.classifier.parameters()):,}")
193
+
194
+ # Smoke test with dummy input
195
+ from transformers import AutoProcessor
196
+ processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf", token=HF_TOKEN)
197
+ processor.patch_size = 14
198
+ processor.vision_feature_select_strategy = "default"
199
+
200
+ from PIL import Image
201
+ dummy_img = Image.new("RGB", (336, 336), color=(128, 128, 128))
202
+ inputs = processor(
203
+ text="USER: <image>\nIs this real?\nASSISTANT:",
204
+ images=dummy_img,
205
+ return_tensors="pt",
206
+ ).to(device)
207
+
208
+ labels = torch.tensor([1.0], device=device) # fake
209
+
210
+ with torch.no_grad():
211
+ out = clf(**inputs, labels=labels)
212
+
213
+ print(f"Logit: {out['logit'].item():.4f}")
214
+ print(f"Loss: {out['loss'].item():.4f}")
215
+ print(f"Prob: {torch.sigmoid(out['logit']).item():.4f}")
216
+ print("Test passed.")