regression
Browse files- app.py +17 -11
- inference.py +6 -2
app.py
CHANGED
|
@@ -77,6 +77,9 @@ HEADS_3D = {
|
|
| 77 |
"kneeMRI",
|
| 78 |
}
|
| 79 |
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
DATASET_OPTIONS: Dict[str, str] = {
|
| 82 |
"anatomy-ct": "Anatomy CT (test)",
|
|
@@ -380,7 +383,7 @@ def load_dataset_sample(
|
|
| 380 |
) -> Tuple[
|
| 381 |
Optional[np.ndarray],
|
| 382 |
str,
|
| 383 |
-
|
| 384 |
Dict[str, Any],
|
| 385 |
Optional[Dict[str, Any]],
|
| 386 |
]:
|
|
@@ -388,12 +391,11 @@ def load_dataset_sample(
|
|
| 388 |
subset = DEFAULT_DATASET_FOR_HEAD.get(head)
|
| 389 |
if not subset:
|
| 390 |
gr.Warning("No dataset found for this head.")
|
| 391 |
-
return None, "",
|
| 392 |
|
| 393 |
try:
|
| 394 |
target_id = parse_target_selection(target_selection)
|
| 395 |
image, meta = sample_dataset_example(subset, target_id)
|
| 396 |
-
|
| 397 |
# Apply windowing only for display, keep raw image for model inference
|
| 398 |
windowed_image = apply_windowing(image, subset)
|
| 399 |
display = to_display_image(windowed_image)
|
|
@@ -412,13 +414,13 @@ def load_dataset_sample(
|
|
| 412 |
return (
|
| 413 |
display,
|
| 414 |
"", # Reset prediction text
|
| 415 |
-
|
| 416 |
ground_truth_update,
|
| 417 |
{"image": image, "mask": meta.get("mask")}, # Store raw image for inference
|
| 418 |
)
|
| 419 |
except Exception as exc: # pragma: no cover - surfaced in UI
|
| 420 |
gr.Warning(f"Failed to load sample: {exc}")
|
| 421 |
-
return None, "",
|
| 422 |
|
| 423 |
|
| 424 |
def format_probabilities(probs: torch.Tensor, id2label: Dict[int, str]) -> pd.DataFrame:
|
|
@@ -437,21 +439,25 @@ def format_probabilities(probs: torch.Tensor, id2label: Dict[int, str]) -> pd.Da
|
|
| 437 |
def run_inference(
|
| 438 |
image_state: Optional[Dict[str, Any]],
|
| 439 |
head: str,
|
| 440 |
-
) -> Tuple[str,
|
| 441 |
if not image_state or "image" not in image_state:
|
| 442 |
-
return "Load a dataset sample or upload an image first.",
|
| 443 |
|
| 444 |
try:
|
| 445 |
image = image_state["image"]
|
| 446 |
-
|
|
|
|
|
|
|
|
|
|
| 447 |
|
| 448 |
# Use id_to_labels.json mapping, fall back to model config if not available
|
| 449 |
id2label = load_id_to_labels().get(head, {})
|
| 450 |
-
|
|
|
|
| 451 |
top_row = df.iloc[0]
|
| 452 |
prediction = f"{top_row['label']} (p={top_row['probability']:.3f})"
|
| 453 |
result_text = f"**Prediction:** {prediction}"
|
| 454 |
-
return result_text, df
|
| 455 |
except Exception as exc: # pragma: no cover - surfaced in UI
|
| 456 |
traceback.print_exc()
|
| 457 |
return f"Failed to run inference: {exc}", gr.update(visible=False)
|
|
@@ -553,7 +559,7 @@ def build_demo() -> gr.Blocks:
|
|
| 553 |
ground_truth_display = gr.Markdown(visible=False)
|
| 554 |
gr.Markdown("### Predictions")
|
| 555 |
main_prediction = gr.Markdown()
|
| 556 |
-
prediction_probs = gr.Dataframe(headers=["class_id", "label", "probability"])
|
| 557 |
|
| 558 |
image_state = gr.State()
|
| 559 |
|
|
|
|
| 77 |
"kneeMRI",
|
| 78 |
}
|
| 79 |
|
| 80 |
+
REGRESSION_HEADS = {
|
| 81 |
+
"ixi",
|
| 82 |
+
}
|
| 83 |
|
| 84 |
DATASET_OPTIONS: Dict[str, str] = {
|
| 85 |
"anatomy-ct": "Anatomy CT (test)",
|
|
|
|
| 383 |
) -> Tuple[
|
| 384 |
Optional[np.ndarray],
|
| 385 |
str,
|
| 386 |
+
Dict[str, Any],
|
| 387 |
Dict[str, Any],
|
| 388 |
Optional[Dict[str, Any]],
|
| 389 |
]:
|
|
|
|
| 391 |
subset = DEFAULT_DATASET_FOR_HEAD.get(head)
|
| 392 |
if not subset:
|
| 393 |
gr.Warning("No dataset found for this head.")
|
| 394 |
+
return None, "", gr.update(visible=False), gr.update(visible=False), None
|
| 395 |
|
| 396 |
try:
|
| 397 |
target_id = parse_target_selection(target_selection)
|
| 398 |
image, meta = sample_dataset_example(subset, target_id)
|
|
|
|
| 399 |
# Apply windowing only for display, keep raw image for model inference
|
| 400 |
windowed_image = apply_windowing(image, subset)
|
| 401 |
display = to_display_image(windowed_image)
|
|
|
|
| 414 |
return (
|
| 415 |
display,
|
| 416 |
"", # Reset prediction text
|
| 417 |
+
gr.update(visible=False),
|
| 418 |
ground_truth_update,
|
| 419 |
{"image": image, "mask": meta.get("mask")}, # Store raw image for inference
|
| 420 |
)
|
| 421 |
except Exception as exc: # pragma: no cover - surfaced in UI
|
| 422 |
gr.Warning(f"Failed to load sample: {exc}")
|
| 423 |
+
return None, "", gr.update(visible=False), gr.update(visible=False), None
|
| 424 |
|
| 425 |
|
| 426 |
def format_probabilities(probs: torch.Tensor, id2label: Dict[int, str]) -> pd.DataFrame:
|
|
|
|
| 439 |
def run_inference(
|
| 440 |
image_state: Optional[Dict[str, Any]],
|
| 441 |
head: str,
|
| 442 |
+
) -> Tuple[str, Dict[str, Any]]:
|
| 443 |
if not image_state or "image" not in image_state:
|
| 444 |
+
return "Load a dataset sample or upload an image first.", gr.update(visible=False)
|
| 445 |
|
| 446 |
try:
|
| 447 |
image = image_state["image"]
|
| 448 |
+
output = infer_image(image, head, image_state.get("mask"), return_probs=head not in REGRESSION_HEADS)
|
| 449 |
+
|
| 450 |
+
if head in REGRESSION_HEADS:
|
| 451 |
+
return f"**Prediction:** {output:.3f}", gr.update(visible=False)
|
| 452 |
|
| 453 |
# Use id_to_labels.json mapping, fall back to model config if not available
|
| 454 |
id2label = load_id_to_labels().get(head, {})
|
| 455 |
+
|
| 456 |
+
df = format_probabilities(output, id2label)
|
| 457 |
top_row = df.iloc[0]
|
| 458 |
prediction = f"{top_row['label']} (p={top_row['probability']:.3f})"
|
| 459 |
result_text = f"**Prediction:** {prediction}"
|
| 460 |
+
return result_text, gr.update(visible=True, value=df)
|
| 461 |
except Exception as exc: # pragma: no cover - surfaced in UI
|
| 462 |
traceback.print_exc()
|
| 463 |
return f"Failed to run inference: {exc}", gr.update(visible=False)
|
|
|
|
| 559 |
ground_truth_display = gr.Markdown(visible=False)
|
| 560 |
gr.Markdown("### Predictions")
|
| 561 |
main_prediction = gr.Markdown()
|
| 562 |
+
prediction_probs = gr.Dataframe(headers=["class_id", "label", "probability"], visible=False)
|
| 563 |
|
| 564 |
image_state = gr.State()
|
| 565 |
|
inference.py
CHANGED
|
@@ -147,6 +147,7 @@ def infer_image(
|
|
| 147 |
image: np.ndarray,
|
| 148 |
head: str,
|
| 149 |
mask: Any | None = None,
|
|
|
|
| 150 |
) -> torch.Tensor:
|
| 151 |
processor = load_processor()
|
| 152 |
model = load_model(head)
|
|
@@ -157,5 +158,8 @@ def infer_image(
|
|
| 157 |
processed["mask"] = mask_tensor
|
| 158 |
outputs = model(**processed)
|
| 159 |
logits = outputs["logits"]
|
| 160 |
-
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
image: np.ndarray,
|
| 148 |
head: str,
|
| 149 |
mask: Any | None = None,
|
| 150 |
+
return_probs: bool = True,
|
| 151 |
) -> torch.Tensor:
|
| 152 |
processor = load_processor()
|
| 153 |
model = load_model(head)
|
|
|
|
| 158 |
processed["mask"] = mask_tensor
|
| 159 |
outputs = model(**processed)
|
| 160 |
logits = outputs["logits"]
|
| 161 |
+
if return_probs:
|
| 162 |
+
probs = torch.nn.functional.softmax(logits[0], dim=-1)
|
| 163 |
+
return probs
|
| 164 |
+
else:
|
| 165 |
+
return logits[0].squeeze()
|