File size: 13,949 Bytes
5f0437a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba92c89
5f0437a
 
 
 
 
 
 
 
 
 
 
 
 
ba92c89
5f0437a
 
 
ba92c89
5f0437a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba92c89
5f0437a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba92c89
 
 
5f0437a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba92c89
5f0437a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba92c89
5f0437a
 
 
 
 
 
 
 
 
 
 
 
ba92c89
5f0437a
 
 
 
 
ba92c89
5f0437a
 
 
 
 
 
 
 
 
 
 
 
 
ba92c89
5f0437a
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
import argparse
import base64
import io
import json
import os
import sys
from dataclasses import dataclass
from typing import Any, Dict, List

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T

import requests


TRUFOR_TRAIN_TEST_DIR = "TruFor_train_test"
TRUFOR_CFG_PATH = "TruFor_train_test/lib/config/trufor_ph3.yaml"
TRUFOR_CKPT_PATH = "weights/trufor.pth.tar"

UFD_FC_WEIGHTS_PATH = "fc_weights.pth"
UFD_CLIP_NAME = "ViT-L/14"

EFFNET_CKPT_PATH = "best_metric_cls_effnet.pt"

W_TRUFOR = 0.5
W_UFD = 0.4
W_EFFNET = 0.1  

IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}


BASETEN_VLM_MODEL_ID = "zq8pe88w"
BASETEN_VLM_URL = f"https://model-{BASETEN_VLM_MODEL_ID}.api.baseten.co/development/predict"

VLM_FALLBACK_REASONING = (
    "The image has odd textures, and unnatural edges."
)



def _pil_to_b64_jpeg(pil: Image.Image, quality: int = 95) -> str:
    buf = io.BytesIO()
    pil.convert("RGB").save(buf, format="JPEG", quality=quality)
    return base64.b64encode(buf.getvalue()).decode("utf-8")


def get_vlm_reasoning_from_baseten(pil: Image.Image, authenticity_score: float) -> str:
    """
    Calls your Baseten model. Assumes the Baseten Truss model expects:
      {
        "authenticity_score": <float>,
        "image": "<base64_jpeg>"
      }
    and returns either a string or a JSON containing a string.
    """
    api_key = "qlTTHbba.uxjD04TMFzgYekDpUeXxaipMyCPzC486"
    if not api_key:
        raise RuntimeError("Missing BASETEN_API_KEY env var.")

    payload = {
        "authenticity_score": float(authenticity_score),  # 0 real, 1 AI
        "image": _pil_to_b64_jpeg(pil),                   # base64 JPEG, no data: prefix
    }

    r = requests.post(
        BASETEN_VLM_URL,
        headers={"Authorization": f"Api-Key {api_key}"},
        json=payload,
        timeout=120,
    )
    r.raise_for_status()
    out = r.json()

    if isinstance(out, dict):
        for k in ("output", "text", "result", "prediction", "vlm_reasoning"):
            v = out.get(k)
            if isinstance(v, str) and v.strip():
                return v.strip()
        return json.dumps(out, ensure_ascii=False)

    return str(out).strip()


import clip  # openai/CLIP

CHANNELS = {
    "RN50": 1024,
    "RN101": 512,
    "RN50x4": 640,
    "RN50x16": 768,
    "RN50x64": 1024,
    "ViT-B/32": 512,
    "ViT-B/16": 512,
    "ViT-L/14": 768,
    "ViT-L/14@336px": 768,
}


class CLIPModel(nn.Module):
    def __init__(self, name, num_classes=1):
        super(CLIPModel, self).__init__()
        self.model, self.preprocess = clip.load(name, device="cpu")
        self.fc = nn.Linear(CHANNELS[name], num_classes)

    def forward(self, x, return_feature=False):
        features = self.model.encode_image(x)
        if return_feature:
            return features
        return self.fc(features)


