cdancette commited on
Commit
a0a2528
·
1 Parent(s): a1b3940

first commit

Browse files
Files changed (2) hide show
  1. app.py +569 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio Space for exploring Curia models and CuriaBench datasets.
2
+
3
+ This application allows users to:
4
+
5
+ - Select any available Curia classification head.
6
+ - Load the matching CuriaBench test split and sample random images per class.
7
+ - Upload custom medical images that match the model's expected orientation.
8
+ - Forward images through the selected model head and visualise class probabilities.
9
+
10
+ The space expects an HF token with access to "raidium" resources to be
11
+ provided via the HF_TOKEN environment variable (configure it as a secret when
12
+ deploying to Hugging Face Spaces).
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import os
18
+ import random
19
+ from functools import lru_cache
20
+ from typing import Any, Dict, List, Optional, Tuple, Union
21
+
22
+ import gradio as gr
23
+ import numpy as np
24
+ import pandas as pd
25
+ import torch
26
+ from datasets import Dataset, DatasetDict, IterableDataset, load_dataset
27
+ from PIL import Image
28
+ from transformers import (
29
+ AutoImageProcessor,
30
+ AutoModelForImageClassification,
31
+ )
32
+ from torchvision.utils import draw_segmentation_masks
33
+
34
+
35
+ HF_REPO_ID = "raidium/curia"
36
+ HF_DATASET_ID = "raidium/CuriaBench"
37
+
38
+
39
+ # ---------------------------------------------------------------------------
40
+ # Configuration
41
+ # ---------------------------------------------------------------------------
42
+
43
+ HEAD_OPTIONS: List[Tuple[str, str]] = [
44
+ ("anatomy-ct", "Anatomy CT"),
45
+ ("anatomy-mri", "Anatomy MRI"),
46
+ ("atlas-stroke", "Atlas Stroke"),
47
+ ("covidx-ct", "COVIDx CT"),
48
+ ("deep-lesion-site", "Deep Lesion Site"),
49
+ ("emidec-classification-mask", "EMIDEC Classification"),
50
+ ("ich", "Intracranial Hemorrhage"),
51
+ ("ixi", "IXI"),
52
+ ("kits", "KiTS"),
53
+ ("kneeMRI", "Knee MRI"),
54
+ ("luna16-3D", "LUNA16 3D"),
55
+ ("neural_foraminal_narrowing", "Neural Foraminal Narrowing"),
56
+ ("oasis", "OASIS"),
57
+ ("spinal_canal_stenosis", "Spinal Canal Stenosis"),
58
+ ("subarticular_stenosis", "Subarticular Stenosis"),
59
+ ]
60
+
61
+
62
+ DATASET_OPTIONS: Dict[str, Dict[str, Any]] = {
63
+ "anatomy-ct": {"label": "Anatomy CT (test)", "head": "anatomy-ct"},
64
+ "anatomy-ct-hard": {"label": "Anatomy CT Hard (test)", "head": "anatomy-ct"},
65
+ "anatomy-mri": {"label": "Anatomy MRI (test)", "head": "anatomy-mri"},
66
+ "covidctset": {"label": "COVID CT Set (test)", "head": "covidx-ct"},
67
+ "covidx-ct": {"label": "COVIDx CT (test)", "head": "covidx-ct"},
68
+ "deep-lesion-site": {"label": "Deep Lesion Site (test)", "head": "deep-lesion-site"},
69
+ "emidec-classification-mask": {
70
+ "label": "EMIDEC Classification Mask (test)",
71
+ "head": "emidec-classification-mask",
72
+ },
73
+ "ixi": {"label": "IXI (test)", "head": "ixi"},
74
+ "kits": {"label": "KiTS (test)", "head": "kits"},
75
+ "kneeMRI": {"label": "Knee MRI (test)", "head": "kneeMRI"},
76
+ "luna16": {"label": "LUNA16 (test)", "head": "luna16-3D"},
77
+ "luna16-3D": {"label": "LUNA16 3D (test)", "head": "luna16-3D"},
78
+ "oasis": {"label": "OASIS (test)", "head": "oasis"},
79
+ }
80
+
81
+
82
+ # ---------------------------------------------------------------------------
83
+ # Utility helpers
84
+ # ---------------------------------------------------------------------------
85
+
86
+
87
+ def resolve_token() -> Optional[str]:
88
+ """Return the Hugging Face token if configured."""
89
+
90
+ return os.environ.get("HF_TOKEN")
91
+
92
+
93
+ @lru_cache(maxsize=1)
94
+ def load_processor() -> AutoImageProcessor:
95
+ token = resolve_token()
96
+ return AutoImageProcessor.from_pretrained(HF_REPO_ID, trust_remote_code=True, token=token)
97
+
98
+
99
+ @lru_cache(maxsize=len(HEAD_OPTIONS))
100
+ def load_model(head: str) -> AutoModelForImageClassification:
101
+ token = resolve_token()
102
+ model = AutoModelForImageClassification.from_pretrained(
103
+ HF_REPO_ID,
104
+ trust_remote_code=True,
105
+ subfolder=head,
106
+ token=token,
107
+ )
108
+ model.eval()
109
+ return model
110
+
111
+
112
+ @lru_cache(maxsize=len(DATASET_OPTIONS))
113
+ def load_curia_dataset(subset: str) -> Any:
114
+ token = resolve_token()
115
+ ds = load_dataset(
116
+ HF_DATASET_ID,
117
+ subset,
118
+ split="test",
119
+ token=token,
120
+ )
121
+ if isinstance(ds, DatasetDict):
122
+ return ds["test"]
123
+ return ds
124
+
125
+
126
+ def to_numpy_image(image: Any) -> np.ndarray:
127
+ """Convert dataset or user-provided imagery to a float32 numpy array."""
128
+
129
+ if isinstance(image, np.ndarray):
130
+ arr = image
131
+ elif isinstance(image, Image.Image):
132
+ arr = np.array(image)
133
+ else:
134
+ # Some datasets provide nested dicts or lists – attempt to coerce.
135
+ arr = np.array(image)
136
+
137
+ if arr.ndim == 3 and arr.shape[-1] == 3:
138
+ # Convert RGB to grayscale by averaging channels
139
+ arr = arr.mean(axis=-1)
140
+
141
+ if arr.ndim != 2:
142
+ raise ValueError("Expected a 2D image (H, W). Please provide a single axial/coronal/sagittal slice.")
143
+
144
+ return arr.astype(np.float32)
145
+
146
+
147
+ def to_display_image(image: np.ndarray) -> np.ndarray:
148
+ """Normalise image for display purposes (uint8, 3-channel)."""
149
+
150
+ arr = np.array(image, copy=True)
151
+ if not np.isfinite(arr).all():
152
+ arr = np.nan_to_num(arr, nan=0.0)
153
+
154
+ arr_min = float(arr.min())
155
+ arr_max = float(arr.max())
156
+ if arr_max - arr_min > 1e-6:
157
+ arr = (arr - arr_min) / (arr_max - arr_min)
158
+ else:
159
+ arr = np.zeros_like(arr)
160
+
161
+ arr = (arr * 255).clip(0, 255).astype(np.uint8)
162
+ if arr.ndim == 2:
163
+ arr = np.stack([arr, arr, arr], axis=-1)
164
+ return arr
165
+
166
+
167
+ def coerce_mask_array(mask: Any) -> Optional[np.ndarray]:
168
+ if mask is None:
169
+ return None
170
+
171
+ try:
172
+ arr = np.array(mask)
173
+ except Exception:
174
+ return None
175
+
176
+ if arr.size == 0:
177
+ return None
178
+ return arr
179
+
180
+
181
+ def prepare_mask_tensor(mask: Any, height: int, width: int) -> Optional[torch.Tensor]:
182
+ mask_array = coerce_mask_array(mask)
183
+ if mask_array is None:
184
+ return None
185
+
186
+ arr = np.squeeze(mask_array)
187
+ if arr.ndim == 2:
188
+ arr = arr.reshape(1, height, width)
189
+ else:
190
+ if arr.shape[-2:] == (height, width):
191
+ arr = arr.reshape(-1, height, width)
192
+ elif arr.shape[0] == height and arr.shape[1] == width:
193
+ arr = np.transpose(arr, (2, 0, 1))
194
+ elif arr.shape[1] == height and arr.shape[2] == width:
195
+ arr = arr.reshape(arr.shape[0], height, width)
196
+ elif arr.size % (height * width) == 0:
197
+ try:
198
+ arr = arr.reshape(-1, height, width)
199
+ except ValueError:
200
+ return None
201
+ else:
202
+ return None
203
+
204
+ mask_tensors: List[torch.Tensor] = []
205
+ for idx, slice_arr in enumerate(arr):
206
+ bool_mask = torch.from_numpy(slice_arr > 0)
207
+ if bool_mask.any():
208
+ mask_tensors.append(bool_mask)
209
+
210
+ if not mask_tensors:
211
+ return None
212
+
213
+ stacked = torch.stack(mask_tensors, dim=0).bool()
214
+ return stacked
215
+
216
+
217
+ def apply_mask_overlay(image: np.ndarray, mask: Any) -> np.ndarray:
218
+ height, width = image.shape[:2]
219
+ mask_tensor = prepare_mask_tensor(mask, height, width)
220
+ if mask_tensor is None:
221
+ return image
222
+
223
+ img_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
224
+ overlaid = draw_segmentation_masks(img_tensor, mask_tensor, colors=[(255, 0, 0)], alpha=0.4)
225
+ output = overlaid.permute(1, 2, 0).mul(255).clamp(0, 255).byte().numpy()
226
+ return output
227
+
228
+
229
+ def render_image_with_mask_info(image: np.ndarray, mask: Any) -> Tuple[np.ndarray, Optional[str]]:
230
+ display = to_display_image(image)
231
+ if mask is None:
232
+ return display, None
233
+
234
+ try:
235
+ overlaid = apply_mask_overlay(display, mask)
236
+ return overlaid, ""
237
+ except Exception:
238
+ return display, "Mask provided but could not be visualised."
239
+
240
+
241
+ def dataset_class_metadata(dataset: Dataset) -> Tuple[List[int], Dict[int, str]]:
242
+ target_feature = dataset.features.get("target")
243
+ if target_feature and hasattr(target_feature, "names"):
244
+ names = list(target_feature.names)
245
+ id2label = {i: name for i, name in enumerate(names)}
246
+ classes = list(range(len(names)))
247
+ return classes, id2label
248
+
249
+ # Fall back to generic inspection
250
+ targets = dataset["target"] if "target" in dataset.column_names else []
251
+ unique = sorted({int(t) for t in targets}) if targets else []
252
+ id2label = {i: str(i) for i in unique}
253
+ return unique, id2label
254
+
255
+
256
+ def pick_random_indices(dataset: Dataset, target: Optional[int]) -> int:
257
+ if "target" not in dataset.column_names:
258
+ return random.randrange(len(dataset))
259
+
260
+ if target is None:
261
+ return random.randrange(len(dataset))
262
+
263
+ indices = [idx for idx, value in enumerate(dataset["target"]) if value == target]
264
+ if not indices:
265
+ return random.randrange(len(dataset))
266
+ return random.choice(indices)
267
+
268
+
269
+ def format_probabilities(probs: torch.Tensor, id2label: Dict[int, str]) -> pd.DataFrame:
270
+ """Return a dataframe sorted by probability desc."""
271
+
272
+ values = probs.detach().cpu().numpy()
273
+ rows = [
274
+ {"class_id": idx, "label": id2label.get(idx, str(idx)), "probability": float(val)}
275
+ for idx, val in enumerate(values)
276
+ ]
277
+ df = pd.DataFrame(rows)
278
+ df.sort_values("probability", ascending=False, inplace=True)
279
+ return df
280
+
281
+
282
+ def infer_image(
283
+ image: np.ndarray,
284
+ head: str,
285
+ ) -> Tuple[str, pd.DataFrame]:
286
+ processor = load_processor()
287
+ model = load_model(head)
288
+ with torch.no_grad():
289
+ processed = processor(images=image, return_tensors="pt")
290
+ outputs = model(**processed)
291
+ print(outputs)
292
+ logits = outputs["logits"]
293
+ probs = torch.nn.functional.softmax(logits[0], dim=-1)
294
+
295
+ id2label = model.config.id2label or {}
296
+ df = format_probabilities(probs, id2label)
297
+ top_row = df.iloc[0]
298
+ prediction = f"{top_row['label']} (p={top_row['probability']:.3f})"
299
+ return prediction, df
300
+
301
+
302
+ # ---------------------------------------------------------------------------
303
+ # Gradio callbacks
304
+ # ---------------------------------------------------------------------------
305
+
306
+
307
+ def update_dataset_from_head(head: str) -> Dict[str, Any]:
308
+ # Find the first dataset that matches this head
309
+ for dataset_key, meta in DATASET_OPTIONS.items():
310
+ if meta["head"] == head:
311
+ return gr.update(value=dataset_key)
312
+ return gr.update()
313
+
314
+
315
+ def load_dataset_metadata(subset: str) -> Tuple[Dict[str, Any], str]:
316
+ try:
317
+ dataset = load_curia_dataset(subset)
318
+ except Exception as exc: # pragma: no cover - surfaced in UI
319
+ dropdown = gr.update(choices=["Random"], value="Random")
320
+ return dropdown, f"Failed to load dataset: {exc}"
321
+
322
+ classes, id2label = dataset_class_metadata(dataset)
323
+ if not classes:
324
+ dropdown = gr.update(
325
+ choices=["Random"],
326
+ value="Random",
327
+ )
328
+ return dropdown, "No class metadata detected; sampling at random"
329
+
330
+ options = [
331
+ "Random",
332
+ *[f"{cls_id}: {id2label.get(cls_id, str(cls_id))}" for cls_id in classes],
333
+ ]
334
+ dropdown = gr.update(choices=options, value="Random")
335
+ return dropdown, f"Loaded {subset} ({len(dataset)} test samples)"
336
+
337
+
338
+ def parse_target_selection(selection: str) -> Optional[int]:
339
+ if not selection or selection == "Random":
340
+ return None
341
+
342
+ try:
343
+ target_str = selection.split(":", 1)[0].strip()
344
+ return int(target_str)
345
+ except (ValueError, AttributeError):
346
+ return None
347
+
348
+
349
+ def sample_dataset_example(
350
+ subset: str,
351
+ target_id: Optional[int],
352
+ ) -> Tuple[np.ndarray, str, Dict[str, Any]]:
353
+ dataset = load_curia_dataset(subset)
354
+ index = pick_random_indices(dataset, target_id)
355
+ record = dataset[index]
356
+ image = to_numpy_image(record["image"])
357
+ mask_array = coerce_mask_array(record.get("mask"))
358
+
359
+ meta = {
360
+ "index": index,
361
+ "target": record.get("target"),
362
+ "mask": mask_array,
363
+ }
364
+
365
+ return image, f"Sample #{index}", meta
366
+
367
+
368
+ def load_dataset_sample(
369
+ subset: str,
370
+ target_selection: str,
371
+ head: str,
372
+ ) -> Tuple[
373
+ Optional[np.ndarray],
374
+ str,
375
+ pd.DataFrame,
376
+ Dict[str, Any],
377
+ Optional[Dict[str, Any]],
378
+ ]:
379
+ try:
380
+ target_id = parse_target_selection(target_selection)
381
+ image, caption, meta = sample_dataset_example(subset, target_id)
382
+ display, mask_msg = render_image_with_mask_info(image, meta.get("mask"))
383
+ target = meta.get("target")
384
+ meta_text = caption
385
+ if target is not None:
386
+ meta_text += f" | target={target}"
387
+ status = "Image loaded. Click 'Run inference' to compute predictions."
388
+ if mask_msg:
389
+ status += f" {mask_msg}"
390
+ meta_text = status + "\n\n" + meta_text
391
+
392
+ # Generate ground truth display
393
+ ground_truth_update = gr.update(visible=False)
394
+ if target is not None:
395
+ model = load_model(head)
396
+ id2label = model.config.id2label or {}
397
+ label_name = id2label.get(target, str(target))
398
+ ground_truth_update = gr.update(value=f"**Ground Truth:** {label_name} (class {target})", visible=True)
399
+
400
+ return (
401
+ display,
402
+ meta_text,
403
+ pd.DataFrame(),
404
+ ground_truth_update,
405
+ {"image": image, "mask": meta.get("mask")},
406
+ )
407
+ except Exception as exc: # pragma: no cover - surfaced in UI
408
+ return None, f"Failed to load sample: {exc}", pd.DataFrame(), gr.update(visible=False), None
409
+
410
+
411
+ def run_inference(
412
+ sample_state: Optional[Dict[str, Any]],
413
+ head: str,
414
+ ) -> Tuple[str, pd.DataFrame]:
415
+ if not sample_state or "image" not in sample_state:
416
+ return "Load a dataset sample or upload an image first.", pd.DataFrame()
417
+
418
+ try:
419
+ image = sample_state["image"]
420
+ prediction, df = infer_image(image, head)
421
+ result_text = f"**Prediction:** {prediction}"
422
+ return result_text, df
423
+ except Exception as exc: # pragma: no cover - surfaced in UI
424
+ return f"Failed to run inference: {exc}", pd.DataFrame()
425
+
426
+
427
+ def handle_upload_preview(
428
+ image: np.ndarray | Image.Image | None,
429
+ ) -> Tuple[Optional[np.ndarray], str, pd.DataFrame, Dict[str, Any], Optional[Dict[str, Any]]]:
430
+ if image is None:
431
+ return None, "Please upload an image.", pd.DataFrame(), gr.update(visible=False), None
432
+
433
+ try:
434
+ np_image = to_numpy_image(image)
435
+ display = to_display_image(np_image)
436
+ return (
437
+ display,
438
+ "Image uploaded. Click 'Run inference' to compute predictions.",
439
+ pd.DataFrame(),
440
+ gr.update(visible=False),
441
+ {"image": np_image, "mask": None},
442
+ )
443
+ except Exception as exc: # pragma: no cover - surfaced in UI
444
+ return None, f"Failed to load image: {exc}", pd.DataFrame(), gr.update(visible=False), None
445
+
446
+
447
+ # ---------------------------------------------------------------------------
448
+ # Interface definition
449
+ # ---------------------------------------------------------------------------
450
+
451
+
452
+ def build_demo() -> gr.Blocks:
453
+ with gr.Blocks(css=".gr-prose { max-width: 900px; }") as demo:
454
+ gr.Markdown(
455
+ """
456
+ # Curia Model Playground
457
+
458
+ Experiment with the multi-head Curia models on CuriaBench evaluation data or
459
+ your own medical images. Each head expects a single 2D slice in the
460
+ corresponding plane/orientation as defined for Curia (PL for axial, IL for
461
+ coronal, IP for sagittal). Ensure images are unwindowed and either raw
462
+ Hounsfield units (CT) or normalised intensity values (MRI).
463
+ """
464
+ )
465
+
466
+ head_dropdown = gr.Dropdown(
467
+ label="Model head",
468
+ choices=[(label, key) for key, label in HEAD_OPTIONS],
469
+ value="anatomy-ct",
470
+ )
471
+
472
+ gr.Markdown("---")
473
+
474
+ with gr.Row():
475
+ with gr.Column():
476
+ gr.Markdown("### Load dataset sample")
477
+ dataset_dropdown = gr.Dropdown(
478
+ label="CuriaBench subset",
479
+ choices=[(meta["label"], key) for key, meta in DATASET_OPTIONS.items()],
480
+ value="anatomy-ct",
481
+ )
482
+ dataset_status = gr.Markdown("Select a dataset to load class metadata.")
483
+ class_dropdown = gr.Dropdown(label="Target class filter", choices=["Random"], value="Random")
484
+ dataset_btn = gr.Button("Load dataset sample")
485
+
486
+ with gr.Column():
487
+ gr.Markdown("### Upload custom image")
488
+ upload_component = gr.Image(label="Upload image", image_mode="L", type="numpy")
489
+
490
+ gr.Markdown("---")
491
+
492
+ infer_btn = gr.Button("Run inference", variant="primary")
493
+
494
+ with gr.Row():
495
+ with gr.Column():
496
+ image_display = gr.Image(label="Image", interactive=False, type="numpy")
497
+ ground_truth_display = gr.Markdown(visible=False)
498
+
499
+ with gr.Column():
500
+ gr.Markdown("### Predictions")
501
+ status_text = gr.Markdown()
502
+ prediction_probs = gr.Dataframe(headers=["class_id", "label", "probability"])
503
+
504
+ image_state = gr.State()
505
+
506
+ # Event wiring
507
+ head_dropdown.change(
508
+ fn=update_dataset_from_head,
509
+ inputs=[head_dropdown],
510
+ outputs=[dataset_dropdown],
511
+ )
512
+
513
+ dataset_dropdown.change(
514
+ fn=load_dataset_metadata,
515
+ inputs=[dataset_dropdown],
516
+ outputs=[class_dropdown, dataset_status],
517
+ )
518
+
519
+ dataset_btn.click(
520
+ fn=load_dataset_sample,
521
+ inputs=[dataset_dropdown, class_dropdown, head_dropdown],
522
+ outputs=[
523
+ image_display,
524
+ status_text,
525
+ prediction_probs,
526
+ ground_truth_display,
527
+ image_state,
528
+ ],
529
+ )
530
+
531
+ upload_component.upload(
532
+ fn=handle_upload_preview,
533
+ inputs=[upload_component],
534
+ outputs=[
535
+ image_display,
536
+ status_text,
537
+ prediction_probs,
538
+ ground_truth_display,
539
+ image_state,
540
+ ],
541
+ )
542
+
543
+ infer_btn.click(
544
+ fn=run_inference,
545
+ inputs=[image_state, head_dropdown],
546
+ outputs=[status_text, prediction_probs],
547
+ )
548
+
549
+ gr.Markdown(
550
+ """
551
+ ### Notes
552
+
553
+ - Configure the `HF_TOKEN` secret in your Space to load private checkpoints
554
+ and datasets from the `raidium` organisation.
555
+ - When masks are available in the dataset sample, they are overlaid on the
556
+ image for visual reference (courtesy of `torchvision.utils.draw_segmentation_masks`).
557
+ - Uploaded images must be single-channel arrays. Multi-channel inputs are
558
+ converted to grayscale automatically.
559
+ """
560
+ )
561
+
562
+ return demo
563
+
564
+
565
+ demo = build_demo()
566
+
567
+
568
+ if __name__ == "__main__":
569
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.44.0
2
+ transformers>=4.41.0
3
+ datasets>=2.19.0
4
+ torch>=2.2.0
5
+ torchvision>=0.17.0
6
+ pandas>=2.2.0
7
+ numpy>=1.26.0
8
+ pillow>=10.2.0
9
+