|
|
import argparse |
|
|
import base64 |
|
|
import io |
|
|
import json |
|
|
import os |
|
|
import sys |
|
|
from dataclasses import dataclass |
|
|
from typing import Any, Dict, List |
|
|
|
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torchvision.transforms as T |
|
|
|
|
|
import requests |
|
|
|
|
|
|
|
|
TRUFOR_TRAIN_TEST_DIR = "TruFor_train_test" |
|
|
TRUFOR_CFG_PATH = "TruFor_train_test/lib/config/trufor_ph3.yaml" |
|
|
TRUFOR_CKPT_PATH = "weights/trufor.pth.tar" |
|
|
|
|
|
UFD_FC_WEIGHTS_PATH = "fc_weights.pth" |
|
|
UFD_CLIP_NAME = "ViT-L/14" |
|
|
|
|
|
EFFNET_CKPT_PATH = "best_metric_cls_effnet.pt" |
|
|
|
|
|
W_TRUFOR = 0.5 |
|
|
W_UFD = 0.4 |
|
|
W_EFFNET = 0.1 |
|
|
|
|
|
IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"} |
|
|
|
|
|
|
|
|
BASETEN_VLM_MODEL_ID = "zq8pe88w" |
|
|
BASETEN_VLM_URL = f"https://model-{BASETEN_VLM_MODEL_ID}.api.baseten.co/development/predict" |
|
|
|
|
|
VLM_FALLBACK_REASONING = ( |
|
|
"The image has odd textures, and unnatural edges." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def _pil_to_b64_jpeg(pil: Image.Image, quality: int = 95) -> str: |
|
|
buf = io.BytesIO() |
|
|
pil.convert("RGB").save(buf, format="JPEG", quality=quality) |
|
|
return base64.b64encode(buf.getvalue()).decode("utf-8") |
|
|
|
|
|
|
|
|
def get_vlm_reasoning_from_baseten(pil: Image.Image, authenticity_score: float) -> str: |
|
|
""" |
|
|
Calls your Baseten model. Assumes the Baseten Truss model expects: |
|
|
{ |
|
|
"authenticity_score": <float>, |
|
|
"image": "<base64_jpeg>" |
|
|
} |
|
|
and returns either a string or a JSON containing a string. |
|
|
""" |
|
|
api_key = "qlTTHbba.uxjD04TMFzgYekDpUeXxaipMyCPzC486" |
|
|
if not api_key: |
|
|
raise RuntimeError("Missing BASETEN_API_KEY env var.") |
|
|
|
|
|
payload = { |
|
|
"authenticity_score": float(authenticity_score), |
|
|
"image": _pil_to_b64_jpeg(pil), |
|
|
} |
|
|
|
|
|
r = requests.post( |
|
|
BASETEN_VLM_URL, |
|
|
headers={"Authorization": f"Api-Key {api_key}"}, |
|
|
json=payload, |
|
|
timeout=120, |
|
|
) |
|
|
r.raise_for_status() |
|
|
out = r.json() |
|
|
|
|
|
if isinstance(out, dict): |
|
|
for k in ("output", "text", "result", "prediction", "vlm_reasoning"): |
|
|
v = out.get(k) |
|
|
if isinstance(v, str) and v.strip(): |
|
|
return v.strip() |
|
|
return json.dumps(out, ensure_ascii=False) |
|
|
|
|
|
return str(out).strip() |
|
|
|
|
|
|
|
|
import clip |
|
|
|
|
|
CHANNELS = { |
|
|
"RN50": 1024, |
|
|
"RN101": 512, |
|
|
"RN50x4": 640, |
|
|
"RN50x16": 768, |
|
|
"RN50x64": 1024, |
|
|
"ViT-B/32": 512, |
|
|
"ViT-B/16": 512, |
|
|
"ViT-L/14": 768, |
|
|
"ViT-L/14@336px": 768, |
|
|
} |
|
|
|
|
|
|
|
|
class CLIPModel(nn.Module): |
|
|
def __init__(self, name, num_classes=1): |
|
|
super(CLIPModel, self).__init__() |
|
|
self.model, self.preprocess = clip.load(name, device="cpu") |
|
|
self.fc = nn.Linear(CHANNELS[name], num_classes) |
|
|
|
|
|
def forward(self, x, return_feature=False): |
|
|
features = self.model.encode_image(x) |
|
|
if return_feature: |
|
|
return features |
|
|
return self.fc(features) |
|
|
|
|
|
|
|
|
class UniversalFakeDetectDetector: |
|
|
def __init__(self, device: torch.device): |
|
|
self.device = device |
|
|
self.model = CLIPModel(UFD_CLIP_NAME, num_classes=1) |
|
|
self.model.eval() |
|
|
|
|
|
sd = torch.load(UFD_FC_WEIGHTS_PATH, map_location="cpu") |
|
|
if isinstance(sd, dict) and "state_dict" in sd and isinstance(sd["state_dict"], dict): |
|
|
sd = sd["state_dict"] |
|
|
|
|
|
if isinstance(sd, dict) and any(k.startswith("fc.") for k in sd.keys()): |
|
|
fc_sd = {k.replace("fc.", ""): v for k, v in sd.items() if k.startswith("fc.")} |
|
|
self.model.fc.load_state_dict(fc_sd, strict=True) |
|
|
elif isinstance(sd, dict) and "weight" in sd and "bias" in sd: |
|
|
self.model.fc.load_state_dict({"weight": sd["weight"], "bias": sd["bias"]}, strict=True) |
|
|
elif isinstance(sd, dict) and set(sd.keys()) == {"weight", "bias"}: |
|
|
self.model.fc.load_state_dict(sd, strict=True) |
|
|
else: |
|
|
raise RuntimeError( |
|
|
f"[UFD] Unsupported fc checkpoint format. Top keys: {list(sd.keys())[:50] if isinstance(sd, dict) else type(sd)}" |
|
|
) |
|
|
|
|
|
self.model.fc.to(self.device) |
|
|
self.preprocess = self.model.preprocess |
|
|
|
|
|
@torch.no_grad() |
|
|
def predict_prob(self, pil: Image.Image) -> float: |
|
|
x = self.preprocess(pil.convert("RGB")).unsqueeze(0) |
|
|
features = self.model(x, return_feature=True) |
|
|
features = features.to(self.device) |
|
|
logit = self.model.fc(features).view(-1)[0] |
|
|
return float(torch.sigmoid(logit).item()) |
|
|
|
|
|
|
|
|
|
|
|
import timm |
|
|
|
|
|
|
|
|
class EffNetMetricClassifier(nn.Module): |
|
|
""" |
|
|
Must match the architecture used to train best_metric_cls_effnet.pt: |
|
|
backbone (timm, num_classes=0, global_pool=avg) |
|
|
proj: Linear -> BN |
|
|
classifier: Linear(embed_dim -> 2) |
|
|
""" |
|
|
|
|
|
def __init__(self, model_name="efficientnet_b0", embed_dim=128, num_classes=2, pretrained=False): |
|
|
super().__init__() |
|
|
self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0, global_pool="avg") |
|
|
feat_dim = self.backbone.num_features |
|
|
self.proj = nn.Sequential( |
|
|
nn.Linear(feat_dim, embed_dim), |
|
|
nn.BatchNorm1d(embed_dim), |
|
|
) |
|
|
self.classifier = nn.Linear(embed_dim, num_classes) |
|
|
|
|
|
def forward(self, x): |
|
|
feat = self.backbone(x) |
|
|
z = self.proj(feat) |
|
|
emb = F.normalize(z, p=2, dim=1) |
|
|
logits = self.classifier(z) |
|
|
return emb, logits |
|
|
|
|
|
|
|
|
def _strip_module_prefix(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
|
if any(k.startswith("module.") for k in sd.keys()): |
|
|
return {k.replace("module.", "", 1): v for k, v in sd.items()} |
|
|
return sd |
|
|
|
|
|
|
|
|
class EffNetDetector: |
|
|
def __init__(self, device: torch.device, ckpt_path: str = EFFNET_CKPT_PATH, img_size: int = 224): |
|
|
self.device = device |
|
|
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
|
|
|
|
|
if isinstance(ckpt, dict) and "state_dict" in ckpt: |
|
|
sd = ckpt["state_dict"] |
|
|
model_name = ckpt.get("model_name", "efficientnet_b0") |
|
|
embed_dim = int(ckpt.get("embed_dim", 128)) |
|
|
elif isinstance(ckpt, dict): |
|
|
|
|
|
sd = ckpt |
|
|
model_name = "efficientnet_b0" |
|
|
embed_dim = 128 |
|
|
else: |
|
|
raise RuntimeError(f"[EffNet] Unsupported checkpoint type: {type(ckpt)}") |
|
|
|
|
|
sd = _strip_module_prefix(sd) |
|
|
|
|
|
self.model = EffNetMetricClassifier(model_name=model_name, embed_dim=embed_dim, num_classes=2, pretrained=False) |
|
|
self.model.load_state_dict(sd, strict=True) |
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
self.transform = T.Compose([ |
|
|
T.Resize(int(img_size * 1.15)), |
|
|
T.CenterCrop(img_size), |
|
|
T.ToTensor(), |
|
|
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), |
|
|
]) |
|
|
|
|
|
@torch.no_grad() |
|
|
def predict_prob(self, pil: Image.Image) -> float: |
|
|
x = self.transform(pil.convert("RGB")).unsqueeze(0).to(self.device) |
|
|
_, logits = self.model(x) |
|
|
|
|
|
if logits.shape[-1] == 2: |
|
|
p1 = torch.softmax(logits, dim=1)[0, 1] |
|
|
return float(p1.item()) |
|
|
|
|
|
logit = logits.view(-1)[0] |
|
|
return float(torch.sigmoid(logit).item()) |
|
|
|
|
|
|
|
|
|
|
|
def _add_trufor_to_syspath(): |
|
|
if not os.path.isdir(TRUFOR_TRAIN_TEST_DIR): |
|
|
raise FileNotFoundError(f"TRUFOR_TRAIN_TEST_DIR not found: {TRUFOR_TRAIN_TEST_DIR}") |
|
|
lib_dir = os.path.join(TRUFOR_TRAIN_TEST_DIR, "lib") |
|
|
if not os.path.isdir(lib_dir): |
|
|
raise FileNotFoundError(f"Expected TruFor_train_test/lib at: {lib_dir}") |
|
|
if TRUFOR_TRAIN_TEST_DIR not in sys.path: |
|
|
sys.path.insert(0, TRUFOR_TRAIN_TEST_DIR) |
|
|
|
|
|
|
|
|
def _load_trufor_config(): |
|
|
from lib.config import config as cfg |
|
|
from lib.config import update_config |
|
|
|
|
|
class Args: |
|
|
def __init__(self, cfg_path: str): |
|
|
self.cfg = cfg_path |
|
|
self.opts = [] |
|
|
self.modelDir = "" |
|
|
self.logDir = "" |
|
|
self.dataDir = "" |
|
|
self.prevModelDir = "" |
|
|
self.gpu = "0" |
|
|
|
|
|
args = Args(TRUFOR_CFG_PATH) |
|
|
update_config(cfg, args) |
|
|
return cfg |
|
|
|
|
|
|
|
|
def _load_state_dict_from_ckpt(path: str) -> Dict[str, torch.Tensor]: |
|
|
ckpt = torch.load(path, map_location="cpu", weights_only=False) |
|
|
if not isinstance(ckpt, dict): |
|
|
raise RuntimeError(f"[TruFor] checkpoint is not a dict: {type(ckpt)}") |
|
|
if "state_dict" not in ckpt: |
|
|
raise KeyError(f"[TruFor] checkpoint missing 'state_dict'. Keys={list(ckpt.keys())}") |
|
|
sd = ckpt["state_dict"] |
|
|
if not isinstance(sd, dict): |
|
|
raise RuntimeError(f"[TruFor] checkpoint['state_dict'] is not a dict: {type(sd)}") |
|
|
if any(k.startswith("module.") for k in sd.keys()): |
|
|
sd = {k.replace("module.", "", 1): v for k, v in sd.items()} |
|
|
return sd |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TruForOutputs: |
|
|
score: float |
|
|
loc_prob: np.ndarray |
|
|
conf_prob: np.ndarray |
|
|
|
|
|
|
|
|
class TruForDetector: |
|
|
def __init__(self, device: torch.device): |
|
|
self.device = device |
|
|
_add_trufor_to_syspath() |
|
|
cfg = _load_trufor_config() |
|
|
self.cfg = cfg |
|
|
|
|
|
from lib.utils import get_model |
|
|
|
|
|
self.model = get_model(cfg) |
|
|
sd = _load_state_dict_from_ckpt(TRUFOR_CKPT_PATH) |
|
|
self.model.load_state_dict(sd, strict=True) |
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
self.size = tuple(cfg.TRAIN.IMAGE_SIZE) if hasattr(cfg, "TRAIN") else (512, 512) |
|
|
self.to_tensor = T.ToTensor() |
|
|
|
|
|
def _prep(self, pil: Image.Image) -> torch.Tensor: |
|
|
w, h = int(self.size[0]), int(self.size[1]) |
|
|
pil = pil.convert("RGB").resize((w, h), resample=Image.BILINEAR) |
|
|
x = self.to_tensor(pil) |
|
|
return x.unsqueeze(0).to(self.device) |
|
|
|
|
|
@torch.no_grad() |
|
|
def predict(self, pil: Image.Image) -> TruForOutputs: |
|
|
x = self._prep(pil) |
|
|
out, conf, det, _ = self.model(x) |
|
|
|
|
|
if det is None: |
|
|
raise RuntimeError("[TruFor] det is None (no detection head). Your config must include det_head.") |
|
|
|
|
|
score = float(torch.sigmoid(det.view(-1)[0]).item()) |
|
|
|
|
|
if out.ndim != 4 or out.shape[1] != 2: |
|
|
raise RuntimeError(f"[TruFor] Expected out shape [B,2,H,W], got {tuple(out.shape)}") |
|
|
loc_prob = torch.softmax(out, dim=1)[:, 1, :, :].detach().float().cpu().numpy()[0] |
|
|
|
|
|
if conf is None: |
|
|
raise RuntimeError("[TruFor] conf is None but config suggests conf_head should exist.") |
|
|
if conf.ndim != 4 or conf.shape[1] != 1: |
|
|
raise RuntimeError(f"[TruFor] Expected conf shape [B,1,H,W], got {tuple(conf.shape)}") |
|
|
conf_prob = torch.sigmoid(conf)[:, 0, :, :].detach().float().cpu().numpy()[0] |
|
|
|
|
|
return TruForOutputs(score=score, loc_prob=loc_prob, conf_prob=conf_prob) |
|
|
|
|
|
|
|
|
def list_images(input_dir: str) -> List[str]: |
|
|
paths = [] |
|
|
for root, _, files in os.walk(input_dir): |
|
|
for f in files: |
|
|
if os.path.splitext(f.lower())[1] in IMG_EXTS: |
|
|
paths.append(os.path.join(root, f)) |
|
|
return sorted(paths) |
|
|
|
|
|
|
|
|
def fuse_scores(trufor_score: float, ufd_score: float, effnet_score: float) -> float: |
|
|
wsum = (W_TRUFOR + W_UFD + W_EFFNET) |
|
|
s = (W_TRUFOR * trufor_score + W_UFD * ufd_score + W_EFFNET * effnet_score) / wsum |
|
|
return float(np.clip(s, 0.0, 1.0)) |
|
|
|
|
|
|
|
|
def manipulation_type_from_maps(tru: TruForOutputs, ufd_prob: float, fused: float) -> str: |
|
|
if fused < 0.5: |
|
|
return "none" |
|
|
|
|
|
area = float((tru.loc_prob > 0.5).mean()) |
|
|
|
|
|
if ufd_prob >= 0.80 and area >= 0.40: |
|
|
return "full_synthesis" |
|
|
if area >= 0.12: |
|
|
return "inpainting" |
|
|
if area >= 0.03: |
|
|
return "splicing" |
|
|
if ufd_prob >= 0.65 and area < 0.03: |
|
|
return "filter" |
|
|
return "manipulated" |
|
|
|
|
|
|
|
|
def save_prob_map_png(prob_hw: np.ndarray, out_path: str) -> None: |
|
|
prob = np.clip(prob_hw, 0.0, 1.0) |
|
|
img_u8 = (prob * 255.0).round().astype(np.uint8) |
|
|
Image.fromarray(img_u8, mode="L").save(out_path) |
|
|
|
|
|
|
|
|
def main(): |
|
|
ap = argparse.ArgumentParser() |
|
|
ap.add_argument("--input_dir", required=True) |
|
|
ap.add_argument("--output_file", required=True) |
|
|
args = ap.parse_args() |
|
|
|
|
|
for p, name in [ |
|
|
(TRUFOR_TRAIN_TEST_DIR, "TRUFOR_TRAIN_TEST_DIR"), |
|
|
(TRUFOR_CFG_PATH, "TRUFOR_CFG_PATH"), |
|
|
(TRUFOR_CKPT_PATH, "TRUFOR_CKPT_PATH"), |
|
|
(UFD_FC_WEIGHTS_PATH, "UFD_FC_WEIGHTS_PATH"), |
|
|
(EFFNET_CKPT_PATH, "EFFNET_CKPT_PATH"), |
|
|
]: |
|
|
if not p or not os.path.exists(p): |
|
|
raise FileNotFoundError(f"{name} missing or not found: {p}") |
|
|
|
|
|
device = torch.device("cuda") |
|
|
|
|
|
trufor = TruForDetector(device=device) |
|
|
ufd = UniversalFakeDetectDetector(device=device) |
|
|
effnet = EffNetDetector(device=device, ckpt_path=EFFNET_CKPT_PATH) |
|
|
|
|
|
preds: List[Dict[str, Any]] = [] |
|
|
for img_path in list_images(args.input_dir): |
|
|
img_name = os.path.basename(img_path) |
|
|
pil = Image.open(img_path) |
|
|
|
|
|
tru = trufor.predict(pil) |
|
|
ufd_prob = ufd.predict_prob(pil) |
|
|
eff_prob = effnet.predict_prob(pil) |
|
|
|
|
|
fused = fuse_scores(tru.score, ufd_prob, eff_prob) |
|
|
|
|
|
if fused < 0.5: |
|
|
vlm_reasoning = "It looks natural." |
|
|
continue |
|
|
else: |
|
|
try: |
|
|
vlm_reasoning = get_vlm_reasoning_from_baseten(pil, fused) |
|
|
except Exception: |
|
|
vlm_reasoning = VLM_FALLBACK_REASONING |
|
|
|
|
|
rec: Dict[str, Any] = { |
|
|
"image_name": img_name, |
|
|
"authenticity_score": float(fused), |
|
|
"manipulation_type": manipulation_type_from_maps(tru, ufd_prob, fused), |
|
|
"vlm_reasoning": vlm_reasoning, |
|
|
} |
|
|
|
|
|
preds.append(rec) |
|
|
|
|
|
with open(args.output_file, "w", encoding="utf-8") as f: |
|
|
json.dump(preds, f, indent=2) |
|
|
|
|
|
print(f"Wrote {len(preds)} predictions to {args.output_file}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|