Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- app.py +112 -57
- requirements.txt +1 -4
app.py
CHANGED
|
@@ -60,9 +60,13 @@ CHECKPOINTS = {
|
|
| 60 |
HF_REPO_ID = os.getenv("HF_REPO_ID", "tp53/oncoseg")
|
| 61 |
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
def get_model(checkpoint: str = "brain"):
|
| 64 |
-
"""Load or retrieve cached model."""
|
| 65 |
-
global MODELS
|
| 66 |
|
| 67 |
if checkpoint not in MODELS:
|
| 68 |
logger.info(f"Loading model: {checkpoint}")
|
|
@@ -89,11 +93,45 @@ def get_model(checkpoint: str = "brain"):
|
|
| 89 |
MODELS[checkpoint] = model
|
| 90 |
logger.info(f"Model {checkpoint} loaded on {DEVICE}")
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
except Exception as e:
|
| 93 |
logger.error(f"Failed to load model {checkpoint}: {e}")
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
-
return
|
| 97 |
|
| 98 |
|
| 99 |
def preprocess_slice(slice_2d: np.ndarray, target_size: int = 1024) -> torch.Tensor:
|
|
@@ -204,44 +242,53 @@ def _segment_slice_impl(
|
|
| 204 |
"error": f"Slice index {slice_idx} out of range [0, {volume.shape[0]})",
|
| 205 |
}
|
| 206 |
|
| 207 |
-
#
|
| 208 |
-
model = get_model(checkpoint)
|
| 209 |
-
|
| 210 |
-
# Extract and preprocess slice
|
| 211 |
slice_2d = volume[slice_idx]
|
| 212 |
original_shape = slice_2d.shape
|
| 213 |
-
slice_tensor = preprocess_slice(slice_2d) # (1, 3, 1024, 1024) tensor on DEVICE
|
| 214 |
|
| 215 |
-
#
|
| 216 |
-
|
| 217 |
-
target_size = slice_tensor.shape[-1] # 1024
|
| 218 |
-
input_boxes = torch.tensor(
|
| 219 |
-
[[0, 0, target_size, target_size]], dtype=torch.float32, device=DEVICE
|
| 220 |
-
)
|
| 221 |
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
)
|
| 229 |
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
# Fallback: try to extract from tuple/list
|
| 238 |
-
logger.warning(f"Unexpected output type: {type(outputs)}")
|
| 239 |
-
pred_mask = np.zeros((target_size, target_size))
|
| 240 |
|
| 241 |
-
|
| 242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
|
| 244 |
-
|
|
|
|
|
|
|
| 245 |
|
| 246 |
# Threshold to binary
|
| 247 |
mask = (mask > 0.5).astype(np.uint8)
|
|
@@ -260,6 +307,7 @@ def _segment_slice_impl(
|
|
| 260 |
|
| 261 |
return {
|
| 262 |
"success": True,
|
|
|
|
| 263 |
"mask_b64": mask_b64,
|
| 264 |
"mask_shape": list(mask.shape),
|
| 265 |
"contours": contours,
|
|
@@ -310,8 +358,9 @@ def _segment_volume_impl(
|
|
| 310 |
|
| 311 |
logger.info(f"Loaded volume shape: {volume.shape}")
|
| 312 |
|
| 313 |
-
# Load model
|
| 314 |
model = get_model(checkpoint)
|
|
|
|
| 315 |
|
| 316 |
num_slices = volume.shape[0]
|
| 317 |
all_contours = {}
|
|
@@ -326,31 +375,37 @@ def _segment_volume_impl(
|
|
| 326 |
if skip_empty and slice_2d.max() - slice_2d.min() < 0.01:
|
| 327 |
continue
|
| 328 |
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
)
|
| 335 |
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
)
|
| 342 |
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
|
| 355 |
if mask.sum() >= min_area:
|
| 356 |
mask = keep_largest_component(mask)
|
|
|
|
| 60 |
HF_REPO_ID = os.getenv("HF_REPO_ID", "tp53/oncoseg")
|
| 61 |
|
| 62 |
|
| 63 |
+
# Flag to track if we're using fallback mode
|
| 64 |
+
USE_FALLBACK = False
|
| 65 |
+
|
| 66 |
+
|
| 67 |
def get_model(checkpoint: str = "brain"):
|
| 68 |
+
"""Load or retrieve cached model. Falls back to simple segmentation if SAM3 unavailable."""
|
| 69 |
+
global MODELS, USE_FALLBACK
|
| 70 |
|
| 71 |
if checkpoint not in MODELS:
|
| 72 |
logger.info(f"Loading model: {checkpoint}")
|
|
|
|
| 93 |
MODELS[checkpoint] = model
|
| 94 |
logger.info(f"Model {checkpoint} loaded on {DEVICE}")
|
| 95 |
|
| 96 |
+
except ImportError as e:
|
| 97 |
+
logger.warning(f"SAM3 not available, using fallback segmentation: {e}")
|
| 98 |
+
USE_FALLBACK = True
|
| 99 |
+
MODELS[checkpoint] = None
|
| 100 |
except Exception as e:
|
| 101 |
logger.error(f"Failed to load model {checkpoint}: {e}")
|
| 102 |
+
USE_FALLBACK = True
|
| 103 |
+
MODELS[checkpoint] = None
|
| 104 |
+
|
| 105 |
+
return MODELS.get(checkpoint)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def fallback_segment(slice_2d: np.ndarray):
|
| 109 |
+
"""
|
| 110 |
+
Simple intensity-based segmentation fallback when SAM3 is not available.
|
| 111 |
+
Works well for FLAIR MRI where tumors appear hyperintense.
|
| 112 |
+
"""
|
| 113 |
+
from skimage.filters import threshold_otsu
|
| 114 |
+
from skimage.morphology import binary_opening, binary_closing, disk
|
| 115 |
+
|
| 116 |
+
# Normalize
|
| 117 |
+
vmin, vmax = slice_2d.min(), slice_2d.max()
|
| 118 |
+
if vmax - vmin < 1e-8:
|
| 119 |
+
return np.zeros_like(slice_2d, dtype=np.uint8)
|
| 120 |
+
|
| 121 |
+
normalized = (slice_2d - vmin) / (vmax - vmin)
|
| 122 |
+
|
| 123 |
+
# Use percentile threshold (top 15% intensity = potential tumor)
|
| 124 |
+
threshold = np.percentile(normalized, 85)
|
| 125 |
+
mask = (normalized > threshold).astype(np.uint8)
|
| 126 |
+
|
| 127 |
+
# Morphological cleanup
|
| 128 |
+
try:
|
| 129 |
+
mask = binary_opening(mask, disk(2))
|
| 130 |
+
mask = binary_closing(mask, disk(3))
|
| 131 |
+
except:
|
| 132 |
+
pass
|
| 133 |
|
| 134 |
+
return mask.astype(np.uint8)
|
| 135 |
|
| 136 |
|
| 137 |
def preprocess_slice(slice_2d: np.ndarray, target_size: int = 1024) -> torch.Tensor:
|
|
|
|
| 242 |
"error": f"Slice index {slice_idx} out of range [0, {volume.shape[0]})",
|
| 243 |
}
|
| 244 |
|
| 245 |
+
# Extract slice
|
|
|
|
|
|
|
|
|
|
| 246 |
slice_2d = volume[slice_idx]
|
| 247 |
original_shape = slice_2d.shape
|
|
|
|
| 248 |
|
| 249 |
+
# Load model (may return None if fallback mode)
|
| 250 |
+
model = get_model(checkpoint)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
|
| 252 |
+
if model is None or USE_FALLBACK:
|
| 253 |
+
# Use fallback segmentation
|
| 254 |
+
logger.info("Using fallback segmentation (SAM3 not available)")
|
| 255 |
+
mask = fallback_segment(slice_2d)
|
| 256 |
+
backend = "fallback"
|
| 257 |
+
else:
|
| 258 |
+
# Use SAM3 model
|
| 259 |
+
slice_tensor = preprocess_slice(
|
| 260 |
+
slice_2d
|
| 261 |
+
) # (1, 3, 1024, 1024) tensor on DEVICE
|
| 262 |
+
|
| 263 |
+
# Create full-image bounding box prompt (auto-segment entire image)
|
| 264 |
+
# Format: [x_min, y_min, x_max, y_max] in pixel coordinates
|
| 265 |
+
target_size = slice_tensor.shape[-1] # 1024
|
| 266 |
+
input_boxes = torch.tensor(
|
| 267 |
+
[[0, 0, target_size, target_size]], dtype=torch.float32, device=DEVICE
|
| 268 |
)
|
| 269 |
|
| 270 |
+
# Run inference with text prompt for grounding
|
| 271 |
+
with torch.no_grad():
|
| 272 |
+
outputs = model(
|
| 273 |
+
pixel_values=slice_tensor,
|
| 274 |
+
input_boxes=input_boxes,
|
| 275 |
+
text_prompt=text_prompt,
|
| 276 |
+
)
|
|
|
|
|
|
|
|
|
|
| 277 |
|
| 278 |
+
# Extract mask from SAM3 output
|
| 279 |
+
# SAM3 returns a dict with 'pred_masks' key, shape (B, 1, H, W)
|
| 280 |
+
if isinstance(outputs, dict) and "pred_masks" in outputs:
|
| 281 |
+
pred_mask = outputs["pred_masks"][0, 0].cpu().numpy()
|
| 282 |
+
elif hasattr(outputs, "pred_masks"):
|
| 283 |
+
pred_mask = outputs.pred_masks[0, 0].cpu().numpy()
|
| 284 |
+
else:
|
| 285 |
+
# Fallback: try to extract from tuple/list
|
| 286 |
+
logger.warning(f"Unexpected output type: {type(outputs)}")
|
| 287 |
+
pred_mask = np.zeros((target_size, target_size))
|
| 288 |
|
| 289 |
+
# Resize mask back to original shape
|
| 290 |
+
mask = cv2.resize(pred_mask, (original_shape[1], original_shape[0]))
|
| 291 |
+
backend = "sam3"
|
| 292 |
|
| 293 |
# Threshold to binary
|
| 294 |
mask = (mask > 0.5).astype(np.uint8)
|
|
|
|
| 307 |
|
| 308 |
return {
|
| 309 |
"success": True,
|
| 310 |
+
"backend": backend,
|
| 311 |
"mask_b64": mask_b64,
|
| 312 |
"mask_shape": list(mask.shape),
|
| 313 |
"contours": contours,
|
|
|
|
| 358 |
|
| 359 |
logger.info(f"Loaded volume shape: {volume.shape}")
|
| 360 |
|
| 361 |
+
# Load model (may return None if fallback mode)
|
| 362 |
model = get_model(checkpoint)
|
| 363 |
+
use_fallback = model is None or USE_FALLBACK
|
| 364 |
|
| 365 |
num_slices = volume.shape[0]
|
| 366 |
all_contours = {}
|
|
|
|
| 375 |
if skip_empty and slice_2d.max() - slice_2d.min() < 0.01:
|
| 376 |
continue
|
| 377 |
|
| 378 |
+
if use_fallback:
|
| 379 |
+
# Use fallback segmentation
|
| 380 |
+
mask = fallback_segment(slice_2d)
|
| 381 |
+
else:
|
| 382 |
+
slice_tensor = preprocess_slice(slice_2d, target_size)
|
|
|
|
| 383 |
|
| 384 |
+
# Create full-image bounding box
|
| 385 |
+
input_boxes = torch.tensor(
|
| 386 |
+
[[0, 0, target_size, target_size]],
|
| 387 |
+
dtype=torch.float32,
|
| 388 |
+
device=DEVICE,
|
| 389 |
)
|
| 390 |
|
| 391 |
+
with torch.no_grad():
|
| 392 |
+
outputs = model(
|
| 393 |
+
pixel_values=slice_tensor,
|
| 394 |
+
input_boxes=input_boxes,
|
| 395 |
+
text_prompt=text_prompt,
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
# Extract mask from SAM3 output
|
| 399 |
+
if isinstance(outputs, dict) and "pred_masks" in outputs:
|
| 400 |
+
pred_mask = outputs["pred_masks"][0, 0].cpu().numpy()
|
| 401 |
+
elif hasattr(outputs, "pred_masks"):
|
| 402 |
+
pred_mask = outputs.pred_masks[0, 0].cpu().numpy()
|
| 403 |
+
else:
|
| 404 |
+
continue # Skip if no valid output
|
| 405 |
+
|
| 406 |
+
# Resize to original shape and threshold
|
| 407 |
+
mask = cv2.resize(pred_mask, (original_shape[1], original_shape[0]))
|
| 408 |
+
mask = (mask > 0.5).astype(np.uint8)
|
| 409 |
|
| 410 |
if mask.sum() >= min_area:
|
| 411 |
mask = keep_largest_component(mask)
|
requirements.txt
CHANGED
|
@@ -17,10 +17,7 @@ opencv-python-headless>=4.8.0
|
|
| 17 |
huggingface_hub>=0.19.0
|
| 18 |
spaces>=0.19.0
|
| 19 |
|
| 20 |
-
# SAM3
|
| 21 |
-
git+https://github.com/facebookresearch/sam3.git
|
| 22 |
-
|
| 23 |
-
# Additional SAM3 Dependencies
|
| 24 |
einops>=0.7.0
|
| 25 |
timm>=0.9.0
|
| 26 |
peft>=0.7.0
|
|
|
|
| 17 |
huggingface_hub>=0.19.0
|
| 18 |
spaces>=0.19.0
|
| 19 |
|
| 20 |
+
# SAM3 Dependencies (SAM3 itself installed separately if available)
|
|
|
|
|
|
|
|
|
|
| 21 |
einops>=0.7.0
|
| 22 |
timm>=0.9.0
|
| 23 |
peft>=0.7.0
|