File size: 22,574 Bytes
a0a2528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab9af50
a0a2528
0fddfa1
a0a2528
b49f319
a0a2528
 
 
 
0fddfa1
a0a2528
44a88e8
a0a2528
0fddfa1
 
 
 
 
a0a2528
 
 
 
 
 
0fddfa1
a0a2528
 
 
 
 
 
 
 
 
 
 
d76fff9
 
 
 
a0a2528
 
 
1925b95
 
 
 
 
 
 
 
 
0fddfa1
1925b95
 
 
 
05c5aa5
 
 
 
 
 
de6c079
 
 
a0a2528
a8db175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0a2528
 
 
031ad70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab9af50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
031ad70
a0a2528
 
 
 
 
a8db175
6f2fc2c
 
 
 
 
 
 
 
 
 
 
 
a8db175
6f2fc2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0a2528
 
 
467b7ba
 
 
 
 
a0a2528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0fddfa1
 
a0a2528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b49f319
 
 
 
 
 
 
a0a2528
 
 
 
 
b49f319
 
 
 
 
 
 
 
 
 
 
 
 
a0a2528
 
 
0fddfa1
a0a2528
 
0fddfa1
a0a2528
 
b49f319
0fddfa1
a0a2528
0fddfa1
 
a0a2528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8db175
 
 
 
 
 
 
 
a0a2528
 
1d3f735
1925b95
 
1d3f735
 
 
1925b95
1d3f735
 
05c5aa5
 
 
 
 
 
 
1d3f735
 
 
 
1925b95
 
a8db175
 
 
 
 
 
 
 
 
 
 
 
a0a2528
 
 
a8db175
 
 
a0a2528
a8db175
 
a0a2528
 
a8db175
a0a2528
a8db175
 
 
a0a2528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0fddfa1
a0a2528
 
 
fbe16e9
0fddfa1
a0a2528
 
 
 
 
 
 
0fddfa1
a0a2528
 
 
 
 
 
 
a8db175
de6c079
a0a2528
 
 
a8db175
 
 
 
de6c079
a8db175
a0a2528
 
0fddfa1
6f2fc2c
 
0fddfa1
 
 
6f2fc2c
a0a2528
 
1b8bf68
a0a2528
0fddfa1
 
a0a2528
1b8bf68
a0a2528
 
 
a8db175
de6c079
a0a2528
6f2fc2c
a0a2528
 
0fddfa1
de6c079
0fddfa1
 
 
 
 
 
 
 
 
 
 
 
 
a0a2528
 
 
0fddfa1
a0a2528
de6c079
0fddfa1
de6c079
a0a2528
 
0fddfa1
de6c079
 
 
cca48d3
0fddfa1
 
 
de6c079
 
0fddfa1
 
1b8bf68
de6c079
a0a2528
44a88e8
 
a0a2528
 
 
a8db175
 
 
a0a2528
a8db175
a0a2528
 
fbe16e9
 
 
 
6f2fc2c
 
a8db175
6f2fc2c
a0a2528
 
517b307
 
a0a2528
1b8bf68
517b307
a0a2528
 
1b8bf68
a0a2528
 
 
 
 
 
 
 
ab9af50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0a2528
1d3f735
a0a2528
 
 
1d3f735
a0a2528
 
1b8bf68
a0a2528
 
 
1b8bf68
a8db175
 
a0a2528
 
 
 
 
1d3f735
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0a2528
 
 
0fddfa1
a0a2528
 
 
 
 
1b8bf68
 
de6c079
a0a2528
 
 
 
a8db175
 
 
 
 
 
 
a0a2528
a8db175
a0a2528
a8db175
1925b95
 
 
1d3f735
a8db175
a0a2528
a8db175
 
a0a2528
 
 
 
a8db175
a0a2528
 
a8db175
a0a2528
 
 
 
517b307
 
 
 
a0a2528
 
 
 
a8db175
a0a2528
 
 
a8db175
a0a2528
 
 
 
517b307
a0a2528
 
0fddfa1
a0a2528
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
"""Gradio Space for exploring Curia models and CuriaBench datasets.

This application allows users to:

- Select any available Curia classification head.
- Load the matching CuriaBench test split and sample random images per class.
- Upload custom medical images that match the model's expected orientation.
- Forward images through the selected model head and visualise class probabilities.

The space expects an HF token with access to "raidium" resources to be
provided via the HF_TOKEN environment variable (configure it as a secret when
deploying to Hugging Face Spaces).
"""

