DaniilOr commited on
Commit
ba92c89
·
verified ·
1 Parent(s): 5f0437a

Update predict.py

Browse files
Files changed (1) hide show
  1. predict.py +12 -69
predict.py CHANGED
@@ -15,11 +15,8 @@ import torch.nn as nn
15
  import torch.nn.functional as F
16
  import torchvision.transforms as T
17
 
18
- import requests # NEW: for Baseten VLM calls
19
 
20
- # =========================
21
- # GLOBAL PATHS (YOU SET)
22
- # =========================
23
 
24
  TRUFOR_TRAIN_TEST_DIR = "TruFor_train_test"
25
  TRUFOR_CFG_PATH = "TruFor_train_test/lib/config/trufor_ph3.yaml"
@@ -28,19 +25,15 @@ TRUFOR_CKPT_PATH = "weights/trufor.pth.tar"
28
  UFD_FC_WEIGHTS_PATH = "fc_weights.pth"
29
  UFD_CLIP_NAME = "ViT-L/14"
30
 
31
- # NEW: EfficientNet metric+classifier checkpoint
32
  EFFNET_CKPT_PATH = "best_metric_cls_effnet.pt"
33
 
34
- # Weights for fusion
35
  W_TRUFOR = 0.5
36
  W_UFD = 0.4
37
- W_EFFNET = 0.1 # NEW
38
 
39
  IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}
40
 
41
- # =========================
42
- # NEW: Baseten VLM (your model)
43
- # =========================
44
  BASETEN_VLM_MODEL_ID = "zq8pe88w"
45
  BASETEN_VLM_URL = f"https://model-{BASETEN_VLM_MODEL_ID}.api.baseten.co/development/predict"
46
 
@@ -83,23 +76,16 @@ def get_vlm_reasoning_from_baseten(pil: Image.Image, authenticity_score: float)
83
  r.raise_for_status()
84
  out = r.json()
85
 
86
- # Tolerate different response shapes
87
  if isinstance(out, dict):
88
- # Common keys you might return from the Truss model
89
  for k in ("output", "text", "result", "prediction", "vlm_reasoning"):
90
  v = out.get(k)
91
  if isinstance(v, str) and v.strip():
92
  return v.strip()
93
- # If your Truss returns {"data": "..."} or similar, you'll see it here
94
  return json.dumps(out, ensure_ascii=False)
95
 
96
- # If it’s a raw string/list/etc.
97
  return str(out).strip()
98
 
99
 
100
- # =========================
101
- # UFD CLIPModel
102
- # =========================
103
  import clip # openai/CLIP
104
 
