Jatin-tec commited on
Commit
aa30915
·
1 Parent(s): 655bd0e

Enhance AI detector initialization and improve TruFor result handling

Browse files
Files changed (2) hide show
  1. app.py +84 -47
  2. trufor_runner.py +6 -46
app.py CHANGED
@@ -1,23 +1,54 @@
 
 
1
  import gradio as gr
2
  import torch
3
  from PIL import Image
4
  from typing import Dict, Optional, Tuple
5
- from transformers import AutoImageProcessor, SiglipForImageClassification
 
 
 
6
 
7
  from trufor_runner import TruForEngine, TruForResult, TruForUnavailableError
8
 
9
- MODEL_ID = "Ateeqq/ai-vs-human-image-detector"
 
 
10
 
11
- # Use GPU when available so large batches stay responsive.
12
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
13
 
14
  try:
15
- processor = AutoImageProcessor.from_pretrained(MODEL_ID)
16
- model = SiglipForImageClassification.from_pretrained(MODEL_ID)
17
- model.to(device)
 
 
 
 
 
 
 
 
18
  model.eval()
 
 
 
 
 
 
 
 
 
 
 
19
  except Exception as exc: # pragma: no cover - surface loading issues early.
20
- raise RuntimeError(f"Failed to load model from {MODEL_ID}") from exc
 
 
 
 
21
 
22
  try:
23
  TRUFOR_ENGINE: Optional[TruForEngine] = TruForEngine(device="cpu")
@@ -28,64 +59,72 @@ except TruForUnavailableError as exc:
28
 
29
 
30
  def analyze_ai_vs_human(image: Image.Image) -> Tuple[Dict[str, float], str]:
31
- """Run the Hugging Face detector and return confidences with a readable summary."""
 
 
 
 
 
32
  if image is None:
33
- empty_scores = {label: 0.0 for label in model.config.id2label.values()}
34
  return empty_scores, "No image provided."
35
 
36
  image = image.convert("RGB")
37
- inputs = processor(images=image, return_tensors="pt").to(device)
38
 
39
  with torch.no_grad():
40
- logits = model(**inputs).logits
41
-
42
- probabilities = torch.softmax(logits, dim=-1)[0]
43
- scores = {
44
- model.config.id2label[idx]: float(probabilities[idx])
45
- for idx in range(probabilities.size(0))
46
- }
47
 
48
- top_idx = int(probabilities.argmax().item())
49
- top_label = model.config.id2label[top_idx]
50
- top_score = scores[top_label]
51
- summary = f"**Predicted Label:** {top_label} \
52
- **Confidence:** {top_score:.4f}"
 
 
 
 
 
 
 
 
 
 
53
 
54
  return scores, summary
55
 
56
 
57
- def analyze_trufor(image: Image.Image) -> Tuple[str, Optional[Image.Image], Optional[Image.Image]]:
58
  """Run TruFor inference when available, otherwise return diagnostics."""
59
  if TRUFOR_ENGINE is None:
60
- return TRUFOR_STATUS, None, None
61
 
62
  if image is None:
63
- return "Upload an image to run TruFor.", None, None
64
 
65
  try:
66
  result: TruForResult = TRUFOR_ENGINE.infer(image)
67
  except TruForUnavailableError as exc:
68
- return str(exc), None, None
69
 
70
- summary_lines = []
71
- if result.score is not None:
72
- summary_lines.append(f"**Tamper Score:** {result.score:.4f}")
73
- extras_dict = result.raw_scores.copy()
74
- if result.score is not None:
75
- extras_dict.pop("tamper_score", None)
76
- if extras_dict:
77
- extras = " ".join(f"{key}: {value:.4f}" for key, value in extras_dict.items())
78
- summary_lines.append(f"`{extras}`")
79
- if not summary_lines:
80
- summary_lines.append("TruFor returned no scores for this image.")
81
 
82
- return "\n".join(summary_lines), result.map_overlay, result.confidence_overlay
83
 
84
 
85
- def analyze_image(image: Image.Image) -> Tuple[Dict[str, float], str, str, Optional[Image.Image], Optional[Image.Image]]:
86
  ai_scores, ai_summary = analyze_ai_vs_human(image)
87
- trufor_summary, tamper_overlay, conf_overlay = analyze_trufor(image)
88
- return ai_scores, ai_summary, trufor_summary, tamper_overlay, conf_overlay
89
 
90
 
91
  with gr.Blocks() as demo:
@@ -93,7 +132,7 @@ with gr.Blocks() as demo:
93
  """# Image Authenticity Workbench\nUpload an image to compare the AI-vs-human classifier with the TruFor forgery detector."""
94
  )
95
 
96
- status_box = gr.Markdown(f"`{TRUFOR_STATUS}`")
97
 
98
  image_input = gr.Image(label="Input Image", type="pil")
99
  analyze_button = gr.Button("Analyze", variant="primary", size="sm")
@@ -101,18 +140,16 @@ with gr.Blocks() as demo:
101
  with gr.Tabs():
102
  with gr.TabItem("AI vs Human"):
103
  ai_label_output = gr.Label(label="Prediction", num_top_classes=2)
104
- ai_summary_output = gr.Markdown("Upload an image to view the prediction.")
105
  with gr.TabItem("TruFor Forgery Detection"):
106
  trufor_summary_output = gr.Markdown("Configure TruFor assets to enable tamper analysis.")
107
- tamper_overlay_output = gr.Image(label="Tamper Heatmap", type="pil", interactive=False)
108
- conf_overlay_output = gr.Image(label="Confidence Heatmap", type="pil", interactive=False)
109
 
110
  output_components = [
111
  ai_label_output,
112
  ai_summary_output,
113
  trufor_summary_output,
114
  tamper_overlay_output,
115
- conf_overlay_output,
116
  ]
117
 
118
  analyze_button.click(
 
1
+ import os
2
+
3
  import gradio as gr
4
  import torch
5
  from PIL import Image
6
  from typing import Dict, Optional, Tuple
7
+ from torchvision import transforms
8
+ from timm import create_model
9
+ from huggingface_hub import hf_hub_download
10
+ from huggingface_hub.errors import GatedRepoError, HfHubHTTPError
11
 
12
  from trufor_runner import TruForEngine, TruForResult, TruForUnavailableError
13
 
14
+ IMG_SIZE = 380
15
+ LABEL_MAPPING = {0: "human", 1: "ai"}
16
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
+ transform: Optional[transforms.Compose]
19
+ model: Optional[torch.nn.Module]
20
+ MODEL_STATUS: str
21
 
22
  try:
23
+ token = os.getenv("HF_TOKEN")
24
+ model_path = hf_hub_download(repo_id="Dafilab/ai-image-detector", filename="pytorch_model.pth", token=token)
25
+ transform = transforms.Compose([
26
+ transforms.Resize(IMG_SIZE + 20),
27
+ transforms.CenterCrop(IMG_SIZE),
28
+ transforms.ToTensor(),
29
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
30
+ ])
31
+ model = create_model("efficientnet_b4", pretrained=False, num_classes=len(LABEL_MAPPING))
32
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
33
+ model.to(DEVICE)
34
  model.eval()
35
+ MODEL_STATUS = "AI detector ready."
36
+ except GatedRepoError:
37
+ transform = None
38
+ model = None
39
+ MODEL_STATUS = (
40
+ "AI detector requires approved Hugging Face access. Configure HF_TOKEN with a permitted token."
41
+ )
42
+ except (HfHubHTTPError, OSError) as exc:
43
+ transform = None
44
+ model = None
45
+ MODEL_STATUS = f"AI detector unavailable: {exc}"
46
  except Exception as exc: # pragma: no cover - surface loading issues early.
47
+ transform = None
48
+ model = None
49
+ MODEL_STATUS = f"AI detector failed to initialize: {exc}"
50
+
51
+ AI_INTRO_SUMMARY = MODEL_STATUS if model is None else "Upload an image to view the prediction."
52
 
53
  try:
54
  TRUFOR_ENGINE: Optional[TruForEngine] = TruForEngine(device="cpu")
 
59
 
60
 
61
  def analyze_ai_vs_human(image: Image.Image) -> Tuple[Dict[str, float], str]:
62
+ """Run the EfficientNet-based detector and return confidences with a readable summary."""
63
+ empty_scores = {label: 0.0 for label in LABEL_MAPPING.values()}
64
+
65
+ if model is None or transform is None:
66
+ return empty_scores, MODEL_STATUS
67
+
68
  if image is None:
 
69
  return empty_scores, "No image provided."
70
 
71
  image = image.convert("RGB")
72
+ inputs = transform(image).unsqueeze(0).to(DEVICE)
73
 
74
  with torch.no_grad():
75
+ logits = model(inputs)
 
 
 
 
 
 
76
 
77
+ probabilities = torch.softmax(logits, dim=1)[0]
78
+ ordered_scores = sorted(
79
+ ((LABEL_MAPPING[idx], float(probabilities[idx])) for idx in LABEL_MAPPING),
80
+ key=lambda item: item[1],
81
+ reverse=True,
82
+ )
83
+ scores = dict(ordered_scores)
84
+
85
+ top_label, top_score = ordered_scores[0]
86
+ second_label, second_score = ordered_scores[1]
87
+ summary = (
88
+ f"**Predicted Label:** {top_label} \
89
+ **Confidence:** {top_score:.2%}\n"
90
+ f"`{top_label}: {top_score:.2%} | {second_label}: {second_score:.2%}`"
91
+ )
92
 
93
  return scores, summary
94
 
95
 
96
+ def analyze_trufor(image: Image.Image) -> Tuple[str, Optional[Image.Image]]:
97
  """Run TruFor inference when available, otherwise return diagnostics."""
98
  if TRUFOR_ENGINE is None:
99
+ return TRUFOR_STATUS, None
100
 
101
  if image is None:
102
+ return "Upload an image to run TruFor.", None
103
 
104
  try:
105
  result: TruForResult = TRUFOR_ENGINE.infer(image)
106
  except TruForUnavailableError as exc:
107
+ return str(exc), None
108
 
109
+ # Determine if image is altered based on tamper score threshold
110
+ if result.score is None:
111
+ return "TruFor returned no prediction for this image.", result.map_overlay
112
+
113
+ # Threshold for altered vs not altered (adjust as needed)
114
+ threshold = 0.5
115
+ is_altered = result.score >= threshold
116
+ prediction = "Altered" if is_altered else "Not Altered"
117
+ confidence = result.score if is_altered else (1.0 - result.score)
118
+
119
+ summary = f"**Prediction:** {prediction}\n**Confidence:** {confidence:.2%}"
120
 
121
+ return summary, result.map_overlay
122
 
123
 
124
+ def analyze_image(image: Image.Image) -> Tuple[Dict[str, float], str, str, Optional[Image.Image]]:
125
  ai_scores, ai_summary = analyze_ai_vs_human(image)
126
+ trufor_summary, tamper_overlay = analyze_trufor(image)
127
+ return ai_scores, ai_summary, trufor_summary, tamper_overlay
128
 
129
 
130
  with gr.Blocks() as demo:
 
132
  """# Image Authenticity Workbench\nUpload an image to compare the AI-vs-human classifier with the TruFor forgery detector."""
133
  )
134
 
135
+ status_box = gr.Markdown(f"`TruFor: {TRUFOR_STATUS}`\n`AI Detector: {MODEL_STATUS}`")
136
 
137
  image_input = gr.Image(label="Input Image", type="pil")
138
  analyze_button = gr.Button("Analyze", variant="primary", size="sm")
 
140
  with gr.Tabs():
141
  with gr.TabItem("AI vs Human"):
142
  ai_label_output = gr.Label(label="Prediction", num_top_classes=2)
143
+ ai_summary_output = gr.Markdown(AI_INTRO_SUMMARY)
144
  with gr.TabItem("TruFor Forgery Detection"):
145
  trufor_summary_output = gr.Markdown("Configure TruFor assets to enable tamper analysis.")
146
+ tamper_overlay_output = gr.Image(label="Altered Regions Map", type="pil", interactive=False)
 
147
 
148
  output_components = [
149
  ai_label_output,
150
  ai_summary_output,
151
  trufor_summary_output,
152
  tamper_overlay_output,
 
153
  ]
154
 
155
  analyze_button.click(
trufor_runner.py CHANGED
@@ -23,8 +23,6 @@ class TruForUnavailableError(RuntimeError):
23
  class TruForResult:
24
  score: Optional[float]
25
  map_overlay: Optional[Image.Image]
26
- confidence_overlay: Optional[Image.Image]
27
- raw_scores: Dict[str, float]
28
 
29
 
30
  class TruForEngine:
@@ -151,35 +149,15 @@ class TruForEngine:
151
  def _infer_native(self, image: Image.Image) -> TruForResult:
152
  outputs = self.native_model.predict(image)
153
 
154
- overlays: Dict[str, Optional[Image.Image]] = {"map": None, "conf": None}
155
  try:
156
- overlays["map"] = self._apply_heatmap(image, outputs.tamper_map)
157
  except Exception as exc: # pragma: no cover - visualisation fallback
158
  LOGGER.debug("Failed to build tamper heatmap: %s", exc)
159
 
160
- if outputs.confidence_map is not None:
161
- try:
162
- overlays["conf"] = self._apply_heatmap(image, outputs.confidence_map)
163
- except Exception as exc: # pragma: no cover
164
- LOGGER.debug("Failed to build confidence heatmap: %s", exc)
165
-
166
- raw_scores: Dict[str, float] = {
167
- "tamper_mean": float(np.mean(outputs.tamper_map)),
168
- "tamper_max": float(np.max(outputs.tamper_map)),
169
- }
170
-
171
- if outputs.confidence_map is not None:
172
- raw_scores["confidence_mean"] = float(np.mean(outputs.confidence_map))
173
- raw_scores["confidence_max"] = float(np.max(outputs.confidence_map))
174
-
175
- if outputs.detection_score is not None:
176
- raw_scores["tamper_score"] = float(outputs.detection_score)
177
-
178
  return TruForResult(
179
  score=outputs.detection_score,
180
- map_overlay=overlays["map"],
181
- confidence_overlay=overlays["conf"],
182
- raw_scores=raw_scores,
183
  )
184
 
185
  def _infer_docker(self, image: Image.Image) -> TruForResult:
@@ -257,35 +235,17 @@ class TruForEngine:
257
 
258
  data = np.load(npz_files[0], allow_pickle=False)
259
  tamper_map = data.get("map")
260
- conf_map = data.get("conf")
261
  score = float(data["score"]) if "score" in data.files else None
262
 
263
- overlays: Dict[str, Optional[Image.Image]] = {"map": None, "conf": None}
264
  try:
265
- overlays["map"] = self._apply_heatmap(image, tamper_map) if tamper_map is not None else None
266
  except Exception as exc: # pragma: no cover
267
  LOGGER.debug("Failed to build tamper heatmap: %s", exc)
268
 
269
- try:
270
- overlays["conf"] = self._apply_heatmap(image, conf_map) if conf_map is not None else None
271
- except Exception as exc: # pragma: no cover
272
- LOGGER.debug("Failed to build confidence heatmap: %s", exc)
273
-
274
- raw_scores: Dict[str, float] = {}
275
- if score is not None:
276
- raw_scores["tamper_score"] = score
277
- if tamper_map is not None:
278
- raw_scores["tamper_mean"] = float(np.mean(tamper_map))
279
- raw_scores["tamper_max"] = float(np.max(tamper_map))
280
- if conf_map is not None:
281
- raw_scores["confidence_mean"] = float(np.mean(conf_map))
282
- raw_scores["confidence_max"] = float(np.max(conf_map))
283
-
284
  return TruForResult(
285
  score=score,
286
- map_overlay=overlays["map"],
287
- confidence_overlay=overlays["conf"],
288
- raw_scores=raw_scores,
289
  )
290
 
291
  @staticmethod
 
23
  class TruForResult:
24
  score: Optional[float]
25
  map_overlay: Optional[Image.Image]
 
 
26
 
27
 
28
  class TruForEngine:
 
149
  def _infer_native(self, image: Image.Image) -> TruForResult:
150
  outputs = self.native_model.predict(image)
151
 
152
+ map_overlay = None
153
  try:
154
+ map_overlay = self._apply_heatmap(image, outputs.tamper_map)
155
  except Exception as exc: # pragma: no cover - visualisation fallback
156
  LOGGER.debug("Failed to build tamper heatmap: %s", exc)
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  return TruForResult(
159
  score=outputs.detection_score,
160
+ map_overlay=map_overlay,
 
 
161
  )
162
 
163
  def _infer_docker(self, image: Image.Image) -> TruForResult:
 
235
 
236
  data = np.load(npz_files[0], allow_pickle=False)
237
  tamper_map = data.get("map")
 
238
  score = float(data["score"]) if "score" in data.files else None
239
 
240
+ map_overlay = None
241
  try:
242
+ map_overlay = self._apply_heatmap(image, tamper_map) if tamper_map is not None else None
243
  except Exception as exc: # pragma: no cover
244
  LOGGER.debug("Failed to build tamper heatmap: %s", exc)
245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  return TruForResult(
247
  score=score,
248
+ map_overlay=map_overlay,
 
 
249
  )
250
 
251
  @staticmethod