update space
Browse files- app.py +57 -148
- id_to_labels.json +173 -0
- inference.py +99 -0
- requirements.txt +1 -1
app.py
CHANGED
|
@@ -14,26 +14,24 @@ deploying to Hugging Face Spaces).
|
|
| 14 |
|
| 15 |
from __future__ import annotations
|
| 16 |
|
| 17 |
-
import os
|
| 18 |
import random
|
| 19 |
-
from
|
| 20 |
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 21 |
|
| 22 |
import cv2
|
| 23 |
import gradio as gr
|
| 24 |
import numpy as np
|
| 25 |
import pandas as pd
|
| 26 |
import torch
|
| 27 |
-
from datasets import Dataset
|
| 28 |
from PIL import Image
|
| 29 |
-
from transformers import (
|
| 30 |
-
AutoImageProcessor,
|
| 31 |
-
AutoModelForImageClassification,
|
| 32 |
-
)
|
| 33 |
-
|
| 34 |
|
| 35 |
-
|
| 36 |
-
HF_DATASET_ID
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
# ---------------------------------------------------------------------------
|
|
@@ -41,6 +39,7 @@ HF_DATASET_ID = "raidium/CuriaBench"
|
|
| 41 |
# ---------------------------------------------------------------------------
|
| 42 |
|
| 43 |
HEAD_OPTIONS: List[Tuple[str, str]] = [
|
|
|
|
| 44 |
("anatomy-ct", "Anatomy CT"),
|
| 45 |
("anatomy-mri", "Anatomy MRI"),
|
| 46 |
("atlas-stroke", "Atlas Stroke"),
|
|
@@ -53,6 +52,8 @@ HEAD_OPTIONS: List[Tuple[str, str]] = [
|
|
| 53 |
("kneeMRI", "Knee MRI"),
|
| 54 |
("luna16-3D", "LUNA16 3D"),
|
| 55 |
("neural_foraminal_narrowing", "Neural Foraminal Narrowing"),
|
|
|
|
|
|
|
| 56 |
("oasis", "OASIS"),
|
| 57 |
]
|
| 58 |
|
|
@@ -65,6 +66,7 @@ HEADS_REQUIRING_MASK: set[str] = {
|
|
| 65 |
"kits",
|
| 66 |
"kneeMRI",
|
| 67 |
"luna16-3D",
|
|
|
|
| 68 |
"spinal_canal_stenosis",
|
| 69 |
"subarticular_stenosis",
|
| 70 |
}
|
|
@@ -115,60 +117,6 @@ DEFAULT_WINDOWINGS: Dict[str, Optional[Dict[str, int]]] = {
|
|
| 115 |
# ---------------------------------------------------------------------------
|
| 116 |
|
| 117 |
|
| 118 |
-
@lru_cache(maxsize=1)
|
| 119 |
-
def load_processor() -> AutoImageProcessor:
|
| 120 |
-
token = os.environ.get("HF_TOKEN")
|
| 121 |
-
return AutoImageProcessor.from_pretrained(HF_REPO_ID, trust_remote_code=True, token=token)
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
@lru_cache(maxsize=len(HEAD_OPTIONS))
|
| 125 |
-
def load_model(head: str) -> AutoModelForImageClassification:
|
| 126 |
-
token = os.environ.get("HF_TOKEN")
|
| 127 |
-
model = AutoModelForImageClassification.from_pretrained(
|
| 128 |
-
HF_REPO_ID,
|
| 129 |
-
trust_remote_code=True,
|
| 130 |
-
subfolder=head,
|
| 131 |
-
token=token,
|
| 132 |
-
)
|
| 133 |
-
model.eval()
|
| 134 |
-
return model
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
@lru_cache(maxsize=len(DATASET_OPTIONS))
|
| 138 |
-
def load_curia_dataset(subset: str) -> Any:
|
| 139 |
-
token = os.environ.get("HF_TOKEN")
|
| 140 |
-
ds = load_dataset(
|
| 141 |
-
HF_DATASET_ID,
|
| 142 |
-
subset,
|
| 143 |
-
split="test",
|
| 144 |
-
token=token,
|
| 145 |
-
)
|
| 146 |
-
if isinstance(ds, DatasetDict):
|
| 147 |
-
return ds["test"]
|
| 148 |
-
return ds
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
def to_numpy_image(image: Any) -> np.ndarray:
|
| 152 |
-
"""Convert dataset or user-provided imagery to a float32 numpy array."""
|
| 153 |
-
|
| 154 |
-
if isinstance(image, np.ndarray):
|
| 155 |
-
arr = image
|
| 156 |
-
elif isinstance(image, Image.Image):
|
| 157 |
-
arr = np.array(image)
|
| 158 |
-
else:
|
| 159 |
-
# Some datasets provide nested dicts or lists – attempt to coerce.
|
| 160 |
-
arr = np.array(image)
|
| 161 |
-
|
| 162 |
-
if arr.ndim == 3 and arr.shape[-1] == 3:
|
| 163 |
-
# Convert RGB to grayscale by averaging channels
|
| 164 |
-
arr = arr.mean(axis=-1)
|
| 165 |
-
|
| 166 |
-
if arr.ndim != 2:
|
| 167 |
-
raise ValueError("Expected a 2D image (H, W). Please provide a single axial/coronal/sagittal slice.")
|
| 168 |
-
|
| 169 |
-
return arr.astype(np.float32)
|
| 170 |
-
|
| 171 |
-
|
| 172 |
def apply_windowing(image: np.ndarray, subset: str) -> np.ndarray:
|
| 173 |
"""Apply CT windowing based on the dataset.
|
| 174 |
|
|
@@ -223,26 +171,8 @@ def to_display_image(image: np.ndarray) -> np.ndarray:
|
|
| 223 |
return arr
|
| 224 |
|
| 225 |
|
| 226 |
-
def
|
| 227 |
-
|
| 228 |
-
return None
|
| 229 |
-
|
| 230 |
-
try:
|
| 231 |
-
arr = np.array(mask)
|
| 232 |
-
except Exception:
|
| 233 |
-
return None
|
| 234 |
-
|
| 235 |
-
if arr.size == 0:
|
| 236 |
-
return None
|
| 237 |
-
return arr
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
def prepare_mask_tensor(mask: Any, height: int, width: int) -> Optional[torch.Tensor]:
|
| 241 |
-
mask_array = coerce_mask_array(mask)
|
| 242 |
-
if mask_array is None:
|
| 243 |
-
return None
|
| 244 |
-
|
| 245 |
-
arr = np.squeeze(mask_array)
|
| 246 |
if arr.ndim == 2:
|
| 247 |
arr = arr.reshape(1, height, width)
|
| 248 |
else:
|
|
@@ -301,16 +231,17 @@ def apply_contour_overlay(
|
|
| 301 |
return output
|
| 302 |
|
| 303 |
|
| 304 |
-
def render_image_with_mask_info(image: np.ndarray, mask: Any) ->
|
| 305 |
display = to_display_image(image)
|
| 306 |
if mask is None:
|
| 307 |
-
return display
|
| 308 |
|
| 309 |
try:
|
| 310 |
overlaid = apply_contour_overlay(display, mask)
|
| 311 |
-
return overlaid
|
| 312 |
except Exception:
|
| 313 |
-
|
|
|
|
| 314 |
|
| 315 |
|
| 316 |
def dataset_class_metadata(dataset: Dataset) -> Tuple[List[int], Dict[int, str]]:
|
|
@@ -341,39 +272,6 @@ def pick_random_indices(dataset: Dataset, target: Optional[int]) -> int:
|
|
| 341 |
return random.choice(indices)
|
| 342 |
|
| 343 |
|
| 344 |
-
def format_probabilities(probs: torch.Tensor, id2label: Dict[int, str]) -> pd.DataFrame:
|
| 345 |
-
"""Return a dataframe sorted by probability desc."""
|
| 346 |
-
|
| 347 |
-
values = probs.detach().cpu().numpy()
|
| 348 |
-
rows = [
|
| 349 |
-
{"class_id": idx, "label": id2label.get(idx, str(idx)), "probability": float(val)}
|
| 350 |
-
for idx, val in enumerate(values)
|
| 351 |
-
]
|
| 352 |
-
df = pd.DataFrame(rows)
|
| 353 |
-
df.sort_values("probability", ascending=False, inplace=True)
|
| 354 |
-
return df
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
def infer_image(
|
| 358 |
-
image: np.ndarray,
|
| 359 |
-
head: str,
|
| 360 |
-
) -> Tuple[str, pd.DataFrame]:
|
| 361 |
-
processor = load_processor()
|
| 362 |
-
model = load_model(head)
|
| 363 |
-
with torch.no_grad():
|
| 364 |
-
processed = processor(images=image, return_tensors="pt")
|
| 365 |
-
outputs = model(**processed)
|
| 366 |
-
print(outputs)
|
| 367 |
-
logits = outputs["logits"]
|
| 368 |
-
probs = torch.nn.functional.softmax(logits[0], dim=-1)
|
| 369 |
-
|
| 370 |
-
id2label = model.config.id2label or {}
|
| 371 |
-
df = format_probabilities(probs, id2label)
|
| 372 |
-
top_row = df.iloc[0]
|
| 373 |
-
prediction = f"{top_row['label']} (p={top_row['probability']:.3f})"
|
| 374 |
-
return prediction, df
|
| 375 |
-
|
| 376 |
-
|
| 377 |
# ---------------------------------------------------------------------------
|
| 378 |
# Gradio callbacks
|
| 379 |
# ---------------------------------------------------------------------------
|
|
@@ -439,12 +337,12 @@ def parse_target_selection(selection: str) -> Optional[int]:
|
|
| 439 |
def sample_dataset_example(
|
| 440 |
subset: str,
|
| 441 |
target_id: Optional[int],
|
| 442 |
-
) -> Tuple[np.ndarray,
|
| 443 |
dataset = load_curia_dataset(subset)
|
| 444 |
index = pick_random_indices(dataset, target_id)
|
| 445 |
record = dataset[index]
|
| 446 |
image = to_numpy_image(record["image"])
|
| 447 |
-
mask_array =
|
| 448 |
|
| 449 |
meta = {
|
| 450 |
"index": index,
|
|
@@ -452,7 +350,7 @@ def sample_dataset_example(
|
|
| 452 |
"mask": mask_array,
|
| 453 |
}
|
| 454 |
|
| 455 |
-
return image,
|
| 456 |
|
| 457 |
|
| 458 |
def load_dataset_sample(
|
|
@@ -461,57 +359,69 @@ def load_dataset_sample(
|
|
| 461 |
head: str,
|
| 462 |
) -> Tuple[
|
| 463 |
Optional[np.ndarray],
|
| 464 |
-
str,
|
| 465 |
pd.DataFrame,
|
| 466 |
Dict[str, Any],
|
| 467 |
Optional[Dict[str, Any]],
|
| 468 |
]:
|
| 469 |
try:
|
| 470 |
target_id = parse_target_selection(target_selection)
|
| 471 |
-
image,
|
| 472 |
|
| 473 |
# Apply windowing only for display, keep raw image for model inference
|
| 474 |
windowed_image = apply_windowing(image, subset)
|
| 475 |
-
display
|
|
|
|
|
|
|
| 476 |
|
| 477 |
target = meta.get("target")
|
| 478 |
-
meta_text = caption
|
| 479 |
-
if target is not None:
|
| 480 |
-
meta_text += f" | target={target}"
|
| 481 |
-
status = "Image loaded. Click 'Run inference' to compute predictions."
|
| 482 |
-
if mask_msg:
|
| 483 |
-
status += f" {mask_msg}"
|
| 484 |
-
meta_text = status + "\n\n" + meta_text
|
| 485 |
-
|
| 486 |
# Generate ground truth display
|
| 487 |
ground_truth_update = gr.update(visible=False)
|
| 488 |
if target is not None:
|
| 489 |
-
|
| 490 |
-
id2label =
|
| 491 |
label_name = id2label.get(target, str(target))
|
| 492 |
ground_truth_update = gr.update(value=f"**Ground Truth:** {label_name} (class {target})", visible=True)
|
| 493 |
|
| 494 |
return (
|
| 495 |
display,
|
| 496 |
-
meta_text,
|
| 497 |
pd.DataFrame(),
|
| 498 |
ground_truth_update,
|
| 499 |
{"image": image, "mask": meta.get("mask")}, # Store raw image for inference
|
| 500 |
)
|
| 501 |
except Exception as exc: # pragma: no cover - surfaced in UI
|
| 502 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
|
| 504 |
|
| 505 |
def run_inference(
|
| 506 |
-
|
| 507 |
head: str,
|
| 508 |
) -> Tuple[str, pd.DataFrame]:
|
| 509 |
-
if not
|
| 510 |
return "Load a dataset sample or upload an image first.", pd.DataFrame()
|
| 511 |
|
| 512 |
try:
|
| 513 |
-
image =
|
| 514 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
result_text = f"**Prediction:** {prediction}"
|
| 516 |
return result_text, df
|
| 517 |
except Exception as exc: # pragma: no cover - surfaced in UI
|
|
@@ -605,7 +515,7 @@ def build_demo() -> gr.Blocks:
|
|
| 605 |
gr.Markdown("---")
|
| 606 |
|
| 607 |
infer_btn = gr.Button("Run inference", variant="primary")
|
| 608 |
-
|
| 609 |
with gr.Row():
|
| 610 |
with gr.Column():
|
| 611 |
image_display = gr.Image(label="Image", interactive=False, type="numpy")
|
|
@@ -613,7 +523,7 @@ def build_demo() -> gr.Blocks:
|
|
| 613 |
|
| 614 |
with gr.Column():
|
| 615 |
gr.Markdown("### Predictions")
|
| 616 |
-
|
| 617 |
prediction_probs = gr.Dataframe(headers=["class_id", "label", "probability"])
|
| 618 |
|
| 619 |
image_state = gr.State()
|
|
@@ -640,7 +550,6 @@ def build_demo() -> gr.Blocks:
|
|
| 640 |
inputs=[dataset_dropdown, class_dropdown, head_dropdown],
|
| 641 |
outputs=[
|
| 642 |
image_display,
|
| 643 |
-
status_text,
|
| 644 |
prediction_probs,
|
| 645 |
ground_truth_display,
|
| 646 |
image_state,
|
|
@@ -662,7 +571,7 @@ def build_demo() -> gr.Blocks:
|
|
| 662 |
infer_btn.click(
|
| 663 |
fn=run_inference,
|
| 664 |
inputs=[image_state, head_dropdown],
|
| 665 |
-
outputs=[
|
| 666 |
)
|
| 667 |
|
| 668 |
gr.Markdown(
|
|
|
|
| 14 |
|
| 15 |
from __future__ import annotations
|
| 16 |
|
|
|
|
| 17 |
import random
|
| 18 |
+
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
| 19 |
|
| 20 |
import cv2
|
| 21 |
import gradio as gr
|
| 22 |
import numpy as np
|
| 23 |
import pandas as pd
|
| 24 |
import torch
|
| 25 |
+
from datasets import Dataset
|
| 26 |
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
+
from inference import (
|
| 29 |
+
HF_DATASET_ID,
|
| 30 |
+
load_curia_dataset,
|
| 31 |
+
load_id_to_labels,
|
| 32 |
+
to_numpy_image,
|
| 33 |
+
infer_image,
|
| 34 |
+
)
|
| 35 |
|
| 36 |
|
| 37 |
# ---------------------------------------------------------------------------
|
|
|
|
| 39 |
# ---------------------------------------------------------------------------
|
| 40 |
|
| 41 |
HEAD_OPTIONS: List[Tuple[str, str]] = [
|
| 42 |
+
("abdominal-trauma", "Active Extravasation"),
|
| 43 |
("anatomy-ct", "Anatomy CT"),
|
| 44 |
("anatomy-mri", "Anatomy MRI"),
|
| 45 |
("atlas-stroke", "Atlas Stroke"),
|
|
|
|
| 52 |
("kneeMRI", "Knee MRI"),
|
| 53 |
("luna16-3D", "LUNA16 3D"),
|
| 54 |
("neural_foraminal_narrowing", "Neural Foraminal Narrowing"),
|
| 55 |
+
("spinal_canal_stenosis", "Spinal Canal Stenosis"),
|
| 56 |
+
("subarticular_stenosis", "Subarticular Stenosis"),
|
| 57 |
("oasis", "OASIS"),
|
| 58 |
]
|
| 59 |
|
|
|
|
| 66 |
"kits",
|
| 67 |
"kneeMRI",
|
| 68 |
"luna16-3D",
|
| 69 |
+
"neural_foraminal_narrowing",
|
| 70 |
"spinal_canal_stenosis",
|
| 71 |
"subarticular_stenosis",
|
| 72 |
}
|
|
|
|
| 117 |
# ---------------------------------------------------------------------------
|
| 118 |
|
| 119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
def apply_windowing(image: np.ndarray, subset: str) -> np.ndarray:
|
| 121 |
"""Apply CT windowing based on the dataset.
|
| 122 |
|
|
|
|
| 171 |
return arr
|
| 172 |
|
| 173 |
|
| 174 |
+
def prepare_mask_tensor(mask: np.ndarray, height: int, width: int) -> Optional[torch.Tensor]:
|
| 175 |
+
arr = np.squeeze(mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
if arr.ndim == 2:
|
| 177 |
arr = arr.reshape(1, height, width)
|
| 178 |
else:
|
|
|
|
| 231 |
return output
|
| 232 |
|
| 233 |
|
| 234 |
+
def render_image_with_mask_info(image: np.ndarray, mask: Any) -> np.ndarray:
|
| 235 |
display = to_display_image(image)
|
| 236 |
if mask is None:
|
| 237 |
+
return display
|
| 238 |
|
| 239 |
try:
|
| 240 |
overlaid = apply_contour_overlay(display, mask)
|
| 241 |
+
return overlaid
|
| 242 |
except Exception:
|
| 243 |
+
gr.Warning("Mask provided but could not be visualised.")
|
| 244 |
+
return display
|
| 245 |
|
| 246 |
|
| 247 |
def dataset_class_metadata(dataset: Dataset) -> Tuple[List[int], Dict[int, str]]:
|
|
|
|
| 272 |
return random.choice(indices)
|
| 273 |
|
| 274 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
# ---------------------------------------------------------------------------
|
| 276 |
# Gradio callbacks
|
| 277 |
# ---------------------------------------------------------------------------
|
|
|
|
| 337 |
def sample_dataset_example(
|
| 338 |
subset: str,
|
| 339 |
target_id: Optional[int],
|
| 340 |
+
) -> Tuple[np.ndarray, Dict[str, Any]]:
|
| 341 |
dataset = load_curia_dataset(subset)
|
| 342 |
index = pick_random_indices(dataset, target_id)
|
| 343 |
record = dataset[index]
|
| 344 |
image = to_numpy_image(record["image"])
|
| 345 |
+
mask_array = record.get("mask")
|
| 346 |
|
| 347 |
meta = {
|
| 348 |
"index": index,
|
|
|
|
| 350 |
"mask": mask_array,
|
| 351 |
}
|
| 352 |
|
| 353 |
+
return image, meta
|
| 354 |
|
| 355 |
|
| 356 |
def load_dataset_sample(
|
|
|
|
| 359 |
head: str,
|
| 360 |
) -> Tuple[
|
| 361 |
Optional[np.ndarray],
|
|
|
|
| 362 |
pd.DataFrame,
|
| 363 |
Dict[str, Any],
|
| 364 |
Optional[Dict[str, Any]],
|
| 365 |
]:
|
| 366 |
try:
|
| 367 |
target_id = parse_target_selection(target_selection)
|
| 368 |
+
image, meta = sample_dataset_example(subset, target_id)
|
| 369 |
|
| 370 |
# Apply windowing only for display, keep raw image for model inference
|
| 371 |
windowed_image = apply_windowing(image, subset)
|
| 372 |
+
display = to_display_image(windowed_image)
|
| 373 |
+
if meta.get("mask") is not None:
|
| 374 |
+
display = apply_contour_overlay(display, meta.get("mask"))
|
| 375 |
|
| 376 |
target = meta.get("target")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
# Generate ground truth display
|
| 378 |
ground_truth_update = gr.update(visible=False)
|
| 379 |
if target is not None:
|
| 380 |
+
# Use id_to_labels.json mapping
|
| 381 |
+
id2label = load_id_to_labels().get(head, {})
|
| 382 |
label_name = id2label.get(target, str(target))
|
| 383 |
ground_truth_update = gr.update(value=f"**Ground Truth:** {label_name} (class {target})", visible=True)
|
| 384 |
|
| 385 |
return (
|
| 386 |
display,
|
|
|
|
| 387 |
pd.DataFrame(),
|
| 388 |
ground_truth_update,
|
| 389 |
{"image": image, "mask": meta.get("mask")}, # Store raw image for inference
|
| 390 |
)
|
| 391 |
except Exception as exc: # pragma: no cover - surfaced in UI
|
| 392 |
+
gr.Warning(f"Failed to load sample: {exc}")
|
| 393 |
+
return None, pd.DataFrame(), gr.update(visible=False), None
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def format_probabilities(probs: torch.Tensor, id2label: Dict[int, str]) -> pd.DataFrame:
|
| 397 |
+
"""Return a dataframe sorted by probability desc."""
|
| 398 |
+
|
| 399 |
+
values = probs.detach().cpu().numpy()
|
| 400 |
+
rows = [
|
| 401 |
+
{"class_id": idx, "label": id2label.get(idx, str(idx)), "probability": float(val)}
|
| 402 |
+
for idx, val in enumerate(values)
|
| 403 |
+
]
|
| 404 |
+
df = pd.DataFrame(rows)
|
| 405 |
+
df.sort_values("probability", ascending=False, inplace=True)
|
| 406 |
+
return df
|
| 407 |
|
| 408 |
|
| 409 |
def run_inference(
|
| 410 |
+
image_state: Optional[Dict[str, Any]],
|
| 411 |
head: str,
|
| 412 |
) -> Tuple[str, pd.DataFrame]:
|
| 413 |
+
if not image_state or "image" not in image_state:
|
| 414 |
return "Load a dataset sample or upload an image first.", pd.DataFrame()
|
| 415 |
|
| 416 |
try:
|
| 417 |
+
image = image_state["image"]
|
| 418 |
+
probs = infer_image(image, head)
|
| 419 |
+
|
| 420 |
+
# Use id_to_labels.json mapping, fall back to model config if not available
|
| 421 |
+
id2label = load_id_to_labels().get(head, {})
|
| 422 |
+
df = format_probabilities(probs, id2label)
|
| 423 |
+
top_row = df.iloc[0]
|
| 424 |
+
prediction = f"{top_row['label']} (p={top_row['probability']:.3f})"
|
| 425 |
result_text = f"**Prediction:** {prediction}"
|
| 426 |
return result_text, df
|
| 427 |
except Exception as exc: # pragma: no cover - surfaced in UI
|
|
|
|
| 515 |
gr.Markdown("---")
|
| 516 |
|
| 517 |
infer_btn = gr.Button("Run inference", variant="primary")
|
| 518 |
+
status_text = gr.Markdown()
|
| 519 |
with gr.Row():
|
| 520 |
with gr.Column():
|
| 521 |
image_display = gr.Image(label="Image", interactive=False, type="numpy")
|
|
|
|
| 523 |
|
| 524 |
with gr.Column():
|
| 525 |
gr.Markdown("### Predictions")
|
| 526 |
+
main_prediction = gr.Markdown()
|
| 527 |
prediction_probs = gr.Dataframe(headers=["class_id", "label", "probability"])
|
| 528 |
|
| 529 |
image_state = gr.State()
|
|
|
|
| 550 |
inputs=[dataset_dropdown, class_dropdown, head_dropdown],
|
| 551 |
outputs=[
|
| 552 |
image_display,
|
|
|
|
| 553 |
prediction_probs,
|
| 554 |
ground_truth_display,
|
| 555 |
image_state,
|
|
|
|
| 571 |
infer_btn.click(
|
| 572 |
fn=run_inference,
|
| 573 |
inputs=[image_state, head_dropdown],
|
| 574 |
+
outputs=[main_prediction, prediction_probs],
|
| 575 |
)
|
| 576 |
|
| 577 |
gr.Markdown(
|
id_to_labels.json
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"covidx-ct": {
|
| 3 |
+
"0": "normal",
|
| 4 |
+
"1": "pneumonia",
|
| 5 |
+
"2": "COVID-19"
|
| 6 |
+
},
|
| 7 |
+
"deep-lesion-site": {
|
| 8 |
+
"0": "abdomen",
|
| 9 |
+
"1": "bone",
|
| 10 |
+
"2": "kidney",
|
| 11 |
+
"3": "liver",
|
| 12 |
+
"4": "lung",
|
| 13 |
+
"5": "mediastinum",
|
| 14 |
+
"6": "pelvis",
|
| 15 |
+
"7": "soft_tissue"
|
| 16 |
+
},
|
| 17 |
+
"kits": {
|
| 18 |
+
"0": "lesion_kidney",
|
| 19 |
+
"1": "cyst"
|
| 20 |
+
},
|
| 21 |
+
"kneeMRI": {
|
| 22 |
+
"0": "healthy",
|
| 23 |
+
"1": "partially injured",
|
| 24 |
+
"2": "completely ruptured"
|
| 25 |
+
},
|
| 26 |
+
"luna16-3D": {
|
| 27 |
+
"0": "benign",
|
| 28 |
+
"1": "malignant"
|
| 29 |
+
},
|
| 30 |
+
"oasis": {
|
| 31 |
+
"0": "nondemented",
|
| 32 |
+
"1": "signs of dementia"
|
| 33 |
+
},
|
| 34 |
+
"abdominal-trauma": {
|
| 35 |
+
"0": "no active extravasation",
|
| 36 |
+
"1": "active extravasation"
|
| 37 |
+
},
|
| 38 |
+
"ich": {
|
| 39 |
+
"0": "no hemorrhage",
|
| 40 |
+
"1": "intracranial hemorrhage"
|
| 41 |
+
},
|
| 42 |
+
"neural_foraminal_narrowing": {
|
| 43 |
+
"0": "normal/mild",
|
| 44 |
+
"1": "moderate",
|
| 45 |
+
"2": "severe"
|
| 46 |
+
},
|
| 47 |
+
"spinal_canal_stenosis": {
|
| 48 |
+
"0": "normal/mild",
|
| 49 |
+
"1": "moderate",
|
| 50 |
+
"2": "severe"
|
| 51 |
+
},
|
| 52 |
+
"subarticular_stenosis": {
|
| 53 |
+
"0": "normal/mild",
|
| 54 |
+
"1": "moderate",
|
| 55 |
+
"2": "severe"
|
| 56 |
+
},
|
| 57 |
+
"anatomy-ct": {
|
| 58 |
+
"0": "adrenal_gland_left",
|
| 59 |
+
"1": "adrenal_gland_right",
|
| 60 |
+
"2": "aorta",
|
| 61 |
+
"3": "autochthon_left",
|
| 62 |
+
"4": "autochthon_right",
|
| 63 |
+
"5": "brain",
|
| 64 |
+
"6": "clavicula_left",
|
| 65 |
+
"7": "clavicula_right",
|
| 66 |
+
"8": "colon",
|
| 67 |
+
"9": "duodenum",
|
| 68 |
+
"10": "esophagus",
|
| 69 |
+
"11": "face",
|
| 70 |
+
"12": "femur_left",
|
| 71 |
+
"13": "femur_right",
|
| 72 |
+
"14": "gallbladder",
|
| 73 |
+
"15": "gluteus_maximus_left",
|
| 74 |
+
"16": "gluteus_maximus_right",
|
| 75 |
+
"17": "gluteus_medius_left",
|
| 76 |
+
"18": "gluteus_medius_right",
|
| 77 |
+
"19": "gluteus_minimus_left",
|
| 78 |
+
"20": "gluteus_minimus_right",
|
| 79 |
+
"21": "heart",
|
| 80 |
+
"22": "hip_left",
|
| 81 |
+
"23": "hip_right",
|
| 82 |
+
"24": "humerus_left",
|
| 83 |
+
"25": "humerus_right",
|
| 84 |
+
"26": "iliac_artery_left",
|
| 85 |
+
"27": "iliac_artery_right",
|
| 86 |
+
"28": "iliac_vena_left",
|
| 87 |
+
"29": "iliac_vena_right",
|
| 88 |
+
"30": "iliopsoas_left",
|
| 89 |
+
"31": "iliopsoas_right",
|
| 90 |
+
"32": "inferior_vena_cava",
|
| 91 |
+
"33": "kidney_left",
|
| 92 |
+
"34": "kidney_right",
|
| 93 |
+
"35": "liver",
|
| 94 |
+
"36": "lung_left",
|
| 95 |
+
"37": "lung_right",
|
| 96 |
+
"38": "pancreas",
|
| 97 |
+
"39": "portal_vein",
|
| 98 |
+
"40": "pulmonary_artery",
|
| 99 |
+
"41": "rib_left",
|
| 100 |
+
"42": "rib_right",
|
| 101 |
+
"43": "sacrum",
|
| 102 |
+
"44": "scapula_left",
|
| 103 |
+
"45": "scapula_right",
|
| 104 |
+
"46": "small_bowel",
|
| 105 |
+
"47": "spleen",
|
| 106 |
+
"48": "stomach",
|
| 107 |
+
"49": "trachea",
|
| 108 |
+
"50": "urinary_bladder",
|
| 109 |
+
"51": "vertebrae_cervical",
|
| 110 |
+
"52": "vertebrae_dorsal",
|
| 111 |
+
"53": "vertebrae_lumbar"
|
| 112 |
+
},
|
| 113 |
+
"anatomy-mri": {
|
| 114 |
+
"0": "adrenal_gland_left",
|
| 115 |
+
"1": "adrenal_gland_right",
|
| 116 |
+
"2": "aorta",
|
| 117 |
+
"3": "autochthon_left",
|
| 118 |
+
"4": "autochthon_right",
|
| 119 |
+
"5": "brain",
|
| 120 |
+
"6": "colon",
|
| 121 |
+
"7": "duodenum",
|
| 122 |
+
"8": "esophagus",
|
| 123 |
+
"9": "femur_left",
|
| 124 |
+
"10": "femur_right",
|
| 125 |
+
"11": "fibula",
|
| 126 |
+
"12": "gallbladder",
|
| 127 |
+
"13": "gluteus_maximus_left",
|
| 128 |
+
"14": "gluteus_maximus_right",
|
| 129 |
+
"15": "gluteus_medius_left",
|
| 130 |
+
"16": "gluteus_medius_right",
|
| 131 |
+
"17": "heart",
|
| 132 |
+
"18": "hip_left",
|
| 133 |
+
"19": "hip_right",
|
| 134 |
+
"20": "humerus_left",
|
| 135 |
+
"21": "humerus_right",
|
| 136 |
+
"22": "iliac_artery_left",
|
| 137 |
+
"23": "iliac_artery_right",
|
| 138 |
+
"24": "iliac_vena_left",
|
| 139 |
+
"25": "iliac_vena_right",
|
| 140 |
+
"26": "iliopsoas_left",
|
| 141 |
+
"27": "iliopsoas_right",
|
| 142 |
+
"28": "inferior_vena_cava",
|
| 143 |
+
"29": "intervertebral_discs",
|
| 144 |
+
"30": "kidney_left",
|
| 145 |
+
"31": "kidney_right",
|
| 146 |
+
"32": "liver",
|
| 147 |
+
"33": "lung_left",
|
| 148 |
+
"34": "lung_right",
|
| 149 |
+
"35": "pancreas",
|
| 150 |
+
"36": "portal_vein_and_splenic_vein",
|
| 151 |
+
"37": "prostate",
|
| 152 |
+
"38": "quadriceps_femoris_left",
|
| 153 |
+
"39": "quadriceps_femoris_right",
|
| 154 |
+
"40": "sacrum",
|
| 155 |
+
"41": "sartorius_left",
|
| 156 |
+
"42": "sartorius_right",
|
| 157 |
+
"43": "small_bowel",
|
| 158 |
+
"44": "spinal_cord",
|
| 159 |
+
"45": "spleen",
|
| 160 |
+
"46": "stomach",
|
| 161 |
+
"47": "thigh_medial_compartment_left",
|
| 162 |
+
"48": "thigh_medial_compartment_right",
|
| 163 |
+
"49": "thigh_posterior_compartment_left",
|
| 164 |
+
"50": "thigh_posterior_compartment_right",
|
| 165 |
+
"51": "tibia",
|
| 166 |
+
"52": "urinary_bladder",
|
| 167 |
+
"53": "vertebrae"
|
| 168 |
+
},
|
| 169 |
+
"emidec-classification-mask": {
|
| 170 |
+
"0": "healthy",
|
| 171 |
+
"1": "infarction"
|
| 172 |
+
}
|
| 173 |
+
}
|
inference.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model and dataset loading, inference, and label extraction functions."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
from functools import lru_cache
|
| 8 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import torch
|
| 13 |
+
from datasets import Dataset, DatasetDict, load_dataset
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from transformers import (
|
| 16 |
+
AutoImageProcessor,
|
| 17 |
+
AutoModelForImageClassification,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
HF_REPO_ID = "raidium/curia"
|
| 22 |
+
HF_DATASET_ID = "raidium/CuriaBench"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@lru_cache(maxsize=1)
|
| 26 |
+
def load_id_to_labels() -> Dict[str, Dict[str, str]]:
|
| 27 |
+
"""Load the id_to_labels.json mapping file."""
|
| 28 |
+
json_path = os.path.join(os.path.dirname(__file__), "id_to_labels.json")
|
| 29 |
+
with open(json_path, "r") as f:
|
| 30 |
+
data = json.load(f)
|
| 31 |
+
# convert string keys to integers
|
| 32 |
+
for head in data:
|
| 33 |
+
data[head] = {int(k): v for k, v in data[head].items()}
|
| 34 |
+
return data
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@lru_cache(maxsize=1)
|
| 38 |
+
def load_processor() -> AutoImageProcessor:
|
| 39 |
+
token = os.environ.get("HF_TOKEN")
|
| 40 |
+
return AutoImageProcessor.from_pretrained(HF_REPO_ID, trust_remote_code=True, token=token)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@lru_cache(maxsize=None)
|
| 44 |
+
def load_model(head: str) -> AutoModelForImageClassification:
|
| 45 |
+
token = os.environ.get("HF_TOKEN")
|
| 46 |
+
model = AutoModelForImageClassification.from_pretrained(
|
| 47 |
+
HF_REPO_ID,
|
| 48 |
+
trust_remote_code=True,
|
| 49 |
+
subfolder=head,
|
| 50 |
+
token=token,
|
| 51 |
+
)
|
| 52 |
+
model.eval()
|
| 53 |
+
return model
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@lru_cache(maxsize=None)
|
| 57 |
+
def load_curia_dataset(subset: str) -> Any:
|
| 58 |
+
token = os.environ.get("HF_TOKEN")
|
| 59 |
+
ds = load_dataset(
|
| 60 |
+
HF_DATASET_ID,
|
| 61 |
+
subset,
|
| 62 |
+
split="test",
|
| 63 |
+
token=token,
|
| 64 |
+
)
|
| 65 |
+
if isinstance(ds, DatasetDict):
|
| 66 |
+
return ds["test"]
|
| 67 |
+
return ds
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def to_numpy_image(image: Any) -> np.ndarray:
|
| 71 |
+
"""Convert dataset or user-provided imagery to a float32 numpy array."""
|
| 72 |
+
|
| 73 |
+
if isinstance(image, np.ndarray):
|
| 74 |
+
arr = image
|
| 75 |
+
elif isinstance(image, Image.Image):
|
| 76 |
+
arr = np.array(image)
|
| 77 |
+
else:
|
| 78 |
+
# Some datasets provide nested dicts or lists – attempt to coerce.
|
| 79 |
+
arr = np.array(image)
|
| 80 |
+
|
| 81 |
+
if arr.ndim == 3 and arr.shape[-1] == 3:
|
| 82 |
+
# Convert RGB to grayscale by averaging channels
|
| 83 |
+
arr = arr.mean(axis=-1)
|
| 84 |
+
|
| 85 |
+
return arr.astype(np.float32)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def infer_image(
|
| 89 |
+
image: np.ndarray,
|
| 90 |
+
head: str,
|
| 91 |
+
) -> torch.Tensor:
|
| 92 |
+
processor = load_processor()
|
| 93 |
+
model = load_model(head)
|
| 94 |
+
with torch.no_grad():
|
| 95 |
+
processed = processor(images=image, return_tensors="pt")
|
| 96 |
+
outputs = model(**processed)
|
| 97 |
+
logits = outputs["logits"]
|
| 98 |
+
probs = torch.nn.functional.softmax(logits[0], dim=-1)
|
| 99 |
+
return probs
|
requirements.txt
CHANGED
|
@@ -6,4 +6,4 @@ pandas>=2.2.0
|
|
| 6 |
numpy>=1.26.0
|
| 7 |
pillow>=10.2.0
|
| 8 |
opencv-python>=4.8.0
|
| 9 |
-
|
|
|
|
| 6 |
numpy>=1.26.0
|
| 7 |
pillow>=10.2.0
|
| 8 |
opencv-python>=4.8.0
|
| 9 |
+
torchvision
|