class UniversalFakeDetectDetector:
    def __init__(self, device: torch.device):
        self.device = device
        self.model = CLIPModel(UFD_CLIP_NAME, num_classes=1)
        self.model.eval()

        sd = torch.load(UFD_FC_WEIGHTS_PATH, map_location="cpu")
        if isinstance(sd, dict) and "state_dict" in sd and isinstance(sd["state_dict"], dict):
            sd = sd["state_dict"]

        if isinstance(sd, dict) and any(k.startswith("fc.") for k in sd.keys()):
            fc_sd = {k.replace("fc.", ""): v for k, v in sd.items() if k.startswith("fc.")}
            self.model.fc.load_state_dict(fc_sd, strict=True)
        elif isinstance(sd, dict) and "weight" in sd and "bias" in sd:
            self.model.fc.load_state_dict({"weight": sd["weight"], "bias": sd["bias"]}, strict=True)
        elif isinstance(sd, dict) and set(sd.keys()) == {"weight", "bias"}:
            self.model.fc.load_state_dict(sd, strict=True)
        else:
            raise RuntimeError(
                f"[UFD] Unsupported fc checkpoint format. Top keys: {list(sd.keys())[:50] if isinstance(sd, dict) else type(sd)}"
            )

        self.model.fc.to(self.device)
        self.preprocess = self.model.preprocess

    @torch.no_grad()
    def predict_prob(self, pil: Image.Image) -> float:
        x = self.preprocess(pil.convert("RGB")).unsqueeze(0)  # CPU
        features = self.model(x, return_feature=True)  # CPU
        features = features.to(self.device)
        logit = self.model.fc(features).view(-1)[0]
        return float(torch.sigmoid(logit).item())



import timm


class EffNetMetricClassifier(nn.Module):
    """
    Must match the architecture used to train best_metric_cls_effnet.pt:
      backbone (timm, num_classes=0, global_pool=avg)
      proj: Linear -> BN
      classifier: Linear(embed_dim -> 2)
    """

    def __init__(self, model_name="efficientnet_b0", embed_dim=128, num_classes=2, pretrained=False):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0, global_pool="avg")
        feat_dim = self.backbone.num_features
        self.proj = nn.Sequential(
            nn.Linear(feat_dim, embed_dim),
            nn.BatchNorm1d(embed_dim),
        )
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        feat = self.backbone(x)
        z = self.proj(feat)                 
        emb = F.normalize(z, p=2, dim=1)    
        logits = self.classifier(z)         
        return emb, logits


def _strip_module_prefix(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    if any(k.startswith("module.") for k in sd.keys()):
        return {k.replace("module.", "", 1): v for k, v in sd.items()}
    return sd


class EffNetDetector:
    def __init__(self, device: torch.device, ckpt_path: str = EFFNET_CKPT_PATH, img_size: int = 224):
        self.device = device
        ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)

        if isinstance(ckpt, dict) and "state_dict" in ckpt:
            sd = ckpt["state_dict"]
            model_name = ckpt.get("model_name", "efficientnet_b0")
            embed_dim = int(ckpt.get("embed_dim", 128))
        elif isinstance(ckpt, dict):
            # fallback: assume ckpt itself is state_dict
            sd = ckpt
            model_name = "efficientnet_b0"
            embed_dim = 128
        else:
            raise RuntimeError(f"[EffNet] Unsupported checkpoint type: {type(ckpt)}")

        sd = _strip_module_prefix(sd)

        self.model = EffNetMetricClassifier(model_name=model_name, embed_dim=embed_dim, num_classes=2, pretrained=False)
        self.model.load_state_dict(sd, strict=True)
        self.model.to(self.device)
        self.model.eval()

        self.transform = T.Compose([
            T.Resize(int(img_size * 1.15)),
            T.CenterCrop(img_size),
            T.ToTensor(),
            T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ])

    @torch.no_grad()
    def predict_prob(self, pil: Image.Image) -> float:
        x = self.transform(pil.convert("RGB")).unsqueeze(0).to(self.device)
        _, logits = self.model(x)

        if logits.shape[-1] == 2:
            p1 = torch.softmax(logits, dim=1)[0, 1]
            return float(p1.item())

        logit = logits.view(-1)[0]
        return float(torch.sigmoid(logit).item())



