MilicMilos commited on
Commit
46c5d16
·
1 Parent(s): 142af98

Improve model inference performance and reliability on hardware

Browse files

Force disable Flash Attention and optimize inference loop, pre-load the model, and reduce input size for faster processing.

Replit-Commit-Author: Agent
Replit-Commit-Session-Id: c144be0a-7fab-4a53-a663-fc927a204409
Replit-Commit-Checkpoint-Type: intermediate_checkpoint
Replit-Commit-Event-Id: 63465661-b0cc-45eb-97fa-7aed76fbe293
Replit-Commit-Screenshot-Url: https://storage.googleapis.com/screenshot-production-us-central1/5b4b75b9-1619-404c-a78d-526127514111/c144be0a-7fab-4a53-a663-fc927a204409/35LY8UZ
Replit-Helium-Checkpoint-Created: true

Dockerfile CHANGED
@@ -23,5 +23,9 @@ RUN mkdir -p uploads checkpoints
23
  EXPOSE 7860
24
 
25
  ENV PORT=7860
 
 
 
 
26
 
27
  CMD ["python", "main.py"]
 
23
  EXPOSE 7860
24
 
25
  ENV PORT=7860
26
+ ENV SAM2_ALLOW_ALL_KERNELS=1
27
+ ENV TORCH_CUDNN_SDPA_ENABLED=0
28
+ ENV U_FLASH_ATTN=0
29
+ ENV MATH_KERNEL_ON=0
30
 
31
  CMD ["python", "main.py"]
main.py CHANGED
@@ -1395,7 +1395,25 @@ def batch_report():
1395
  except Exception as e:
1396
  return jsonify({'error': f'Error generating PDF: {str(e)}'}), 500
1397
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1398
  if __name__ == '__main__':
1399
  import sys
1400
  port = int(os.environ.get('PORT', sys.argv[1] if len(sys.argv) > 1 else 7860))
 
1401
  app.run(host='0.0.0.0', port=port, debug=True)
 
1395
  except Exception as e:
1396
  return jsonify({'error': f'Error generating PDF: {str(e)}'}), 500
1397
 
1398
+ def preload_medsam2():
1399
+ import threading
1400
+ def _load():
1401
+ try:
1402
+ print("[Startup] Pre-loading MedSAM2 model...")
1403
+ from models.medsam2_inference import load_medsam2_model
1404
+ predictor = load_medsam2_model()
1405
+ if predictor is not None:
1406
+ print("[Startup] MedSAM2 model pre-loaded successfully")
1407
+ else:
1408
+ print("[Startup] MedSAM2 model not available (will retry on first request)")
1409
+ except Exception as e:
1410
+ print(f"[Startup] MedSAM2 pre-load failed: {e}")
1411
+ t = threading.Thread(target=_load, daemon=True)
1412
+ t.start()
1413
+
1414
+
1415
  if __name__ == '__main__':
1416
  import sys
1417
  port = int(os.environ.get('PORT', sys.argv[1] if len(sys.argv) > 1 else 7860))
1418
+ preload_medsam2()
1419
  app.run(host='0.0.0.0', port=port, debug=True)
medsam2_pkg/sam2/modeling/sam/transformer.py CHANGED
@@ -17,22 +17,22 @@ from torch import nn, Tensor
17
 
18
  from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
19
  from sam2.modeling.sam2_utils import MLP
20
- from sam2.utils.misc import get_sdpa_settings
21
 
22
  warnings.simplefilter(action="ignore", category=FutureWarning)
23
- OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
24
  ALLOW_ALL_KERNELS = os.environ.get("SAM2_ALLOW_ALL_KERNELS", "1") == "1"
25
- if ALLOW_ALL_KERNELS:
26
- print("[SAM2] Flash Attention DISABLED — using all available kernels fallback for maximum compatibility")
27
- else:
28
- print(f"[SAM2] Flash Attention: {USE_FLASH_ATTN}, Math kernel: {MATH_KERNEL_ON}, Old GPU: {OLD_GPU}")
 
 
 
29
 
30
 
31
  def sdp_kernel_context(dropout_p):
32
  """
33
  Get the context for the attention scaled dot-product kernel.
34
  Defaults to allowing all kernels for maximum compatibility.
35
- Set SAM2_ALLOW_ALL_KERNELS=0 to use Flash Attention when available.
36
  """
37
  if ALLOW_ALL_KERNELS:
38
  return contextlib.nullcontext()
 
17
 
18
  from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
19
  from sam2.modeling.sam2_utils import MLP
 
20
 
21
  warnings.simplefilter(action="ignore", category=FutureWarning)
 
