tp53 commited on
Commit
4eaaaee
·
verified ·
1 Parent(s): 59f8f17

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +112 -57
  2. requirements.txt +1 -4
app.py CHANGED
@@ -60,9 +60,13 @@ CHECKPOINTS = {
60
  HF_REPO_ID = os.getenv("HF_REPO_ID", "tp53/oncoseg")
61
 
62
 
 
 
 
 
63
  def get_model(checkpoint: str = "brain"):
64
- """Load or retrieve cached model."""
65
- global MODELS
66
 
67
  if checkpoint not in MODELS:
68
  logger.info(f"Loading model: {checkpoint}")
@@ -89,11 +93,45 @@ def get_model(checkpoint: str = "brain"):
89
  MODELS[checkpoint] = model
90
  logger.info(f"Model {checkpoint} loaded on {DEVICE}")
91
 
 
 
 
 
92
  except Exception as e:
93
  logger.error(f"Failed to load model {checkpoint}: {e}")
94
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- return MODELS[checkpoint]
97
 
98
 
99
  def preprocess_slice(slice_2d: np.ndarray, target_size: int = 1024) -> torch.Tensor:
@@ -204,44 +242,53 @@ def _segment_slice_impl(
204
  "error": f"Slice index {slice_idx} out of range [0, {volume.shape[0]})",
205
  }
206
 
207
- # Load model
208
- model = get_model(checkpoint)
209
-
210
- # Extract and preprocess slice
211
  slice_2d = volume[slice_idx]
212
  original_shape = slice_2d.shape
213
- slice_tensor = preprocess_slice(slice_2d) # (1, 3, 1024, 1024) tensor on DEVICE
214
 
215
- # Create full-image bounding box prompt (auto-segment entire image)
216
- # Format: [x_min, y_min, x_max, y_max] in pixel coordinates
217
- target_size = slice_tensor.shape[-1] # 1024
218
- input_boxes = torch.tensor(
219
- [[0, 0, target_size, target_size]], dtype=torch.float32, device=DEVICE
220
- )
221
 
222
- # Run inference with text prompt for grounding
223
- with torch.no_grad():
224
- outputs = model(
225
- pixel_values=slice_tensor,
226
- input_boxes=input_boxes,
227
- text_prompt=text_prompt,
 
 
 
 
 
 
 
 
 
 
228
  )
229
 
230
- # Extract mask from SAM3 output
231
- # SAM3 returns a dict with 'pred_masks' key, shape (B, 1, H, W)
232
- if isinstance(outputs, dict) and "pred_masks" in outputs:
233
- pred_mask = outputs["pred_masks"][0, 0].cpu().numpy()
234
- elif hasattr(outputs, "pred_masks"):
235
- pred_mask = outputs.pred_masks[0, 0].cpu().numpy()
236
- else:
237
- # Fallback: try to extract from tuple/list
238
- logger.warning(f"Unexpected output type: {type(outputs)}")
239
- pred_mask = np.zeros((target_size, target_size))
240
 
241
- # Resize mask back to original shape
242
- import cv2
 
 
 
 
 
 
 
 
243
 
244
- mask = cv2.resize(pred_mask, (original_shape[1], original_shape[0]))
 
 
245
 
246
  # Threshold to binary
247
  mask = (mask > 0.5).astype(np.uint8)
@@ -260,6 +307,7 @@ def _segment_slice_impl(
260
 
261
  return {
262
  "success": True,
 
263
  "mask_b64": mask_b64,
264
  "mask_shape": list(mask.shape),
265
  "contours": contours,
@@ -310,8 +358,9 @@ def _segment_volume_impl(
310
 
311
  logger.info(f"Loaded volume shape: {volume.shape}")
312
 
313
- # Load model
314
  model = get_model(checkpoint)
 
315
 
316
  num_slices = volume.shape[0]
317
  all_contours = {}
@@ -326,31 +375,37 @@ def _segment_volume_impl(
326
  if skip_empty and slice_2d.max() - slice_2d.min() < 0.01:
327
  continue
328
 
329
- slice_tensor = preprocess_slice(slice_2d, target_size)
330
-
331
- # Create full-image bounding box
332
- input_boxes = torch.tensor(
333
- [[0, 0, target_size, target_size]], dtype=torch.float32, device=DEVICE
334
- )
335
 
336
- with torch.no_grad():
337
- outputs = model(
338
- pixel_values=slice_tensor,
339
- input_boxes=input_boxes,
340
- text_prompt=text_prompt,
341
  )
342
 
343
- # Extract mask from SAM3 output
344
- if isinstance(outputs, dict) and "pred_masks" in outputs:
345
- pred_mask = outputs["pred_masks"][0, 0].cpu().numpy()
346
- elif hasattr(outputs, "pred_masks"):
347
- pred_mask = outputs.pred_masks[0, 0].cpu().numpy()
348
- else:
349
- continue # Skip if no valid output
350
-
351
- # Resize to original shape and threshold
352
- mask = cv2.resize(pred_mask, (original_shape[1], original_shape[0]))
353
- mask = (mask > 0.5).astype(np.uint8)
 
 
 
 
 
 
 
354
 
355
  if mask.sum() >= min_area:
356
  mask = keep_largest_component(mask)
 
60
  HF_REPO_ID = os.getenv("HF_REPO_ID", "tp53/oncoseg")
61
 
62
 
63
+ # Flag to track if we're using fallback mode
64
+ USE_FALLBACK = False
65
+
66
+
67
  def get_model(checkpoint: str = "brain"):
68
+ """Load or retrieve cached model. Falls back to simple segmentation if SAM3 unavailable."""
69
+ global MODELS, USE_FALLBACK
70
 
71
  if checkpoint not in MODELS:
72
  logger.info(f"Loading model: {checkpoint}")
 
93
  MODELS[checkpoint] = model
94
  logger.info(f"Model {checkpoint} loaded on {DEVICE}")
95
 
96
+ except ImportError as e:
97
+ logger.warning(f"SAM3 not available, using fallback segmentation: {e}")
98
+ USE_FALLBACK = True
99
+ MODELS[checkpoint] = None
100
  except Exception as e:
101
  logger.error(f"Failed to load model {checkpoint}: {e}")
102
+ USE_FALLBACK = True
103
+ MODELS[checkpoint] = None
104
+
105
+ return MODELS.get(checkpoint)
106
+
107
+
108
+ def fallback_segment(slice_2d: np.ndarray):
109
+ """
110
+ Simple intensity-based segmentation fallback when SAM3 is not available.
111
+ Works well for FLAIR MRI where tumors appear hyperintense.
112
+ """
113
+ from skimage.filters import threshold_otsu
114
+ from skimage.morphology import binary_opening, binary_closing, disk
115
+
116
+ # Normalize
117
+ vmin, vmax = slice_2d.min(), slice_2d.max()
118
+ if vmax - vmin < 1e-8:
119
+ return np.zeros_like(slice_2d, dtype=np.uint8)
120
+
121
+ normalized = (slice_2d - vmin) / (vmax - vmin)
122
+
123
+ # Use percentile threshold (top 15% intensity = potential tumor)
124
+ threshold = np.percentile(normalized, 85)
125
+ mask = (normalized > threshold).astype(np.uint8)
126
+
127
+ # Morphological cleanup
128
+ try:
129
+ mask = binary_opening(mask, disk(2))
130
+ mask = binary_closing(mask, disk(3))
131
+ except:
132
+ pass
133
 
134
+ return mask.astype(np.uint8)
135
 
136
 
137
  def preprocess_slice(slice_2d: np.ndarray, target_size: int = 1024) -> torch.Tensor:
 
242
  "error": f"Slice index {slice_idx} out of range [0, {volume.shape[0]})",
243
  }
244
 
245
+ # Extract slice
 
 
 
246
  slice_2d = volume[slice_idx]
247
  original_shape = slice_2d.shape
 
248
 
249
+ # Load model (may return None if fallback mode)
250
+ model = get_model(checkpoint)
 
 
 
 
251
 
252
+ if model is None or USE_FALLBACK:
253
+ # Use fallback segmentation
254
+ logger.info("Using fallback segmentation (SAM3 not available)")
255
+ mask = fallback_segment(slice_2d)
256
+ backend = "fallback"
257
+ else:
258
+ # Use SAM3 model
259
+ slice_tensor = preprocess_slice(
260
+ slice_2d
261
+ ) # (1, 3, 1024, 1024) tensor on DEVICE
262
+
263
+ # Create full-image bounding box prompt (auto-segment entire image)
264
+ # Format: [x_min, y_min, x_max, y_max] in pixel coordinates
265
+ target_size = slice_tensor.shape[-1] # 1024
266
+ input_boxes = torch.tensor(
267
+ [[0, 0, target_size, target_size]], dtype=torch.float32, device=DEVICE
268
  )
269
 
270
+ # Run inference with text prompt for grounding
271
+ with torch.no_grad():
272
+ outputs = model(
273
+ pixel_values=slice_tensor,
274
+ input_boxes=input_boxes,
275
+ text_prompt=text_prompt,
276
+ )
 
 
 
277
 
278
+ # Extract mask from SAM3 output
279
+ # SAM3 returns a dict with 'pred_masks' key, shape (B, 1, H, W)
280
+ if isinstance(outputs, dict) and "pred_masks" in outputs:
281
+ pred_mask = outputs["pred_masks"][0, 0].cpu().numpy()
282
+ elif hasattr(outputs, "pred_masks"):
283
+ pred_mask = outputs.pred_masks[0, 0].cpu().numpy()
284
+ else:
285
+ # Fallback: try to extract from tuple/list
286
+ logger.warning(f"Unexpected output type: {type(outputs)}")
287
+ pred_mask = np.zeros((target_size, target_size))
288
 
289
+ # Resize mask back to original shape
290
+ mask = cv2.resize(pred_mask, (original_shape[1], original_shape[0]))
291
+ backend = "sam3"
292
 
293
  # Threshold to binary
294
  mask = (mask > 0.5).astype(np.uint8)
 
307
 
308
  return {
309
  "success": True,
310
+ "backend": backend,
311
  "mask_b64": mask_b64,
312
  "mask_shape": list(mask.shape),
313
  "contours": contours,
 
358
 
359
  logger.info(f"Loaded volume shape: {volume.shape}")
360
 
361
+ # Load model (may return None if fallback mode)
362
  model = get_model(checkpoint)
363
+ use_fallback = model is None or USE_FALLBACK
364
 
365
  num_slices = volume.shape[0]
366
  all_contours = {}
 
375
  if skip_empty and slice_2d.max() - slice_2d.min() < 0.01:
376
  continue
377
 
378
+ if use_fallback:
379
+ # Use fallback segmentation
380
+ mask = fallback_segment(slice_2d)
381
+ else:
382
+ slice_tensor = preprocess_slice(slice_2d, target_size)
 
383
 
384
+ # Create full-image bounding box
385
+ input_boxes = torch.tensor(
386
+ [[0, 0, target_size, target_size]],
387
+ dtype=torch.float32,
388
+ device=DEVICE,
389
  )
390
 
391
+ with torch.no_grad():
392
+ outputs = model(
393
+ pixel_values=slice_tensor,
394
+ input_boxes=input_boxes,
395
+ text_prompt=text_prompt,
396
+ )
397
+
398
+ # Extract mask from SAM3 output
399
+ if isinstance(outputs, dict) and "pred_masks" in outputs:
400
+ pred_mask = outputs["pred_masks"][0, 0].cpu().numpy()
401
+ elif hasattr(outputs, "pred_masks"):
402
+ pred_mask = outputs.pred_masks[0, 0].cpu().numpy()
403
+ else:
404
+ continue # Skip if no valid output
405
+
406
+ # Resize to original shape and threshold
407
+ mask = cv2.resize(pred_mask, (original_shape[1], original_shape[0]))
408
+ mask = (mask > 0.5).astype(np.uint8)
409
 
410
  if mask.sum() >= min_area:
411
  mask = keep_largest_component(mask)
requirements.txt CHANGED
@@ -17,10 +17,7 @@ opencv-python-headless>=4.8.0
17
  huggingface_hub>=0.19.0
18
  spaces>=0.19.0
19
 
20
- # SAM3 - Install from Facebook Research
21
- git+https://github.com/facebookresearch/sam3.git
22
-
23
- # Additional SAM3 Dependencies
24
  einops>=0.7.0
25
  timm>=0.9.0
26
  peft>=0.7.0
 
17
  huggingface_hub>=0.19.0
18
  spaces>=0.19.0
19
 
20
+ # SAM3 Dependencies (SAM3 itself installed separately if available)
 
 
 
21
  einops>=0.7.0
22
  timm>=0.9.0
23
  peft>=0.7.0