def _add_trufor_to_syspath():
    if not os.path.isdir(TRUFOR_TRAIN_TEST_DIR):
        raise FileNotFoundError(f"TRUFOR_TRAIN_TEST_DIR not found: {TRUFOR_TRAIN_TEST_DIR}")
    lib_dir = os.path.join(TRUFOR_TRAIN_TEST_DIR, "lib")
    if not os.path.isdir(lib_dir):
        raise FileNotFoundError(f"Expected TruFor_train_test/lib at: {lib_dir}")
    if TRUFOR_TRAIN_TEST_DIR not in sys.path:
        sys.path.insert(0, TRUFOR_TRAIN_TEST_DIR)


def _load_trufor_config():
    from lib.config import config as cfg
    from lib.config import update_config

    class Args:
        def __init__(self, cfg_path: str):
            self.cfg = cfg_path
            self.opts = []
            self.modelDir = ""
            self.logDir = ""
            self.dataDir = ""
            self.prevModelDir = ""
            self.gpu = "0"

    args = Args(TRUFOR_CFG_PATH)
    update_config(cfg, args)
    return cfg


def _load_state_dict_from_ckpt(path: str) -> Dict[str, torch.Tensor]:
    ckpt = torch.load(path, map_location="cpu", weights_only=False)
    if not isinstance(ckpt, dict):
        raise RuntimeError(f"[TruFor] checkpoint is not a dict: {type(ckpt)}")
    if "state_dict" not in ckpt:
        raise KeyError(f"[TruFor] checkpoint missing 'state_dict'. Keys={list(ckpt.keys())}")
    sd = ckpt["state_dict"]
    if not isinstance(sd, dict):
        raise RuntimeError(f"[TruFor] checkpoint['state_dict'] is not a dict: {type(sd)}")
    if any(k.startswith("module.") for k in sd.keys()):
        sd = {k.replace("module.", "", 1): v for k, v in sd.items()}
    return sd


@dataclass
class TruForOutputs:
    score: float
    loc_prob: np.ndarray
    conf_prob: np.ndarray


class TruForDetector:
    def __init__(self, device: torch.device):
        self.device = device
        _add_trufor_to_syspath()
        cfg = _load_trufor_config()
        self.cfg = cfg

        from lib.utils import get_model

        self.model = get_model(cfg)
        sd = _load_state_dict_from_ckpt(TRUFOR_CKPT_PATH)
        self.model.load_state_dict(sd, strict=True)
        self.model.to(self.device)
        self.model.eval()

        self.size = tuple(cfg.TRAIN.IMAGE_SIZE) if hasattr(cfg, "TRAIN") else (512, 512)
        self.to_tensor = T.ToTensor()

    def _prep(self, pil: Image.Image) -> torch.Tensor:
        w, h = int(self.size[0]), int(self.size[1])
        pil = pil.convert("RGB").resize((w, h), resample=Image.BILINEAR)
        x = self.to_tensor(pil)
        return x.unsqueeze(0).to(self.device)

    @torch.no_grad()
    def predict(self, pil: Image.Image) -> TruForOutputs:
        x = self._prep(pil)
        out, conf, det, _ = self.model(x)

        if det is None:
            raise RuntimeError("[TruFor] det is None (no detection head). Your config must include det_head.")

        score = float(torch.sigmoid(det.view(-1)[0]).item())

        if out.ndim != 4 or out.shape[1] != 2:
            raise RuntimeError(f"[TruFor] Expected out shape [B,2,H,W], got {tuple(out.shape)}")
        loc_prob = torch.softmax(out, dim=1)[:, 1, :, :].detach().float().cpu().numpy()[0]

        if conf is None:
            raise RuntimeError("[TruFor] conf is None but config suggests conf_head should exist.")
        if conf.ndim != 4 or conf.shape[1] != 1:
            raise RuntimeError(f"[TruFor] Expected conf shape [B,1,H,W], got {tuple(conf.shape)}")
        conf_prob = torch.sigmoid(conf)[:, 0, :, :].detach().float().cpu().numpy()[0]

        return TruForOutputs(score=score, loc_prob=loc_prob, conf_prob=conf_prob)


