Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import time | |
| import math | |
| import json | |
| import io | |
| from contextlib import nullcontext | |
| from pathlib import Path | |
| import gradio as gr | |
| import torch | |
| from fastapi import FastAPI, UploadFile, File | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from dotenv import load_dotenv | |
| from PIL import Image, ImageOps | |
| from transformers import AutoProcessor, AutoImageProcessor, AutoModelForImageClassification | |
| ROOT_DIR = Path(__file__).resolve().parent | |
| SCRIPTS_DIR = ROOT_DIR / "scripts" | |
| if str(SCRIPTS_DIR) not in sys.path: | |
| sys.path.insert(0, str(SCRIPTS_DIR)) | |
| load_dotenv() | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| BASE_MODEL_ID = "llava-hf/llava-1.5-7b-hf" | |
| ADAPTER_PATH = ROOT_DIR / "final-production-weights" / "best_model" | |
| ADAPTER_REPO_ID = os.getenv("ADAPTER_REPO_ID", "Werrewulf/TMOS-DD") | |
| ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER", "") | |
| CPU_FALLBACK_MODEL_ID = os.getenv("CPU_FALLBACK_MODEL_ID", "DaMsTaR/Detecto-DeepFake_Image_Detector") | |
| DEFAULT_INVERT_FALLBACK = CPU_FALLBACK_MODEL_ID.lower() == "damstar/detecto-deepfake_image_detector" | |
| INVERT_FALLBACK_OUTPUT = os.getenv("INVERT_FALLBACK_OUTPUT", str(DEFAULT_INVERT_FALLBACK)).strip().lower() == "true" | |
| TMOS_PROMPT = "USER: <image>\nIs this video real or produced by AI?\nASSISTANT:" | |
| TARGET_IMAGE_SIZE = 336 | |
| THRESHOLD = 0.5 | |
| model = None | |
| processor = None | |
| inference_device = None | |
| def resolve_inference_device(model_obj) -> torch.device: | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| device_map = getattr(model_obj, "hf_device_map", None) | |
| if isinstance(device_map, dict): | |
| for mapped in device_map.values(): | |
| if isinstance(mapped, str) and mapped.startswith("cuda"): | |
| return torch.device(mapped) | |
| return torch.device("cpu") | |
| def find_classifier_weight_tensor(model_obj): | |
| visited = set() | |
| queue = [model_obj] | |
| while queue: | |
| current = queue.pop(0) | |
| if current is None: | |
| continue | |
| obj_id = id(current) | |
| if obj_id in visited: | |
| continue | |
| visited.add(obj_id) | |
| classifier = getattr(current, "classifier", None) | |
| if classifier is not None and hasattr(classifier, "weight"): | |
| return classifier.weight | |
| for attr in ("model", "base_model", "module"): | |
| nested = getattr(current, attr, None) | |
| if nested is not None: | |
| queue.append(nested) | |
| return None | |
| def count_lora_layers(model_obj) -> int: | |
| count = 0 | |
| for _, module in model_obj.named_modules(): | |
| if hasattr(module, "lora_A") and hasattr(module, "lora_B"): | |
| count += 1 | |
| return count | |
| def is_tmos_adapter_config(cfg: dict) -> bool: | |
| modules_to_save = cfg.get("modules_to_save") or [] | |
| target_modules = set(cfg.get("target_modules") or []) | |
| required_targets = {"q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"} | |
| return ( | |
| "classifier" in modules_to_save | |
| and cfg.get("r") == 64 | |
| and required_targets.issubset(target_modules) | |
| ) | |
| def load_local_adapter_config(adapter_dir: Path) -> dict | None: | |
| cfg_path = adapter_dir / "adapter_config.json" | |
| if not cfg_path.exists(): | |
| return None | |
| with cfg_path.open("r", encoding="utf-8") as fp: | |
| return json.load(fp) | |
| def load_remote_adapter_config(repo_id: str, subfolder: str) -> dict | None: | |
| from peft import PeftConfig | |
| try: | |
| peft_cfg = PeftConfig.from_pretrained(repo_id, subfolder=subfolder, token=HF_TOKEN) | |
| return peft_cfg.to_dict() | |
| except Exception: | |
| return None | |
| def select_torch_dtype() -> torch.dtype: | |
| if torch.cuda.is_available(): | |
| return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 | |
| return torch.float32 | |
| def load_tmos_model(): | |
| global model, processor, inference_device | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError( | |
| "TMOS mode requires GPU hardware. CPU fallback should be used on CPU-only environments." | |
| ) | |
| from peft import PeftModel | |
| from tmos_classifier import TMOSClassifier | |
| adapter_source = None | |
| local_adapter_file = next( | |
| ( | |
| candidate | |
| for candidate in ( | |
| ADAPTER_PATH / "adapter_model.safetensors", | |
| ADAPTER_PATH / "adapter_model.bin", | |
| ) | |
| if candidate.exists() | |
| ), | |
| None, | |
| ) | |
| selected_subfolder = "" | |
| if local_adapter_file is not None: | |
| adapter_source = str(ADAPTER_PATH) | |
| local_cfg = load_local_adapter_config(ADAPTER_PATH) | |
| if local_cfg is None or not is_tmos_adapter_config(local_cfg): | |
| raise RuntimeError( | |
| "Local adapter exists but is not TMOS-compatible. Expected modules_to_save=['classifier'], r=64, and TMOS target modules." | |
| ) | |
| else: | |
| adapter_source = ADAPTER_REPO_ID | |
| dtype = select_torch_dtype() | |
| print(f"Loading TMOS-DD model from {adapter_source} with dtype={dtype}...") | |
| base_model = TMOSClassifier( | |
| base_model_id=BASE_MODEL_ID, | |
| torch_dtype=dtype, | |
| device_map="auto", | |
| token=HF_TOKEN, | |
| ) | |
| base_classifier_weight = find_classifier_weight_tensor(base_model) | |
| base_classifier_snapshot = None | |
| if base_classifier_weight is not None: | |
| base_classifier_snapshot = base_classifier_weight.detach().float().cpu().clone() | |
| peft_kwargs = {"is_trainable": False, "token": HF_TOKEN} | |
| if adapter_source == ADAPTER_REPO_ID: | |
| candidate_subfolders = [ | |
| s for s in [ADAPTER_SUBFOLDER, "multimodal", "multimodal/checkpoint-5", "llava"] if s is not None | |
| ] | |
| last_error = None | |
| for subfolder in candidate_subfolders: | |
| try: | |
| remote_cfg = load_remote_adapter_config(adapter_source, subfolder) | |
| if remote_cfg is None or not is_tmos_adapter_config(remote_cfg): | |
| raise ValueError("Adapter config is not TMOS-compatible.") | |
| current_kwargs = dict(peft_kwargs) | |
| if subfolder: | |
| current_kwargs["subfolder"] = subfolder | |
| loaded_model = PeftModel.from_pretrained(base_model, adapter_source, **current_kwargs) | |
| lora_layer_count = count_lora_layers(loaded_model) | |
| if lora_layer_count == 0: | |
| raise RuntimeError("Loaded adapter has zero LoRA layers attached.") | |
| loaded_classifier_weight = find_classifier_weight_tensor(loaded_model) | |
| if loaded_classifier_weight is None: | |
| raise RuntimeError("Classifier head not found after adapter load.") | |
| if base_classifier_snapshot is not None: | |
| classifier_delta = ( | |
| loaded_classifier_weight.detach().float().cpu() - base_classifier_snapshot | |
| ).abs().mean().item() | |
| if classifier_delta < 1e-8: | |
| raise RuntimeError( | |
| "Classifier weights did not change after loading adapter; adapter likely incompatible." | |
| ) | |
| model = loaded_model.merge_and_unload() | |
| selected_subfolder = subfolder | |
| print( | |
| f"Loaded TMOS adapter from repo subfolder: '{subfolder or '.'}' " | |
| f"(lora_layers={lora_layer_count})" | |
| ) | |
| break | |
| except Exception as exc: | |
| last_error = exc | |
| continue | |
| else: | |
| raise RuntimeError( | |
| "No TMOS-compatible adapter found in remote repo. Upload TMOS production weights with classifier head " | |
| "(modules_to_save=['classifier'], r=64, 7-target-module LoRA)." | |
| ) from last_error | |
| else: | |
| loaded_model = PeftModel.from_pretrained(base_model, adapter_source, **peft_kwargs) | |
| lora_layer_count = count_lora_layers(loaded_model) | |
| if lora_layer_count == 0: | |
| raise RuntimeError("Local adapter load produced zero LoRA layers attached.") | |
| model = loaded_model.merge_and_unload() | |
| print(f"Loaded TMOS local adapter (lora_layers={lora_layer_count})") | |
| model.eval() | |
| processor = AutoProcessor.from_pretrained(BASE_MODEL_ID, token=HF_TOKEN) | |
| processor.patch_size = 14 | |
| processor.vision_feature_select_strategy = "default" | |
| inference_device = resolve_inference_device(model) | |
| if adapter_source == ADAPTER_REPO_ID: | |
| print(f"TMOS-DD ready on {inference_device} using remote subfolder '{selected_subfolder or '.'}'.") | |
| else: | |
| print(f"TMOS-DD ready on {inference_device} using local production adapter.") | |
| def load_cpu_fallback_model(): | |
| global model, processor, inference_device | |
| print(f"Loading CPU fallback model from {CPU_FALLBACK_MODEL_ID}...") | |
| processor = AutoImageProcessor.from_pretrained(CPU_FALLBACK_MODEL_ID, token=HF_TOKEN) | |
| model = AutoModelForImageClassification.from_pretrained( | |
| CPU_FALLBACK_MODEL_ID, | |
| torch_dtype=torch.float32, | |
| low_cpu_mem_usage=True, | |
| token=HF_TOKEN, | |
| ) | |
| model.to("cpu").eval() | |
| inference_device = torch.device("cpu") | |
| print("CPU fallback classifier ready.") | |
| def load_model_and_processor(): | |
| global model, processor, inference_device | |
| if model is not None and processor is not None and inference_device is not None: | |
| return model, processor, inference_device | |
| if torch.cuda.is_available(): | |
| print("GPU detected -> loading TMOS") | |
| try: | |
| load_tmos_model() | |
| except Exception as exc: | |
| print(f"TMOS failed: {exc}") | |
| print("Falling back to CPU model...") | |
| load_cpu_fallback_model() | |
| else: | |
| print("No GPU detected -> using CPU fallback") | |
| load_cpu_fallback_model() | |
| return model, processor, inference_device | |
| def preprocess_image(image: Image.Image) -> Image.Image: | |
| image = image.convert("RGB") | |
| return ImageOps.contain(image, (TARGET_IMAGE_SIZE, TARGET_IMAGE_SIZE), method=Image.Resampling.BICUBIC) | |
| def confidence_card(prob_fake: float, label: str) -> str: | |
| confidence = prob_fake if label == "Fake" else 1.0 - prob_fake | |
| confidence_pct = confidence * 100.0 | |
| fake_pct = prob_fake * 100.0 | |
| real_pct = (1.0 - prob_fake) * 100.0 | |
| accent = "#ef4444" if label == "Fake" else "#10b981" | |
| return f""" | |
| <div style="border:1px solid rgba(255,255,255,0.12); border-radius:16px; padding:16px; background:linear-gradient(135deg, rgba(17,24,39,0.92), rgba(15,23,42,0.96)); color:white;"> | |
| <div style="font-size:0.85rem; opacity:0.8; letter-spacing:0.04em; text-transform:uppercase; margin-bottom:8px;">Confidence</div> | |
| <div style="display:flex; align-items:baseline; gap:10px; margin-bottom:12px;"> | |
| <div style="font-size:2rem; font-weight:700; color:{accent};">{confidence_pct:.2f}%</div> | |
| <div style="font-size:1rem; opacity:0.9;">for <strong>{label}</strong></div> | |
| </div> | |
| <div style="height:12px; width:100%; background:rgba(255,255,255,0.08); border-radius:999px; overflow:hidden; margin-bottom:10px;"> | |
| <div style="height:100%; width:{fake_pct:.2f}%; background:linear-gradient(90deg, #f87171, #ef4444);"></div> | |
| </div> | |
| <div style="display:flex; justify-content:space-between; font-size:0.9rem; opacity:0.95;"> | |
| <span>Real: {real_pct:.2f}%</span> | |
| <span>Fake: {fake_pct:.2f}%</span> | |
| </div> | |
| </div> | |
| """ | |
| def score_fallback_logits(logits: torch.Tensor, id2label: dict) -> tuple[float, str]: | |
| probs = torch.softmax(logits.float(), dim=0) | |
| fake_indices = [] | |
| real_indices = [] | |
| for idx in range(len(probs)): | |
| label = str(id2label.get(idx, "")).lower() | |
| if any(key in label for key in ["fake", "deepfake", "ai", "synthetic"]): | |
| fake_indices.append(idx) | |
| if any(key in label for key in ["real", "authentic", "genuine"]): | |
| real_indices.append(idx) | |
| if len(probs) == 2 and not fake_indices and not real_indices: | |
| fake_indices = [1] | |
| real_indices = [0] | |
| fake_prob = float(probs[fake_indices].sum().item()) if fake_indices else 0.0 | |
| real_prob = float(probs[real_indices].sum().item()) if real_indices else 0.0 | |
| total = fake_prob + real_prob | |
| if total > 0: | |
| prob_fake = fake_prob / total | |
| else: | |
| prob_fake = float(probs.max().item()) if len(probs) == 1 else float(probs[1].item()) if len(probs) > 1 else 0.5 | |
| if INVERT_FALLBACK_OUTPUT: | |
| prob_fake = 1.0 - prob_fake | |
| label = "Fake" if prob_fake >= THRESHOLD else "Real" | |
| return prob_fake, label | |
| def infer_image(image: Image.Image): | |
| try: | |
| if image is None: | |
| return None, "Error: please upload an image.", None, None, None, "<div style='color:#f87171;'>Please upload an image before running detection.</div>" | |
| model_obj, processor_obj, device = load_model_and_processor() | |
| prepared_image = preprocess_image(image) | |
| autocast_context = ( | |
| torch.autocast(device_type="cuda", dtype=select_torch_dtype()) | |
| if device.type == "cuda" | |
| else nullcontext() | |
| ) | |
| start_time = time.perf_counter() | |
| with torch.inference_mode(), autocast_context: | |
| if inference_device.type == "cuda": | |
| inputs = processor_obj(text=TMOS_PROMPT, images=prepared_image, return_tensors="pt", padding=True) | |
| inputs = {name: tensor.to(device) for name, tensor in inputs.items()} | |
| outputs = model_obj( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| attention_mask=inputs["attention_mask"], | |
| ) | |
| logit = float(outputs["logit"].squeeze().detach().float().cpu().item()) | |
| if not math.isfinite(logit): | |
| raise gr.Error("Model produced a non-finite logit (NaN/Inf). Please retry.") | |
| prob_fake = float(torch.sigmoid(torch.tensor(logit)).item()) | |
| label = "Fake" if prob_fake >= THRESHOLD else "Real" | |
| else: | |
| inputs = processor_obj(images=prepared_image, return_tensors="pt") | |
| inputs = {name: tensor.to(device) for name, tensor in inputs.items()} | |
| outputs = model_obj(**inputs) | |
| logits = outputs.logits.squeeze(0).detach().float().cpu() | |
| id2label = getattr(model_obj.config, "id2label", {}) or {} | |
| prob_fake, label = score_fallback_logits(logits, id2label) | |
| if device.type == "cuda": | |
| torch.cuda.synchronize() | |
| elapsed_ms = (time.perf_counter() - start_time) * 1000.0 | |
| if not math.isfinite(prob_fake): | |
| raise gr.Error("Model produced a non-finite probability (NaN/Inf). Please retry.") | |
| confidence = prob_fake if label == "Fake" else 1.0 - prob_fake | |
| return prepared_image, label, round(prob_fake, 6), round(confidence * 100.0, 2), round(elapsed_ms, 2), confidence_card(prob_fake, label) | |
| except Exception as exc: | |
| err = f"Inference failed: {type(exc).__name__}: {exc}" | |
| err_html = f"<div style='color:#fca5a5; border:1px solid rgba(252,165,165,0.35); padding:10px; border-radius:10px;'>\n<b>Inference error</b><br>{err}</div>" | |
| return None, err, None, None, None, err_html | |
| api = FastAPI() | |
| api.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def predict(file: UploadFile = File(...)): | |
| try: | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)).convert("RGB") | |
| _, label, prob_fake, confidence, latency, _ = infer_image(image) | |
| return JSONResponse( | |
| { | |
| "verdict": label, | |
| "confidence_percent": confidence, | |
| "p_fake": prob_fake, | |
| "latency_ms": latency, | |
| } | |
| ) | |
| except Exception as exc: | |
| return JSONResponse({"error": str(exc)}, status_code=500) | |
| load_model_and_processor() | |
| with gr.Blocks(title="TMOS Deepfake Detector", theme=gr.themes.Soft()) as demo: | |
| device_label = "GPU (TMOS Model)" if torch.cuda.is_available() else "CPU Fallback Model" | |
| gr.Markdown( | |
| f"# TMOS Deepfake Detector\n" | |
| f"**Running on:** {device_label}\n\n" | |
| f"> Warning: runs on free infrastructure, so startup and inference may take time." | |
| ) | |
| with gr.Row(): | |
| image_input = gr.Image(type="pil", label="Upload image") | |
| with gr.Column(): | |
| prediction_output = gr.Textbox(label="Prediction", interactive=False) | |
| probability_output = gr.Number(label="P(fake)", interactive=False, precision=6) | |
| confidence_output = gr.Number(label="Confidence (%)", interactive=False, precision=2) | |
| latency_output = gr.Number(label="Latency (ms)", interactive=False, precision=2) | |
| preview_output = gr.Image(label="Processed image passed to the model", interactive=False) | |
| confidence_html = gr.HTML() | |
| detect_button = gr.Button("Run detection", variant="primary") | |
| detect_button.click( | |
| fn=infer_image, | |
| inputs=image_input, | |
| outputs=[preview_output, prediction_output, probability_output, confidence_output, latency_output, confidence_html], | |
| ) | |
| demo.queue(default_concurrency_limit=1, max_size=8) | |
| app = gr.mount_gradio_app(api, demo, path="/") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "7860"))) | |