vaniv commited on
Commit
99712d5
·
verified ·
1 Parent(s): c3f0b92

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -80,20 +80,23 @@ def _hf_predict_proba(pil_img: Image.Image) -> float:
80
  Returns P(Deepfake) in [0,1] using the ViT classifier.
81
  """
82
  inputs = _hf_processor(images=pil_img.convert("RGB"), return_tensors="pt")
83
- logits = _hf_model(**inputs).logits # (1, C)
 
 
84
  if logits.shape[-1] == 1:
85
- # Unlikely for this model, but handle binary-sigmoid heads
86
  return torch.sigmoid(logits.squeeze(0))[0].item()
87
 
88
- probs = torch.softmax(logits.squeeze(0), dim=-1).cpu().numpy()
 
89
  if _DEEP_IDX is not None and 0 <= _DEEP_IDX < probs.shape[0]:
90
  return float(probs[_DEEP_IDX])
91
 
92
- # Binary softmax fallback: assume index 1 = deepfake
93
  if probs.shape[0] == 2:
94
  return float(probs[1])
95
 
96
- # Last resort: take the highest class prob (not ideal, but safe)
97
  return float(probs.max())
98
 
99
  # -------------------- Output card --------------------
 
80
  Returns P(Deepfake) in [0,1] using the ViT classifier.
81
  """
82
  inputs = _hf_processor(images=pil_img.convert("RGB"), return_tensors="pt")
83
+ with torch.no_grad():
84
+ logits = _hf_model(**inputs).logits # (1, C)
85
+
86
  if logits.shape[-1] == 1:
87
+ # Binary sigmoid head
88
  return torch.sigmoid(logits.squeeze(0))[0].item()
89
 
90
+ # Softmax head
91
+ probs = torch.softmax(logits.squeeze(0), dim=-1).detach().cpu().numpy()
92
  if _DEEP_IDX is not None and 0 <= _DEEP_IDX < probs.shape[0]:
93
  return float(probs[_DEEP_IDX])
94
 
95
+ # Binary fallback
96
  if probs.shape[0] == 2:
97
  return float(probs[1])
98
 
99
+ # Last resort: take max
100
  return float(probs.max())
101
 
102
  # -------------------- Output card --------------------