def list_images(input_dir: str) -> List[str]:
    paths = []
    for root, _, files in os.walk(input_dir):
        for f in files:
            if os.path.splitext(f.lower())[1] in IMG_EXTS:
                paths.append(os.path.join(root, f))
    return sorted(paths)


def fuse_scores(trufor_score: float, ufd_score: float, effnet_score: float) -> float:
    wsum = (W_TRUFOR + W_UFD + W_EFFNET)
    s = (W_TRUFOR * trufor_score + W_UFD * ufd_score + W_EFFNET * effnet_score) / wsum
    return float(np.clip(s, 0.0, 1.0))


def manipulation_type_from_maps(tru: TruForOutputs, ufd_prob: float, fused: float) -> str:
    if fused < 0.5:
        return "none"

    area = float((tru.loc_prob > 0.5).mean())

    if ufd_prob >= 0.80 and area >= 0.40:
        return "full_synthesis"
    if area >= 0.12:
        return "inpainting"
    if area >= 0.03:
        return "splicing"
    if ufd_prob >= 0.65 and area < 0.03:
        return "filter"
    return "manipulated"


def save_prob_map_png(prob_hw: np.ndarray, out_path: str) -> None:
    prob = np.clip(prob_hw, 0.0, 1.0)
    img_u8 = (prob * 255.0).round().astype(np.uint8)
    Image.fromarray(img_u8, mode="L").save(out_path)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--input_dir", required=True)
    ap.add_argument("--output_file", required=True)
    args = ap.parse_args()

    for p, name in [
        (TRUFOR_TRAIN_TEST_DIR, "TRUFOR_TRAIN_TEST_DIR"),
        (TRUFOR_CFG_PATH, "TRUFOR_CFG_PATH"),
        (TRUFOR_CKPT_PATH, "TRUFOR_CKPT_PATH"),
        (UFD_FC_WEIGHTS_PATH, "UFD_FC_WEIGHTS_PATH"),
        (EFFNET_CKPT_PATH, "EFFNET_CKPT_PATH"),
    ]:
        if not p or not os.path.exists(p):
            raise FileNotFoundError(f"{name} missing or not found: {p}")

    device = torch.device("cuda")

    trufor = TruForDetector(device=device)
    ufd = UniversalFakeDetectDetector(device=device)
    effnet = EffNetDetector(device=device, ckpt_path=EFFNET_CKPT_PATH)  # NEW

    preds: List[Dict[str, Any]] = []
    for img_path in list_images(args.input_dir):
        img_name = os.path.basename(img_path)
        pil = Image.open(img_path)

        tru = trufor.predict(pil)
        ufd_prob = ufd.predict_prob(pil)
        eff_prob = effnet.predict_prob(pil)

        fused = fuse_scores(tru.score, ufd_prob, eff_prob)

        if fused < 0.5:
            vlm_reasoning = "It looks natural."
            continue
        else:
            try:
                vlm_reasoning = get_vlm_reasoning_from_baseten(pil, fused)
            except Exception:
                vlm_reasoning = VLM_FALLBACK_REASONING

        rec: Dict[str, Any] = {
            "image_name": img_name,
            "authenticity_score": float(fused),
            "manipulation_type": manipulation_type_from_maps(tru, ufd_prob, fused),
            "vlm_reasoning": vlm_reasoning,
        }

        preds.append(rec)

    with open(args.output_file, "w", encoding="utf-8") as f:
        json.dump(preds, f, indent=2)

    print(f"Wrote {len(preds)} predictions to {args.output_file}")


if __name__ == "__main__":
    main()