cdancette commited on
Commit
a8db175
·
1 Parent(s): 467b7ba

simplify dataset management

Browse files
Files changed (2) hide show
  1. app.py +92 -77
  2. inference.py +2 -3
app.py CHANGED
@@ -73,22 +73,31 @@ HEADS_REQUIRING_MASK: set[str] = {
73
  }
74
 
75
 
76
- DATASET_OPTIONS: Dict[str, Dict[str, Any]] = {
77
- "anatomy-ct": {"label": "Anatomy CT (test)", "head": "anatomy-ct"},
78
- "anatomy-ct-hard": {"label": "Anatomy CT Hard (test)", "head": "anatomy-ct"},
79
- "anatomy-mri": {"label": "Anatomy MRI (test)", "head": "anatomy-mri"},
80
- "covidx-ct": {"label": "COVIDx CT (test)", "head": "covidx-ct"},
81
- "deep-lesion-site": {"label": "Deep Lesion Site (test)", "head": "deep-lesion-site"},
82
- "emidec-classification-mask": {
83
- "label": "EMIDEC Classification Mask (test)",
84
- "head": "emidec-classification-mask",
85
- },
86
- "ixi": {"label": "IXI (test)", "head": "ixi"},
87
- "kits": {"label": "KiTS (test)", "head": "kits"},
88
- "kneeMRI": {"label": "Knee MRI (test)", "head": "kneeMRI"},
89
- "luna16": {"label": "LUNA16 (test)", "head": "luna16-3D"},
90
- "luna16-3D": {"label": "LUNA16 3D (test)", "head": "luna16-3D"},
91
- "oasis": {"label": "OASIS (test)", "head": "oasis"},
 
 
 
 
 
 
 
 
 
92
  }
93
 
94
 