22
  ALLOW_ALL_KERNELS = os.environ.get("SAM2_ALLOW_ALL_KERNELS", "1") == "1"
23
+ OLD_GPU = True
24
+ USE_FLASH_ATTN = False
25
+ MATH_KERNEL_ON = True
26
+ if not ALLOW_ALL_KERNELS:
27
+ from sam2.utils.misc import get_sdpa_settings
28
+ OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
29
+ print(f"[SAM2] Attention config: ALLOW_ALL_KERNELS={ALLOW_ALL_KERNELS}, FLASH={USE_FLASH_ATTN}, MATH={MATH_KERNEL_ON}")
30
 
31
 
32
  def sdp_kernel_context(dropout_p):
33
  """
34
  Get the context for the attention scaled dot-product kernel.
35
  Defaults to allowing all kernels for maximum compatibility.
 
36
  """
37
  if ALLOW_ALL_KERNELS:
38
  return contextlib.nullcontext()
models/medsam2_inference.py CHANGED
@@ -1,11 +1,21 @@
 
 
 
 
 
 
1
  import numpy as np
2
  import cv2
3
  import sys
4
- import os
5
  import traceback
 
 
6
 
7
  _medsam2_model = None
8
  _medsam2_predictor = None
 
 
 
9
 
10
  MEDSAM2_PATHS = [
11
  os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'medsam2_pkg'),
@@ -40,76 +50,108 @@ def _get_device():
40
  device = "cuda"
41
  print(f"[MedSAM2] Using CUDA device: {torch.cuda.get_device_name(0)}")
