tp53 commited on
Commit
49e4f07
·
verified ·
1 Parent(s): 469af7a

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +26 -338
app.py CHANGED
@@ -1,14 +1,7 @@
1
  #!/usr/bin/env python3
2
  """
3
  OncoSeg Inference API - HuggingFace Space
4
- Optimized for programmatic access from oncoseg-viewer
5
-
6
- This Space provides GPU-accelerated inference for medical image segmentation.
7
- It exposes both a Gradio UI and programmatic API endpoints.
8
-
9
- Usage from viewer:
10
- POST /api/segment_slice
11
- POST /api/segment_volume
12
  """
13
 
14
  import os
@@ -18,12 +11,9 @@ import tempfile
18
  import time
19
  import logging
20
  from pathlib import Path
21
- from typing import Optional, List, Tuple, Any
22
 
23
  import gradio as gr
24
  import numpy as np
25
- import torch
26
- import cv2
27
 
28
  # Configure logging
29
  logging.basicConfig(level=logging.INFO)
@@ -37,79 +27,12 @@ try:
37
  logger.info("ZeroGPU available")
38
  except ImportError:
39
  ZEROGPU_AVAILABLE = False
40
- logger.info("ZeroGPU not available, using standard GPU/CPU")
41
-
42
- # Device setup
43
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
44
- logger.info(f"Using device: {DEVICE}")
45
-
46
- # Global model cache
47
- MODELS = {}
48
-
49
- # Checkpoint mapping (HuggingFace Hub paths)
50
- CHECKPOINTS = {
51
- "brain": "checkpoints/medsam3-task20_brats_gli-final_latest/last.ckpt",
52
- "liver": "checkpoints/medsam3-task03_liver-final_latest/last.ckpt",
53
- "breast": "checkpoints/medsam3-task25_breastdcedl-final_latest/last.ckpt",
54
- "lung": "checkpoints/medsam3-task06_lung-final_latest/last.ckpt",
55
- "kidney": "checkpoints/medsam3-task17_kits23-final_latest/last.ckpt",
56
- "spine": "checkpoints/medsam3-task11_lctsc-final_latest/last.ckpt",
57
- }
58
-
59
- # HF Repo ID for checkpoints
60
- HF_REPO_ID = os.getenv("HF_REPO_ID", "tp53/oncoseg")
61
-
62
-
63
- # Flag to track if we're using fallback mode
64
- USE_FALLBACK = False
65
-
66
-
67
- def get_model(checkpoint: str = "brain"):
68
- """Load or retrieve cached model. Falls back to simple segmentation if SAM3 unavailable."""
69
- global MODELS, USE_FALLBACK
70
-
71
- if checkpoint not in MODELS:
72
- logger.info(f"Loading model: {checkpoint}")
73
-
74
- try:
75
- from huggingface_hub import hf_hub_download
76
 
77
- ckpt_file = CHECKPOINTS.get(checkpoint, CHECKPOINTS["brain"])
78
- ckpt_path = hf_hub_download(
79
- repo_id=HF_REPO_ID,
80
- filename=ckpt_file,
81
- )
82
- logger.info(f"Downloaded checkpoint to: {ckpt_path}")
83
 
84
- # Import model (from local model/ directory in this Space)
85
- from model.medsam3 import MedSAM3Model
86
-
87
- # Initialize model with checkpoint
88
- # Note: MedSAM3Model builds SAM3 internally and loads our LoRA weights
89
- model = MedSAM3Model(checkpoint_path=ckpt_path)
90
- model.to(DEVICE)
91
- model.eval()
92
-
93
- MODELS[checkpoint] = model
94
- logger.info(f"Model {checkpoint} loaded on {DEVICE}")
95
-
96
- except ImportError as e:
97
- logger.warning(f"SAM3 not available, using fallback segmentation: {e}")
98
- USE_FALLBACK = True
99
- MODELS[checkpoint] = None
100
- except Exception as e:
101
- logger.error(f"Failed to load model {checkpoint}: {e}")
102
- USE_FALLBACK = True
103
- MODELS[checkpoint] = None
104
-
105
- return MODELS.get(checkpoint)
106
-
107
-
108
- def fallback_segment(slice_2d: np.ndarray):
109
- """
110
- Simple intensity-based segmentation fallback when SAM3 is not available.
111
- Works well for FLAIR MRI where tumors appear hyperintense.
112
- """
113
  from skimage.filters import threshold_otsu
