cdancette commited on
Commit
0fddfa1
·
1 Parent(s): 1d3f735

update space

Browse files
Files changed (4) hide show
  1. app.py +57 -148
  2. id_to_labels.json +173 -0
  3. inference.py +99 -0
  4. requirements.txt +1 -1
app.py CHANGED
@@ -14,26 +14,24 @@ deploying to Hugging Face Spaces).
14
 
15
  from __future__ import annotations
16
 
17
- import os
18
  import random
19
- from functools import lru_cache
20
- from typing import Any, Dict, List, Optional, Tuple, Union
21
 
22
  import cv2
23
  import gradio as gr
24
  import numpy as np
25
  import pandas as pd
26
  import torch
27
- from datasets import Dataset, DatasetDict, IterableDataset, load_dataset
28
  from PIL import Image
29
- from transformers import (
30
- AutoImageProcessor,
31
- AutoModelForImageClassification,
32
- )
33
-
34
 
35
- HF_REPO_ID = "raidium/curia"
36
- HF_DATASET_ID = "raidium/CuriaBench"
 
 
 
 
 
37
 
38
 
39
  # ---------------------------------------------------------------------------
@@ -41,6 +39,7 @@ HF_DATASET_ID = "raidium/CuriaBench"
41
  # ---------------------------------------------------------------------------
42
 
