detect / predict.py
DaniilOr's picture
Update predict.py
ba92c89 verified
import argparse
import base64
import io
import json
import os
import sys
from dataclasses import dataclass
from typing import Any, Dict, List
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import requests
TRUFOR_TRAIN_TEST_DIR = "TruFor_train_test"
TRUFOR_CFG_PATH = "TruFor_train_test/lib/config/trufor_ph3.yaml"
TRUFOR_CKPT_PATH = "weights/trufor.pth.tar"
UFD_FC_WEIGHTS_PATH = "fc_weights.pth"
UFD_CLIP_NAME = "ViT-L/14"
EFFNET_CKPT_PATH = "best_metric_cls_effnet.pt"
W_TRUFOR = 0.5
W_UFD = 0.4
W_EFFNET = 0.1
IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}
BASETEN_VLM_MODEL_ID = "zq8pe88w"
BASETEN_VLM_URL = f"https://model-{BASETEN_VLM_MODEL_ID}.api.baseten.co/development/predict"
VLM_FALLBACK_REASONING = (
"The image has odd textures, and unnatural edges."
)
def _pil_to_b64_jpeg(pil: Image.Image, quality: int = 95) -> str:
buf = io.BytesIO()
pil.convert("RGB").save(buf, format="JPEG", quality=quality)
return base64.b64encode(buf.getvalue()).decode("utf-8")
def get_vlm_reasoning_from_baseten(pil: Image.Image, authenticity_score: float) -> str:
"""
Calls your Baseten model. Assumes the Baseten Truss model expects:
{
"authenticity_score": <float>,
"image": "<base64_jpeg>"
}
and returns either a string or a JSON containing a string.
"""
api_key = "qlTTHbba.uxjD04TMFzgYekDpUeXxaipMyCPzC486"
if not api_key:
raise RuntimeError("Missing BASETEN_API_KEY env var.")
payload = {
"authenticity_score": float(authenticity_score), # 0 real, 1 AI
"image": _pil_to_b64_jpeg(pil), # base64 JPEG, no data: prefix
}
r = requests.post(
BASETEN_VLM_URL,
headers={"Authorization": f"Api-Key {api_key}"},
json=payload,
timeout=120,
)
r.raise_for_status()
out = r.json()
if isinstance(out, dict):
for k in ("output", "text", "result", "prediction", "vlm_reasoning"):
v = out.get(k)
if isinstance(v, str) and v.strip():
return v.strip()
return json.dumps(out, ensure_ascii=False)
return str(out).strip()
import clip # openai/CLIP
CHANNELS = {
"RN50": 1024,
"RN101": 512,
"RN50x4": 640,
"RN50x16": 768,
"RN50x64": 1024,
"ViT-B/32": 512,
"ViT-B/16": 512,
"ViT-L/14": 768,
"ViT-L/14@336px": 768,
}
class CLIPModel(nn.Module):
def __init__(self, name, num_classes=1):
super(CLIPModel, self).__init__()
self.model, self.preprocess = clip.load(name, device="cpu")
self.fc = nn.Linear(CHANNELS[name], num_classes)
def forward(self, x, return_feature=False):
features = self.model.encode_image(x)
if return_feature:
return features
return self.fc(features)
class UniversalFakeDetectDetector:
def __init__(self, device: torch.device):
self.device = device
self.model = CLIPModel(UFD_CLIP_NAME, num_classes=1)
self.model.eval()
sd = torch.load(UFD_FC_WEIGHTS_PATH, map_location="cpu")
if isinstance(sd, dict) and "state_dict" in sd and isinstance(sd["state_dict"], dict):
sd = sd["state_dict"]
if isinstance(sd, dict) and any(k.startswith("fc.") for k in sd.keys()):
fc_sd = {k.replace("fc.", ""): v for k, v in sd.items() if k.startswith("fc.")}
self.model.fc.load_state_dict(fc_sd, strict=True)
elif isinstance(sd, dict) and "weight" in sd and "bias" in sd:
self.model.fc.load_state_dict({"weight": sd["weight"], "bias": sd["bias"]}, strict=True)
elif isinstance(sd, dict) and set(sd.keys()) == {"weight", "bias"}:
self.model.fc.load_state_dict(sd, strict=True)
else:
raise RuntimeError(
f"[UFD] Unsupported fc checkpoint format. Top keys: {list(sd.keys())[:50] if isinstance(sd, dict) else type(sd)}"
)
self.model.fc.to(self.device)
self.preprocess = self.model.preprocess
@torch.no_grad()
def predict_prob(self, pil: Image.Image) -> float:
x = self.preprocess(pil.convert("RGB")).unsqueeze(0) # CPU
features = self.model(x, return_feature=True) # CPU
features = features.to(self.device)
logit = self.model.fc(features).view(-1)[0]
return float(torch.sigmoid(logit).item())
import timm
class EffNetMetricClassifier(nn.Module):
"""
Must match the architecture used to train best_metric_cls_effnet.pt:
backbone (timm, num_classes=0, global_pool=avg)
proj: Linear -> BN
classifier: Linear(embed_dim -> 2)
"""
def __init__(self, model_name="efficientnet_b0", embed_dim=128, num_classes=2, pretrained=False):
super().__init__()
self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0, global_pool="avg")
feat_dim = self.backbone.num_features
self.proj = nn.Sequential(
nn.Linear(feat_dim, embed_dim),
nn.BatchNorm1d(embed_dim),
)
self.classifier = nn.Linear(embed_dim, num_classes)
def forward(self, x):
feat = self.backbone(x)
z = self.proj(feat)
emb = F.normalize(z, p=2, dim=1)
logits = self.classifier(z)
return emb, logits
def _strip_module_prefix(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
if any(k.startswith("module.") for k in sd.keys()):
return {k.replace("module.", "", 1): v for k, v in sd.items()}
return sd
class EffNetDetector:
def __init__(self, device: torch.device, ckpt_path: str = EFFNET_CKPT_PATH, img_size: int = 224):
self.device = device
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
if isinstance(ckpt, dict) and "state_dict" in ckpt:
sd = ckpt["state_dict"]
model_name = ckpt.get("model_name", "efficientnet_b0")
embed_dim = int(ckpt.get("embed_dim", 128))
elif isinstance(ckpt, dict):
# fallback: assume ckpt itself is state_dict
sd = ckpt
model_name = "efficientnet_b0"
embed_dim = 128
else:
raise RuntimeError(f"[EffNet] Unsupported checkpoint type: {type(ckpt)}")
sd = _strip_module_prefix(sd)
self.model = EffNetMetricClassifier(model_name=model_name, embed_dim=embed_dim, num_classes=2, pretrained=False)
self.model.load_state_dict(sd, strict=True)
self.model.to(self.device)
self.model.eval()
self.transform = T.Compose([
T.Resize(int(img_size * 1.15)),
T.CenterCrop(img_size),
T.ToTensor(),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
@torch.no_grad()
def predict_prob(self, pil: Image.Image) -> float:
x = self.transform(pil.convert("RGB")).unsqueeze(0).to(self.device)
_, logits = self.model(x)
if logits.shape[-1] == 2:
p1 = torch.softmax(logits, dim=1)[0, 1]
return float(p1.item())
logit = logits.view(-1)[0]
return float(torch.sigmoid(logit).item())
def _add_trufor_to_syspath():
if not os.path.isdir(TRUFOR_TRAIN_TEST_DIR):
raise FileNotFoundError(f"TRUFOR_TRAIN_TEST_DIR not found: {TRUFOR_TRAIN_TEST_DIR}")
lib_dir = os.path.join(TRUFOR_TRAIN_TEST_DIR, "lib")
if not os.path.isdir(lib_dir):
raise FileNotFoundError(f"Expected TruFor_train_test/lib at: {lib_dir}")
if TRUFOR_TRAIN_TEST_DIR not in sys.path:
sys.path.insert(0, TRUFOR_TRAIN_TEST_DIR)
def _load_trufor_config():
from lib.config import config as cfg
from lib.config import update_config
class Args:
def __init__(self, cfg_path: str):
self.cfg = cfg_path
self.opts = []
self.modelDir = ""
self.logDir = ""
self.dataDir = ""
self.prevModelDir = ""
self.gpu = "0"
args = Args(TRUFOR_CFG_PATH)
update_config(cfg, args)
return cfg
def _load_state_dict_from_ckpt(path: str) -> Dict[str, torch.Tensor]:
ckpt = torch.load(path, map_location="cpu", weights_only=False)
if not isinstance(ckpt, dict):
raise RuntimeError(f"[TruFor] checkpoint is not a dict: {type(ckpt)}")
if "state_dict" not in ckpt:
raise KeyError(f"[TruFor] checkpoint missing 'state_dict'. Keys={list(ckpt.keys())}")
sd = ckpt["state_dict"]
if not isinstance(sd, dict):
raise RuntimeError(f"[TruFor] checkpoint['state_dict'] is not a dict: {type(sd)}")
if any(k.startswith("module.") for k in sd.keys()):
sd = {k.replace("module.", "", 1): v for k, v in sd.items()}
return sd
@dataclass
class TruForOutputs:
score: float
loc_prob: np.ndarray
conf_prob: np.ndarray
class TruForDetector:
def __init__(self, device: torch.device):
self.device = device
_add_trufor_to_syspath()
cfg = _load_trufor_config()
self.cfg = cfg
from lib.utils import get_model
self.model = get_model(cfg)
sd = _load_state_dict_from_ckpt(TRUFOR_CKPT_PATH)
self.model.load_state_dict(sd, strict=True)
self.model.to(self.device)
self.model.eval()
self.size = tuple(cfg.TRAIN.IMAGE_SIZE) if hasattr(cfg, "TRAIN") else (512, 512)
self.to_tensor = T.ToTensor()
def _prep(self, pil: Image.Image) -> torch.Tensor:
w, h = int(self.size[0]), int(self.size[1])
pil = pil.convert("RGB").resize((w, h), resample=Image.BILINEAR)
x = self.to_tensor(pil)
return x.unsqueeze(0).to(self.device)
@torch.no_grad()
def predict(self, pil: Image.Image) -> TruForOutputs:
x = self._prep(pil)
out, conf, det, _ = self.model(x)
if det is None:
raise RuntimeError("[TruFor] det is None (no detection head). Your config must include det_head.")
score = float(torch.sigmoid(det.view(-1)[0]).item())
if out.ndim != 4 or out.shape[1] != 2:
raise RuntimeError(f"[TruFor] Expected out shape [B,2,H,W], got {tuple(out.shape)}")
loc_prob = torch.softmax(out, dim=1)[:, 1, :, :].detach().float().cpu().numpy()[0]
if conf is None:
raise RuntimeError("[TruFor] conf is None but config suggests conf_head should exist.")
if conf.ndim != 4 or conf.shape[1] != 1:
raise RuntimeError(f"[TruFor] Expected conf shape [B,1,H,W], got {tuple(conf.shape)}")
conf_prob = torch.sigmoid(conf)[:, 0, :, :].detach().float().cpu().numpy()[0]
return TruForOutputs(score=score, loc_prob=loc_prob, conf_prob=conf_prob)
def list_images(input_dir: str) -> List[str]:
paths = []
for root, _, files in os.walk(input_dir):
for f in files:
if os.path.splitext(f.lower())[1] in IMG_EXTS:
paths.append(os.path.join(root, f))
return sorted(paths)
def fuse_scores(trufor_score: float, ufd_score: float, effnet_score: float) -> float:
wsum = (W_TRUFOR + W_UFD + W_EFFNET)
s = (W_TRUFOR * trufor_score + W_UFD * ufd_score + W_EFFNET * effnet_score) / wsum
return float(np.clip(s, 0.0, 1.0))
def manipulation_type_from_maps(tru: TruForOutputs, ufd_prob: float, fused: float) -> str:
if fused < 0.5:
return "none"
area = float((tru.loc_prob > 0.5).mean())
if ufd_prob >= 0.80 and area >= 0.40:
return "full_synthesis"
if area >= 0.12:
return "inpainting"
if area >= 0.03:
return "splicing"
if ufd_prob >= 0.65 and area < 0.03:
return "filter"
return "manipulated"
def save_prob_map_png(prob_hw: np.ndarray, out_path: str) -> None:
prob = np.clip(prob_hw, 0.0, 1.0)
img_u8 = (prob * 255.0).round().astype(np.uint8)
Image.fromarray(img_u8, mode="L").save(out_path)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--input_dir", required=True)
ap.add_argument("--output_file", required=True)
args = ap.parse_args()
for p, name in [
(TRUFOR_TRAIN_TEST_DIR, "TRUFOR_TRAIN_TEST_DIR"),
(TRUFOR_CFG_PATH, "TRUFOR_CFG_PATH"),
(TRUFOR_CKPT_PATH, "TRUFOR_CKPT_PATH"),
(UFD_FC_WEIGHTS_PATH, "UFD_FC_WEIGHTS_PATH"),
(EFFNET_CKPT_PATH, "EFFNET_CKPT_PATH"),
]:
if not p or not os.path.exists(p):
raise FileNotFoundError(f"{name} missing or not found: {p}")
device = torch.device("cuda")
trufor = TruForDetector(device=device)
ufd = UniversalFakeDetectDetector(device=device)
effnet = EffNetDetector(device=device, ckpt_path=EFFNET_CKPT_PATH) # NEW
preds: List[Dict[str, Any]] = []
for img_path in list_images(args.input_dir):
img_name = os.path.basename(img_path)
pil = Image.open(img_path)
tru = trufor.predict(pil)
ufd_prob = ufd.predict_prob(pil)
eff_prob = effnet.predict_prob(pil)
fused = fuse_scores(tru.score, ufd_prob, eff_prob)
if fused < 0.5:
vlm_reasoning = "It looks natural."
continue
else:
try:
vlm_reasoning = get_vlm_reasoning_from_baseten(pil, fused)
except Exception:
vlm_reasoning = VLM_FALLBACK_REASONING
rec: Dict[str, Any] = {
"image_name": img_name,
"authenticity_score": float(fused),
"manipulation_type": manipulation_type_from_maps(tru, ufd_prob, fused),
"vlm_reasoning": vlm_reasoning,
}
preds.append(rec)
with open(args.output_file, "w", encoding="utf-8") as f:
json.dump(preds, f, indent=2)
print(f"Wrote {len(preds)} predictions to {args.output_file}")
if __name__ == "__main__":
main()