114
  from skimage.morphology import binary_opening, binary_closing, disk
115
 
@@ -126,51 +49,15 @@ def fallback_segment(slice_2d: np.ndarray):
126
 
127
  # Morphological cleanup
128
  try:
129
- mask = binary_opening(mask, disk(2))
130
- mask = binary_closing(mask, disk(3))
131
- except:
132
  pass
133
 
134
- return mask.astype(np.uint8)
135
-
136
-
137
- def preprocess_slice(slice_2d: np.ndarray, target_size: int = 1024) -> torch.Tensor:
138
- """
139
- Preprocess a 2D slice for SAM3 input.
140
-
141
- Args:
142
- slice_2d: Input slice (H, W)
143
- target_size: Target size for SAM3 (default 1024)
144
-
145
- Returns:
146
- Preprocessed tensor (1, 3, H, W) on DEVICE
147
- """
148
- import cv2
149
-
150
- # Normalize to [0, 1]
151
- vmin, vmax = slice_2d.min(), slice_2d.max()
152
- if vmax - vmin < 1e-8:
153
- slice_norm = np.zeros_like(slice_2d)
154
- else:
155
- slice_norm = (slice_2d - vmin) / (vmax - vmin)
156
-
157
- # Resize to target size
158
- slice_resized = cv2.resize(
159
- slice_norm.astype(np.float32), (target_size, target_size)
160
- )
161
-
162
- # Scale to [-1, 1] for SAM3
163
- slice_scaled = slice_resized * 2 - 1
164
-
165
- # Convert to 3-channel tensor (B, C, H, W)
166
- slice_tensor = torch.from_numpy(slice_scaled).float()
167
- slice_tensor = slice_tensor.unsqueeze(0).unsqueeze(0) # (1, 1, H, W)
168
- slice_tensor = slice_tensor.repeat(1, 3, 1, 1) # (1, 3, H, W)
169
 
170
- return slice_tensor.to(DEVICE)
171
 
172
-
173
- def find_contours(mask: np.ndarray) -> List[List[List[float]]]:
174
  """Extract contours from binary mask."""
175
  try:
176
  from skimage.measure import find_contours as sk_find_contours
@@ -181,7 +68,7 @@ def find_contours(mask: np.ndarray) -> List[List[List[float]]]:
181
  return []
182
 
183
 
184
- def keep_largest_component(mask: np.ndarray) -> np.ndarray:
185
  """Keep only the largest connected component."""
186
  try:
187
  from scipy import ndimage
@@ -196,8 +83,7 @@ def keep_largest_component(mask: np.ndarray) -> np.ndarray:
196
  return mask
197
 
198
 
