TrueFrame / app.py
NeelakshSaxena's picture
Deploy auto GPU fallback + FastAPI /predict
8d017ca verified
import os
import sys
import time
import math
import json
import io
from contextlib import nullcontext
from pathlib import Path
import gradio as gr
import torch
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from dotenv import load_dotenv
from PIL import Image, ImageOps
from transformers import AutoProcessor, AutoImageProcessor, AutoModelForImageClassification
ROOT_DIR = Path(__file__).resolve().parent
SCRIPTS_DIR = ROOT_DIR / "scripts"
if str(SCRIPTS_DIR) not in sys.path:
sys.path.insert(0, str(SCRIPTS_DIR))
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
BASE_MODEL_ID = "llava-hf/llava-1.5-7b-hf"
ADAPTER_PATH = ROOT_DIR / "final-production-weights" / "best_model"
ADAPTER_REPO_ID = os.getenv("ADAPTER_REPO_ID", "Werrewulf/TMOS-DD")
ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER", "")
CPU_FALLBACK_MODEL_ID = os.getenv("CPU_FALLBACK_MODEL_ID", "DaMsTaR/Detecto-DeepFake_Image_Detector")
DEFAULT_INVERT_FALLBACK = CPU_FALLBACK_MODEL_ID.lower() == "damstar/detecto-deepfake_image_detector"
INVERT_FALLBACK_OUTPUT = os.getenv("INVERT_FALLBACK_OUTPUT", str(DEFAULT_INVERT_FALLBACK)).strip().lower() == "true"
TMOS_PROMPT = "USER: <image>\nIs this video real or produced by AI?\nASSISTANT:"
TARGET_IMAGE_SIZE = 336
THRESHOLD = 0.5
model = None
processor = None
inference_device = None
def resolve_inference_device(model_obj) -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda")
device_map = getattr(model_obj, "hf_device_map", None)
if isinstance(device_map, dict):
for mapped in device_map.values():
if isinstance(mapped, str) and mapped.startswith("cuda"):
return torch.device(mapped)
return torch.device("cpu")
def find_classifier_weight_tensor(model_obj):
visited = set()
queue = [model_obj]
while queue:
current = queue.pop(0)
if current is None:
continue
obj_id = id(current)
if obj_id in visited:
continue
visited.add(obj_id)
classifier = getattr(current, "classifier", None)
if classifier is not None and hasattr(classifier, "weight"):
return classifier.weight
for attr in ("model", "base_model", "module"):
nested = getattr(current, attr, None)
if nested is not None:
queue.append(nested)
return None
def count_lora_layers(model_obj) -> int:
count = 0
for _, module in model_obj.named_modules():
if hasattr(module, "lora_A") and hasattr(module, "lora_B"):
count += 1
return count
def is_tmos_adapter_config(cfg: dict) -> bool:
modules_to_save = cfg.get("modules_to_save") or []
target_modules = set(cfg.get("target_modules") or [])
required_targets = {"q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"}
return (
"classifier" in modules_to_save
and cfg.get("r") == 64
and required_targets.issubset(target_modules)
)
def load_local_adapter_config(adapter_dir: Path) -> dict | None:
cfg_path = adapter_dir / "adapter_config.json"
if not cfg_path.exists():
return None
with cfg_path.open("r", encoding="utf-8") as fp:
return json.load(fp)
def load_remote_adapter_config(repo_id: str, subfolder: str) -> dict | None:
from peft import PeftConfig
try:
peft_cfg = PeftConfig.from_pretrained(repo_id, subfolder=subfolder, token=HF_TOKEN)
return peft_cfg.to_dict()
except Exception:
return None
def select_torch_dtype() -> torch.dtype:
if torch.cuda.is_available():
return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
return torch.float32
def load_tmos_model():
global model, processor, inference_device
if not torch.cuda.is_available():
raise RuntimeError(
"TMOS mode requires GPU hardware. CPU fallback should be used on CPU-only environments."
)
from peft import PeftModel
from tmos_classifier import TMOSClassifier
adapter_source = None
local_adapter_file = next(
(
candidate
for candidate in (
ADAPTER_PATH / "adapter_model.safetensors",
ADAPTER_PATH / "adapter_model.bin",
)
if candidate.exists()
),
None,
)
selected_subfolder = ""
if local_adapter_file is not None:
adapter_source = str(ADAPTER_PATH)
local_cfg = load_local_adapter_config(ADAPTER_PATH)
if local_cfg is None or not is_tmos_adapter_config(local_cfg):
raise RuntimeError(
"Local adapter exists but is not TMOS-compatible. Expected modules_to_save=['classifier'], r=64, and TMOS target modules."
)
else:
adapter_source = ADAPTER_REPO_ID
dtype = select_torch_dtype()
print(f"Loading TMOS-DD model from {adapter_source} with dtype={dtype}...")
base_model = TMOSClassifier(
base_model_id=BASE_MODEL_ID,
torch_dtype=dtype,
device_map="auto",
token=HF_TOKEN,
)
base_classifier_weight = find_classifier_weight_tensor(base_model)
base_classifier_snapshot = None
if base_classifier_weight is not None:
base_classifier_snapshot = base_classifier_weight.detach().float().cpu().clone()
peft_kwargs = {"is_trainable": False, "token": HF_TOKEN}
if adapter_source == ADAPTER_REPO_ID:
candidate_subfolders = [
s for s in [ADAPTER_SUBFOLDER, "multimodal", "multimodal/checkpoint-5", "llava"] if s is not None
]
last_error = None
for subfolder in candidate_subfolders:
try:
remote_cfg = load_remote_adapter_config(adapter_source, subfolder)
if remote_cfg is None or not is_tmos_adapter_config(remote_cfg):
raise ValueError("Adapter config is not TMOS-compatible.")
current_kwargs = dict(peft_kwargs)
if subfolder:
current_kwargs["subfolder"] = subfolder
loaded_model = PeftModel.from_pretrained(base_model, adapter_source, **current_kwargs)
lora_layer_count = count_lora_layers(loaded_model)
if lora_layer_count == 0:
raise RuntimeError("Loaded adapter has zero LoRA layers attached.")
loaded_classifier_weight = find_classifier_weight_tensor(loaded_model)
if loaded_classifier_weight is None:
raise RuntimeError("Classifier head not found after adapter load.")
if base_classifier_snapshot is not None:
classifier_delta = (
loaded_classifier_weight.detach().float().cpu() - base_classifier_snapshot
).abs().mean().item()
if classifier_delta < 1e-8:
raise RuntimeError(
"Classifier weights did not change after loading adapter; adapter likely incompatible."
)
model = loaded_model.merge_and_unload()
selected_subfolder = subfolder
print(
f"Loaded TMOS adapter from repo subfolder: '{subfolder or '.'}' "
f"(lora_layers={lora_layer_count})"
)
break
except Exception as exc:
last_error = exc
continue
else:
raise RuntimeError(
"No TMOS-compatible adapter found in remote repo. Upload TMOS production weights with classifier head "
"(modules_to_save=['classifier'], r=64, 7-target-module LoRA)."
) from last_error
else:
loaded_model = PeftModel.from_pretrained(base_model, adapter_source, **peft_kwargs)
lora_layer_count = count_lora_layers(loaded_model)
if lora_layer_count == 0:
raise RuntimeError("Local adapter load produced zero LoRA layers attached.")
model = loaded_model.merge_and_unload()
print(f"Loaded TMOS local adapter (lora_layers={lora_layer_count})")
model.eval()
processor = AutoProcessor.from_pretrained(BASE_MODEL_ID, token=HF_TOKEN)
processor.patch_size = 14
processor.vision_feature_select_strategy = "default"
inference_device = resolve_inference_device(model)
if adapter_source == ADAPTER_REPO_ID:
print(f"TMOS-DD ready on {inference_device} using remote subfolder '{selected_subfolder or '.'}'.")
else:
print(f"TMOS-DD ready on {inference_device} using local production adapter.")
def load_cpu_fallback_model():
global model, processor, inference_device
print(f"Loading CPU fallback model from {CPU_FALLBACK_MODEL_ID}...")
processor = AutoImageProcessor.from_pretrained(CPU_FALLBACK_MODEL_ID, token=HF_TOKEN)
model = AutoModelForImageClassification.from_pretrained(
CPU_FALLBACK_MODEL_ID,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
token=HF_TOKEN,
)
model.to("cpu").eval()
inference_device = torch.device("cpu")
print("CPU fallback classifier ready.")
def load_model_and_processor():
global model, processor, inference_device
if model is not None and processor is not None and inference_device is not None:
return model, processor, inference_device
if torch.cuda.is_available():
print("GPU detected -> loading TMOS")
try:
load_tmos_model()
except Exception as exc:
print(f"TMOS failed: {exc}")
print("Falling back to CPU model...")
load_cpu_fallback_model()
else:
print("No GPU detected -> using CPU fallback")
load_cpu_fallback_model()
return model, processor, inference_device
def preprocess_image(image: Image.Image) -> Image.Image:
image = image.convert("RGB")
return ImageOps.contain(image, (TARGET_IMAGE_SIZE, TARGET_IMAGE_SIZE), method=Image.Resampling.BICUBIC)
def confidence_card(prob_fake: float, label: str) -> str:
confidence = prob_fake if label == "Fake" else 1.0 - prob_fake
confidence_pct = confidence * 100.0
fake_pct = prob_fake * 100.0
real_pct = (1.0 - prob_fake) * 100.0
accent = "#ef4444" if label == "Fake" else "#10b981"
return f"""
<div style="border:1px solid rgba(255,255,255,0.12); border-radius:16px; padding:16px; background:linear-gradient(135deg, rgba(17,24,39,0.92), rgba(15,23,42,0.96)); color:white;">
<div style="font-size:0.85rem; opacity:0.8; letter-spacing:0.04em; text-transform:uppercase; margin-bottom:8px;">Confidence</div>
<div style="display:flex; align-items:baseline; gap:10px; margin-bottom:12px;">
<div style="font-size:2rem; font-weight:700; color:{accent};">{confidence_pct:.2f}%</div>
<div style="font-size:1rem; opacity:0.9;">for <strong>{label}</strong></div>
</div>
<div style="height:12px; width:100%; background:rgba(255,255,255,0.08); border-radius:999px; overflow:hidden; margin-bottom:10px;">
<div style="height:100%; width:{fake_pct:.2f}%; background:linear-gradient(90deg, #f87171, #ef4444);"></div>
</div>
<div style="display:flex; justify-content:space-between; font-size:0.9rem; opacity:0.95;">
<span>Real: {real_pct:.2f}%</span>
<span>Fake: {fake_pct:.2f}%</span>
</div>
</div>
"""
def score_fallback_logits(logits: torch.Tensor, id2label: dict) -> tuple[float, str]:
probs = torch.softmax(logits.float(), dim=0)
fake_indices = []
real_indices = []
for idx in range(len(probs)):
label = str(id2label.get(idx, "")).lower()
if any(key in label for key in ["fake", "deepfake", "ai", "synthetic"]):
fake_indices.append(idx)
if any(key in label for key in ["real", "authentic", "genuine"]):
real_indices.append(idx)
if len(probs) == 2 and not fake_indices and not real_indices:
fake_indices = [1]
real_indices = [0]
fake_prob = float(probs[fake_indices].sum().item()) if fake_indices else 0.0
real_prob = float(probs[real_indices].sum().item()) if real_indices else 0.0
total = fake_prob + real_prob
if total > 0:
prob_fake = fake_prob / total
else:
prob_fake = float(probs.max().item()) if len(probs) == 1 else float(probs[1].item()) if len(probs) > 1 else 0.5
if INVERT_FALLBACK_OUTPUT:
prob_fake = 1.0 - prob_fake
label = "Fake" if prob_fake >= THRESHOLD else "Real"
return prob_fake, label
def infer_image(image: Image.Image):
try:
if image is None:
return None, "Error: please upload an image.", None, None, None, "<div style='color:#f87171;'>Please upload an image before running detection.</div>"
model_obj, processor_obj, device = load_model_and_processor()
prepared_image = preprocess_image(image)
autocast_context = (
torch.autocast(device_type="cuda", dtype=select_torch_dtype())
if device.type == "cuda"
else nullcontext()
)
start_time = time.perf_counter()
with torch.inference_mode(), autocast_context:
if inference_device.type == "cuda":
inputs = processor_obj(text=TMOS_PROMPT, images=prepared_image, return_tensors="pt", padding=True)
inputs = {name: tensor.to(device) for name, tensor in inputs.items()}
outputs = model_obj(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
attention_mask=inputs["attention_mask"],
)
logit = float(outputs["logit"].squeeze().detach().float().cpu().item())
if not math.isfinite(logit):
raise gr.Error("Model produced a non-finite logit (NaN/Inf). Please retry.")
prob_fake = float(torch.sigmoid(torch.tensor(logit)).item())
label = "Fake" if prob_fake >= THRESHOLD else "Real"
else:
inputs = processor_obj(images=prepared_image, return_tensors="pt")
inputs = {name: tensor.to(device) for name, tensor in inputs.items()}
outputs = model_obj(**inputs)
logits = outputs.logits.squeeze(0).detach().float().cpu()
id2label = getattr(model_obj.config, "id2label", {}) or {}
prob_fake, label = score_fallback_logits(logits, id2label)
if device.type == "cuda":
torch.cuda.synchronize()
elapsed_ms = (time.perf_counter() - start_time) * 1000.0
if not math.isfinite(prob_fake):
raise gr.Error("Model produced a non-finite probability (NaN/Inf). Please retry.")
confidence = prob_fake if label == "Fake" else 1.0 - prob_fake
return prepared_image, label, round(prob_fake, 6), round(confidence * 100.0, 2), round(elapsed_ms, 2), confidence_card(prob_fake, label)
except Exception as exc:
err = f"Inference failed: {type(exc).__name__}: {exc}"
err_html = f"<div style='color:#fca5a5; border:1px solid rgba(252,165,165,0.35); padding:10px; border-radius:10px;'>\n<b>Inference error</b><br>{err}</div>"
return None, err, None, None, None, err_html
api = FastAPI()
api.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@api.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
_, label, prob_fake, confidence, latency, _ = infer_image(image)
return JSONResponse(
{
"verdict": label,
"confidence_percent": confidence,
"p_fake": prob_fake,
"latency_ms": latency,
}
)
except Exception as exc:
return JSONResponse({"error": str(exc)}, status_code=500)
load_model_and_processor()
with gr.Blocks(title="TMOS Deepfake Detector", theme=gr.themes.Soft()) as demo:
device_label = "GPU (TMOS Model)" if torch.cuda.is_available() else "CPU Fallback Model"
gr.Markdown(
f"# TMOS Deepfake Detector\n"
f"**Running on:** {device_label}\n\n"
f"> Warning: runs on free infrastructure, so startup and inference may take time."
)
with gr.Row():
image_input = gr.Image(type="pil", label="Upload image")
with gr.Column():
prediction_output = gr.Textbox(label="Prediction", interactive=False)
probability_output = gr.Number(label="P(fake)", interactive=False, precision=6)
confidence_output = gr.Number(label="Confidence (%)", interactive=False, precision=2)
latency_output = gr.Number(label="Latency (ms)", interactive=False, precision=2)
preview_output = gr.Image(label="Processed image passed to the model", interactive=False)
confidence_html = gr.HTML()
detect_button = gr.Button("Run detection", variant="primary")
detect_button.click(
fn=infer_image,
inputs=image_input,
outputs=[preview_output, prediction_output, probability_output, confidence_output, latency_output, confidence_html],
)
demo.queue(default_concurrency_limit=1, max_size=8)
app = gr.mount_gradio_app(api, demo, path="/")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "7860")))