from __future__ import annotations

import base64
import random
from typing import Any, Dict, List, Optional, Tuple

import cv2
import gradio as gr
import numpy as np
import pandas as pd
import torch
from datasets import Dataset
from PIL import Image
import traceback

from inference import (
    load_curia_dataset,
    load_id_to_labels,
    infer_image,
)

# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------

HEAD_OPTIONS: List[Tuple[str, str]] = [
    ("abdominal-trauma", "Active Extravasation"),
    ("anatomy-ct", "Anatomy CT"),
    ("anatomy-mri", "Anatomy MRI"),
    ("atlas-stroke", "Atlas Stroke"),
    ("covidx-ct", "COVIDx CT"),
    ("deep-lesion-site", "Deep Lesion Site"),
    ("emidec-classification-mask", "EMIDEC Classification"),
    ("ich", "Intracranial Hemorrhage"),
    ("ixi", "IXI"),
    ("kits", "KiTS"),
    ("kneeMRI", "Knee MRI"),
    ("luna16-3D", "LUNA16 3D"),
    # disable this, as we cannot share the dataset, and they need mask (so no upload)
    # ("neural_foraminal_narrowing", "Neural Foraminal Narrowing"),
    # ("spinal_canal_stenosis", "Spinal Canal Stenosis"),
    # ("subarticular_stenosis", "Subarticular Stenosis"),
    ("oasis", "OASIS"),
]

# Heads that require masks - custom image upload will be disabled for these
HEADS_REQUIRING_MASK: set[str] = {
    "anatomy-ct",
    "anatomy-mri",
    "deep-lesion-site",
    "emidec-classification-mask",
    "kits",
    "kneeMRI",
    "luna16-3D",
    "neural_foraminal_narrowing",
    "spinal_canal_stenosis",
    "subarticular_stenosis",
}

HEADS_3D = {
    "oasis",
    "luna16-3D",
    "kneeMRI",
}

REGRESSION_HEADS = {
    "ixi",
}

DATASET_OPTIONS: Dict[str, str] = {
    "anatomy-ct": "Anatomy CT (test)",
    "anatomy-ct-hard": "Anatomy CT Hard (test)",
    "anatomy-mri": "Anatomy MRI (test)",
    "covidx-ct": "COVIDx CT (test)",
    "deep-lesion-site": "Deep Lesion Site (test)",
    "emidec-classification-mask": "EMIDEC Classification Mask (test)",
    "ixi": "IXI (test)",
    "kits": "KiTS (test)",
    "kneeMRI": "Knee MRI (test)",
    "luna16-3D": "LUNA16 3D (test)",
    "oasis": "OASIS (test)",
}

DEFAULT_DATASET_FOR_HEAD: Dict[str, str] = {
    "anatomy-ct": "anatomy-ct",
    "anatomy-mri": "anatomy-mri",
    "covidx-ct": "covidx-ct",
    "deep-lesion-site": "deep-lesion-site",
    "emidec-classification-mask": "emidec-classification-mask",
    "ixi": "ixi",
    "kits": "kits",
    "kneeMRI": "kneeMRI",
    "luna16-3D": "luna16-3D",
    "oasis": "oasis",
}


# Default CT windowing for each dataset
# Format: {"window_level": center, "window_width": width} or None for MRI
# CT values are in Hounsfield Units (HU)
DEFAULT_WINDOWINGS: Dict[str, Optional[Dict[str, int]]] = {
    "anatomy-ct": {"window_level": 40, "window_width": 400},
    "anatomy-ct-hard": {"window_level": 40, "window_width": 400},
    "anatomy-mri": None,
    "atlas-stroke": None,
    "covidx-ct": {"window_level": -600, "window_width": 1500},
    "deep-lesion-site": {"window_level": 40, "window_width": 400},
    "emidec-classification-mask": None,
    "ich": {"window_level": 40, "window_width": 80},
    "ixi": None,
    "kits": {"window_level": 40, "window_width": 400},
    "kneeMRI": None,
    "luna16": {"window_level": -600, "window_width": 1500},
    "luna16-3D": {"window_level": -600, "window_width": 1500},
    "oasis": None,
}

LOGO_PATH = "Logo horizontal medium copie 4_CREME.png"