199
- # Define the inference function with optional ZeroGPU decorator
200
- def _segment_slice_impl(
201
  nifti_b64: str,
202
  slice_idx: int,
203
  text_prompt: str = "tumor",
@@ -205,15 +91,7 @@ def _segment_slice_impl(
205
  ):
206
  """
207
  Segment a single slice from a NIfTI volume.
208
-
209
- Args:
210
- nifti_b64: Base64-encoded NIfTI file bytes
211
- slice_idx: Slice index to segment (0-indexed)
212
- text_prompt: Text prompt for segmentation (e.g., "tumor", "lesion")
213
- checkpoint: Model checkpoint name
214
-
215
- Returns:
216
- dict with keys: success, mask_b64, mask_shape, contours, slice_idx, inference_time_ms
217
  """
218
  start_time = time.time()
219
 
@@ -242,56 +120,9 @@ def _segment_slice_impl(
242
  "error": f"Slice index {slice_idx} out of range [0, {volume.shape[0]})",
243
  }
244
 
245
- # Extract slice
246
  slice_2d = volume[slice_idx]
247
- original_shape = slice_2d.shape
248
-
249
- # Load model (may return None if fallback mode)
250
- model = get_model(checkpoint)
251
-
252
- if model is None or USE_FALLBACK:
253
- # Use fallback segmentation
254
- logger.info("Using fallback segmentation (SAM3 not available)")
255
- mask = fallback_segment(slice_2d)
256
- backend = "fallback"
257
- else:
258
- # Use SAM3 model
259
- slice_tensor = preprocess_slice(
260
- slice_2d
261
- ) # (1, 3, 1024, 1024) tensor on DEVICE
262
-
263
- # Create full-image bounding box prompt (auto-segment entire image)
264
- # Format: [x_min, y_min, x_max, y_max] in pixel coordinates
265
- target_size = slice_tensor.shape[-1] # 1024
266
- input_boxes = torch.tensor(
267
- [[0, 0, target_size, target_size]], dtype=torch.float32, device=DEVICE
268
- )
269
-
270
- # Run inference with text prompt for grounding
271
- with torch.no_grad():
272
- outputs = model(
273
- pixel_values=slice_tensor,
274
- input_boxes=input_boxes,
275
- text_prompt=text_prompt,
276
- )
277
-
278
- # Extract mask from SAM3 output
279
- # SAM3 returns a dict with 'pred_masks' key, shape (B, 1, H, W)
280
- if isinstance(outputs, dict) and "pred_masks" in outputs:
281
- pred_mask = outputs["pred_masks"][0, 0].cpu().numpy()
282
- elif hasattr(outputs, "pred_masks"):
283
- pred_mask = outputs.pred_masks[0, 0].cpu().numpy()
284
- else:
285
- # Fallback: try to extract from tuple/list
286
- logger.warning(f"Unexpected output type: {type(outputs)}")
287
- pred_mask = np.zeros((target_size, target_size))
288
-
289
- # Resize mask back to original shape
290
- mask = cv2.resize(pred_mask, (original_shape[1], original_shape[0]))
291
- backend = "sam3"
292
-
293
- # Threshold to binary
294
- mask = (mask > 0.5).astype(np.uint8)
295
  mask = keep_largest_component(mask)
296
 
297
  # Extract contours
@@ -307,7 +138,7 @@ def _segment_slice_impl(
307
 
308
  return {
309
  "success": True,
310
- "backend": backend,
311
  "mask_b64": mask_b64,
312
  "mask_shape": list(mask.shape),
313
  "contours": contours,
@@ -320,146 +151,6 @@ def _segment_slice_impl(
320
  return {"success": False, "error": str(e)}
321
 
322
 
323
- def _segment_volume_impl(
324
- nifti_b64: str,
325
- text_prompt: str = "tumor",
326
- checkpoint: str = "brain",
327
- skip_empty: bool = True,
328
- min_area: int = 50,
329
- ):
330
- """
331
- Segment entire volume and return contours for all slices with detections.
332
-
333
- Args:
334
- nifti_b64: Base64-encoded NIfTI file bytes
335
- text_prompt: Text prompt for segmentation
336
- checkpoint: Model checkpoint name
337
- skip_empty: Skip mostly-empty slices
338
- min_area: Minimum mask area to report
339
-
340
- Returns:
341
- dict with keys: success, contours (dict), num_slices, slices_with_tumor, inference_time_ms
342
- """
343
- start_time = time.time()
344
-
345
- try:
346
- import nibabel as nib
347
-
348
- # Decode NIfTI
349
- nifti_bytes = base64.b64decode(nifti_b64)
350
-
351
- with tempfile.NamedTemporaryFile(suffix=".nii.gz", delete=False) as f:
352
- f.write(nifti_bytes)
353
- temp_path = f.name
354
-
355
- nii = nib.load(temp_path)
356
- volume = nii.get_fdata().astype(np.float32)
357
- os.unlink(temp_path)
358
-
359
- logger.info(f"Loaded volume shape: {volume.shape}")
360
-
361
- # Load model (may return None if fallback mode)
362
- model = get_model(checkpoint)
363
- use_fallback = model is None or USE_FALLBACK
364
-
365
- num_slices = volume.shape[0]
366
- all_contours = {}
367
-
368
- target_size = 1024
369
-
370
- for i in range(num_slices):
371
- slice_2d = volume[i]
372
- original_shape = slice_2d.shape
373
-
374
- # Skip mostly-empty slices
375
- if skip_empty and slice_2d.max() - slice_2d.min() < 0.01:
376
- continue
377
-
378
- if use_fallback:
379
- # Use fallback segmentation
380
- mask = fallback_segment(slice_2d)
381
- else:
382
- slice_tensor = preprocess_slice(slice_2d, target_size)
383
-
384
- # Create full-image bounding box
385
- input_boxes = torch.tensor(
386
- [[0, 0, target_size, target_size]],
387
- dtype=torch.float32,
388
- device=DEVICE,
389
- )
390
-
391
- with torch.no_grad():
392
- outputs = model(
393
- pixel_values=slice_tensor,
394
- input_boxes=input_boxes,
395
- text_prompt=text_prompt,
396
- )
397
-
398
- # Extract mask from SAM3 output
399
- if isinstance(outputs, dict) and "pred_masks" in outputs:
400
- pred_mask = outputs["pred_masks"][0, 0].cpu().numpy()
401
- elif hasattr(outputs, "pred_masks"):
402
- pred_mask = outputs.pred_masks[0, 0].cpu().numpy()
403
- else:
404
- continue # Skip if no valid output
405
-
406
- # Resize to original shape and threshold
407
- mask = cv2.resize(pred_mask, (original_shape[1], original_shape[0]))
408
- mask = (mask > 0.5).astype(np.uint8)
409
-
410
- if mask.sum() >= min_area:
411
- mask = keep_largest_component(mask)
412
- contours = find_contours(mask)
413
- if contours:
414
- all_contours[str(i)] = contours
415
-
416
- inference_time = int((time.time() - start_time) * 1000)
417
- logger.info(
418
- f"Segmented {num_slices} slices in {inference_time}ms, found tumor in {len(all_contours)} slices"
419
- )
420
-
421
- return {
422
- "success": True,
423
- "contours": all_contours,
424
- "num_slices": num_slices,
425
- "slices_with_tumor": list(all_contours.keys()),
426
- "inference_time_ms": inference_time,
427
- }
428
-
429
- except Exception as e:
430
- logger.error(f"Volume segmentation failed: {e}")
431
- return {"success": False, "error": str(e)}
432
-
433
-
434
- # Apply ZeroGPU decorator if available
435
- if ZEROGPU_AVAILABLE:
436
-
437
- @spaces.GPU(duration=60)
438
- def segment_slice_api(
439
- nifti_b64: str,
440
- slice_idx: int,
441
- text_prompt: str = "tumor",
442
- checkpoint: str = "brain",
443
- ):
444
- return _segment_slice_impl(nifti_b64, slice_idx, text_prompt, checkpoint)
445
-
446
- @spaces.GPU(duration=300)
447
- def segment_volume_api(
448
- nifti_b64: str,
449
- text_prompt: str = "tumor",
450
- checkpoint: str = "brain",
451
- skip_empty: bool = True,
452
- min_area: int = 50,
453
- ):
454
- return _segment_volume_impl(
455
- nifti_b64, text_prompt, checkpoint, skip_empty, min_area
456
- )
457
- else:
458
- segment_slice_api = _segment_slice_impl
459
- segment_volume_api = _segment_volume_impl
460
-
461
-
462
- # Gradio UI functions (for interactive demo)
463
  def load_and_display_nifti(file):
464
  """Load NIfTI and return middle slice for display."""
465
  if file is None:
@@ -494,7 +185,7 @@ def load_and_display_nifti(file):
494
  return None, f"Error: {e}", 0
495
 
496
 
497
- def segment_and_overlay(file, slice_idx: int, text_prompt: str, checkpoint: str):
498
  """Segment a slice and overlay the mask."""
499
  if file is None:
500
  return None, "Please upload a file first"
@@ -504,7 +195,7 @@ def segment_and_overlay(file, slice_idx: int, text_prompt: str, checkpoint: str)
504
  with open(file.name, "rb") as f:
505
  nifti_b64 = base64.b64encode(f.read()).decode()
506
 
507
- # Call segmentation API
508
  result = segment_slice_api(nifti_b64, int(slice_idx), text_prompt, checkpoint)
509
 
510
  if not result["success"]:
@@ -536,7 +227,7 @@ def segment_and_overlay(file, slice_idx: int, text_prompt: str, checkpoint: str)
536
  rgb[mask_bool, 1] = rgb[mask_bool, 1] * (1 - alpha) + 50 * alpha
537
  rgb[mask_bool, 2] = rgb[mask_bool, 2] * (1 - alpha) + 50 * alpha
538
 
539
- info = f"Segmented in {result['inference_time_ms']}ms, mask area: {mask.sum()} pixels"
540
 
541
  return rgb.astype(np.uint8), info
542
 
@@ -546,30 +237,27 @@ def segment_and_overlay(file, slice_idx: int, text_prompt: str, checkpoint: str)
546
 
547
  # Build Gradio interface
548
  def build_demo():
549
- with gr.Blocks(
550
- title="OncoSeg Inference API",
551
- theme=gr.themes.Soft(),
552
- ) as demo:
553
  gr.Markdown("""
554
  # OncoSeg Medical Image Segmentation API
555
 
556
  GPU-accelerated segmentation for CT and MRI volumes.
557
 
 
 
558
  **API Endpoints** (for programmatic access):
559
  - `POST /api/segment_slice_api` - Segment a single slice
560
- - `POST /api/segment_volume_api` - Segment entire volume
561
-
562
- **Interactive Demo** below:
563
  """)
564
 
565
  with gr.Row():
566
  with gr.Column(scale=1):
567
  file_input = gr.File(
568
- label="Upload NIfTI (.nii, .nii.gz)", file_types=[".nii", ".nii.gz"]
 
569
  )
570
 
571
  checkpoint = gr.Dropdown(
572
- choices=list(CHECKPOINTS.keys()),
573
  value="brain",
574
  label="Model Checkpoint",
575
  )
@@ -622,7 +310,7 @@ def build_demo():
622
 
623
  # Call API
624
  response = requests.post(
625
- "https://YOUR-SPACE.hf.space/api/segment_slice_api",
626
  json={
627
  "nifti_b64": nifti_b64,
628
  "slice_idx": 77,
 
1
  #!/usr/bin/env python3
2
  """
3
  OncoSeg Inference API - HuggingFace Space
4
+ Minimal version for initial deployment.
 
 
 
 
 
 
 
5
  """
6
 
7
  import os
 
11
  import time
12
  import logging
13
  from pathlib import Path
 
14
 
15
  import gradio as gr
16
  import numpy as np
 
 
17
 
18
  # Configure logging
19
  logging.basicConfig(level=logging.INFO)
 
27
  logger.info("ZeroGPU available")
28
  except ImportError:
29
  ZEROGPU_AVAILABLE = False
30
+ logger.info("ZeroGPU not available")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
 
 
 
 
 
 
32
 
33
+ def fallback_segment(slice_2d):
34
+ """Simple intensity-based segmentation."""
35
+ from scipy import ndimage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  from skimage.filters import threshold_otsu
37
  from skimage.morphology import binary_opening, binary_closing, disk
38
 
 
49
 
50
  # Morphological cleanup
51
  try:
52
+ mask = binary_opening(mask, disk(2)).astype(np.uint8)
53
+ mask = binary_closing(mask, disk(3)).astype(np.uint8)
54
+ except Exception:
55
  pass
56
 
57
+ return mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
 
59
 
60
+ def find_contours(mask):
 
61
  """Extract contours from binary mask."""
62
  try:
63
  from skimage.measure import find_contours as sk_find_contours
 
68
  return []
69
 
70
 
71
+ def keep_largest_component(mask):
72
  """Keep only the largest connected component."""
73
  try:
74
  from scipy import ndimage
 
83
  return mask
84
 
85
 
86
+ def segment_slice_api(
 
87
  nifti_b64: str,
88
  slice_idx: int,
89
  text_prompt: str = "tumor",
 
91
  ):
92
  """
93
  Segment a single slice from a NIfTI volume.
94
+ Currently uses fallback segmentation (SAM3 to be integrated).
 
 
 
 
 
 
 
 
95
  """
96
  start_time = time.time()
97
 
 
120
  "error": f"Slice index {slice_idx} out of range [0, {volume.shape[0]})",
121
  }
122
 
123
+ # Extract slice and segment
124
  slice_2d = volume[slice_idx]
125
+ mask = fallback_segment(slice_2d)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  mask = keep_largest_component(mask)
127
 
128
  # Extract contours
 
138
 
139
  return {
140
  "success": True,
141
+ "backend": "fallback",
142
  "mask_b64": mask_b64,
143
  "mask_shape": list(mask.shape),
144
  "contours": contours,
 
151
  return {"success": False, "error": str(e)}
152
 
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  def load_and_display_nifti(file):
155
  """Load NIfTI and return middle slice for display."""
156
  if file is None:
 
185
  return None, f"Error: {e}", 0
186
 
187
 
188
+ def segment_and_overlay(file, slice_idx, text_prompt, checkpoint):
189
  """Segment a slice and overlay the mask."""
190
  if file is None:
191
  return None, "Please upload a file first"
 
195
  with open(file.name, "rb") as f:
196
  nifti_b64 = base64.b64encode(f.read()).decode()
197
 
198
+ # Call segmentation
199
  result = segment_slice_api(nifti_b64, int(slice_idx), text_prompt, checkpoint)
200
 
201
  if not result["success"]:
 
227
  rgb[mask_bool, 1] = rgb[mask_bool, 1] * (1 - alpha) + 50 * alpha
228
  rgb[mask_bool, 2] = rgb[mask_bool, 2] * (1 - alpha) + 50 * alpha
229
 
230
+ info = f"Backend: {result['backend']}, Segmented in {result['inference_time_ms']}ms, mask area: {mask.sum()} pixels"
231
 
232
  return rgb.astype(np.uint8), info
233
 
 
237
 
238
  # Build Gradio interface
239
  def build_demo():
240
+ with gr.Blocks(title="OncoSeg Inference API", theme=gr.themes.Soft()) as demo:
 
 
 
241
  gr.Markdown("""
242
  # OncoSeg Medical Image Segmentation API
243
 
244
  GPU-accelerated segmentation for CT and MRI volumes.
245
 
246
+ **Note:** Currently using fallback segmentation. Full SAM3 model coming soon!
247
+
248
  **API Endpoints** (for programmatic access):
249
  - `POST /api/segment_slice_api` - Segment a single slice
 
 
 
250
  """)
251
 
252
  with gr.Row():
253
  with gr.Column(scale=1):
254
  file_input = gr.File(
255
+ label="Upload NIfTI (.nii, .nii.gz)",
256
+ file_types=[".nii", ".nii.gz", ".gz"],
257
  )
258
 
259
  checkpoint = gr.Dropdown(
260
+ choices=["brain", "liver", "breast", "lung", "kidney", "spine"],
261
  value="brain",
262
  label="Model Checkpoint",
263
  )
 
310
 
311
  # Call API
312
  response = requests.post(
313
+ "https://tp53-oncoseg-api.hf.space/api/segment_slice_api",
314
  json={
315
  "nifti_b64": nifti_b64,
316
  "slice_idx": 77,