@@ -118,7 +127,7 @@ DEFAULT_WINDOWINGS: Dict[str, Optional[Dict[str, int]]] = {
118
  # ---------------------------------------------------------------------------
119
 
120
 
121
- def apply_windowing(image: np.ndarray, subset: str) -> np.ndarray:
122
  """Apply CT windowing based on the dataset.
123
 
124
  For CT images, applies window level and width transformation.
@@ -131,7 +140,7 @@ def apply_windowing(image: np.ndarray, subset: str) -> np.ndarray:
131
  Returns:
132
  Windowed image array
133
  """
134
- windowing = DEFAULT_WINDOWINGS.get(subset)
135
 
136
  # No windowing for MRI or unknown datasets
137
  if windowing is None:
@@ -250,21 +259,6 @@ def render_image_with_mask_info(image: np.ndarray, mask: Any) -> np.ndarray:
250
  return display
251
 
252
 
253
- def dataset_class_metadata(dataset: Dataset) -> Tuple[List[int], Dict[int, str]]:
254
- target_feature = dataset.features.get("target")
255
- if target_feature and hasattr(target_feature, "names"):
256
- names = list(target_feature.names)
257
- id2label = {i: name for i, name in enumerate(names)}
258
- classes = list(range(len(names)))
259
- return classes, id2label
260
-
261
- # Fall back to generic inspection
262
- targets = dataset["target"] if "target" in dataset.column_names else []
263
- unique = sorted({int(t) for t in targets}) if targets else []
264
- id2label = {i: str(i) for i in unique}
265
- return unique, id2label
266
-
267
-
268
  def pick_random_indices(dataset: Dataset, target: Optional[int]) -> int:
269
  if "target" not in dataset.column_names:
270
  return random.randrange(len(dataset))
@@ -283,12 +277,14 @@ def pick_random_indices(dataset: Dataset, target: Optional[int]) -> int:
283
  # ---------------------------------------------------------------------------
284
 
285
 
286
- def update_dataset_from_head(head: str) -> Dict[str, Any]:
287
- # Find the first dataset that matches this head
288
- for dataset_key, meta in DATASET_OPTIONS.items():
289
- if meta["head"] == head:
290
- return gr.update(value=dataset_key)
291
- return gr.update()
 
 
292
 
293
 
294
  def update_upload_component_state(head: str) -> Tuple[Dict[str, Any], Dict[str, Any]]:
@@ -306,27 +302,34 @@ def update_upload_component_state(head: str) -> Tuple[Dict[str, Any], Dict[str,
306
  return info_update, upload_update
307
 
308
 
309
- def load_dataset_metadata(subset: str) -> Tuple[Dict[str, Any], str]:
 
 
 
 
 
 
 
 
 
 
 
310
  try:
311
  dataset = load_curia_dataset(subset)
312
  except Exception as exc: # pragma: no cover - surfaced in UI
313
- dropdown = gr.update(choices=["Random"], value="Random")
314
- return dropdown, f"Failed to load dataset: {exc}"
315
-
316
- classes, id2label = dataset_class_metadata(dataset)
317
- if not classes:
318
- dropdown = gr.update(
319
- choices=["Random"],
320
- value="Random",
321
- )
322
- return dropdown, "No class metadata detected; sampling at random"
323
 
 
 
324
  options = [
325
  "Random",
326
- *[f"{cls_id}: {id2label.get(cls_id, str(cls_id))}" for cls_id in classes],
327
  ]
328
- dropdown = gr.update(choices=options, value="Random")
329
- return dropdown, f"Loaded {subset} ({len(dataset)} test samples)"
 
330
 
331
 
332
  def parse_target_selection(selection: str) -> Optional[int]:
@@ -348,6 +351,7 @@ def sample_dataset_example(
348
  index = pick_random_indices(dataset, target_id)
349
  record = dataset[index]
350
  image = to_numpy_image(record["image"])
 
351
  mask_array = record.get("mask")
352
 
353
  meta = {
@@ -360,15 +364,21 @@ def sample_dataset_example(
360
 
361
 
362
  def load_dataset_sample(
363
- subset: str,
364
  target_selection: str,
365
  head: str,
366
  ) -> Tuple[
367
  Optional[np.ndarray],
 
368
  pd.DataFrame,
369
  Dict[str, Any],
370
  Optional[Dict[str, Any]],
371
  ]:
 
 
 
 
 
 
372
  try:
373
  target_id = parse_target_selection(target_selection)
374
  image, meta = sample_dataset_example(subset, target_id)
@@ -390,13 +400,14 @@ def load_dataset_sample(
390
 
391
  return (
392
  display,
 
393
  pd.DataFrame(),
394
  ground_truth_update,
395
  {"image": image, "mask": meta.get("mask")}, # Store raw image for inference
396
  )
397
  except Exception as exc: # pragma: no cover - surfaced in UI
398
  gr.Warning(f"Failed to load sample: {exc}")
399
- return None, pd.DataFrame(), gr.update(visible=False), None
400
 
401
 
402
  def format_probabilities(probs: torch.Tensor, id2label: Dict[int, str]) -> pd.DataFrame:
@@ -436,27 +447,28 @@ def run_inference(
436
 
437
  def handle_upload_preview(
438
  image: np.ndarray | Image.Image | None,
439
- subset: str,
440
- ) -> Tuple[Optional[np.ndarray], str, pd.DataFrame, Dict[str, Any], Optional[Dict[str, Any]]]:
 
441
  if image is None:
442
- return None, "Please upload an image.", pd.DataFrame(), gr.update(visible=False), None
443
 
444
  try:
445
  np_image = to_numpy_image(image)
446
 
447
  # Apply windowing only for display, keep raw image for model inference
448
- windowed_image = apply_windowing(np_image, subset)
449
- display = to_display_image(windowed_image)
450
 
451
  return (
452
  display,
453
  "Image uploaded. Click 'Run inference' to compute predictions.",
 
454
  pd.DataFrame(),
455
  gr.update(visible=False),
456
  {"image": np_image, "mask": None}, # Store raw image for inference
457
  )
458
  except Exception as exc: # pragma: no cover - surfaced in UI
459
- return None, f"Failed to load image: {exc}", pd.DataFrame(), gr.update(visible=False), None
460
 
461
 
462
  # ---------------------------------------------------------------------------
@@ -490,12 +502,8 @@ def build_demo() -> gr.Blocks:
490
  with gr.Row():
491
  with gr.Column():
492
  gr.Markdown("### Load dataset sample")
493
- dataset_dropdown = gr.Dropdown(
494
- label="CuriaBench subset",
495
- choices=[(meta["label"], key) for key, meta in DATASET_OPTIONS.items()],
496
- value="kits",
497
- )
498
- dataset_status = gr.Markdown("Select a dataset to load class metadata.")
499
  class_dropdown = gr.Dropdown(label="Target class filter", choices=["Random"], value="Random")
500
  dataset_btn = gr.Button("Load dataset sample")
501
 
@@ -525,9 +533,9 @@ def build_demo() -> gr.Blocks:
525
  with gr.Row():
526
  with gr.Column():
527
  image_display = gr.Image(label="Image", interactive=False, type="numpy")
528
- ground_truth_display = gr.Markdown(visible=False)
529
 
530
  with gr.Column():
 
531
  gr.Markdown("### Predictions")
532
  main_prediction = gr.Markdown()
533
  prediction_probs = gr.Dataframe(headers=["class_id", "label", "probability"])
@@ -535,27 +543,33 @@ def build_demo() -> gr.Blocks:
535
  image_state = gr.State()
536
 
537
  # Event wiring
 
 
 
 
 
 
 
538
  head_dropdown.change(
539
- fn=update_dataset_from_head,
540
  inputs=[head_dropdown],
541
- outputs=[dataset_dropdown],
542
  ).then(
543
  fn=update_upload_component_state,
544
  inputs=[head_dropdown],
545
  outputs=[upload_info_text, upload_component],
546
- )
547
-
548
- dataset_dropdown.change(
549
  fn=load_dataset_metadata,
550
- inputs=[dataset_dropdown],
551
- outputs=[class_dropdown, dataset_status],
552
  )
553
 
554
  dataset_btn.click(
555
  fn=load_dataset_sample,
556
- inputs=[dataset_dropdown, class_dropdown, head_dropdown],
557
  outputs=[
558
  image_display,
 
559
  prediction_probs,
560
  ground_truth_display,
561
  image_state,
@@ -564,10 +578,11 @@ def build_demo() -> gr.Blocks:
564
 
565
  upload_component.upload(
566
  fn=handle_upload_preview,
567
- inputs=[upload_component, dataset_dropdown],
568
  outputs=[
569
  image_display,
570
  status_text,
 
571
  prediction_probs,
572
  ground_truth_display,
573
  image_state,
 
73
  }
74
 
75
 
76
+ DATASET_OPTIONS: Dict[str, str] = {
77
+ "anatomy-ct": "Anatomy CT (test)",
78
+ "anatomy-ct-hard": "Anatomy CT Hard (test)",
79
+ "anatomy-mri": "Anatomy MRI (test)",
80
+ "covidx-ct": "COVIDx CT (test)",
81
+ "deep-lesion-site": "Deep Lesion Site (test)",
82
+ "emidec-classification-mask": "EMIDEC Classification Mask (test)",
83
+ "ixi": "IXI (test)",
84
+ "kits": "KiTS (test)",
85
+ "kneeMRI": "Knee MRI (test)",
86
+ "luna16-3D": "LUNA16 3D (test)",
87
+ "oasis": "OASIS (test)",
88
+ }
89
+
90
+ DEFAULT_DATASET_FOR_HEAD: Dict[str, str] = {
91
+ "anatomy-ct": "anatomy-ct",
92
+ "anatomy-mri": "anatomy-mri",
93
+ "covidx-ct": "covidx-ct",
94
+ "deep-lesion-site": "deep-lesion-site",
95
+ "emidec-classification-mask": "emidec-classification-mask",
96
+ "ixi": "ixi",
97
+ "kits": "kits",
98
+ "kneeMRI": "kneeMRI",
99
+ "luna16-3D": "luna16-3D",
100
+ "oasis": "oasis",
101
  }
102
 
103
 
 
127
  # ---------------------------------------------------------------------------
128
 
129
 
130
+ def apply_windowing(image: np.ndarray, head: str) -> np.ndarray:
131
  """Apply CT windowing based on the dataset.
132
 
133
  For CT images, applies window level and width transformation.
 
140
  Returns:
141
  Windowed image array
142
  """
143
+ windowing = DEFAULT_WINDOWINGS.get(head)
144
 
145
  # No windowing for MRI or unknown datasets
146
  if windowing is None:
 
259
  return display
260
 
261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  def pick_random_indices(dataset: Dataset, target: Optional[int]) -> int:
263
  if "target" not in dataset.column_names:
264
  return random.randrange(len(dataset))
 
277
  # ---------------------------------------------------------------------------
278
 
279
 
280
+
281
+ def update_dataset_display(head: str) -> str:
282
+ """Update the dataset name display based on the selected head."""
283
+ dataset_key = DEFAULT_DATASET_FOR_HEAD.get(head)
284
+ if dataset_key:
285
+ dataset_label = DATASET_OPTIONS.get(dataset_key, dataset_key)
286
+ return f"**Dataset:** {dataset_label}"
287
+ return "**Dataset:** not available"
288
 
289
 
290
  def update_upload_component_state(head: str) -> Tuple[Dict[str, Any], Dict[str, Any]]:
 
302
  return info_update, upload_update
303
 
304
 
305
+ def load_dataset_metadata(head: str) -> Tuple[Dict[str, Any], str, Dict[str, Any]]:
306
+ """Load dataset metadata based on the selected head."""
307
+ subset = DEFAULT_DATASET_FOR_HEAD.get(head)
308
+ if not subset:
309
+ dropdown = gr.update(choices=["Random"], value="Random", interactive=False)
310
+ button = gr.update(interactive=False)
311
+ return dropdown, "No dataset found for this head.", button
312
+
313
+ # Load class labels from id_to_labels.json
314
+ id2label = load_id_to_labels().get(head, {})
315
+
316
+
317
  try:
318
  dataset = load_curia_dataset(subset)
319
  except Exception as exc: # pragma: no cover - surfaced in UI
320
+ dropdown = gr.update(choices=["Random"], value="Random", interactive=False)
321
+ button = gr.update(interactive=False)
322
+ return dropdown, f"Failed to load dataset: {exc}", button
 
 
 
 
 
 
 
323
 
324
+ # Build dropdown options from id_to_labels.json
325
+ classes = sorted(id2label.keys())
326
  options = [
327
  "Random",
328
+ *[f"{cls_id}: {id2label[cls_id]}" for cls_id in classes],
329
  ]
330
+ dropdown = gr.update(choices=options, value="Random", interactive=True)
331
+ button = gr.update(interactive=True)
332
+ return dropdown, f"Loaded {subset} ({len(dataset)} test samples)", button
333
 
334
 
335
  def parse_target_selection(selection: str) -> Optional[int]:
 
351
  index = pick_random_indices(dataset, target_id)
352
  record = dataset[index]
353
  image = to_numpy_image(record["image"])
354
+ print(image.shape)
355
  mask_array = record.get("mask")
356
 
357
  meta = {
 
364
 
365
 
366
  def load_dataset_sample(
 
367
  target_selection: str,
368
  head: str,
369
  ) -> Tuple[
370
  Optional[np.ndarray],
371
+ str,
372
  pd.DataFrame,
373
  Dict[str, Any],
374
  Optional[Dict[str, Any]],
375
  ]:
376
+ """Load a dataset sample based on the selected head."""
377
+ subset = DEFAULT_DATASET_FOR_HEAD.get(head)
378
+ if not subset:
379
+ gr.Warning("No dataset found for this head.")
380
+ return None, "", pd.DataFrame(), gr.update(visible=False), None
381
+
382
  try:
383
  target_id = parse_target_selection(target_selection)
384
  image, meta = sample_dataset_example(subset, target_id)
 
400
 
401
  return (
402
  display,
403
+ "", # Reset prediction text
404
  pd.DataFrame(),
405
  ground_truth_update,
406
  {"image": image, "mask": meta.get("mask")}, # Store raw image for inference
407
  )
408
  except Exception as exc: # pragma: no cover - surfaced in UI
409
  gr.Warning(f"Failed to load sample: {exc}")
410
+ return None, "", pd.DataFrame(), gr.update(visible=False), None
411
 
412
 
413
  def format_probabilities(probs: torch.Tensor, id2label: Dict[int, str]) -> pd.DataFrame:
 
447
 
448
  def handle_upload_preview(
449
  image: np.ndarray | Image.Image | None,
450
+ head: str,
451
+ ) -> Tuple[Optional[np.ndarray], str, str, pd.DataFrame, Dict[str, Any], Optional[Dict[str, Any]]]:
452
+ """Handle image upload preview, deriving dataset from head."""
453
  if image is None:
454
+ return None, "Please upload an image.", "", pd.DataFrame(), gr.update(visible=False), None
455
 
456
  try:
457
  np_image = to_numpy_image(image)
458
 
459
  # Apply windowing only for display, keep raw image for model inference
460
+ display = to_display_image(np_image)
 
461
 
462
  return (
463
  display,
464
  "Image uploaded. Click 'Run inference' to compute predictions.",
465
+ "", # Reset prediction text
466
  pd.DataFrame(),
467
  gr.update(visible=False),
468
  {"image": np_image, "mask": None}, # Store raw image for inference
469
  )
470
  except Exception as exc: # pragma: no cover - surfaced in UI
471
+ return None, f"Failed to load image: {exc}", "", pd.DataFrame(), gr.update(visible=False), None
472
 
473
 
474
  # ---------------------------------------------------------------------------
 
502
  with gr.Row():
503
  with gr.Column():
504
  gr.Markdown("### Load dataset sample")
505
+ dataset_display = gr.Markdown(f"**Dataset:** {DATASET_OPTIONS.get(DEFAULT_DATASET_FOR_HEAD.get(default_head, ''), 'Unknown')}")
506
+ dataset_status = gr.Markdown("Select a model head to load class metadata.")
 
 
 
 
507
  class_dropdown = gr.Dropdown(label="Target class filter", choices=["Random"], value="Random")
508
  dataset_btn = gr.Button("Load dataset sample")
509
 
 
533
  with gr.Row():
534
  with gr.Column():
535
  image_display = gr.Image(label="Image", interactive=False, type="numpy")
 
536
 
537
  with gr.Column():
538
+ ground_truth_display = gr.Markdown(visible=False)
539
  gr.Markdown("### Predictions")
540
  main_prediction = gr.Markdown()
541
  prediction_probs = gr.Dataframe(headers=["class_id", "label", "probability"])
 
543
  image_state = gr.State()
544
 
545
  # Event wiring
546
+ # Initialize on page load
547
+ demo.load(
548
+ fn=load_dataset_metadata,
549
+ inputs=[head_dropdown],
550
+ outputs=[class_dropdown, dataset_status, dataset_btn],
551
+ )
552
+
553
  head_dropdown.change(
554
+ fn=update_dataset_display,
555
  inputs=[head_dropdown],
556
+ outputs=[dataset_display],
557
  ).then(
558
  fn=update_upload_component_state,
559
  inputs=[head_dropdown],
560
  outputs=[upload_info_text, upload_component],
561
+ ).then(
 
 
562
  fn=load_dataset_metadata,
563
+ inputs=[head_dropdown],
564
+ outputs=[class_dropdown, dataset_status, dataset_btn],
565
  )
566
 
567
  dataset_btn.click(
568
  fn=load_dataset_sample,
569
+ inputs=[class_dropdown, head_dropdown],
570
  outputs=[
571
  image_display,
572
+ main_prediction,
573
  prediction_probs,
574
  ground_truth_display,
575
  image_state,
 
578
 
579
  upload_component.upload(
580
  fn=handle_upload_preview,
581
+ inputs=[upload_component, head_dropdown],
582
  outputs=[
583
  image_display,
584
  status_text,
585
+ main_prediction,
586
  prediction_probs,
587
  ground_truth_display,
588
  image_state,
inference.py CHANGED
@@ -5,12 +5,11 @@ from __future__ import annotations
5
  import json
6
  import os
7
  from functools import lru_cache
8
- from typing import Any, Dict, List, Optional, Tuple
9
 
10
  import numpy as np
11
- import pandas as pd
12
  import torch
13
- from datasets import Dataset, DatasetDict, load_dataset
14
  from PIL import Image
15
  from torchvision import transforms
16
  from torchvision.transforms import functional as TF
 
5
  import json
6
  import os
7
  from functools import lru_cache
8
+ from typing import Any, Dict, Optional
9
 
10
  import numpy as np
 
11
  import torch
12
+ from datasets import DatasetDict, load_dataset
13
  from PIL import Image
14
  from torchvision import transforms
15
  from torchvision.transforms import functional as TF