CUSTOM_CSS = """
.gr-prose { max-width: 900px; }
#app-hero {
    display: flex;
    align-items: center;
    gap: 2.5rem;
    margin-bottom: 1.5rem;
    padding-right: 1.5rem;
}
#app-hero .hero-text {
    flex: 1;
    padding-right: 1rem;
}
#app-hero .hero-text h1 {
    font-size: 2.25rem;
    margin-bottom: 0.5rem;
}
#app-hero .hero-text p {
    margin: 0.25rem 0;
    line-height: 1.5;
}
#app-hero .hero-logo img {
    max-height: 60px;
    width: auto;
    display: block;
}
@media (max-width: 768px) {
    #app-hero {
        flex-direction: column;
        text-align: center;
        padding-right: 0;
    }
    #app-hero .hero-text {
        padding-right: 0;
    }
    #app-hero .hero-text h1,
    #app-hero .hero-text p {
        text-align: center;
    }
    #app-hero .hero-logo img {
        margin: 0 auto 1rem;
    }
}
"""


def load_logo_data_uri() -> str:
    try:
        with open(LOGO_PATH, "rb") as logo_file:
            encoded = base64.b64encode(logo_file.read()).decode("ascii")
        return f"data:image/png;base64,{encoded}"
    except FileNotFoundError:
        return ""


LOGO_DATA_URI = load_logo_data_uri()


# ---------------------------------------------------------------------------
# Utility helpers
# ---------------------------------------------------------------------------


def apply_windowing(image: np.ndarray, head: str) -> np.ndarray:
    """Apply CT windowing based on the dataset.

    For CT images, applies window level and width transformation.
    For MRI images (windowing=None), returns the image unchanged.

    Args:
        image: Raw image array (e.g., in Hounsfield Units for CT)
        subset: Dataset subset name to determine windowing parameters

    Returns:
        Windowed image array
    """
    windowing = DEFAULT_WINDOWINGS.get(head)

    # No windowing for MRI or unknown datasets
    if windowing is None:
        return image

    window_level = windowing["window_level"]
    window_width = windowing["window_width"]

    # Apply CT windowing transformation
    # Convert window level/width to min/max values
    window_min = window_level - window_width / 2
    window_max = window_level + window_width / 2

    # Clip and normalize to [0, 1] range
    windowed = np.clip(image, window_min, window_max)
    windowed = (windowed - window_min) / (window_max - window_min)

    return windowed.astype(np.float32)