42
  print(f"[MedSAM2] CUDA memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
 
43
  else:
44
  device = "cpu"
45
- print("[MedSAM2] Using CPU device")
46
  return device
47
 
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def load_medsam2_model():
50
  global _medsam2_model, _medsam2_predictor
51
 
52
  if _medsam2_predictor is not None:
53
  return _medsam2_predictor
54
 
55
- if not is_medsam2_available():
56
- print("[MedSAM2] Dependencies not available")
57
- return None
58
 
59
- from models.checkpoint_manager import CheckpointManager
60
-
61
- checkpoint_path = CheckpointManager.get_medsam2_checkpoint()
62
- if checkpoint_path is None:
63
- print("[MedSAM2] Checkpoint not available")
64
- return None
65
-
66
- try:
67
- import torch
68
- device = _get_device()
69
- print(f"[MedSAM2] Loading model on device: {device}")
70
- print(f"[MedSAM2] Checkpoint: {checkpoint_path}")
71
 
72
- _ensure_medsam2_path()
73
 
74
- try:
75
- from sam2.build_sam import build_sam2
76
- from sam2.sam2_image_predictor import SAM2ImagePredictor
77
- except ImportError as e:
78
- print(f"[MedSAM2] SAM2 library not importable: {e}")
79
  return None
80
 
81
- medsam2_path = _find_medsam2_path()
82
- config_dir = os.path.join(medsam2_path, 'sam2', 'configs')
83
- config_yaml = os.path.join(config_dir, 'sam2.1_hiera_t512.yaml')
84
- if not os.path.exists(config_yaml):
85
- yaml_files = [f for f in os.listdir(config_dir)] if os.path.isdir(config_dir) else []
86
- print(f"[MedSAM2] Config not found at {config_yaml}, available: {yaml_files}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  return None
88
 
89
- abs_config = '/' + os.path.abspath(config_yaml)
90
-
91
- os.environ["SAM2_ALLOW_ALL_KERNELS"] = "1"
92
-
93
- with torch.no_grad():
94
- _medsam2_model = build_sam2(
95
- abs_config,
96
- ckpt_path=str(checkpoint_path),
97
- device=device
98
- )
99
-
100
- _medsam2_predictor = SAM2ImagePredictor(_medsam2_model)
101
-
102
- print(f"[MedSAM2] Model loaded successfully on {device}")
103
- print(f"[MedSAM2] Model device: {_medsam2_predictor.device}")
104
- return _medsam2_predictor
105
-
106
- except Exception as e:
107
- print(f"[MedSAM2] Failed to load model: {e}")
108
- traceback.print_exc()
109
- _medsam2_model = None
110
- _medsam2_predictor = None
111
- return None
112
-
113
 
114
  def segment_with_medsam2(image, click_x, click_y):
115
  import torch
@@ -127,7 +169,6 @@ def segment_with_medsam2(image, click_x, click_y):
127
 
128
  click_x = int(max(0, min(click_x, img_w - 1)))
129
  click_y = int(max(0, min(click_y, img_h - 1)))
130
- print(f"[MedSAM2] Clamped point: ({click_x}, {click_y})")
131
 
132
  if len(image.shape) == 2:
133
  image_rgb = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
@@ -136,34 +177,36 @@ def segment_with_medsam2(image, click_x, click_y):
136
  else:
137
  image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
138
 
139
- print(f"[MedSAM2] RGB image: shape={image_rgb.shape}, dtype={image_rgb.dtype}, range=[{image_rgb.min()}, {image_rgb.max()}]")
140
-
141
  if image_rgb.dtype != np.uint8:
142
  if image_rgb.max() <= 1.0:
143
  image_rgb = (image_rgb * 255).astype(np.uint8)
144
  else:
145
  image_rgb = image_rgb.astype(np.uint8)
146
- print(f"[MedSAM2] Converted to uint8, range=[{image_rgb.min()}, {image_rgb.max()}]")
147
 
148
- print("[MedSAM2] Setting image on predictor...")
149
- with torch.no_grad():
150
- predictor.set_image(image_rgb)
151
- print("[MedSAM2] Image set successfully")
152
 
153
  point_coords = np.array([[click_x, click_y]], dtype=np.float32)
154
  point_labels = np.array([1], dtype=np.int32)
155
- print(f"[MedSAM2] Point coords: {point_coords}, labels: {point_labels}")
156
- print(f"[MedSAM2] Point coords dtype: {point_coords.dtype}, labels dtype: {point_labels.dtype}")
157
 
158
- print("[MedSAM2] Running predict()...")
159
- with torch.no_grad():
 
 
 
 
 
 
 
 
160
  masks, scores, logits = predictor.predict(
161
  point_coords=point_coords,
162
  point_labels=point_labels,
163
  multimask_output=True
164
  )
165
-
166
- print(f"[MedSAM2] predict() returned: masks type={type(masks)}, scores type={type(scores)}")
167
 
168
  if masks is None:
169
  print("[MedSAM2] ERROR: predict() returned None for masks")
@@ -171,15 +214,24 @@ def segment_with_medsam2(image, click_x, click_y):
171
 
172
  print(f"[MedSAM2] Masks shape: {masks.shape}, dtype: {masks.dtype}")
173
  print(f"[MedSAM2] Scores: {scores}")
174
- print(f"[MedSAM2] Logits shape: {logits.shape}")
175
 
176
  if len(masks) == 0:
177
  print("[MedSAM2] ERROR: predict() returned empty masks array")
178
  return None
179
 
 
 
 
 
 
 
 
 
 
 
180
  for i, (mask, score) in enumerate(zip(masks, scores)):
181
  nonzero = np.count_nonzero(mask)
182
- print(f"[MedSAM2] Mask {i}: shape={mask.shape}, nonzero_pixels={nonzero}, score={score:.4f}")
183
 
184
  from utils.image_processing import postprocess_mask
185
 
@@ -193,14 +245,12 @@ def segment_with_medsam2(image, click_x, click_y):
193
  'score': float(score),
194
  'area': area
195
  })
196
- print(f"[MedSAM2] Processed mask {i}: area={area} pixels, score={float(score):.4f}")
197
 
198
  mask_list.sort(key=lambda m: m['area'])
199
 
200
  total_area = sum(m['area'] for m in mask_list)
201
  if total_area == 0:
202
- print("[MedSAM2] WARNING: All masks have zero area after postprocessing")
203
- print("[MedSAM2] Returning raw masks without postprocessing cleanup")
204
  mask_list = []
205
  for i, (mask, score) in enumerate(zip(masks, scores)):
206
  binary = (mask.astype(np.uint8)) * 255
@@ -212,7 +262,11 @@ def segment_with_medsam2(image, click_x, click_y):
212
  })
213
  mask_list.sort(key=lambda m: m['area'])
214
 
215
- print(f"[MedSAM2] Segmentation complete: {len(mask_list)} masks returned")
 
 
 
 
216
  return mask_list
217
 
218
  except Exception as e:
 
1
+ import os
2
+ os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "0"
3
+ os.environ["SAM2_ALLOW_ALL_KERNELS"] = "1"
4
+ os.environ["U_FLASH_ATTN"] = "0"
5
+ os.environ["MATH_KERNEL_ON"] = "0"
6
+
7
  import numpy as np
8
  import cv2
9
  import sys
 
10
  import traceback
11
+ import time
12
+ import threading
13
 
14
  _medsam2_model = None
15
  _medsam2_predictor = None
