Spaces:
Sleeping
Sleeping
MilicMilos commited on
Commit ·
46c5d16
1
Parent(s): 142af98
Improve model inference performance and reliability on hardware
Browse filesForce disable Flash Attention and optimize inference loop, pre-load the model, and reduce input size for faster processing.
Replit-Commit-Author: Agent
Replit-Commit-Session-Id: c144be0a-7fab-4a53-a663-fc927a204409
Replit-Commit-Checkpoint-Type: intermediate_checkpoint
Replit-Commit-Event-Id: 63465661-b0cc-45eb-97fa-7aed76fbe293
Replit-Commit-Screenshot-Url: https://storage.googleapis.com/screenshot-production-us-central1/5b4b75b9-1619-404c-a78d-526127514111/c144be0a-7fab-4a53-a663-fc927a204409/35LY8UZ
Replit-Helium-Checkpoint-Created: true
- Dockerfile +4 -0
- main.py +18 -0
- medsam2_pkg/sam2/modeling/sam/transformer.py +7 -7
- models/medsam2_inference.py +127 -73
Dockerfile
CHANGED
|
@@ -23,5 +23,9 @@ RUN mkdir -p uploads checkpoints
|
|
| 23 |
EXPOSE 7860
|
| 24 |
|
| 25 |
ENV PORT=7860
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
CMD ["python", "main.py"]
|
|
|
|
| 23 |
EXPOSE 7860
|
| 24 |
|
| 25 |
ENV PORT=7860
|
| 26 |
+
ENV SAM2_ALLOW_ALL_KERNELS=1
|
| 27 |
+
ENV TORCH_CUDNN_SDPA_ENABLED=0
|
| 28 |
+
ENV U_FLASH_ATTN=0
|
| 29 |
+
ENV MATH_KERNEL_ON=0
|
| 30 |
|
| 31 |
CMD ["python", "main.py"]
|
main.py
CHANGED
|
@@ -1395,7 +1395,25 @@ def batch_report():
|
|
| 1395 |
except Exception as e:
|
| 1396 |
return jsonify({'error': f'Error generating PDF: {str(e)}'}), 500
|
| 1397 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1398 |
if __name__ == '__main__':
|
| 1399 |
import sys
|
| 1400 |
port = int(os.environ.get('PORT', sys.argv[1] if len(sys.argv) > 1 else 7860))
|
|
|
|
| 1401 |
app.run(host='0.0.0.0', port=port, debug=True)
|
|
|
|
| 1395 |
except Exception as e:
|
| 1396 |
return jsonify({'error': f'Error generating PDF: {str(e)}'}), 500
|
| 1397 |
|
| 1398 |
+
def preload_medsam2():
|
| 1399 |
+
import threading
|
| 1400 |
+
def _load():
|
| 1401 |
+
try:
|
| 1402 |
+
print("[Startup] Pre-loading MedSAM2 model...")
|
| 1403 |
+
from models.medsam2_inference import load_medsam2_model
|
| 1404 |
+
predictor = load_medsam2_model()
|
| 1405 |
+
if predictor is not None:
|
| 1406 |
+
print("[Startup] MedSAM2 model pre-loaded successfully")
|
| 1407 |
+
else:
|
| 1408 |
+
print("[Startup] MedSAM2 model not available (will retry on first request)")
|
| 1409 |
+
except Exception as e:
|
| 1410 |
+
print(f"[Startup] MedSAM2 pre-load failed: {e}")
|
| 1411 |
+
t = threading.Thread(target=_load, daemon=True)
|
| 1412 |
+
t.start()
|
| 1413 |
+
|
| 1414 |
+
|
| 1415 |
if __name__ == '__main__':
|
| 1416 |
import sys
|
| 1417 |
port = int(os.environ.get('PORT', sys.argv[1] if len(sys.argv) > 1 else 7860))
|
| 1418 |
+
preload_medsam2()
|
| 1419 |
app.run(host='0.0.0.0', port=port, debug=True)
|
medsam2_pkg/sam2/modeling/sam/transformer.py
CHANGED
|
@@ -17,22 +17,22 @@ from torch import nn, Tensor
|
|
| 17 |
|
| 18 |
from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
|
| 19 |
from sam2.modeling.sam2_utils import MLP
|
| 20 |
-
from sam2.utils.misc import get_sdpa_settings
|
| 21 |
|
| 22 |
warnings.simplefilter(action="ignore", category=FutureWarning)
|
| 23 |
-
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
|
| 24 |
ALLOW_ALL_KERNELS = os.environ.get("SAM2_ALLOW_ALL_KERNELS", "1") == "1"
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
def sdp_kernel_context(dropout_p):
|
| 32 |
"""
|
| 33 |
Get the context for the attention scaled dot-product kernel.
|
| 34 |
Defaults to allowing all kernels for maximum compatibility.
|
| 35 |
-
Set SAM2_ALLOW_ALL_KERNELS=0 to use Flash Attention when available.
|
| 36 |
"""
|
| 37 |
if ALLOW_ALL_KERNELS:
|
| 38 |
return contextlib.nullcontext()
|
|
|
|
| 17 |
|
| 18 |
from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
|
| 19 |
from sam2.modeling.sam2_utils import MLP
|
|
|
|
| 20 |
|
| 21 |
warnings.simplefilter(action="ignore", category=FutureWarning)
|
|
|
|
| 22 |
ALLOW_ALL_KERNELS = os.environ.get("SAM2_ALLOW_ALL_KERNELS", "1") == "1"
|
| 23 |
+
OLD_GPU = True
|
| 24 |
+
USE_FLASH_ATTN = False
|
| 25 |
+
MATH_KERNEL_ON = True
|
| 26 |
+
if not ALLOW_ALL_KERNELS:
|
| 27 |
+
from sam2.utils.misc import get_sdpa_settings
|
| 28 |
+
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
|
| 29 |
+
print(f"[SAM2] Attention config: ALLOW_ALL_KERNELS={ALLOW_ALL_KERNELS}, FLASH={USE_FLASH_ATTN}, MATH={MATH_KERNEL_ON}")
|
| 30 |
|
| 31 |
|
| 32 |
def sdp_kernel_context(dropout_p):
|
| 33 |
"""
|
| 34 |
Get the context for the attention scaled dot-product kernel.
|
| 35 |
Defaults to allowing all kernels for maximum compatibility.
|
|
|
|
| 36 |
"""
|
| 37 |
if ALLOW_ALL_KERNELS:
|
| 38 |
return contextlib.nullcontext()
|
models/medsam2_inference.py
CHANGED
|
@@ -1,11 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
import cv2
|
| 3 |
import sys
|
| 4 |
-
import os
|
| 5 |
import traceback
|
|
|
|
|
|
|
| 6 |
|
| 7 |
_medsam2_model = None
|
| 8 |
_medsam2_predictor = None
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
MEDSAM2_PATHS = [
|
| 11 |
os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'medsam2_pkg'),
|
|
@@ -40,76 +50,108 @@ def _get_device():
|
|
| 40 |
device = "cuda"
|
| 41 |
print(f"[MedSAM2] Using CUDA device: {torch.cuda.get_device_name(0)}")
|
| 42 |
print(f"[MedSAM2] CUDA memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
|
|
|
|
| 43 |
else:
|
| 44 |
device = "cpu"
|
| 45 |
-
print("[MedSAM2] Using CPU device")
|
| 46 |
return device
|
| 47 |
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
def load_medsam2_model():
|
| 50 |
global _medsam2_model, _medsam2_predictor
|
| 51 |
|
| 52 |
if _medsam2_predictor is not None:
|
| 53 |
return _medsam2_predictor
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
if checkpoint_path is None:
|
| 63 |
-
print("[MedSAM2] Checkpoint not available")
|
| 64 |
-
return None
|
| 65 |
-
|
| 66 |
-
try:
|
| 67 |
-
import torch
|
| 68 |
-
device = _get_device()
|
| 69 |
-
print(f"[MedSAM2] Loading model on device: {device}")
|
| 70 |
-
print(f"[MedSAM2] Checkpoint: {checkpoint_path}")
|
| 71 |
|
| 72 |
-
|
| 73 |
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
except ImportError as e:
|
| 78 |
-
print(f"[MedSAM2] SAM2 library not importable: {e}")
|
| 79 |
return None
|
| 80 |
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
print(f"[MedSAM2]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
return None
|
| 88 |
|
| 89 |
-
abs_config = '/' + os.path.abspath(config_yaml)
|
| 90 |
-
|
| 91 |
-
os.environ["SAM2_ALLOW_ALL_KERNELS"] = "1"
|
| 92 |
-
|
| 93 |
-
with torch.no_grad():
|
| 94 |
-
_medsam2_model = build_sam2(
|
| 95 |
-
abs_config,
|
| 96 |
-
ckpt_path=str(checkpoint_path),
|
| 97 |
-
device=device
|
| 98 |
-
)
|
| 99 |
-
|
| 100 |
-
_medsam2_predictor = SAM2ImagePredictor(_medsam2_model)
|
| 101 |
-
|
| 102 |
-
print(f"[MedSAM2] Model loaded successfully on {device}")
|
| 103 |
-
print(f"[MedSAM2] Model device: {_medsam2_predictor.device}")
|
| 104 |
-
return _medsam2_predictor
|
| 105 |
-
|
| 106 |
-
except Exception as e:
|
| 107 |
-
print(f"[MedSAM2] Failed to load model: {e}")
|
| 108 |
-
traceback.print_exc()
|
| 109 |
-
_medsam2_model = None
|
| 110 |
-
_medsam2_predictor = None
|
| 111 |
-
return None
|
| 112 |
-
|
| 113 |
|
| 114 |
def segment_with_medsam2(image, click_x, click_y):
|
| 115 |
import torch
|
|
@@ -127,7 +169,6 @@ def segment_with_medsam2(image, click_x, click_y):
|
|
| 127 |
|
| 128 |
click_x = int(max(0, min(click_x, img_w - 1)))
|
| 129 |
click_y = int(max(0, min(click_y, img_h - 1)))
|
| 130 |
-
print(f"[MedSAM2] Clamped point: ({click_x}, {click_y})")
|
| 131 |
|
| 132 |
if len(image.shape) == 2:
|
| 133 |
image_rgb = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
|
@@ -136,34 +177,36 @@ def segment_with_medsam2(image, click_x, click_y):
|
|
| 136 |
else:
|
| 137 |
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 138 |
|
| 139 |
-
print(f"[MedSAM2] RGB image: shape={image_rgb.shape}, dtype={image_rgb.dtype}, range=[{image_rgb.min()}, {image_rgb.max()}]")
|
| 140 |
-
|
| 141 |
if image_rgb.dtype != np.uint8:
|
| 142 |
if image_rgb.max() <= 1.0:
|
| 143 |
image_rgb = (image_rgb * 255).astype(np.uint8)
|
| 144 |
else:
|
| 145 |
image_rgb = image_rgb.astype(np.uint8)
|
| 146 |
-
print(f"[MedSAM2] Converted to uint8, range=[{image_rgb.min()}, {image_rgb.max()}]")
|
| 147 |
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
print("[MedSAM2] Image set successfully")
|
| 152 |
|
| 153 |
point_coords = np.array([[click_x, click_y]], dtype=np.float32)
|
| 154 |
point_labels = np.array([1], dtype=np.int32)
|
| 155 |
-
print(f"[MedSAM2] Point coords: {point_coords}, labels: {point_labels}")
|
| 156 |
-
print(f"[MedSAM2] Point coords dtype: {point_coords.dtype}, labels dtype: {point_labels.dtype}")
|
| 157 |
|
| 158 |
-
print("[MedSAM2]
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
masks, scores, logits = predictor.predict(
|
| 161 |
point_coords=point_coords,
|
| 162 |
point_labels=point_labels,
|
| 163 |
multimask_output=True
|
| 164 |
)
|
| 165 |
-
|
| 166 |
-
print(f"[MedSAM2] predict()
|
| 167 |
|
| 168 |
if masks is None:
|
| 169 |
print("[MedSAM2] ERROR: predict() returned None for masks")
|
|
@@ -171,15 +214,24 @@ def segment_with_medsam2(image, click_x, click_y):
|
|
| 171 |
|
| 172 |
print(f"[MedSAM2] Masks shape: {masks.shape}, dtype: {masks.dtype}")
|
| 173 |
print(f"[MedSAM2] Scores: {scores}")
|
| 174 |
-
print(f"[MedSAM2] Logits shape: {logits.shape}")
|
| 175 |
|
| 176 |
if len(masks) == 0:
|
| 177 |
print("[MedSAM2] ERROR: predict() returned empty masks array")
|
| 178 |
return None
|
| 179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
for i, (mask, score) in enumerate(zip(masks, scores)):
|
| 181 |
nonzero = np.count_nonzero(mask)
|
| 182 |
-
print(f"[MedSAM2] Mask {i}:
|
| 183 |
|
| 184 |
from utils.image_processing import postprocess_mask
|
| 185 |
|
|
@@ -193,14 +245,12 @@ def segment_with_medsam2(image, click_x, click_y):
|
|
| 193 |
'score': float(score),
|
| 194 |
'area': area
|
| 195 |
})
|
| 196 |
-
print(f"[MedSAM2] Processed mask {i}: area={area} pixels, score={float(score):.4f}")
|
| 197 |
|
| 198 |
mask_list.sort(key=lambda m: m['area'])
|
| 199 |
|
| 200 |
total_area = sum(m['area'] for m in mask_list)
|
| 201 |
if total_area == 0:
|
| 202 |
-
print("[MedSAM2] WARNING: All masks
|
| 203 |
-
print("[MedSAM2] Returning raw masks without postprocessing cleanup")
|
| 204 |
mask_list = []
|
| 205 |
for i, (mask, score) in enumerate(zip(masks, scores)):
|
| 206 |
binary = (mask.astype(np.uint8)) * 255
|
|
@@ -212,7 +262,11 @@ def segment_with_medsam2(image, click_x, click_y):
|
|
| 212 |
})
|
| 213 |
mask_list.sort(key=lambda m: m['area'])
|
| 214 |
|
| 215 |
-
print(f"[MedSAM2] Segmentation complete: {len(mask_list)} masks
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
return mask_list
|
| 217 |
|
| 218 |
except Exception as e:
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "0"
|
| 3 |
+
os.environ["SAM2_ALLOW_ALL_KERNELS"] = "1"
|
| 4 |
+
os.environ["U_FLASH_ATTN"] = "0"
|
| 5 |
+
os.environ["MATH_KERNEL_ON"] = "0"
|
| 6 |
+
|
| 7 |
import numpy as np
|
| 8 |
import cv2
|
| 9 |
import sys
|
|
|
|
| 10 |
import traceback
|
| 11 |
+
import time
|
| 12 |
+
import threading
|
| 13 |
|
| 14 |
_medsam2_model = None
|
| 15 |
_medsam2_predictor = None
|
| 16 |
+
_model_lock = threading.Lock()
|
| 17 |
+
|
| 18 |
+
MAX_INPUT_SIZE = 512
|
| 19 |
|
| 20 |
MEDSAM2_PATHS = [
|
| 21 |
os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'medsam2_pkg'),
|
|
|
|
| 50 |
device = "cuda"
|
| 51 |
print(f"[MedSAM2] Using CUDA device: {torch.cuda.get_device_name(0)}")
|
| 52 |
print(f"[MedSAM2] CUDA memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
|
| 53 |
+
print(f"[MedSAM2] CUDA capability: {torch.cuda.get_device_properties(0).major}.{torch.cuda.get_device_properties(0).minor}")
|
| 54 |
else:
|
| 55 |
device = "cpu"
|
| 56 |
+
print("[MedSAM2] Using CPU device (no CUDA available)")
|
| 57 |
return device
|
| 58 |
|
| 59 |
|
| 60 |
+
def _resize_for_model(image_rgb, click_x, click_y):
|
| 61 |
+
h, w = image_rgb.shape[:2]
|
| 62 |
+
if max(h, w) <= MAX_INPUT_SIZE:
|
| 63 |
+
return image_rgb, click_x, click_y, 1.0
|
| 64 |
+
|
| 65 |
+
scale = MAX_INPUT_SIZE / max(h, w)
|
| 66 |
+
new_w = int(w * scale)
|
| 67 |
+
new_h = int(h * scale)
|
| 68 |
+
resized = cv2.resize(image_rgb, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
| 69 |
+
new_click_x = int(click_x * scale)
|
| 70 |
+
new_click_y = int(click_y * scale)
|
| 71 |
+
new_click_x = max(0, min(new_click_x, new_w - 1))
|
| 72 |
+
new_click_y = max(0, min(new_click_y, new_h - 1))
|
| 73 |
+
print(f"[MedSAM2] Resized input: {w}x{h} -> {new_w}x{new_h} (scale={scale:.3f})")
|
| 74 |
+
print(f"[MedSAM2] Scaled click: ({click_x},{click_y}) -> ({new_click_x},{new_click_y})")
|
| 75 |
+
return resized, new_click_x, new_click_y, scale
|
| 76 |
+
|
| 77 |
+
|
| 78 |
def load_medsam2_model():
|
| 79 |
global _medsam2_model, _medsam2_predictor
|
| 80 |
|
| 81 |
if _medsam2_predictor is not None:
|
| 82 |
return _medsam2_predictor
|
| 83 |
|
| 84 |
+
with _model_lock:
|
| 85 |
+
if _medsam2_predictor is not None:
|
| 86 |
+
return _medsam2_predictor
|
| 87 |
|
| 88 |
+
if not is_medsam2_available():
|
| 89 |
+
print("[MedSAM2] Dependencies not available")
|
| 90 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
+
from models.checkpoint_manager import CheckpointManager
|
| 93 |
|
| 94 |
+
checkpoint_path = CheckpointManager.get_medsam2_checkpoint()
|
| 95 |
+
if checkpoint_path is None:
|
| 96 |
+
print("[MedSAM2] Checkpoint not available")
|
|
|
|
|
|
|
| 97 |
return None
|
| 98 |
|
| 99 |
+
try:
|
| 100 |
+
import torch
|
| 101 |
+
device = _get_device()
|
| 102 |
+
print(f"[MedSAM2] Loading model on device: {device}")
|
| 103 |
+
print(f"[MedSAM2] Checkpoint: {checkpoint_path}")
|
| 104 |
+
print(f"[MedSAM2] PyTorch version: {torch.__version__}")
|
| 105 |
+
print(f"[MedSAM2] CUDA available: {torch.cuda.is_available()}")
|
| 106 |
+
|
| 107 |
+
_ensure_medsam2_path()
|
| 108 |
+
|
| 109 |
+
try:
|
| 110 |
+
from sam2.build_sam import build_sam2
|
| 111 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
| 112 |
+
except ImportError as e:
|
| 113 |
+
print(f"[MedSAM2] SAM2 library not importable: {e}")
|
| 114 |
+
return None
|
| 115 |
+
|
| 116 |
+
medsam2_path = _find_medsam2_path()
|
| 117 |
+
config_dir = os.path.join(medsam2_path, 'sam2', 'configs')
|
| 118 |
+
config_yaml = os.path.join(config_dir, 'sam2.1_hiera_t512.yaml')
|
| 119 |
+
if not os.path.exists(config_yaml):
|
| 120 |
+
yaml_files = [f for f in os.listdir(config_dir)] if os.path.isdir(config_dir) else []
|
| 121 |
+
print(f"[MedSAM2] Config not found at {config_yaml}, available: {yaml_files}")
|
| 122 |
+
return None
|
| 123 |
+
|
| 124 |
+
abs_config = '/' + os.path.abspath(config_yaml)
|
| 125 |
+
|
| 126 |
+
t0 = time.time()
|
| 127 |
+
with torch.inference_mode():
|
| 128 |
+
_medsam2_model = build_sam2(
|
| 129 |
+
abs_config,
|
| 130 |
+
ckpt_path=str(checkpoint_path),
|
| 131 |
+
device=device
|
| 132 |
+
)
|
| 133 |
+
load_time = time.time() - t0
|
| 134 |
+
|
| 135 |
+
_medsam2_predictor = SAM2ImagePredictor(_medsam2_model)
|
| 136 |
+
|
| 137 |
+
print(f"[MedSAM2] Model loaded in {load_time:.1f}s on {device}")
|
| 138 |
+
print(f"[MedSAM2] Model device: {_medsam2_predictor.device}")
|
| 139 |
+
print(f"[MedSAM2] Model image_size: {_medsam2_model.image_size}")
|
| 140 |
+
|
| 141 |
+
if device == "cuda":
|
| 142 |
+
mem_alloc = torch.cuda.memory_allocated() / 1e9
|
| 143 |
+
mem_reserved = torch.cuda.memory_reserved() / 1e9
|
| 144 |
+
print(f"[MedSAM2] GPU memory: allocated={mem_alloc:.2f}GB, reserved={mem_reserved:.2f}GB")
|
| 145 |
+
|
| 146 |
+
return _medsam2_predictor
|
| 147 |
+
|
| 148 |
+
except Exception as e:
|
| 149 |
+
print(f"[MedSAM2] Failed to load model: {e}")
|
| 150 |
+
traceback.print_exc()
|
| 151 |
+
_medsam2_model = None
|
| 152 |
+
_medsam2_predictor = None
|
| 153 |
return None
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
def segment_with_medsam2(image, click_x, click_y):
|
| 157 |
import torch
|
|
|
|
| 169 |
|
| 170 |
click_x = int(max(0, min(click_x, img_w - 1)))
|
| 171 |
click_y = int(max(0, min(click_y, img_h - 1)))
|
|
|
|
| 172 |
|
| 173 |
if len(image.shape) == 2:
|
| 174 |
image_rgb = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
|
|
|
| 177 |
else:
|
| 178 |
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 179 |
|
|
|
|
|
|
|
| 180 |
if image_rgb.dtype != np.uint8:
|
| 181 |
if image_rgb.max() <= 1.0:
|
| 182 |
image_rgb = (image_rgb * 255).astype(np.uint8)
|
| 183 |
else:
|
| 184 |
image_rgb = image_rgb.astype(np.uint8)
|
|
|
|
| 185 |
|
| 186 |
+
image_rgb, click_x, click_y, scale = _resize_for_model(image_rgb, click_x, click_y)
|
| 187 |
+
|
| 188 |
+
print(f"[MedSAM2] Final input: {image_rgb.shape}, click=({click_x},{click_y})")
|
|
|
|
| 189 |
|
| 190 |
point_coords = np.array([[click_x, click_y]], dtype=np.float32)
|
| 191 |
point_labels = np.array([1], dtype=np.int32)
|
|
|
|
|
|
|
| 192 |
|
| 193 |
+
print("[MedSAM2] Setting image on predictor...")
|
| 194 |
+
t0 = time.time()
|
| 195 |
+
with torch.inference_mode():
|
| 196 |
+
predictor.set_image(image_rgb)
|
| 197 |
+
set_time = time.time() - t0
|
| 198 |
+
print(f"[MedSAM2] Image set in {set_time:.2f}s")
|
| 199 |
+
|
| 200 |
+
print(f"[MedSAM2] Running predict(): coords={point_coords}, labels={point_labels}")
|
| 201 |
+
t0 = time.time()
|
| 202 |
+
with torch.inference_mode():
|
| 203 |
masks, scores, logits = predictor.predict(
|
| 204 |
point_coords=point_coords,
|
| 205 |
point_labels=point_labels,
|
| 206 |
multimask_output=True
|
| 207 |
)
|
| 208 |
+
pred_time = time.time() - t0
|
| 209 |
+
print(f"[MedSAM2] predict() completed in {pred_time:.2f}s")
|
| 210 |
|
| 211 |
if masks is None:
|
| 212 |
print("[MedSAM2] ERROR: predict() returned None for masks")
|
|
|
|
| 214 |
|
| 215 |
print(f"[MedSAM2] Masks shape: {masks.shape}, dtype: {masks.dtype}")
|
| 216 |
print(f"[MedSAM2] Scores: {scores}")
|
|
|
|
| 217 |
|
| 218 |
if len(masks) == 0:
|
| 219 |
print("[MedSAM2] ERROR: predict() returned empty masks array")
|
| 220 |
return None
|
| 221 |
|
| 222 |
+
if scale < 1.0:
|
| 223 |
+
orig_h, orig_w = img_h, img_w
|
| 224 |
+
upscaled_masks = []
|
| 225 |
+
for m in masks:
|
| 226 |
+
m_uint8 = m.astype(np.uint8) * 255
|
| 227 |
+
m_up = cv2.resize(m_uint8, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)
|
| 228 |
+
upscaled_masks.append(m_up > 127)
|
| 229 |
+
masks = np.array(upscaled_masks)
|
| 230 |
+
print(f"[MedSAM2] Upscaled masks back to {orig_w}x{orig_h}")
|
| 231 |
+
|
| 232 |
for i, (mask, score) in enumerate(zip(masks, scores)):
|
| 233 |
nonzero = np.count_nonzero(mask)
|
| 234 |
+
print(f"[MedSAM2] Mask {i}: nonzero={nonzero}, score={score:.4f}")
|
| 235 |
|
| 236 |
from utils.image_processing import postprocess_mask
|
| 237 |
|
|
|
|
| 245 |
'score': float(score),
|
| 246 |
'area': area
|
| 247 |
})
|
|
|
|
| 248 |
|
| 249 |
mask_list.sort(key=lambda m: m['area'])
|
| 250 |
|
| 251 |
total_area = sum(m['area'] for m in mask_list)
|
| 252 |
if total_area == 0:
|
| 253 |
+
print("[MedSAM2] WARNING: All masks zero area after postprocessing, returning raw")
|
|
|
|
| 254 |
mask_list = []
|
| 255 |
for i, (mask, score) in enumerate(zip(masks, scores)):
|
| 256 |
binary = (mask.astype(np.uint8)) * 255
|
|
|
|
| 262 |
})
|
| 263 |
mask_list.sort(key=lambda m: m['area'])
|
| 264 |
|
| 265 |
+
print(f"[MedSAM2] Segmentation complete: {len(mask_list)} masks, total time={set_time + pred_time:.2f}s")
|
| 266 |
+
|
| 267 |
+
if predictor.device.type == "cuda":
|
| 268 |
+
torch.cuda.empty_cache()
|
| 269 |
+
|
| 270 |
return mask_list
|
| 271 |
|
| 272 |
except Exception as e:
|