def to_display_image(image: np.ndarray) -> np.ndarray:
    """Normalise image for display purposes (uint8, 3-channel)."""

    # if image is 3D, keep the middle slice
    if image.ndim == 3:
        gr.Info(f"Image is 3D, we display only the middle slice")
        image = image[:, :, image.shape[2] // 2]

    arr = np.array(image, copy=True)
    if not np.isfinite(arr).all():
        arr = np.nan_to_num(arr, nan=0.0)

    arr_min = float(arr.min())
    arr_max = float(arr.max())
    if arr_max - arr_min > 1e-6:
        arr = (arr - arr_min) / (arr_max - arr_min)
    else:
        arr = np.zeros_like(arr)

    arr = (arr * 255).clip(0, 255).astype(np.uint8)
    if arr.ndim == 2:
        arr = np.stack([arr, arr, arr], axis=-1)
    return arr


def prepare_mask_tensor(mask: np.ndarray, height: int, width: int) -> Optional[torch.Tensor]:
    arr = np.squeeze(mask)
    if arr.ndim == 2:
        arr = arr.reshape(1, height, width)
    else:
        if arr.shape[-2:] == (height, width):
            arr = arr.reshape(-1, height, width)
        elif arr.shape[0] == height and arr.shape[1] == width:
            arr = np.transpose(arr, (2, 0, 1))
        elif arr.shape[1] == height and arr.shape[2] == width:
            arr = arr.reshape(arr.shape[0], height, width)
        elif arr.size % (height * width) == 0:
            try:
                arr = arr.reshape(-1, height, width)
            except ValueError:
                return None
        else:
            return None

    mask_tensors: List[torch.Tensor] = []
    for idx, slice_arr in enumerate(arr):
        bool_mask = torch.from_numpy(slice_arr > 0)
        if bool_mask.any():
            mask_tensors.append(bool_mask)

    if not mask_tensors:
        return None

    stacked = torch.stack(mask_tensors, dim=0).bool()
    return stacked


def apply_contour_overlay(
    image: np.ndarray,
    mask: Any,
    thickness: int = 1,
    color: Tuple[int, int, int] = (255, 0, 0),
) -> np.ndarray:
    """Draw only the contours of segmentation masks instead of filled masks."""
    height, width = image.shape[:2]
    mask_tensor = prepare_mask_tensor(mask, height, width)
    if mask_tensor is None:
        return image

    # Work on a copy of the image
    output = image.copy()

    # Process each mask separately
    for idx in range(mask_tensor.shape[0]):
        mask_np = mask_tensor[idx].numpy().astype(np.uint8)

        # Find contours
        contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        # Draw contours on the image
        cv2.drawContours(output, contours, -1, color, thickness)

    return output


def render_image_with_mask_info(image: np.ndarray, mask: Any) -> np.ndarray:
    display = to_display_image(image)
    if mask is None:
        return display

    try:
        overlaid = apply_contour_overlay(display, mask)
        return overlaid
    except Exception:
        gr.Warning("Mask provided but could not be visualised.")
        return display


def pick_random_indices(dataset: Dataset, target: Optional[int]) -> int:
    if "target" not in dataset.column_names:
        return random.randrange(len(dataset))

    if target is None:
        return random.randrange(len(dataset))

    indices = [idx for idx, value in enumerate(dataset["target"]) if value == target]
    if not indices:
        return random.randrange(len(dataset))
    return random.choice(indices)


# ---------------------------------------------------------------------------
# Gradio callbacks
# ---------------------------------------------------------------------------



def update_dataset_display(head: str) -> str:
    """Update the dataset name display based on the selected head."""
    dataset_key = DEFAULT_DATASET_FOR_HEAD.get(head)
    if dataset_key:
        dataset_label = DATASET_OPTIONS.get(dataset_key, dataset_key)
        return f"**Dataset:** {dataset_label}"
    return "**Dataset:** not available"


def update_upload_component_state(head: str) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    """Disable upload component for heads that require masks."""
    if head in HEADS_REQUIRING_MASK:
        info_update = gr.update(
            value="⚠️ Custom image upload is disabled for this task because it requires a mask from the dataset.",
            visible=True,
        )
        upload_update = gr.update(interactive=False)
        return info_update, upload_update
    elif head in HEADS_3D:
        info_update = gr.update(
            value="⚠️ Custom image upload is disabled for this task because it requires a 3D image.",
            visible=True,
        )
        upload_update = gr.update(interactive=False)
        return info_update, upload_update

    info_update = gr.update(visible=False)
    upload_update = gr.update(interactive=True)
    return info_update, upload_update


def load_dataset_metadata(head: str) -> Tuple[Dict[str, Any], str, Dict[str, Any]]:
    """Load dataset metadata based on the selected head."""
    subset = DEFAULT_DATASET_FOR_HEAD.get(head)
    if not subset:
        dropdown = gr.update(choices=["Random"], value="Random", interactive=False)
        button = gr.update(interactive=False)
        return dropdown, "No dataset found for this head.", button

    # Load class labels from id_to_labels.json
    id2label = load_id_to_labels().get(head, {})


    try:
        dataset = load_curia_dataset(subset)
    except Exception as exc:  # pragma: no cover - surfaced in UI
        dropdown = gr.update(choices=["Random"], value="Random", interactive=False)
        button = gr.update(interactive=False)
        return dropdown, f"Failed to load dataset: {exc}", button

    # Build dropdown options from id_to_labels.json
    classes = sorted(id2label.keys())
    options = [
        "Random",
        *[f"{cls_id}: {id2label[cls_id]}" for cls_id in classes],
    ]
    dropdown = gr.update(choices=options, value="Random", interactive=True)
    button = gr.update(interactive=True)
    return dropdown, f"Loaded {subset} ({len(dataset)} test samples)", button


def parse_target_selection(selection: str) -> Optional[int]:
    if not selection or selection == "Random":
        return None

    try:
        target_str = selection.split(":", 1)[0].strip()
        return int(target_str)
    except (ValueError, AttributeError):
        return None


def sample_dataset_example(
    subset: str,
    target_id: Optional[int],
) -> Tuple[np.ndarray, Dict[str, Any]]:
    dataset = load_curia_dataset(subset)
    index = pick_random_indices(dataset, target_id)
    record = dataset[index]
    image = np.array(record["image"]).astype(np.float32)
    mask_array = record.get("mask")

    meta = {
        "index": index,
        "target": record.get("target"),
        "mask": mask_array,
    }

    return image, meta


def load_dataset_sample(
    target_selection: str,
    head: str,
) -> Tuple[
    Optional[np.ndarray],
    str,
    Dict[str, Any],
    Dict[str, Any],
    Optional[Dict[str, Any]],
]:
    """Load a dataset sample based on the selected head."""
    subset = DEFAULT_DATASET_FOR_HEAD.get(head)
    if not subset:
        gr.Warning("No dataset found for this head.")
        return None, "", gr.update(visible=False), gr.update(visible=False), None

    try:
        target_id = parse_target_selection(target_selection)
        image, meta = sample_dataset_example(subset, target_id)
        # Apply windowing only for display, keep raw image for model inference
        windowed_image = apply_windowing(image, subset)
        display = to_display_image(windowed_image)
        if meta.get("mask") is not None:
            display = apply_contour_overlay(display, meta.get("mask"))

        target = meta.get("target")
        # Generate ground truth display
        ground_truth_update = gr.update(value="")
        if target is not None:
            # Use id_to_labels.json mapping
            id2label = load_id_to_labels().get(head, {})
            label_name = id2label.get(target, str(target))
            ground_truth_update = gr.update(value=f"{label_name} (class {target})", visible=True)

        return (
            display,
            "",  # Reset prediction text
            gr.update(visible=False),
            ground_truth_update,
            {"image": image, "mask": meta.get("mask")},  # Store raw image for inference
        )
    except Exception as exc:  # pragma: no cover - surfaced in UI
        gr.Warning(f"Failed to load sample: {exc}")
        return None, "", gr.update(visible=False), gr.update(visible=False), None


def format_probabilities(probs: torch.Tensor, id2label: Dict[int, str]) -> pd.DataFrame:
    """Return a dataframe sorted by probability desc."""

    values = probs.detach().cpu().numpy()
    rows = [
        {"class_id": idx, "label": id2label.get(idx, str(idx)), "probability": float(val)}
        for idx, val in enumerate(values)
    ]
    df = pd.DataFrame(rows)
    df.sort_values("probability", ascending=False, inplace=True)
    return df


def run_inference(
    image_state: Optional[Dict[str, Any]],
    head: str,
) -> Tuple[str, Dict[str, Any]]:
    if not image_state or "image" not in image_state:
        return "Load a dataset sample or upload an image first.", gr.update(visible=False)

    try:
        image = image_state["image"]
        output = infer_image(image, head, image_state.get("mask"), return_probs=head not in REGRESSION_HEADS)

        if head in REGRESSION_HEADS:
            return f"{output:.1f}", gr.update(visible=False)

        # Use id_to_labels.json mapping, fall back to model config if not available
        id2label = load_id_to_labels().get(head, {})

        df = format_probabilities(output, id2label)
        top_row = df.iloc[0]
        prediction = f"{top_row['label']} (p={top_row['probability']:.3f})"
        result_text = prediction
        return result_text, gr.update(visible=True, value=df)
    except Exception as exc:  # pragma: no cover - surfaced in UI
        traceback.print_exc()
        return f"Failed to run inference: {exc}", gr.update(visible=False)

def handle_upload_preview(
    image: np.ndarray | Image.Image | None,
    head: str,
) -> Tuple[Optional[np.ndarray], str, str, pd.DataFrame, Dict[str, Any], Optional[Dict[str, Any]]]:
    """Handle image upload preview, deriving dataset from head."""
    if image is None:
        return None, "Please upload an image.", "", pd.DataFrame(), gr.update(visible=False), None

    try:
        np_image = np.array(image).astype(np.float32)
        if np_image.ndim == 3: # RGB image
            # convert to grayscale
            np_image = np_image.mean(axis=-1)

        # Apply windowing only for display, keep raw image for model inference
        display = to_display_image(np_image)

        return (
            display,
            "Image uploaded. Computing predictions...",
            "",
            pd.DataFrame(),
            gr.update(value=""),
            {"image": np_image, "mask": None},
        )
    except Exception as exc:  # pragma: no cover - surfaced in UI
        return None, f"Failed to load image: {exc}", "", pd.DataFrame(),  gr.update(value=""), None


# ---------------------------------------------------------------------------
# Interface definition
# ---------------------------------------------------------------------------


def build_demo() -> gr.Blocks:
    with gr.Blocks(css=CUSTOM_CSS) as demo:
        logo_block = ""
        if LOGO_DATA_URI:
            logo_block = f'<div class="hero-logo"><img src="{LOGO_DATA_URI}" alt="Curia logo" /></div>'
        hero_html = f"""
        <div id=\"app-hero\">
            {logo_block}
            <div class=\"hero-text\">
                <h1>Curia Model Playground</h1>
                <p>Experiment with the multi-head Curia models on CuriaBench evaluation data or your own medical images.</p>
                <p>Each head expects a single 2D slice in the Curia-defined plane/orientation (PL axial, IL coronal, IP sagittal) with raw Hounsfield units (CT) or normalised MRI intensities.</p>
            </div>
        </div>
        """
        gr.HTML(hero_html)

        default_head = "kits"
        head_dropdown = gr.Dropdown(
            label="Model head",
            choices=[(label, key) for key, label in HEAD_OPTIONS],
            value=default_head,
        )

        # gr.Markdown("---")

        with gr.Row():
            with gr.Column():
                # gr.Markdown("### Load dataset sample")
                dataset_display = gr.Markdown(f"**Dataset:** {DATASET_OPTIONS.get(DEFAULT_DATASET_FOR_HEAD.get(default_head, ''), 'Unknown')}")
                dataset_status = gr.Markdown("Select a model head to load class metadata.")
                class_dropdown = gr.Dropdown(label="Target class filter", choices=["Random"], value="Random")
                dataset_btn = gr.Button("Load dataset sample")

            with gr.Column():
                gr.Markdown("### Upload custom image")
                # Set initial state based on default head
                initial_requires_mask = default_head in HEADS_REQUIRING_MASK
                upload_info_text = gr.Markdown(
                    value=(
                        "⚠️ Custom image upload is disabled for this task because it requires a mask from the dataset."
                        if initial_requires_mask
                        else ""
                    ),
                    visible=initial_requires_mask,
                )
                upload_component = gr.Image(
                    label="Upload image",
                    image_mode="L",
                    type="numpy",
                    interactive=not initial_requires_mask,
                )

        gr.Markdown("---")

        status_text = gr.Markdown()
        with gr.Row():
            with gr.Column():
                image_display = gr.Image(label="Image", interactive=False, type="numpy")

            with gr.Column():
                ground_truth_display = gr.Textbox(label="Ground Truth", interactive=False)
                main_prediction = gr.Textbox(label="Prediction", value="", interactive=False)
                prediction_probs = gr.Dataframe(headers=["class_id", "label", "probability"], visible=False)

        image_state = gr.State()

        # Event wiring
        # Initialize on page load
        demo.load(
            fn=load_dataset_metadata,
            inputs=[head_dropdown],
            outputs=[class_dropdown, dataset_status, dataset_btn],
        )
        
        head_dropdown.change(
            fn=update_dataset_display,
            inputs=[head_dropdown],
            outputs=[dataset_display],
        ).then(
            fn=update_upload_component_state,
            inputs=[head_dropdown],
            outputs=[upload_info_text, upload_component],
        ).then(
            fn=load_dataset_metadata,
            inputs=[head_dropdown],
            outputs=[class_dropdown, dataset_status, dataset_btn],
        )

        dataset_btn.click(
            fn=load_dataset_sample,
            inputs=[class_dropdown, head_dropdown],
            outputs=[
                image_display,
                main_prediction,
                prediction_probs,
                ground_truth_display,
                image_state,
            ],
        ).then(
            fn=run_inference,
            inputs=[image_state, head_dropdown],
            outputs=[main_prediction, prediction_probs],
        )

        upload_component.upload(
            fn=handle_upload_preview,
            inputs=[upload_component, head_dropdown],
            outputs=[
                image_display,
                status_text,
                main_prediction,
                prediction_probs,
                ground_truth_display,
                image_state,
            ],
        ).then(
            fn=run_inference,
            inputs=[image_state, head_dropdown],
            outputs=[main_prediction, prediction_probs],
        )

    return demo


demo = build_demo()

if __name__ == "__main__":
    demo.launch()