import argparse import base64 import io import json import os import sys from dataclasses import dataclass from typing import Any, Dict, List import numpy as np from PIL import Image import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as T import requests TRUFOR_TRAIN_TEST_DIR = "TruFor_train_test" TRUFOR_CFG_PATH = "TruFor_train_test/lib/config/trufor_ph3.yaml" TRUFOR_CKPT_PATH = "weights/trufor.pth.tar" UFD_FC_WEIGHTS_PATH = "fc_weights.pth" UFD_CLIP_NAME = "ViT-L/14" EFFNET_CKPT_PATH = "best_metric_cls_effnet.pt" W_TRUFOR = 0.5 W_UFD = 0.4 W_EFFNET = 0.1 IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"} BASETEN_VLM_MODEL_ID = "zq8pe88w" BASETEN_VLM_URL = f"https://model-{BASETEN_VLM_MODEL_ID}.api.baseten.co/development/predict" VLM_FALLBACK_REASONING = ( "The image has odd textures, and unnatural edges." ) def _pil_to_b64_jpeg(pil: Image.Image, quality: int = 95) -> str: buf = io.BytesIO() pil.convert("RGB").save(buf, format="JPEG", quality=quality) return base64.b64encode(buf.getvalue()).decode("utf-8") def get_vlm_reasoning_from_baseten(pil: Image.Image, authenticity_score: float) -> str: """ Calls your Baseten model. Assumes the Baseten Truss model expects: { "authenticity_score": , "image": "" } and returns either a string or a JSON containing a string. """ api_key = "qlTTHbba.uxjD04TMFzgYekDpUeXxaipMyCPzC486" if not api_key: raise RuntimeError("Missing BASETEN_API_KEY env var.") payload = { "authenticity_score": float(authenticity_score), # 0 real, 1 AI "image": _pil_to_b64_jpeg(pil), # base64 JPEG, no data: prefix } r = requests.post( BASETEN_VLM_URL, headers={"Authorization": f"Api-Key {api_key}"}, json=payload, timeout=120, ) r.raise_for_status() out = r.json() if isinstance(out, dict): for k in ("output", "text", "result", "prediction", "vlm_reasoning"): v = out.get(k) if isinstance(v, str) and v.strip(): return v.strip() return json.dumps(out, ensure_ascii=False) return str(out).strip() import clip # openai/CLIP CHANNELS = { "RN50": 1024, "RN101": 512, "RN50x4": 640, "RN50x16": 768, "RN50x64": 1024, "ViT-B/32": 512, "ViT-B/16": 512, "ViT-L/14": 768, "ViT-L/14@336px": 768, } class CLIPModel(nn.Module): def __init__(self, name, num_classes=1): super(CLIPModel, self).__init__() self.model, self.preprocess = clip.load(name, device="cpu") self.fc = nn.Linear(CHANNELS[name], num_classes) def forward(self, x, return_feature=False): features = self.model.encode_image(x) if return_feature: return features return self.fc(features) class UniversalFakeDetectDetector: def __init__(self, device: torch.device): self.device = device self.model = CLIPModel(UFD_CLIP_NAME, num_classes=1) self.model.eval() sd = torch.load(UFD_FC_WEIGHTS_PATH, map_location="cpu") if isinstance(sd, dict) and "state_dict" in sd and isinstance(sd["state_dict"], dict): sd = sd["state_dict"] if isinstance(sd, dict) and any(k.startswith("fc.") for k in sd.keys()): fc_sd = {k.replace("fc.", ""): v for k, v in sd.items() if k.startswith("fc.")} self.model.fc.load_state_dict(fc_sd, strict=True) elif isinstance(sd, dict) and "weight" in sd and "bias" in sd: self.model.fc.load_state_dict({"weight": sd["weight"], "bias": sd["bias"]}, strict=True) elif isinstance(sd, dict) and set(sd.keys()) == {"weight", "bias"}: self.model.fc.load_state_dict(sd, strict=True) else: raise RuntimeError( f"[UFD] Unsupported fc checkpoint format. Top keys: {list(sd.keys())[:50] if isinstance(sd, dict) else type(sd)}" ) self.model.fc.to(self.device) self.preprocess = self.model.preprocess @torch.no_grad() def predict_prob(self, pil: Image.Image) -> float: x = self.preprocess(pil.convert("RGB")).unsqueeze(0) # CPU features = self.model(x, return_feature=True) # CPU features = features.to(self.device) logit = self.model.fc(features).view(-1)[0] return float(torch.sigmoid(logit).item()) import timm class EffNetMetricClassifier(nn.Module): """ Must match the architecture used to train best_metric_cls_effnet.pt: backbone (timm, num_classes=0, global_pool=avg) proj: Linear -> BN classifier: Linear(embed_dim -> 2) """ def __init__(self, model_name="efficientnet_b0", embed_dim=128, num_classes=2, pretrained=False): super().__init__() self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0, global_pool="avg") feat_dim = self.backbone.num_features self.proj = nn.Sequential( nn.Linear(feat_dim, embed_dim), nn.BatchNorm1d(embed_dim), ) self.classifier = nn.Linear(embed_dim, num_classes) def forward(self, x): feat = self.backbone(x) z = self.proj(feat) emb = F.normalize(z, p=2, dim=1) logits = self.classifier(z) return emb, logits def _strip_module_prefix(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: if any(k.startswith("module.") for k in sd.keys()): return {k.replace("module.", "", 1): v for k, v in sd.items()} return sd class EffNetDetector: def __init__(self, device: torch.device, ckpt_path: str = EFFNET_CKPT_PATH, img_size: int = 224): self.device = device ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) if isinstance(ckpt, dict) and "state_dict" in ckpt: sd = ckpt["state_dict"] model_name = ckpt.get("model_name", "efficientnet_b0") embed_dim = int(ckpt.get("embed_dim", 128)) elif isinstance(ckpt, dict): # fallback: assume ckpt itself is state_dict sd = ckpt model_name = "efficientnet_b0" embed_dim = 128 else: raise RuntimeError(f"[EffNet] Unsupported checkpoint type: {type(ckpt)}") sd = _strip_module_prefix(sd) self.model = EffNetMetricClassifier(model_name=model_name, embed_dim=embed_dim, num_classes=2, pretrained=False) self.model.load_state_dict(sd, strict=True) self.model.to(self.device) self.model.eval() self.transform = T.Compose([ T.Resize(int(img_size * 1.15)), T.CenterCrop(img_size), T.ToTensor(), T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ]) @torch.no_grad() def predict_prob(self, pil: Image.Image) -> float: x = self.transform(pil.convert("RGB")).unsqueeze(0).to(self.device) _, logits = self.model(x) if logits.shape[-1] == 2: p1 = torch.softmax(logits, dim=1)[0, 1] return float(p1.item()) logit = logits.view(-1)[0] return float(torch.sigmoid(logit).item()) def _add_trufor_to_syspath(): if not os.path.isdir(TRUFOR_TRAIN_TEST_DIR): raise FileNotFoundError(f"TRUFOR_TRAIN_TEST_DIR not found: {TRUFOR_TRAIN_TEST_DIR}") lib_dir = os.path.join(TRUFOR_TRAIN_TEST_DIR, "lib") if not os.path.isdir(lib_dir): raise FileNotFoundError(f"Expected TruFor_train_test/lib at: {lib_dir}") if TRUFOR_TRAIN_TEST_DIR not in sys.path: sys.path.insert(0, TRUFOR_TRAIN_TEST_DIR) def _load_trufor_config(): from lib.config import config as cfg from lib.config import update_config class Args: def __init__(self, cfg_path: str): self.cfg = cfg_path self.opts = [] self.modelDir = "" self.logDir = "" self.dataDir = "" self.prevModelDir = "" self.gpu = "0" args = Args(TRUFOR_CFG_PATH) update_config(cfg, args) return cfg def _load_state_dict_from_ckpt(path: str) -> Dict[str, torch.Tensor]: ckpt = torch.load(path, map_location="cpu", weights_only=False) if not isinstance(ckpt, dict): raise RuntimeError(f"[TruFor] checkpoint is not a dict: {type(ckpt)}") if "state_dict" not in ckpt: raise KeyError(f"[TruFor] checkpoint missing 'state_dict'. Keys={list(ckpt.keys())}") sd = ckpt["state_dict"] if not isinstance(sd, dict): raise RuntimeError(f"[TruFor] checkpoint['state_dict'] is not a dict: {type(sd)}") if any(k.startswith("module.") for k in sd.keys()): sd = {k.replace("module.", "", 1): v for k, v in sd.items()} return sd @dataclass class TruForOutputs: score: float loc_prob: np.ndarray conf_prob: np.ndarray class TruForDetector: def __init__(self, device: torch.device): self.device = device _add_trufor_to_syspath() cfg = _load_trufor_config() self.cfg = cfg from lib.utils import get_model self.model = get_model(cfg) sd = _load_state_dict_from_ckpt(TRUFOR_CKPT_PATH) self.model.load_state_dict(sd, strict=True) self.model.to(self.device) self.model.eval() self.size = tuple(cfg.TRAIN.IMAGE_SIZE) if hasattr(cfg, "TRAIN") else (512, 512) self.to_tensor = T.ToTensor() def _prep(self, pil: Image.Image) -> torch.Tensor: w, h = int(self.size[0]), int(self.size[1]) pil = pil.convert("RGB").resize((w, h), resample=Image.BILINEAR) x = self.to_tensor(pil) return x.unsqueeze(0).to(self.device) @torch.no_grad() def predict(self, pil: Image.Image) -> TruForOutputs: x = self._prep(pil) out, conf, det, _ = self.model(x) if det is None: raise RuntimeError("[TruFor] det is None (no detection head). Your config must include det_head.") score = float(torch.sigmoid(det.view(-1)[0]).item()) if out.ndim != 4 or out.shape[1] != 2: raise RuntimeError(f"[TruFor] Expected out shape [B,2,H,W], got {tuple(out.shape)}") loc_prob = torch.softmax(out, dim=1)[:, 1, :, :].detach().float().cpu().numpy()[0] if conf is None: raise RuntimeError("[TruFor] conf is None but config suggests conf_head should exist.") if conf.ndim != 4 or conf.shape[1] != 1: raise RuntimeError(f"[TruFor] Expected conf shape [B,1,H,W], got {tuple(conf.shape)}") conf_prob = torch.sigmoid(conf)[:, 0, :, :].detach().float().cpu().numpy()[0] return TruForOutputs(score=score, loc_prob=loc_prob, conf_prob=conf_prob) def list_images(input_dir: str) -> List[str]: paths = [] for root, _, files in os.walk(input_dir): for f in files: if os.path.splitext(f.lower())[1] in IMG_EXTS: paths.append(os.path.join(root, f)) return sorted(paths) def fuse_scores(trufor_score: float, ufd_score: float, effnet_score: float) -> float: wsum = (W_TRUFOR + W_UFD + W_EFFNET) s = (W_TRUFOR * trufor_score + W_UFD * ufd_score + W_EFFNET * effnet_score) / wsum return float(np.clip(s, 0.0, 1.0)) def manipulation_type_from_maps(tru: TruForOutputs, ufd_prob: float, fused: float) -> str: if fused < 0.5: return "none" area = float((tru.loc_prob > 0.5).mean()) if ufd_prob >= 0.80 and area >= 0.40: return "full_synthesis" if area >= 0.12: return "inpainting" if area >= 0.03: return "splicing" if ufd_prob >= 0.65 and area < 0.03: return "filter" return "manipulated" def save_prob_map_png(prob_hw: np.ndarray, out_path: str) -> None: prob = np.clip(prob_hw, 0.0, 1.0) img_u8 = (prob * 255.0).round().astype(np.uint8) Image.fromarray(img_u8, mode="L").save(out_path) def main(): ap = argparse.ArgumentParser() ap.add_argument("--input_dir", required=True) ap.add_argument("--output_file", required=True) args = ap.parse_args() for p, name in [ (TRUFOR_TRAIN_TEST_DIR, "TRUFOR_TRAIN_TEST_DIR"), (TRUFOR_CFG_PATH, "TRUFOR_CFG_PATH"), (TRUFOR_CKPT_PATH, "TRUFOR_CKPT_PATH"), (UFD_FC_WEIGHTS_PATH, "UFD_FC_WEIGHTS_PATH"), (EFFNET_CKPT_PATH, "EFFNET_CKPT_PATH"), ]: if not p or not os.path.exists(p): raise FileNotFoundError(f"{name} missing or not found: {p}") device = torch.device("cuda") trufor = TruForDetector(device=device) ufd = UniversalFakeDetectDetector(device=device) effnet = EffNetDetector(device=device, ckpt_path=EFFNET_CKPT_PATH) # NEW preds: List[Dict[str, Any]] = [] for img_path in list_images(args.input_dir): img_name = os.path.basename(img_path) pil = Image.open(img_path) tru = trufor.predict(pil) ufd_prob = ufd.predict_prob(pil) eff_prob = effnet.predict_prob(pil) fused = fuse_scores(tru.score, ufd_prob, eff_prob) if fused < 0.5: vlm_reasoning = "It looks natural." continue else: try: vlm_reasoning = get_vlm_reasoning_from_baseten(pil, fused) except Exception: vlm_reasoning = VLM_FALLBACK_REASONING rec: Dict[str, Any] = { "image_name": img_name, "authenticity_score": float(fused), "manipulation_type": manipulation_type_from_maps(tru, ufd_prob, fused), "vlm_reasoning": vlm_reasoning, } preds.append(rec) with open(args.output_file, "w", encoding="utf-8") as f: json.dump(preds, f, indent=2) print(f"Wrote {len(preds)} predictions to {args.output_file}") if __name__ == "__main__": main()