43
  HEAD_OPTIONS: List[Tuple[str, str]] = [
 
44
  ("anatomy-ct", "Anatomy CT"),
45
  ("anatomy-mri", "Anatomy MRI"),
46
  ("atlas-stroke", "Atlas Stroke"),
@@ -53,6 +52,8 @@ HEAD_OPTIONS: List[Tuple[str, str]] = [
53
  ("kneeMRI", "Knee MRI"),
54
  ("luna16-3D", "LUNA16 3D"),
55
  ("neural_foraminal_narrowing", "Neural Foraminal Narrowing"),
 
 
56
  ("oasis", "OASIS"),
57
  ]
58
 
@@ -65,6 +66,7 @@ HEADS_REQUIRING_MASK: set[str] = {
65
  "kits",
66
  "kneeMRI",
67
  "luna16-3D",
 
68
  "spinal_canal_stenosis",
69
  "subarticular_stenosis",
70
  }
@@ -115,60 +117,6 @@ DEFAULT_WINDOWINGS: Dict[str, Optional[Dict[str, int]]] = {
115
  # ---------------------------------------------------------------------------
116
 
117
 
118
- @lru_cache(maxsize=1)
119
- def load_processor() -> AutoImageProcessor:
120
- token = os.environ.get("HF_TOKEN")
121
- return AutoImageProcessor.from_pretrained(HF_REPO_ID, trust_remote_code=True, token=token)
122
-
123
-
124
- @lru_cache(maxsize=len(HEAD_OPTIONS))
125
- def load_model(head: str) -> AutoModelForImageClassification:
126
- token = os.environ.get("HF_TOKEN")
127
- model = AutoModelForImageClassification.from_pretrained(
128
- HF_REPO_ID,
129
- trust_remote_code=True,
130
- subfolder=head,
131
- token=token,
132
- )
133
- model.eval()
134
- return model
135
-
136
-
137
- @lru_cache(maxsize=len(DATASET_OPTIONS))
138
- def load_curia_dataset(subset: str) -> Any:
139
- token = os.environ.get("HF_TOKEN")
140
- ds = load_dataset(
141
- HF_DATASET_ID,
142
- subset,
143
- split="test",
144
- token=token,
145
- )
146
- if isinstance(ds, DatasetDict):
147
- return ds["test"]
148
- return ds
149
-
150
-
151
- def to_numpy_image(image: Any) -> np.ndarray:
152
- """Convert dataset or user-provided imagery to a float32 numpy array."""
153
-
154
- if isinstance(image, np.ndarray):
155
- arr = image
156
- elif isinstance(image, Image.Image):
157
- arr = np.array(image)
158
- else:
159
- # Some datasets provide nested dicts or lists – attempt to coerce.
160
- arr = np.array(image)
161
-
162
- if arr.ndim == 3 and arr.shape[-1] == 3:
163
- # Convert RGB to grayscale by averaging channels
164
- arr = arr.mean(axis=-1)
165
-
166
- if arr.ndim != 2:
167
- raise ValueError("Expected a 2D image (H, W). Please provide a single axial/coronal/sagittal slice.")
168
-
169
- return arr.astype(np.float32)
170
-
171
-
172
  def apply_windowing(image: np.ndarray, subset: str) -> np.ndarray:
173
  """Apply CT windowing based on the dataset.
174
 
@@ -223,26 +171,8 @@ def to_display_image(image: np.ndarray) -> np.ndarray:
223
  return arr
224
 
225
 
226
- def coerce_mask_array(mask: Any) -> Optional[np.ndarray]:
227
- if mask is None:
228
- return None
229
-
230
- try:
231
- arr = np.array(mask)
232
- except Exception:
233
- return None
234
-
235
- if arr.size == 0:
236
- return None
237
- return arr
238
-
239
-
240
- def prepare_mask_tensor(mask: Any, height: int, width: int) -> Optional[torch.Tensor]:
241
- mask_array = coerce_mask_array(mask)
242
- if mask_array is None:
243
- return None
244
-
245
- arr = np.squeeze(mask_array)
246
  if arr.ndim == 2:
247
  arr = arr.reshape(1, height, width)
248
  else:
@@ -301,16 +231,17 @@ def apply_contour_overlay(
301
  return output
302
 
303
 
304
- def render_image_with_mask_info(image: np.ndarray, mask: Any) -> Tuple[np.ndarray, Optional[str]]:
305
  display = to_display_image(image)
306
  if mask is None:
307
- return display, None
308
 
309
  try:
310
  overlaid = apply_contour_overlay(display, mask)
311
- return overlaid, ""
312
  except Exception:
313
- return display, "Mask provided but could not be visualised."
 
314
 
315
 
316
  def dataset_class_metadata(dataset: Dataset) -> Tuple[List[int], Dict[int, str]]:
@@ -341,39 +272,6 @@ def pick_random_indices(dataset: Dataset, target: Optional[int]) -> int:
341
  return random.choice(indices)
342
 
343
 
344
- def format_probabilities(probs: torch.Tensor, id2label: Dict[int, str]) -> pd.DataFrame:
345
- """Return a dataframe sorted by probability desc."""
346
-
347
- values = probs.detach().cpu().numpy()
348
- rows = [
349
- {"class_id": idx, "label": id2label.get(idx, str(idx)), "probability": float(val)}
350
- for idx, val in enumerate(values)
351
- ]
352
- df = pd.DataFrame(rows)
353
- df.sort_values("probability", ascending=False, inplace=True)
354
- return df
355
-
356
-
357
- def infer_image(
358
- image: np.ndarray,
359
- head: str,
360
- ) -> Tuple[str, pd.DataFrame]:
361
- processor = load_processor()
362
- model = load_model(head)
363
- with torch.no_grad():
364
- processed = processor(images=image, return_tensors="pt")
365
- outputs = model(**processed)
366
- print(outputs)
367
- logits = outputs["logits"]
368
- probs = torch.nn.functional.softmax(logits[0], dim=-1)
369
-
370
- id2label = model.config.id2label or {}
371
- df = format_probabilities(probs, id2label)
372
- top_row = df.iloc[0]
373
- prediction = f"{top_row['label']} (p={top_row['probability']:.3f})"
374
- return prediction, df
375
-
376
-
377
  # ---------------------------------------------------------------------------
378
  # Gradio callbacks
379
  # ---------------------------------------------------------------------------
@@ -439,12 +337,12 @@ def parse_target_selection(selection: str) -> Optional[int]:
439
  def sample_dataset_example(
440
  subset: str,
441
  target_id: Optional[int],
442
- ) -> Tuple[np.ndarray, str, Dict[str, Any]]:
443
  dataset = load_curia_dataset(subset)
444
  index = pick_random_indices(dataset, target_id)
445
  record = dataset[index]
446
  image = to_numpy_image(record["image"])
447
- mask_array = coerce_mask_array(record.get("mask"))
448
 
449
  meta = {
450
  "index": index,
@@ -452,7 +350,7 @@ def sample_dataset_example(
452
  "mask": mask_array,
453
  }
454
 
455
- return image, f"Sample #{index}", meta
456
 
457
 
458
  def load_dataset_sample(
@@ -461,57 +359,69 @@ def load_dataset_sample(
461
  head: str,
462
  ) -> Tuple[
463
  Optional[np.ndarray],
464
- str,
465
  pd.DataFrame,
466
  Dict[str, Any],
467
  Optional[Dict[str, Any]],
468
  ]:
469
  try:
470
  target_id = parse_target_selection(target_selection)
471
- image, caption, meta = sample_dataset_example(subset, target_id)
472
 
473
  # Apply windowing only for display, keep raw image for model inference
474
  windowed_image = apply_windowing(image, subset)
475
- display, mask_msg = render_image_with_mask_info(windowed_image, meta.get("mask"))
 
 
476
 
477
  target = meta.get("target")
478
- meta_text = caption
479
- if target is not None:
480
- meta_text += f" | target={target}"
481
- status = "Image loaded. Click 'Run inference' to compute predictions."
482
- if mask_msg:
483
- status += f" {mask_msg}"
484
- meta_text = status + "\n\n" + meta_text
485
-
486
  # Generate ground truth display
487
  ground_truth_update = gr.update(visible=False)
488
  if target is not None:
489
- model = load_model(head)
490
- id2label = model.config.id2label or {}
491
  label_name = id2label.get(target, str(target))
492
  ground_truth_update = gr.update(value=f"**Ground Truth:** {label_name} (class {target})", visible=True)
493
 
494
  return (
495
  display,
496
- meta_text,
497
  pd.DataFrame(),
498
  ground_truth_update,
499
  {"image": image, "mask": meta.get("mask")}, # Store raw image for inference
500
  )
501
  except Exception as exc: # pragma: no cover - surfaced in UI
502
- return None, f"Failed to load sample: {exc}", pd.DataFrame(), gr.update(visible=False), None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
 
504
 
505
  def run_inference(
506
- sample_state: Optional[Dict[str, Any]],
507
  head: str,
508
  ) -> Tuple[str, pd.DataFrame]:
509
- if not sample_state or "image" not in sample_state:
510
  return "Load a dataset sample or upload an image first.", pd.DataFrame()
511
 
512
  try:
513
- image = sample_state["image"]
514
- prediction, df = infer_image(image, head)
 
 
 
 
 
 
515
  result_text = f"**Prediction:** {prediction}"
516
  return result_text, df
517
  except Exception as exc: # pragma: no cover - surfaced in UI
@@ -605,7 +515,7 @@ def build_demo() -> gr.Blocks:
605
  gr.Markdown("---")
606
 
607
  infer_btn = gr.Button("Run inference", variant="primary")
608
-
609
  with gr.Row():
610
  with gr.Column():
611
  image_display = gr.Image(label="Image", interactive=False, type="numpy")
@@ -613,7 +523,7 @@ def build_demo() -> gr.Blocks:
613
 
614
  with gr.Column():
615
  gr.Markdown("### Predictions")
616
- status_text = gr.Markdown()
617
  prediction_probs = gr.Dataframe(headers=["class_id", "label", "probability"])
618
 
619
  image_state = gr.State()
@@ -640,7 +550,6 @@ def build_demo() -> gr.Blocks:
640
  inputs=[dataset_dropdown, class_dropdown, head_dropdown],
641
  outputs=[
642
  image_display,
643
- status_text,
644
  prediction_probs,
645
  ground_truth_display,
646
  image_state,
@@ -662,7 +571,7 @@ def build_demo() -> gr.Blocks:
662
  infer_btn.click(
663
  fn=run_inference,
664
  inputs=[image_state, head_dropdown],
665
- outputs=[status_text, prediction_probs],
666
  )
667
 
668
  gr.Markdown(
 
14
 
15
  from __future__ import annotations
16
 
 
17
  import random
18
+ from typing import Any, Dict, List, Optional, Tuple
 
19
 
20
  import cv2
21
  import gradio as gr
22
  import numpy as np
23
  import pandas as pd
24
  import torch
25
+ from datasets import Dataset
26
  from PIL import Image
 
 
 
 
 
27
 
28
+ from inference import (
29
+ HF_DATASET_ID,
30
+ load_curia_dataset,
31
+ load_id_to_labels,
32
+ to_numpy_image,
33
+ infer_image,
34
+ )
35
 
36
 
37
  # ---------------------------------------------------------------------------
 
39
  # ---------------------------------------------------------------------------
40
 
41
  HEAD_OPTIONS: List[Tuple[str, str]] = [
42
+ ("abdominal-trauma", "Active Extravasation"),
43
  ("anatomy-ct", "Anatomy CT"),
44
  ("anatomy-mri", "Anatomy MRI"),
45
  ("atlas-stroke", "Atlas Stroke"),
 
52
  ("kneeMRI", "Knee MRI"),
53
  ("luna16-3D", "LUNA16 3D"),
54
  ("neural_foraminal_narrowing", "Neural Foraminal Narrowing"),
55
+ ("spinal_canal_stenosis", "Spinal Canal Stenosis"),
56
+ ("subarticular_stenosis", "Subarticular Stenosis"),
57
  ("oasis", "OASIS"),
58
  ]
59
 
 
66
  "kits",
67
  "kneeMRI",
68
  "luna16-3D",
69
+ "neural_foraminal_narrowing",
70
  "spinal_canal_stenosis",
71
  "subarticular_stenosis",
72
  }
 
117
  # ---------------------------------------------------------------------------
118
 
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  def apply_windowing(image: np.ndarray, subset: str) -> np.ndarray:
121
  """Apply CT windowing based on the dataset.
122
 
 
171
  return arr
172
 
173
 
174
+ def prepare_mask_tensor(mask: np.ndarray, height: int, width: int) -> Optional[torch.Tensor]:
175
+ arr = np.squeeze(mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  if arr.ndim == 2:
177
  arr = arr.reshape(1, height, width)
178
  else:
 
231
  return output
232
 
233
 
234
+ def render_image_with_mask_info(image: np.ndarray, mask: Any) -> np.ndarray:
235
  display = to_display_image(image)
236
  if mask is None:
237
+ return display
238
 
239
  try:
240
  overlaid = apply_contour_overlay(display, mask)
241
+ return overlaid
242
  except Exception:
243
+ gr.Warning("Mask provided but could not be visualised.")
244
+ return display
245
 
246
 
247
  def dataset_class_metadata(dataset: Dataset) -> Tuple[List[int], Dict[int, str]]:
 
272
  return random.choice(indices)
273
 
274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  # ---------------------------------------------------------------------------
276
  # Gradio callbacks
277
  # ---------------------------------------------------------------------------
 
337
  def sample_dataset_example(
338
  subset: str,
339
  target_id: Optional[int],
340
+ ) -> Tuple[np.ndarray, Dict[str, Any]]:
341
  dataset = load_curia_dataset(subset)
342
  index = pick_random_indices(dataset, target_id)
343
  record = dataset[index]
344
  image = to_numpy_image(record["image"])
345
+ mask_array = record.get("mask")
346
 
347
  meta = {
348
  "index": index,
 
350
  "mask": mask_array,
351
  }
352
 
353
+ return image, meta
354
 
355
 
356
  def load_dataset_sample(
 
359
  head: str,
360
  ) -> Tuple[
361
  Optional[np.ndarray],
 
362
  pd.DataFrame,
363
  Dict[str, Any],
364
  Optional[Dict[str, Any]],
365
  ]:
366
  try:
367
  target_id = parse_target_selection(target_selection)
368
+ image, meta = sample_dataset_example(subset, target_id)
369
 
370
  # Apply windowing only for display, keep raw image for model inference
371
  windowed_image = apply_windowing(image, subset)
372
+ display = to_display_image(windowed_image)
373
+ if meta.get("mask") is not None:
374
+ display = apply_contour_overlay(display, meta.get("mask"))
375
 
376
  target = meta.get("target")
 
 
 
 
 
 
 
 
377
  # Generate ground truth display
378
  ground_truth_update = gr.update(visible=False)
379
  if target is not None:
380
+ # Use id_to_labels.json mapping
381
+ id2label = load_id_to_labels().get(head, {})
382
  label_name = id2label.get(target, str(target))
383
  ground_truth_update = gr.update(value=f"**Ground Truth:** {label_name} (class {target})", visible=True)
384
 
385
  return (
386
  display,
 
387
  pd.DataFrame(),
388
  ground_truth_update,
389
  {"image": image, "mask": meta.get("mask")}, # Store raw image for inference
390
  )
391
  except Exception as exc: # pragma: no cover - surfaced in UI
392
+ gr.Warning(f"Failed to load sample: {exc}")
393
+ return None, pd.DataFrame(), gr.update(visible=False), None
394
+
395
+
396
+ def format_probabilities(probs: torch.Tensor, id2label: Dict[int, str]) -> pd.DataFrame:
397
+ """Return a dataframe sorted by probability desc."""
398
+
399
+ values = probs.detach().cpu().numpy()
400
+ rows = [
401
+ {"class_id": idx, "label": id2label.get(idx, str(idx)), "probability": float(val)}
402
+ for idx, val in enumerate(values)
403
+ ]
404
+ df = pd.DataFrame(rows)
405
+ df.sort_values("probability", ascending=False, inplace=True)
406
+ return df
407
 
408
 
409
  def run_inference(
410
+ image_state: Optional[Dict[str, Any]],
411
  head: str,
412
  ) -> Tuple[str, pd.DataFrame]:
413
+ if not image_state or "image" not in image_state:
414
  return "Load a dataset sample or upload an image first.", pd.DataFrame()
415
 
416
  try:
417
+ image = image_state["image"]
418
+ probs = infer_image(image, head)
419
+
420
+ # Use id_to_labels.json mapping, fall back to model config if not available
421
+ id2label = load_id_to_labels().get(head, {})
422
+ df = format_probabilities(probs, id2label)
423
+ top_row = df.iloc[0]
424
+ prediction = f"{top_row['label']} (p={top_row['probability']:.3f})"
425
  result_text = f"**Prediction:** {prediction}"
426
  return result_text, df
427
  except Exception as exc: # pragma: no cover - surfaced in UI
 
515
  gr.Markdown("---")
516
 
517
  infer_btn = gr.Button("Run inference", variant="primary")
518
+ status_text = gr.Markdown()
519
  with gr.Row():
520
  with gr.Column():
521
  image_display = gr.Image(label="Image", interactive=False, type="numpy")
 
523
 
524
  with gr.Column():
525
  gr.Markdown("### Predictions")
526
+ main_prediction = gr.Markdown()
527
  prediction_probs = gr.Dataframe(headers=["class_id", "label", "probability"])
528
 
529
  image_state = gr.State()
 
550
  inputs=[dataset_dropdown, class_dropdown, head_dropdown],
551
  outputs=[
552
  image_display,
 
553
  prediction_probs,
554
  ground_truth_display,
555
  image_state,
 
571
  infer_btn.click(
572
  fn=run_inference,
573
  inputs=[image_state, head_dropdown],
574
+ outputs=[main_prediction, prediction_probs],
575
  )
576
 
577
  gr.Markdown(
id_to_labels.json ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "covidx-ct": {
3
+ "0": "normal",
4
+ "1": "pneumonia",
5
+ "2": "COVID-19"
6
+ },
7
+ "deep-lesion-site": {
8
+ "0": "abdomen",
9
+ "1": "bone",
10
+ "2": "kidney",
11
+ "3": "liver",
12
+ "4": "lung",
13
+ "5": "mediastinum",
14
+ "6": "pelvis",
15
+ "7": "soft_tissue"
16
+ },
17
+ "kits": {
18
+ "0": "lesion_kidney",
19
+ "1": "cyst"
20
+ },
21
+ "kneeMRI": {
22
+ "0": "healthy",
23
+ "1": "partially injured",
24
+ "2": "completely ruptured"
25
+ },
26
+ "luna16-3D": {
27
+ "0": "benign",
28
+ "1": "malignant"
29
+ },
30
+ "oasis": {
31
+ "0": "nondemented",
32
+ "1": "signs of dementia"
33
+ },
34
+ "abdominal-trauma": {
35
+ "0": "no active extravasation",
36
+ "1": "active extravasation"
37
+ },
38
+ "ich": {
39
+ "0": "no hemorrhage",
40
+ "1": "intracranial hemorrhage"
41
+ },
42
+ "neural_foraminal_narrowing": {
43
+ "0": "normal/mild",
44
+ "1": "moderate",
45
+ "2": "severe"
46
+ },
47
+ "spinal_canal_stenosis": {
48
+ "0": "normal/mild",
49
+ "1": "moderate",
50
+ "2": "severe"
51
+ },
52
+ "subarticular_stenosis": {
53
+ "0": "normal/mild",
54
+ "1": "moderate",
55
+ "2": "severe"
56
+ },
57
+ "anatomy-ct": {
58
+ "0": "adrenal_gland_left",
59
+ "1": "adrenal_gland_right",
60
+ "2": "aorta",
61
+ "3": "autochthon_left",
62
+ "4": "autochthon_right",
63
+ "5": "brain",
64
+ "6": "clavicula_left",
65
+ "7": "clavicula_right",
66
+ "8": "colon",
67
+ "9": "duodenum",
68
+ "10": "esophagus",
69
+ "11": "face",
70
+ "12": "femur_left",
71
+ "13": "femur_right",
72
+ "14": "gallbladder",
73
+ "15": "gluteus_maximus_left",
74
+ "16": "gluteus_maximus_right",
75
+ "17": "gluteus_medius_left",
76
+ "18": "gluteus_medius_right",
77
+ "19": "gluteus_minimus_left",
78
+ "20": "gluteus_minimus_right",
79
+ "21": "heart",
80
+ "22": "hip_left",
81
+ "23": "hip_right",
82
+ "24": "humerus_left",
83
+ "25": "humerus_right",
84
+ "26": "iliac_artery_left",
85
+ "27": "iliac_artery_right",
86
+ "28": "iliac_vena_left",
87
+ "29": "iliac_vena_right",
88
+ "30": "iliopsoas_left",
89
+ "31": "iliopsoas_right",
90
+ "32": "inferior_vena_cava",
91
+ "33": "kidney_left",
92
+ "34": "kidney_right",
93
+ "35": "liver",
94
+ "36": "lung_left",
95
+ "37": "lung_right",
96
+ "38": "pancreas",
97
+ "39": "portal_vein",
98
+ "40": "pulmonary_artery",
99
+ "41": "rib_left",
100
+ "42": "rib_right",
101
+ "43": "sacrum",
102
+ "44": "scapula_left",
103
+ "45": "scapula_right",
104
+ "46": "small_bowel",
105
+ "47": "spleen",
106
+ "48": "stomach",
107
+ "49": "trachea",
108
+ "50": "urinary_bladder",
109
+ "51": "vertebrae_cervical",
110
+ "52": "vertebrae_dorsal",
111
+ "53": "vertebrae_lumbar"
112
+ },
113
+ "anatomy-mri": {
114
+ "0": "adrenal_gland_left",
115
+ "1": "adrenal_gland_right",
116
+ "2": "aorta",
117
+ "3": "autochthon_left",
118
+ "4": "autochthon_right",
119
+ "5": "brain",
120
+ "6": "colon",
121
+ "7": "duodenum",
122
+ "8": "esophagus",
123
+ "9": "femur_left",
124
+ "10": "femur_right",
125
+ "11": "fibula",
126
+ "12": "gallbladder",
127
+ "13": "gluteus_maximus_left",
128
+ "14": "gluteus_maximus_right",
129
+ "15": "gluteus_medius_left",
130
+ "16": "gluteus_medius_right",
131
+ "17": "heart",
132
+ "18": "hip_left",
133
+ "19": "hip_right",
134
+ "20": "humerus_left",
135
+ "21": "humerus_right",
136
+ "22": "iliac_artery_left",
137
+ "23": "iliac_artery_right",
138
+ "24": "iliac_vena_left",
139
+ "25": "iliac_vena_right",
140
+ "26": "iliopsoas_left",
141
+ "27": "iliopsoas_right",
142
+ "28": "inferior_vena_cava",
143
+ "29": "intervertebral_discs",
144
+ "30": "kidney_left",
145
+ "31": "kidney_right",
146
+ "32": "liver",
147
+ "33": "lung_left",
148
+ "34": "lung_right",
149
+ "35": "pancreas",
150
+ "36": "portal_vein_and_splenic_vein",
151
+ "37": "prostate",
152
+ "38": "quadriceps_femoris_left",
153
+ "39": "quadriceps_femoris_right",
154
+ "40": "sacrum",
155
+ "41": "sartorius_left",
156
+ "42": "sartorius_right",
157
+ "43": "small_bowel",
158
+ "44": "spinal_cord",
159
+ "45": "spleen",
160
+ "46": "stomach",
161
+ "47": "thigh_medial_compartment_left",
162
+ "48": "thigh_medial_compartment_right",
163
+ "49": "thigh_posterior_compartment_left",
164
+ "50": "thigh_posterior_compartment_right",
165
+ "51": "tibia",
166
+ "52": "urinary_bladder",
167
+ "53": "vertebrae"
168
+ },
169
+ "emidec-classification-mask": {
170
+ "0": "healthy",
171
+ "1": "infarction"
172
+ }
173
+ }
inference.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model and dataset loading, inference, and label extraction functions."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ from functools import lru_cache
8
+ from typing import Any, Dict, List, Optional, Tuple
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import torch
13
+ from datasets import Dataset, DatasetDict, load_dataset
14
+ from PIL import Image
15
+ from transformers import (
16
+ AutoImageProcessor,
17
+ AutoModelForImageClassification,
18
+ )
19
+
20
+
21
+ HF_REPO_ID = "raidium/curia"
22
+ HF_DATASET_ID = "raidium/CuriaBench"
23
+
24
+
25
+ @lru_cache(maxsize=1)
26
+ def load_id_to_labels() -> Dict[str, Dict[str, str]]:
27
+ """Load the id_to_labels.json mapping file."""
28
+ json_path = os.path.join(os.path.dirname(__file__), "id_to_labels.json")
29
+ with open(json_path, "r") as f:
30
+ data = json.load(f)
31
+ # convert string keys to integers
32
+ for head in data:
33
+ data[head] = {int(k): v for k, v in data[head].items()}
34
+ return data
35
+
36
+
37
+ @lru_cache(maxsize=1)
38
+ def load_processor() -> AutoImageProcessor:
39
+ token = os.environ.get("HF_TOKEN")
40
+ return AutoImageProcessor.from_pretrained(HF_REPO_ID, trust_remote_code=True, token=token)
41
+
42
+
43
+ @lru_cache(maxsize=None)
44
+ def load_model(head: str) -> AutoModelForImageClassification:
45
+ token = os.environ.get("HF_TOKEN")
46
+ model = AutoModelForImageClassification.from_pretrained(
47
+ HF_REPO_ID,
48
+ trust_remote_code=True,
49
+ subfolder=head,
50
+ token=token,
51
+ )
52
+ model.eval()
53
+ return model
54
+
55
+
56
+ @lru_cache(maxsize=None)
57
+ def load_curia_dataset(subset: str) -> Any:
58
+ token = os.environ.get("HF_TOKEN")
59
+ ds = load_dataset(
60
+ HF_DATASET_ID,
61
+ subset,
62
+ split="test",
63
+ token=token,
64
+ )
65
+ if isinstance(ds, DatasetDict):
66
+ return ds["test"]
67
+ return ds
68
+
69
+
70
+ def to_numpy_image(image: Any) -> np.ndarray:
71
+ """Convert dataset or user-provided imagery to a float32 numpy array."""
72
+
73
+ if isinstance(image, np.ndarray):
74
+ arr = image
75
+ elif isinstance(image, Image.Image):
76
+ arr = np.array(image)
77
+ else:
78
+ # Some datasets provide nested dicts or lists – attempt to coerce.
79
+ arr = np.array(image)
80
+
81
+ if arr.ndim == 3 and arr.shape[-1] == 3:
82
+ # Convert RGB to grayscale by averaging channels
83
+ arr = arr.mean(axis=-1)
84
+
85
+ return arr.astype(np.float32)
86
+
87
+
88
+ def infer_image(
89
+ image: np.ndarray,
90
+ head: str,
91
+ ) -> torch.Tensor:
92
+ processor = load_processor()
93
+ model = load_model(head)
94
+ with torch.no_grad():
95
+ processed = processor(images=image, return_tensors="pt")
96
+ outputs = model(**processed)
97
+ logits = outputs["logits"]
98
+ probs = torch.nn.functional.softmax(logits[0], dim=-1)
99
+ return probs
requirements.txt CHANGED
@@ -6,4 +6,4 @@ pandas>=2.2.0
6
  numpy>=1.26.0
7
  pillow>=10.2.0
8
  opencv-python>=4.8.0
9
-
 
6
  numpy>=1.26.0
7
  pillow>=10.2.0
8
  opencv-python>=4.8.0
9
+ torchvision