| |
| """ |
| Enhanced MedSAM2 integration (consolidated under tools.echo). |
| Provides EnhancedMedSAM2VideoSegmenter used for multi-structure segmentation. |
| """ |
|
|
| import os |
| import sys |
| import torch |
| import numpy as np |
| import cv2 |
| import tempfile |
| from typing import List, Dict, Optional, Any |
|
|
| _current_dir = os.path.dirname(os.path.abspath(__file__)) |
| _parent_dir = os.path.dirname(os.path.dirname(_current_dir)) |
| if _parent_dir not in sys.path: |
| sys.path.insert(0, _parent_dir) |
|
|
|
|
| class EnhancedMedSAM2VideoSegmenter: |
| """Utility wrapper that runs MedSAM2 on echo videos and returns combined overlays.""" |
|
|
| DEFAULT_STRUCTURES = ['LV', 'MYO', 'LA', 'RV', 'RA'] |
|
|
| CARDIAC_STRUCTURES = { |
| 'LV': {'name': 'Left Ventricle', 'color': (0, 255, 0)}, |
| 'MYO': {'name': 'Myocardium', 'color': (255, 105, 180)}, |
| 'LA': {'name': 'Left Atrium', 'color': (0, 0, 255)}, |
| 'RV': {'name': 'Right Ventricle', 'color': (255, 0, 0)}, |
| 'RA': {'name': 'Right Atrium', 'color': (255, 255, 0)}, |
| } |
|
|
| def __init__(self, model_path: str = "model_weights/MedSAM2_US_Heart.pt"): |
| self.model_path = self._resolve_model_path(model_path) |
| self.predictor = None |
| self._initialize_predictor() |
|
|
| def _resolve_sam2_paths(self) -> Dict[str, str]: |
| """Resolve absolute paths for sam2 root and configs. |
| Tries multiple candidates and returns {'root': <sam2_root>, 'configs': <configs_dir>}. |
| """ |
| candidates = [] |
| |
| local_tool_repos = os.path.abspath(os.path.join(_current_dir, "..", "..", "tool_repos")) |
| if os.path.isdir(local_tool_repos): |
| for repo_name in ("MedSAM2-main", "MedSAM2"): |
| repo_path = os.path.join(local_tool_repos, repo_name) |
| candidates.append(repo_path) |
| |
| workspace_root = os.getenv("ECHO_WORKSPACE_ROOT") |
| if workspace_root: |
| candidates.append(os.path.join(workspace_root, "MedSAM2-main")) |
|
|
| for base in candidates: |
| sam2_root = os.path.join(base, "sam2") |
| configs_dir = os.path.join(sam2_root, "configs") |
| if os.path.isdir(configs_dir): |
| |
| if base not in sys.path: |
| sys.path.insert(0, base) |
| return {"root": sam2_root, "configs": configs_dir} |
| raise FileNotFoundError("Could not locate sam2/configs directory. Ensure MedSAM2-main/sam2 is available.") |
|
|
| def _resolve_model_path(self, provided_path: Optional[str]) -> str: |
| """Resolve model checkpoint to an absolute, existing path. |
| Tries common locations within the workspace when a relative or missing path is given. |
| """ |
| |
| if provided_path and os.path.isabs(provided_path) and os.path.exists(provided_path): |
| return provided_path |
|
|
| |
| candidates: List[str] = [] |
| |
| if provided_path: |
| candidates.append(os.path.abspath(os.path.join(_current_dir, provided_path))) |
| new_agent_root = os.path.abspath(os.path.join(_current_dir, "..", "..", "..")) |
| candidates.append(os.path.abspath(os.path.join(new_agent_root, provided_path))) |
|
|
| |
| new_agent_root = os.path.abspath(os.path.join(_current_dir, "..", "..", "..")) |
| candidates.append(os.path.join(new_agent_root, "model_weights", "MedSAM2_US_Heart.pt")) |
| candidates.append(os.path.join(new_agent_root, "checkpoints", "MedSAM2_US_Heart.pt")) |
|
|
| |
| workspace_root = os.getenv("ECHO_WORKSPACE_ROOT") |
| if workspace_root: |
| candidates.append(os.path.join(workspace_root, "new_agent", "model_weights", "MedSAM2_US_Heart.pt")) |
| candidates.append(os.path.join(workspace_root, "new_agent", "checkpoints", "MedSAM2_US_Heart.pt")) |
|
|
| for candidate in candidates: |
| if candidate and os.path.exists(candidate): |
| return candidate |
|
|
| |
| raise FileNotFoundError( |
| f"Model file not found. Tried: {', '.join(candidates)}" |
| ) |
|
|
| def _initialize_predictor(self) -> None: |
| try: |
| paths = self._resolve_sam2_paths() |
| configs_dir = paths["configs"] |
| base_dir = os.path.dirname(paths["root"]) |
| |
| if not os.path.exists(self.model_path): |
| raise FileNotFoundError(f"Model file not found: {self.model_path}") |
| |
| from sam2.build_sam import build_sam2_video_predictor |
| config_file = "sam2.1_hiera_t512.yaml" |
| config_yaml = os.path.join(configs_dir, config_file) |
| if not os.path.exists(config_yaml): |
| raise FileNotFoundError(f"Missing config: {config_yaml}") |
| |
| |
| prev_cwd = os.getcwd() |
| try: |
| os.chdir(base_dir) |
| from hydra.core.global_hydra import GlobalHydra |
| from hydra import initialize |
| |
| |
| try: |
| GlobalHydra.instance().clear() |
| except: |
| pass |
| |
| |
| rel_config_path = os.path.relpath(configs_dir, base_dir) |
| with initialize(config_path=rel_config_path, version_base=None): |
| |
| self.predictor = build_sam2_video_predictor( |
| config_file=config_file, |
| ckpt_path=self.model_path, |
| device="cuda" if torch.cuda.is_available() else "cpu", |
| ) |
| finally: |
| os.chdir(prev_cwd) |
| except Exception as e: |
| raise RuntimeError(f"Enhanced MedSAM2 initialization failed: {e}") |
|
|
| def segment_video_multi_structure( |
| self, |
| frames: List[np.ndarray], |
| progress_callback=None, |
| initial_masks: Optional[Dict[str, np.ndarray]] = None, |
| ) -> Dict[str, Any]: |
| """Run MedSAM2 once and propagate a fixed set of cardiac structure prompts. |
| |
| If initial_masks is provided, it should map structure codes (e.g., 'LV','MYO','LA','RV','RA') |
| to 2D mask arrays (H×W, non-zero foreground) for frame 0. These will seed the predictor; any |
| missing structures fall back to coarse auto prompts. |
| """ |
| try: |
| structures = list(self.DEFAULT_STRUCTURES) |
| with tempfile.TemporaryDirectory() as temp_dir: |
| for i, frame in enumerate(frames): |
| cv2.imwrite(os.path.join(temp_dir, f"{i:07d}.jpg"), frame) |
|
|
| state = self.predictor.init_state(video_path=temp_dir) |
| h, w = frames[0].shape[:2] |
|
|
| |
| provided: Dict[str, np.ndarray] = {} |
| if isinstance(initial_masks, dict): |
| for key, arr in initial_masks.items(): |
| s = str(key).upper() |
| if s in structures and isinstance(arr, np.ndarray): |
| m = arr |
| if m.ndim > 2: |
| m = m.squeeze() |
| if m.shape != (h, w): |
| m = cv2.resize(m.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST) |
| m = (m > 0).astype(bool) |
| provided[s] = m |
|
|
| for obj_id, structure in enumerate(structures, start=1): |
| if structure in provided: |
| mask_bool = provided[structure] |
| else: |
| mask_bool = self._initial_prompt_mask(structure, h, w).astype(bool) |
| self.predictor.add_new_mask( |
| inference_state=state, |
| frame_idx=0, |
| obj_id=obj_id, |
| mask=mask_bool |
| ) |
|
|
| all_masks: Dict[int, Dict[int, np.ndarray]] = {} |
| total = len(frames) |
| processed = 0 |
| for frame_idx, obj_ids, mask_logits in self.predictor.propagate_in_video(state): |
| processed += 1 |
| if progress_callback: |
| progress_callback(20 + int((processed / total) * 70), f"Processing frame {processed}/{total}") |
|
|
| frame_masks: Dict[int, np.ndarray] = {} |
| for i, obj_id in enumerate(obj_ids): |
| if i < len(mask_logits): |
| mask_array = (mask_logits[i] > -0.5).cpu().numpy() |
| if mask_array.ndim == 3 and mask_array.shape[0] == 1: |
| mask_array = mask_array[0] |
| if mask_array.shape != (h, w): |
| mask_array = cv2.resize(mask_array.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST) |
| frame_masks[obj_id] = (mask_array > 0).astype(np.uint8) * 255 |
| else: |
| frame_masks[obj_id] = np.zeros((h, w), dtype=np.uint8) |
| all_masks[frame_idx] = frame_masks |
|
|
| return { |
| 'masks': all_masks, |
| 'structures': structures, |
| 'structure_info': {s: self.CARDIAC_STRUCTURES.get(s, {'name': s, 'color': (0, 255, 0)}) for s in structures}, |
| 'total_frames': len(frames), |
| } |
| except Exception as e: |
| raise RuntimeError(f"Enhanced MedSAM2 segmentation failed: {e}") |
|
|
| def _initial_prompt_mask(self, structure: str, height: int, width: int) -> np.ndarray: |
| """Create a coarse initial mask for the requested structure on the first frame.""" |
| mask = np.zeros((height, width), dtype=np.uint8) |
| cx_left, cx_right = int(width * 0.42), int(width * 0.58) |
| cy_mid = int(height * 0.52) |
| cy_atria = int(height * 0.35) |
|
|
| if structure == 'LV': |
| cv2.ellipse(mask, (cx_left, cy_mid), (width // 8, height // 6), 0, 0, 360, 255, -1) |
| elif structure == 'MYO': |
| outer = np.zeros_like(mask) |
| inner = np.zeros_like(mask) |
| cv2.ellipse(outer, (cx_left, cy_mid), (width // 7, height // 5), 0, 0, 360, 255, -1) |
| cv2.ellipse(inner, (cx_left, cy_mid), (width // 10, height // 7), 0, 0, 360, 255, -1) |
| ring = cv2.subtract(outer, inner) |
| mask[ring > 0] = 255 |
| elif structure == 'LA': |
| cv2.ellipse(mask, (cx_left, cy_atria), (width // 12, height // 9), 0, 0, 360, 255, -1) |
| elif structure == 'RV': |
| cv2.ellipse(mask, (cx_right, cy_mid), (width // 9, height // 6), 0, 0, 360, 255, -1) |
| elif structure == 'RA': |
| cv2.ellipse(mask, (cx_right, cy_atria), (width // 12, height // 9), 0, 0, 360, 255, -1) |
| else: |
| cv2.circle(mask, (width // 2, height // 2), min(width, height) // 6, 255, -1) |
| return mask |
|
|
| @staticmethod |
| def create_combined_overlay(frame: np.ndarray, frame_masks: Dict[int, np.ndarray], structures: List[str]) -> np.ndarray: |
| overlay = frame.copy() |
| for obj_id, mask in frame_masks.items(): |
| if 1 <= obj_id <= len(structures): |
| sid = structures[obj_id - 1] |
| color = EnhancedMedSAM2VideoSegmenter.CARDIAC_STRUCTURES.get(sid, {}).get('color', (0, 255, 0)) |
| colored = np.zeros_like(frame) |
| colored[mask > 0] = color |
| overlay = cv2.addWeighted(overlay, 0.7, colored, 0.3, 0) |
| return overlay |
|
|