Update predict.py
Browse files- predict.py +12 -69
predict.py
CHANGED
|
@@ -15,11 +15,8 @@ import torch.nn as nn
|
|
| 15 |
import torch.nn.functional as F
|
| 16 |
import torchvision.transforms as T
|
| 17 |
|
| 18 |
-
import requests
|
| 19 |
|
| 20 |
-
# =========================
|
| 21 |
-
# GLOBAL PATHS (YOU SET)
|
| 22 |
-
# =========================
|
| 23 |
|
| 24 |
TRUFOR_TRAIN_TEST_DIR = "TruFor_train_test"
|
| 25 |
TRUFOR_CFG_PATH = "TruFor_train_test/lib/config/trufor_ph3.yaml"
|
|
@@ -28,19 +25,15 @@ TRUFOR_CKPT_PATH = "weights/trufor.pth.tar"
|
|
| 28 |
UFD_FC_WEIGHTS_PATH = "fc_weights.pth"
|
| 29 |
UFD_CLIP_NAME = "ViT-L/14"
|
| 30 |
|
| 31 |
-
# NEW: EfficientNet metric+classifier checkpoint
|
| 32 |
EFFNET_CKPT_PATH = "best_metric_cls_effnet.pt"
|
| 33 |
|
| 34 |
-
# Weights for fusion
|
| 35 |
W_TRUFOR = 0.5
|
| 36 |
W_UFD = 0.4
|
| 37 |
-
W_EFFNET = 0.1
|
| 38 |
|
| 39 |
IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}
|
| 40 |
|
| 41 |
-
|
| 42 |
-
# NEW: Baseten VLM (your model)
|
| 43 |
-
# =========================
|
| 44 |
BASETEN_VLM_MODEL_ID = "zq8pe88w"
|
| 45 |
BASETEN_VLM_URL = f"https://model-{BASETEN_VLM_MODEL_ID}.api.baseten.co/development/predict"
|
| 46 |
|
|
@@ -83,23 +76,16 @@ def get_vlm_reasoning_from_baseten(pil: Image.Image, authenticity_score: float)
|
|
| 83 |
r.raise_for_status()
|
| 84 |
out = r.json()
|
| 85 |
|
| 86 |
-
# Tolerate different response shapes
|
| 87 |
if isinstance(out, dict):
|
| 88 |
-
# Common keys you might return from the Truss model
|
| 89 |
for k in ("output", "text", "result", "prediction", "vlm_reasoning"):
|
| 90 |
v = out.get(k)
|
| 91 |
if isinstance(v, str) and v.strip():
|
| 92 |
return v.strip()
|
| 93 |
-
# If your Truss returns {"data": "..."} or similar, you'll see it here
|
| 94 |
return json.dumps(out, ensure_ascii=False)
|
| 95 |
|
| 96 |
-
# If it’s a raw string/list/etc.
|
| 97 |
return str(out).strip()
|
| 98 |
|
| 99 |
|
| 100 |
-
# =========================
|
| 101 |
-
# UFD CLIPModel
|
| 102 |
-
# =========================
|
| 103 |
import clip # openai/CLIP
|
| 104 |
|
| 105 |
CHANNELS = {
|
|
@@ -162,10 +148,7 @@ class UniversalFakeDetectDetector:
|
|
| 162 |
return float(torch.sigmoid(logit).item())
|
| 163 |
|
| 164 |
|
| 165 |
-
|
| 166 |
-
# NEW: EfficientNet Metric+Classifier
|
| 167 |
-
# =========================
|
| 168 |
-
# Requires: pip install timm
|
| 169 |
import timm
|
| 170 |
|
| 171 |
|
|
@@ -189,9 +172,9 @@ class EffNetMetricClassifier(nn.Module):
|
|
| 189 |
|
| 190 |
def forward(self, x):
|
| 191 |
feat = self.backbone(x)
|
| 192 |
-
z = self.proj(feat)
|
| 193 |
-
emb = F.normalize(z, p=2, dim=1)
|
| 194 |
-
logits = self.classifier(z)
|
| 195 |
return emb, logits
|
| 196 |
|
| 197 |
|
|
@@ -225,7 +208,6 @@ class EffNetDetector:
|
|
| 225 |
self.model.to(self.device)
|
| 226 |
self.model.eval()
|
| 227 |
|
| 228 |
-
# Match validation preprocessing from training
|
| 229 |
self.transform = T.Compose([
|
| 230 |
T.Resize(int(img_size * 1.15)),
|
| 231 |
T.CenterCrop(img_size),
|
|
@@ -238,19 +220,15 @@ class EffNetDetector:
|
|
| 238 |
x = self.transform(pil.convert("RGB")).unsqueeze(0).to(self.device)
|
| 239 |
_, logits = self.model(x)
|
| 240 |
|
| 241 |
-
# If logits is [B,2], use softmax prob of class 1 (AI)
|
| 242 |
if logits.shape[-1] == 2:
|
| 243 |
p1 = torch.softmax(logits, dim=1)[0, 1]
|
| 244 |
return float(p1.item())
|
| 245 |
|
| 246 |
-
# fallback (if someone trained 1-logit head): sigmoid
|
| 247 |
logit = logits.view(-1)[0]
|
| 248 |
return float(torch.sigmoid(logit).item())
|
| 249 |
|
| 250 |
|
| 251 |
-
|
| 252 |
-
# TruFor
|
| 253 |
-
# =========================
|
| 254 |
def _add_trufor_to_syspath():
|
| 255 |
if not os.path.isdir(TRUFOR_TRAIN_TEST_DIR):
|
| 256 |
raise FileNotFoundError(f"TRUFOR_TRAIN_TEST_DIR not found: {TRUFOR_TRAIN_TEST_DIR}")
|
|
@@ -348,9 +326,6 @@ class TruForDetector:
|
|
| 348 |
return TruForOutputs(score=score, loc_prob=loc_prob, conf_prob=conf_prob)
|
| 349 |
|
| 350 |
|
| 351 |
-
# =========================
|
| 352 |
-
# Mask saving + fusion
|
| 353 |
-
# =========================
|
| 354 |
def list_images(input_dir: str) -> List[str]:
|
| 355 |
paths = []
|
| 356 |
for root, _, files in os.walk(input_dir):
|
|
@@ -393,14 +368,6 @@ def main():
|
|
| 393 |
ap = argparse.ArgumentParser()
|
| 394 |
ap.add_argument("--input_dir", required=True)
|
| 395 |
ap.add_argument("--output_file", required=True)
|
| 396 |
-
ap.add_argument("--threshold", type=float, default=0.5)
|
| 397 |
-
ap.add_argument("--only_flagged", action="store_true")
|
| 398 |
-
|
| 399 |
-
ap.add_argument("--mask_dir", default="", help="If set, save TruFor loc/conf maps as PNGs into this folder.")
|
| 400 |
-
ap.add_argument("--save_conf", action="store_true", help="If set, also save TruFor confidence maps.")
|
| 401 |
-
ap.add_argument("--print_scores", action="store_true", help="If set, print TruFor/UFD/EffNet/Fused scores per image.")
|
| 402 |
-
|
| 403 |
-
ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
| 404 |
args = ap.parse_args()
|
| 405 |
|
| 406 |
for p, name in [
|
|
@@ -413,10 +380,7 @@ def main():
|
|
| 413 |
if not p or not os.path.exists(p):
|
| 414 |
raise FileNotFoundError(f"{name} missing or not found: {p}")
|
| 415 |
|
| 416 |
-
|
| 417 |
-
os.makedirs(args.mask_dir, exist_ok=True)
|
| 418 |
-
|
| 419 |
-
device = torch.device(args.device)
|
| 420 |
|
| 421 |
trufor = TruForDetector(device=device)
|
| 422 |
ufd = UniversalFakeDetectDetector(device=device)
|
|
@@ -425,32 +389,17 @@ def main():
|
|
| 425 |
preds: List[Dict[str, Any]] = []
|
| 426 |
for img_path in list_images(args.input_dir):
|
| 427 |
img_name = os.path.basename(img_path)
|
| 428 |
-
stem = os.path.splitext(img_name)[0]
|
| 429 |
pil = Image.open(img_path)
|
| 430 |
|
| 431 |
tru = trufor.predict(pil)
|
| 432 |
ufd_prob = ufd.predict_prob(pil)
|
| 433 |
-
eff_prob = effnet.predict_prob(pil)
|
| 434 |
|
| 435 |
fused = fuse_scores(tru.score, ufd_prob, eff_prob)
|
| 436 |
|
| 437 |
-
if args.print_scores:
|
| 438 |
-
print(
|
| 439 |
-
f"{img_name}\tTruFor={tru.score:.4f}\tUFD={ufd_prob:.4f}\tEffNet={eff_prob:.4f}\tFused={fused:.4f}",
|
| 440 |
-
flush=True,
|
| 441 |
-
)
|
| 442 |
-
|
| 443 |
-
if args.mask_dir:
|
| 444 |
-
loc_path = os.path.join(args.mask_dir, f"{stem}_trufor_loc.png")
|
| 445 |
-
save_prob_map_png(tru.loc_prob, loc_path)
|
| 446 |
-
|
| 447 |
-
if args.save_conf:
|
| 448 |
-
conf_path = os.path.join(args.mask_dir, f"{stem}_trufor_conf.png")
|
| 449 |
-
save_prob_map_png(tru.conf_prob, conf_path)
|
| 450 |
-
|
| 451 |
-
# NEW: VLM reasoning from Baseten
|
| 452 |
if fused < 0.5:
|
| 453 |
vlm_reasoning = "It looks natural."
|
|
|
|
| 454 |
else:
|
| 455 |
try:
|
| 456 |
vlm_reasoning = get_vlm_reasoning_from_baseten(pil, fused)
|
|
@@ -462,15 +411,9 @@ def main():
|
|
| 462 |
"authenticity_score": float(fused),
|
| 463 |
"manipulation_type": manipulation_type_from_maps(tru, ufd_prob, fused),
|
| 464 |
"vlm_reasoning": vlm_reasoning,
|
| 465 |
-
"debug": {
|
| 466 |
-
"trufor_score": float(tru.score),
|
| 467 |
-
"ufd_score": float(ufd_prob),
|
| 468 |
-
"effnet_score": float(eff_prob),
|
| 469 |
-
},
|
| 470 |
}
|
| 471 |
|
| 472 |
-
|
| 473 |
-
preds.append(rec)
|
| 474 |
|
| 475 |
with open(args.output_file, "w", encoding="utf-8") as f:
|
| 476 |
json.dump(preds, f, indent=2)
|
|
|
|
| 15 |
import torch.nn.functional as F
|
| 16 |
import torchvision.transforms as T
|
| 17 |
|
| 18 |
+
import requests
|
| 19 |
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
TRUFOR_TRAIN_TEST_DIR = "TruFor_train_test"
|
| 22 |
TRUFOR_CFG_PATH = "TruFor_train_test/lib/config/trufor_ph3.yaml"
|
|
|
|
| 25 |
UFD_FC_WEIGHTS_PATH = "fc_weights.pth"
|
| 26 |
UFD_CLIP_NAME = "ViT-L/14"
|
| 27 |
|
|
|
|
| 28 |
EFFNET_CKPT_PATH = "best_metric_cls_effnet.pt"
|
| 29 |
|
|
|
|
| 30 |
W_TRUFOR = 0.5
|
| 31 |
W_UFD = 0.4
|
| 32 |
+
W_EFFNET = 0.1
|
| 33 |
|
| 34 |
IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}
|
| 35 |
|
| 36 |
+
|
|
|
|
|
|
|
| 37 |
BASETEN_VLM_MODEL_ID = "zq8pe88w"
|
| 38 |
BASETEN_VLM_URL = f"https://model-{BASETEN_VLM_MODEL_ID}.api.baseten.co/development/predict"
|
| 39 |
|
|
|
|
| 76 |
r.raise_for_status()
|
| 77 |
out = r.json()
|
| 78 |
|
|
|
|
| 79 |
if isinstance(out, dict):
|
|
|
|
| 80 |
for k in ("output", "text", "result", "prediction", "vlm_reasoning"):
|
| 81 |
v = out.get(k)
|
| 82 |
if isinstance(v, str) and v.strip():
|
| 83 |
return v.strip()
|
|
|
|
| 84 |
return json.dumps(out, ensure_ascii=False)
|
| 85 |
|
|
|
|
| 86 |
return str(out).strip()
|
| 87 |
|
| 88 |
|
|
|
|
|
|
|
|
|
|
| 89 |
import clip # openai/CLIP
|
| 90 |
|
| 91 |
CHANNELS = {
|
|
|
|
| 148 |
return float(torch.sigmoid(logit).item())
|
| 149 |
|
| 150 |
|
| 151 |
+
|
|
|
|
|
|
|
|
|
|
| 152 |
import timm
|
| 153 |
|
| 154 |
|
|
|
|
| 172 |
|
| 173 |
def forward(self, x):
|
| 174 |
feat = self.backbone(x)
|
| 175 |
+
z = self.proj(feat)
|
| 176 |
+
emb = F.normalize(z, p=2, dim=1)
|
| 177 |
+
logits = self.classifier(z)
|
| 178 |
return emb, logits
|
| 179 |
|
| 180 |
|
|
|
|
| 208 |
self.model.to(self.device)
|
| 209 |
self.model.eval()
|
| 210 |
|
|
|
|
| 211 |
self.transform = T.Compose([
|
| 212 |
T.Resize(int(img_size * 1.15)),
|
| 213 |
T.CenterCrop(img_size),
|
|
|
|
| 220 |
x = self.transform(pil.convert("RGB")).unsqueeze(0).to(self.device)
|
| 221 |
_, logits = self.model(x)
|
| 222 |
|
|
|
|
| 223 |
if logits.shape[-1] == 2:
|
| 224 |
p1 = torch.softmax(logits, dim=1)[0, 1]
|
| 225 |
return float(p1.item())
|
| 226 |
|
|
|
|
| 227 |
logit = logits.view(-1)[0]
|
| 228 |
return float(torch.sigmoid(logit).item())
|
| 229 |
|
| 230 |
|
| 231 |
+
|
|
|
|
|
|
|
| 232 |
def _add_trufor_to_syspath():
|
| 233 |
if not os.path.isdir(TRUFOR_TRAIN_TEST_DIR):
|
| 234 |
raise FileNotFoundError(f"TRUFOR_TRAIN_TEST_DIR not found: {TRUFOR_TRAIN_TEST_DIR}")
|
|
|
|
| 326 |
return TruForOutputs(score=score, loc_prob=loc_prob, conf_prob=conf_prob)
|
| 327 |
|
| 328 |
|
|
|
|
|
|
|
|
|
|
| 329 |
def list_images(input_dir: str) -> List[str]:
|
| 330 |
paths = []
|
| 331 |
for root, _, files in os.walk(input_dir):
|
|
|
|
| 368 |
ap = argparse.ArgumentParser()
|
| 369 |
ap.add_argument("--input_dir", required=True)
|
| 370 |
ap.add_argument("--output_file", required=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
args = ap.parse_args()
|
| 372 |
|
| 373 |
for p, name in [
|
|
|
|
| 380 |
if not p or not os.path.exists(p):
|
| 381 |
raise FileNotFoundError(f"{name} missing or not found: {p}")
|
| 382 |
|
| 383 |
+
device = torch.device("cuda")
|
|
|
|
|
|
|
|
|
|
| 384 |
|
| 385 |
trufor = TruForDetector(device=device)
|
| 386 |
ufd = UniversalFakeDetectDetector(device=device)
|
|
|
|
| 389 |
preds: List[Dict[str, Any]] = []
|
| 390 |
for img_path in list_images(args.input_dir):
|
| 391 |
img_name = os.path.basename(img_path)
|
|
|
|
| 392 |
pil = Image.open(img_path)
|
| 393 |
|
| 394 |
tru = trufor.predict(pil)
|
| 395 |
ufd_prob = ufd.predict_prob(pil)
|
| 396 |
+
eff_prob = effnet.predict_prob(pil)
|
| 397 |
|
| 398 |
fused = fuse_scores(tru.score, ufd_prob, eff_prob)
|
| 399 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
if fused < 0.5:
|
| 401 |
vlm_reasoning = "It looks natural."
|
| 402 |
+
continue
|
| 403 |
else:
|
| 404 |
try:
|
| 405 |
vlm_reasoning = get_vlm_reasoning_from_baseten(pil, fused)
|
|
|
|
| 411 |
"authenticity_score": float(fused),
|
| 412 |
"manipulation_type": manipulation_type_from_maps(tru, ufd_prob, fused),
|
| 413 |
"vlm_reasoning": vlm_reasoning,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
}
|
| 415 |
|
| 416 |
+
preds.append(rec)
|
|
|
|
| 417 |
|
| 418 |
with open(args.output_file, "w", encoding="utf-8") as f:
|
| 419 |
json.dump(preds, f, indent=2)
|