105
  CHANNELS = {
@@ -162,10 +148,7 @@ class UniversalFakeDetectDetector:
162
  return float(torch.sigmoid(logit).item())
163
 
164
 
165
- # =========================
166
- # NEW: EfficientNet Metric+Classifier
167
- # =========================
168
- # Requires: pip install timm
169
  import timm
170
 
171
 
@@ -189,9 +172,9 @@ class EffNetMetricClassifier(nn.Module):
189
 
190
  def forward(self, x):
191
  feat = self.backbone(x)
192
- z = self.proj(feat) # unnormalized projected features
193
- emb = F.normalize(z, p=2, dim=1) # embeddings (not used here, but kept for completeness)
194
- logits = self.classifier(z) # 2-class logits
195
  return emb, logits
196
 
197
 
@@ -225,7 +208,6 @@ class EffNetDetector:
225
  self.model.to(self.device)
226
  self.model.eval()
227
 
228
- # Match validation preprocessing from training
229
  self.transform = T.Compose([
230
  T.Resize(int(img_size * 1.15)),
231
  T.CenterCrop(img_size),
@@ -238,19 +220,15 @@ class EffNetDetector:
238
  x = self.transform(pil.convert("RGB")).unsqueeze(0).to(self.device)
239
  _, logits = self.model(x)
240
 
241
- # If logits is [B,2], use softmax prob of class 1 (AI)
242
  if logits.shape[-1] == 2:
243
  p1 = torch.softmax(logits, dim=1)[0, 1]
244
  return float(p1.item())
245
 
246
- # fallback (if someone trained 1-logit head): sigmoid
247
  logit = logits.view(-1)[0]
248
  return float(torch.sigmoid(logit).item())
249
 
250
 
251
- # =========================
252
- # TruFor
253
- # =========================
254
  def _add_trufor_to_syspath():
255
  if not os.path.isdir(TRUFOR_TRAIN_TEST_DIR):
256
  raise FileNotFoundError(f"TRUFOR_TRAIN_TEST_DIR not found: {TRUFOR_TRAIN_TEST_DIR}")
@@ -348,9 +326,6 @@ class TruForDetector:
348
  return TruForOutputs(score=score, loc_prob=loc_prob, conf_prob=conf_prob)
349
 
350
 
351
- # =========================
352
- # Mask saving + fusion
353
- # =========================
354
  def list_images(input_dir: str) -> List[str]:
355
  paths = []
356
  for root, _, files in os.walk(input_dir):
@@ -393,14 +368,6 @@ def main():
393
  ap = argparse.ArgumentParser()
394
  ap.add_argument("--input_dir", required=True)
395
  ap.add_argument("--output_file", required=True)
396
- ap.add_argument("--threshold", type=float, default=0.5)
397
- ap.add_argument("--only_flagged", action="store_true")
398
-
399
- ap.add_argument("--mask_dir", default="", help="If set, save TruFor loc/conf maps as PNGs into this folder.")
400
- ap.add_argument("--save_conf", action="store_true", help="If set, also save TruFor confidence maps.")
401
- ap.add_argument("--print_scores", action="store_true", help="If set, print TruFor/UFD/EffNet/Fused scores per image.")
402
-
403
- ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
404
  args = ap.parse_args()
405
 
406
  for p, name in [
@@ -413,10 +380,7 @@ def main():
413
  if not p or not os.path.exists(p):
414
  raise FileNotFoundError(f"{name} missing or not found: {p}")
415
 
416
- if args.mask_dir:
417
- os.makedirs(args.mask_dir, exist_ok=True)
418
-
419
- device = torch.device(args.device)
420
 
421
  trufor = TruForDetector(device=device)
422
  ufd = UniversalFakeDetectDetector(device=device)
@@ -425,32 +389,17 @@ def main():
425
  preds: List[Dict[str, Any]] = []
426
  for img_path in list_images(args.input_dir):
427
  img_name = os.path.basename(img_path)
428
- stem = os.path.splitext(img_name)[0]
429
  pil = Image.open(img_path)
430
 
431
  tru = trufor.predict(pil)
432
  ufd_prob = ufd.predict_prob(pil)
433
- eff_prob = effnet.predict_prob(pil) # NEW
434
 
435
  fused = fuse_scores(tru.score, ufd_prob, eff_prob)
436
 
437
- if args.print_scores:
438
- print(
439
- f"{img_name}\tTruFor={tru.score:.4f}\tUFD={ufd_prob:.4f}\tEffNet={eff_prob:.4f}\tFused={fused:.4f}",
440
- flush=True,
441
- )
442
-
443
- if args.mask_dir:
444
- loc_path = os.path.join(args.mask_dir, f"{stem}_trufor_loc.png")
445
- save_prob_map_png(tru.loc_prob, loc_path)
446
-
447
- if args.save_conf:
448
- conf_path = os.path.join(args.mask_dir, f"{stem}_trufor_conf.png")
449
- save_prob_map_png(tru.conf_prob, conf_path)
450
-
451
- # NEW: VLM reasoning from Baseten
452
  if fused < 0.5:
453
  vlm_reasoning = "It looks natural."
 
454
  else:
455
  try:
456
  vlm_reasoning = get_vlm_reasoning_from_baseten(pil, fused)
@@ -462,15 +411,9 @@ def main():
462
  "authenticity_score": float(fused),
463
  "manipulation_type": manipulation_type_from_maps(tru, ufd_prob, fused),
464
  "vlm_reasoning": vlm_reasoning,
465
- "debug": {
466
- "trufor_score": float(tru.score),
467
- "ufd_score": float(ufd_prob),
468
- "effnet_score": float(eff_prob),
469
- },
470
  }
471
 
472
- if (not args.only_flagged) or (fused >= args.threshold):
473
- preds.append(rec)
474
 
475
  with open(args.output_file, "w", encoding="utf-8") as f:
476
  json.dump(preds, f, indent=2)
 
15
  import torch.nn.functional as F
16
  import torchvision.transforms as T
17
 
18
+ import requests
19
 
 
 
 
20
 
21
  TRUFOR_TRAIN_TEST_DIR = "TruFor_train_test"
22
  TRUFOR_CFG_PATH = "TruFor_train_test/lib/config/trufor_ph3.yaml"
 
25
  UFD_FC_WEIGHTS_PATH = "fc_weights.pth"
26
  UFD_CLIP_NAME = "ViT-L/14"
27
 
 
28
  EFFNET_CKPT_PATH = "best_metric_cls_effnet.pt"
29
 
 
30
  W_TRUFOR = 0.5
31
  W_UFD = 0.4
32
+ W_EFFNET = 0.1
33
 
34
  IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}
35
 
36
+
 
 
37
  BASETEN_VLM_MODEL_ID = "zq8pe88w"
38
  BASETEN_VLM_URL = f"https://model-{BASETEN_VLM_MODEL_ID}.api.baseten.co/development/predict"
39
 
 
76
  r.raise_for_status()
77
  out = r.json()
78
 
 
79
  if isinstance(out, dict):
 
80
  for k in ("output", "text", "result", "prediction", "vlm_reasoning"):
81
  v = out.get(k)
82
  if isinstance(v, str) and v.strip():
83
  return v.strip()
 
84
  return json.dumps(out, ensure_ascii=False)
85
 
 
86
  return str(out).strip()
87
 
88
 
 
 
 
89
  import clip # openai/CLIP
90
 
91
  CHANNELS = {
 
148
  return float(torch.sigmoid(logit).item())
149
 
150
 
151
+
 
 
 
152
  import timm
153
 
154
 
 
172
 
173
  def forward(self, x):
174
  feat = self.backbone(x)
175
+ z = self.proj(feat)
176
+ emb = F.normalize(z, p=2, dim=1)
177
+ logits = self.classifier(z)
178
  return emb, logits
179
 
180
 
 
208
  self.model.to(self.device)
209
  self.model.eval()
210
 
 
211
  self.transform = T.Compose([
212
  T.Resize(int(img_size * 1.15)),
213
  T.CenterCrop(img_size),
 
220
  x = self.transform(pil.convert("RGB")).unsqueeze(0).to(self.device)
221
  _, logits = self.model(x)
222
 
 
223
  if logits.shape[-1] == 2:
224
  p1 = torch.softmax(logits, dim=1)[0, 1]
225
  return float(p1.item())
226
 
 
227
  logit = logits.view(-1)[0]
228
  return float(torch.sigmoid(logit).item())
229
 
230
 
231
+
 
 
232
  def _add_trufor_to_syspath():
233
  if not os.path.isdir(TRUFOR_TRAIN_TEST_DIR):
234
  raise FileNotFoundError(f"TRUFOR_TRAIN_TEST_DIR not found: {TRUFOR_TRAIN_TEST_DIR}")
 
326
  return TruForOutputs(score=score, loc_prob=loc_prob, conf_prob=conf_prob)
327
 
328
 
 
 
 
329
  def list_images(input_dir: str) -> List[str]:
330
  paths = []
331
  for root, _, files in os.walk(input_dir):
 
368
  ap = argparse.ArgumentParser()
369
  ap.add_argument("--input_dir", required=True)
370
  ap.add_argument("--output_file", required=True)
 
 
 
 
 
 
 
 
371
  args = ap.parse_args()
372
 
373
  for p, name in [
 
380
  if not p or not os.path.exists(p):
381
  raise FileNotFoundError(f"{name} missing or not found: {p}")
382
 
383
+ device = torch.device("cuda")
 
 
 
384
 
385
  trufor = TruForDetector(device=device)
386
  ufd = UniversalFakeDetectDetector(device=device)
 
389
  preds: List[Dict[str, Any]] = []
390
  for img_path in list_images(args.input_dir):
391
  img_name = os.path.basename(img_path)
 
392
  pil = Image.open(img_path)
393
 
394
  tru = trufor.predict(pil)
395
  ufd_prob = ufd.predict_prob(pil)
396
+ eff_prob = effnet.predict_prob(pil)
397
 
398
  fused = fuse_scores(tru.score, ufd_prob, eff_prob)
399
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  if fused < 0.5:
401
  vlm_reasoning = "It looks natural."
402
+ continue
403
  else:
404
  try:
405
  vlm_reasoning = get_vlm_reasoning_from_baseten(pil, fused)
 
411
  "authenticity_score": float(fused),
412
  "manipulation_type": manipulation_type_from_maps(tru, ufd_prob, fused),
413
  "vlm_reasoning": vlm_reasoning,
 
 
 
 
 
414
  }
415
 
416
+ preds.append(rec)
 
417
 
418
  with open(args.output_file, "w", encoding="utf-8") as f:
419
  json.dump(preds, f, indent=2)