16
+ _model_lock = threading.Lock()
17
+
18
+ MAX_INPUT_SIZE = 512
19
 
20
  MEDSAM2_PATHS = [
21
  os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'medsam2_pkg'),
 
50
  device = "cuda"
51
  print(f"[MedSAM2] Using CUDA device: {torch.cuda.get_device_name(0)}")
52
  print(f"[MedSAM2] CUDA memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
53
+ print(f"[MedSAM2] CUDA capability: {torch.cuda.get_device_properties(0).major}.{torch.cuda.get_device_properties(0).minor}")
54
  else:
55
  device = "cpu"
56
+ print("[MedSAM2] Using CPU device (no CUDA available)")
57
  return device
58
 
59
 
60
+ def _resize_for_model(image_rgb, click_x, click_y):
61
+ h, w = image_rgb.shape[:2]
62
+ if max(h, w) <= MAX_INPUT_SIZE:
63
+ return image_rgb, click_x, click_y, 1.0
64
+
65
+ scale = MAX_INPUT_SIZE / max(h, w)
66
+ new_w = int(w * scale)
67
+ new_h = int(h * scale)
68
+ resized = cv2.resize(image_rgb, (new_w, new_h), interpolation=cv2.INTER_AREA)
69
+ new_click_x = int(click_x * scale)
70
+ new_click_y = int(click_y * scale)
71
+ new_click_x = max(0, min(new_click_x, new_w - 1))
72
+ new_click_y = max(0, min(new_click_y, new_h - 1))
73
+ print(f"[MedSAM2] Resized input: {w}x{h} -> {new_w}x{new_h} (scale={scale:.3f})")
74
+ print(f"[MedSAM2] Scaled click: ({click_x},{click_y}) -> ({new_click_x},{new_click_y})")
75
+ return resized, new_click_x, new_click_y, scale
76
+
77
+
78
  def load_medsam2_model():
79
  global _medsam2_model, _medsam2_predictor
80
 
81
  if _medsam2_predictor is not None:
82
  return _medsam2_predictor
83
 
84
+ with _model_lock:
85
+ if _medsam2_predictor is not None:
86
+ return _medsam2_predictor
87
 
88
+ if not is_medsam2_available():
89
+ print("[MedSAM2] Dependencies not available")
90
+ return None
 
 
 
 
 
 
 
 
 
91
 
92
+ from models.checkpoint_manager import CheckpointManager
93
 
94
+ checkpoint_path = CheckpointManager.get_medsam2_checkpoint()
95
+ if checkpoint_path is None:
96
+ print("[MedSAM2] Checkpoint not available")
 
 
97
  return None
98
 
99
+ try:
100
+ import torch
101
+ device = _get_device()
102
+ print(f"[MedSAM2] Loading model on device: {device}")
103
+ print(f"[MedSAM2] Checkpoint: {checkpoint_path}")
104
+ print(f"[MedSAM2] PyTorch version: {torch.__version__}")
105
+ print(f"[MedSAM2] CUDA available: {torch.cuda.is_available()}")
106
+
107
+ _ensure_medsam2_path()
108
+
109
+ try:
110
+ from sam2.build_sam import build_sam2
111
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
112
+ except ImportError as e:
113
+ print(f"[MedSAM2] SAM2 library not importable: {e}")
114
+ return None
115
+
116
+ medsam2_path = _find_medsam2_path()
117
+ config_dir = os.path.join(medsam2_path, 'sam2', 'configs')
118
+ config_yaml = os.path.join(config_dir, 'sam2.1_hiera_t512.yaml')
119
+ if not os.path.exists(config_yaml):
120
+ yaml_files = [f for f in os.listdir(config_dir)] if os.path.isdir(config_dir) else []
121
+ print(f"[MedSAM2] Config not found at {config_yaml}, available: {yaml_files}")
122
+ return None
123
+
124
+ abs_config = '/' + os.path.abspath(config_yaml)
125
+
126
+ t0 = time.time()
127
+ with torch.inference_mode():
128
+ _medsam2_model = build_sam2(
129
+ abs_config,
130
+ ckpt_path=str(checkpoint_path),
131
+ device=device
132
+ )
133
+ load_time = time.time() - t0
134
+
135
+ _medsam2_predictor = SAM2ImagePredictor(_medsam2_model)
136
+
137
+ print(f"[MedSAM2] Model loaded in {load_time:.1f}s on {device}")
138
+ print(f"[MedSAM2] Model device: {_medsam2_predictor.device}")
139
+ print(f"[MedSAM2] Model image_size: {_medsam2_model.image_size}")
140
+
141
+ if device == "cuda":
142
+ mem_alloc = torch.cuda.memory_allocated() / 1e9
143
+ mem_reserved = torch.cuda.memory_reserved() / 1e9
144
+ print(f"[MedSAM2] GPU memory: allocated={mem_alloc:.2f}GB, reserved={mem_reserved:.2f}GB")
145
+
146
+ return _medsam2_predictor
147
+
148
+ except Exception as e:
149
+ print(f"[MedSAM2] Failed to load model: {e}")
150
+ traceback.print_exc()
151
+ _medsam2_model = None
152
+ _medsam2_predictor = None
153
  return None
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  def segment_with_medsam2(image, click_x, click_y):
157
  import torch
 
169
 
170
  click_x = int(max(0, min(click_x, img_w - 1)))
171
  click_y = int(max(0, min(click_y, img_h - 1)))
 
172
 
173
  if len(image.shape) == 2:
174
  image_rgb = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
 
177
  else:
178
  image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
179
 
 
 
180
  if image_rgb.dtype != np.uint8:
181
  if image_rgb.max() <= 1.0:
182
  image_rgb = (image_rgb * 255).astype(np.uint8)
183
  else:
184
  image_rgb = image_rgb.astype(np.uint8)
 
185
 
186
+ image_rgb, click_x, click_y, scale = _resize_for_model(image_rgb, click_x, click_y)
187
+
188
+ print(f"[MedSAM2] Final input: {image_rgb.shape}, click=({click_x},{click_y})")
 
189
 
190
  point_coords = np.array([[click_x, click_y]], dtype=np.float32)
191
  point_labels = np.array([1], dtype=np.int32)
 
 
192
 
193
+ print("[MedSAM2] Setting image on predictor...")
194
+ t0 = time.time()
195
+ with torch.inference_mode():
196
+ predictor.set_image(image_rgb)
197
+ set_time = time.time() - t0
198
+ print(f"[MedSAM2] Image set in {set_time:.2f}s")
199
+
200
+ print(f"[MedSAM2] Running predict(): coords={point_coords}, labels={point_labels}")
201
+ t0 = time.time()
202
+ with torch.inference_mode():
203
  masks, scores, logits = predictor.predict(
204
  point_coords=point_coords,
205
  point_labels=point_labels,
206
  multimask_output=True
207
  )
208
+ pred_time = time.time() - t0
209
+ print(f"[MedSAM2] predict() completed in {pred_time:.2f}s")
210
 
211
  if masks is None:
212
  print("[MedSAM2] ERROR: predict() returned None for masks")
 
214
 
215
  print(f"[MedSAM2] Masks shape: {masks.shape}, dtype: {masks.dtype}")
216
  print(f"[MedSAM2] Scores: {scores}")
 
217
 
218
  if len(masks) == 0:
219
  print("[MedSAM2] ERROR: predict() returned empty masks array")
220
  return None
221
 
222
+ if scale < 1.0:
223
+ orig_h, orig_w = img_h, img_w
224
+ upscaled_masks = []
225
+ for m in masks:
226
+ m_uint8 = m.astype(np.uint8) * 255
227
+ m_up = cv2.resize(m_uint8, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)
228
+ upscaled_masks.append(m_up > 127)
229
+ masks = np.array(upscaled_masks)
230
+ print(f"[MedSAM2] Upscaled masks back to {orig_w}x{orig_h}")
231
+
232
  for i, (mask, score) in enumerate(zip(masks, scores)):
233
  nonzero = np.count_nonzero(mask)
234
+ print(f"[MedSAM2] Mask {i}: nonzero={nonzero}, score={score:.4f}")
235
 
236
  from utils.image_processing import postprocess_mask
237
 
 
245
  'score': float(score),
246
  'area': area
247
  })
 
248
 
249
  mask_list.sort(key=lambda m: m['area'])
250
 
251
  total_area = sum(m['area'] for m in mask_list)
252
  if total_area == 0:
253
+ print("[MedSAM2] WARNING: All masks zero area after postprocessing, returning raw")
 
254
  mask_list = []
255
  for i, (mask, score) in enumerate(zip(masks, scores)):
256
  binary = (mask.astype(np.uint8)) * 255
 
262
  })
263
  mask_list.sort(key=lambda m: m['area'])
264
 
265
+ print(f"[MedSAM2] Segmentation complete: {len(mask_list)} masks, total time={set_time + pred_time:.2f}s")
266
+
267
+ if predictor.device.type == "cuda":
268
+ torch.cuda.empty_cache()
269
+
270
  return mask_list
271
 
272
  except Exception as e: