cdancette commited on
Commit
de6c079
·
1 Parent(s): fbe16e9

regression

Browse files
Files changed (2) hide show
  1. app.py +17 -11
  2. inference.py +6 -2
app.py CHANGED
@@ -77,6 +77,9 @@ HEADS_3D = {
77
  "kneeMRI",
78
  }
79
 
 
 
 
80
 
81
  DATASET_OPTIONS: Dict[str, str] = {
82
  "anatomy-ct": "Anatomy CT (test)",
@@ -380,7 +383,7 @@ def load_dataset_sample(
380
  ) -> Tuple[
381
  Optional[np.ndarray],
382
  str,
383
- pd.DataFrame,
384
  Dict[str, Any],
385
  Optional[Dict[str, Any]],
386
  ]:
@@ -388,12 +391,11 @@ def load_dataset_sample(
388
  subset = DEFAULT_DATASET_FOR_HEAD.get(head)
389
  if not subset:
390
  gr.Warning("No dataset found for this head.")
391
- return None, "", pd.DataFrame(), gr.update(visible=False), None
392
 
393
  try:
394
  target_id = parse_target_selection(target_selection)
395
  image, meta = sample_dataset_example(subset, target_id)
396
-
397
  # Apply windowing only for display, keep raw image for model inference
398
  windowed_image = apply_windowing(image, subset)
399
  display = to_display_image(windowed_image)
@@ -412,13 +414,13 @@ def load_dataset_sample(
412
  return (
413
  display,
414
  "", # Reset prediction text
415
- pd.DataFrame(),
416
  ground_truth_update,
417
  {"image": image, "mask": meta.get("mask")}, # Store raw image for inference
418
  )
419
  except Exception as exc: # pragma: no cover - surfaced in UI
420
  gr.Warning(f"Failed to load sample: {exc}")
421
- return None, "", pd.DataFrame(), gr.update(visible=False), None
422
 
423
 
424
  def format_probabilities(probs: torch.Tensor, id2label: Dict[int, str]) -> pd.DataFrame:
@@ -437,21 +439,25 @@ def format_probabilities(probs: torch.Tensor, id2label: Dict[int, str]) -> pd.Da
437
  def run_inference(
438
  image_state: Optional[Dict[str, Any]],
439
  head: str,
440
- ) -> Tuple[str, pd.DataFrame]:
441
  if not image_state or "image" not in image_state:
442
- return "Load a dataset sample or upload an image first.", pd.DataFrame()
443
 
444
  try:
445
  image = image_state["image"]
446
- probs = infer_image(image, head, image_state.get("mask"))
 
 
 
447
 
448
  # Use id_to_labels.json mapping, fall back to model config if not available
449
  id2label = load_id_to_labels().get(head, {})
450
- df = format_probabilities(probs, id2label)
 
451
  top_row = df.iloc[0]
452
  prediction = f"{top_row['label']} (p={top_row['probability']:.3f})"
453
  result_text = f"**Prediction:** {prediction}"
454
- return result_text, df
455
  except Exception as exc: # pragma: no cover - surfaced in UI
456
  traceback.print_exc()
457
  return f"Failed to run inference: {exc}", gr.update(visible=False)
@@ -553,7 +559,7 @@ def build_demo() -> gr.Blocks:
553
  ground_truth_display = gr.Markdown(visible=False)
554
  gr.Markdown("### Predictions")
555
  main_prediction = gr.Markdown()
556
- prediction_probs = gr.Dataframe(headers=["class_id", "label", "probability"])
557
 
558
  image_state = gr.State()
559
 
 
77
  "kneeMRI",
78
  }
79
 
80
+ REGRESSION_HEADS = {
81
+ "ixi",
82
+ }
83
 
84
  DATASET_OPTIONS: Dict[str, str] = {
85
  "anatomy-ct": "Anatomy CT (test)",
 
383
  ) -> Tuple[
384
  Optional[np.ndarray],
385
  str,
386
+ Dict[str, Any],
387
  Dict[str, Any],
388
  Optional[Dict[str, Any]],
389
  ]:
 
391
  subset = DEFAULT_DATASET_FOR_HEAD.get(head)
392
  if not subset:
393
  gr.Warning("No dataset found for this head.")
394
+ return None, "", gr.update(visible=False), gr.update(visible=False), None
395
 
396
  try:
397
  target_id = parse_target_selection(target_selection)
398
  image, meta = sample_dataset_example(subset, target_id)
 
399
  # Apply windowing only for display, keep raw image for model inference
400
  windowed_image = apply_windowing(image, subset)
401
  display = to_display_image(windowed_image)
 
414
  return (
415
  display,
416
  "", # Reset prediction text
417
+ gr.update(visible=False),
418
  ground_truth_update,
419
  {"image": image, "mask": meta.get("mask")}, # Store raw image for inference
420
  )
421
  except Exception as exc: # pragma: no cover - surfaced in UI
422
  gr.Warning(f"Failed to load sample: {exc}")
423
+ return None, "", gr.update(visible=False), gr.update(visible=False), None
424
 
425
 
426
  def format_probabilities(probs: torch.Tensor, id2label: Dict[int, str]) -> pd.DataFrame:
 
439
  def run_inference(
440
  image_state: Optional[Dict[str, Any]],
441
  head: str,
442
+ ) -> Tuple[str, Dict[str, Any]]:
443
  if not image_state or "image" not in image_state:
444
+ return "Load a dataset sample or upload an image first.", gr.update(visible=False)
445
 
446
  try:
447
  image = image_state["image"]
448
+ output = infer_image(image, head, image_state.get("mask"), return_probs=head not in REGRESSION_HEADS)
449
+
450
+ if head in REGRESSION_HEADS:
451
+ return f"**Prediction:** {output:.3f}", gr.update(visible=False)
452
 
453
  # Use id_to_labels.json mapping, fall back to model config if not available
454
  id2label = load_id_to_labels().get(head, {})
455
+
456
+ df = format_probabilities(output, id2label)
457
  top_row = df.iloc[0]
458
  prediction = f"{top_row['label']} (p={top_row['probability']:.3f})"
459
  result_text = f"**Prediction:** {prediction}"
460
+ return result_text, gr.update(visible=True, value=df)
461
  except Exception as exc: # pragma: no cover - surfaced in UI
462
  traceback.print_exc()
463
  return f"Failed to run inference: {exc}", gr.update(visible=False)
 
559
  ground_truth_display = gr.Markdown(visible=False)
560
  gr.Markdown("### Predictions")
561
  main_prediction = gr.Markdown()
562
+ prediction_probs = gr.Dataframe(headers=["class_id", "label", "probability"], visible=False)
563
 
564
  image_state = gr.State()
565
 
inference.py CHANGED
@@ -147,6 +147,7 @@ def infer_image(
147
  image: np.ndarray,
148
  head: str,
149
  mask: Any | None = None,
 
150
  ) -> torch.Tensor:
151
  processor = load_processor()
152
  model = load_model(head)
@@ -157,5 +158,8 @@ def infer_image(
157
  processed["mask"] = mask_tensor
158
  outputs = model(**processed)
159
  logits = outputs["logits"]
160
- probs = torch.nn.functional.softmax(logits[0], dim=-1)
161
- return probs
 
 
 
 
147
  image: np.ndarray,
148
  head: str,
149
  mask: Any | None = None,
150
+ return_probs: bool = True,
151
  ) -> torch.Tensor:
152
  processor = load_processor()
153
  model = load_model(head)
 
158
  processed["mask"] = mask_tensor
159
  outputs = model(**processed)
160
  logits = outputs["logits"]
161
+ if return_probs:
162
+ probs = torch.nn.functional.softmax(logits[0], dim=-1)
163
+ return probs
164
+ else:
165
+ return logits[0].squeeze()