Echo / tools /echo /enhanced_medsam2_integration.py
moein99's picture
Initial Echo Space
8f51ef2
# -*- coding: utf-8 -*-
"""
Enhanced MedSAM2 integration (consolidated under tools.echo).
Provides EnhancedMedSAM2VideoSegmenter used for multi-structure segmentation.
"""
import os
import sys
import torch
import numpy as np
import cv2
import tempfile
from typing import List, Dict, Optional, Any
_current_dir = os.path.dirname(os.path.abspath(__file__))
_parent_dir = os.path.dirname(os.path.dirname(_current_dir))
if _parent_dir not in sys.path:
sys.path.insert(0, _parent_dir)
class EnhancedMedSAM2VideoSegmenter:
"""Utility wrapper that runs MedSAM2 on echo videos and returns combined overlays."""
DEFAULT_STRUCTURES = ['LV', 'MYO', 'LA', 'RV', 'RA']
CARDIAC_STRUCTURES = {
'LV': {'name': 'Left Ventricle', 'color': (0, 255, 0)},
'MYO': {'name': 'Myocardium', 'color': (255, 105, 180)},
'LA': {'name': 'Left Atrium', 'color': (0, 0, 255)},
'RV': {'name': 'Right Ventricle', 'color': (255, 0, 0)},
'RA': {'name': 'Right Atrium', 'color': (255, 255, 0)},
}
def __init__(self, model_path: str = "model_weights/MedSAM2_US_Heart.pt"):
self.model_path = self._resolve_model_path(model_path)
self.predictor = None
self._initialize_predictor()
def _resolve_sam2_paths(self) -> Dict[str, str]:
"""Resolve absolute paths for sam2 root and configs.
Tries multiple candidates and returns {'root': <sam2_root>, 'configs': <configs_dir>}.
"""
candidates = []
# Preferred: tool_repos bundled with the agent
local_tool_repos = os.path.abspath(os.path.join(_current_dir, "..", "..", "tool_repos"))
if os.path.isdir(local_tool_repos):
for repo_name in ("MedSAM2-main", "MedSAM2"):
repo_path = os.path.join(local_tool_repos, repo_name)
candidates.append(repo_path)
# Workspace override via env var
workspace_root = os.getenv("ECHO_WORKSPACE_ROOT")
if workspace_root:
candidates.append(os.path.join(workspace_root, "MedSAM2-main"))
for base in candidates:
sam2_root = os.path.join(base, "sam2")
configs_dir = os.path.join(sam2_root, "configs")
if os.path.isdir(configs_dir):
# Ensure import works: add base (parent of sam2) to sys.path
if base not in sys.path:
sys.path.insert(0, base)
return {"root": sam2_root, "configs": configs_dir}
raise FileNotFoundError("Could not locate sam2/configs directory. Ensure MedSAM2-main/sam2 is available.")
def _resolve_model_path(self, provided_path: Optional[str]) -> str:
"""Resolve model checkpoint to an absolute, existing path.
Tries common locations within the workspace when a relative or missing path is given.
"""
# If provided is absolute and exists, use it
if provided_path and os.path.isabs(provided_path) and os.path.exists(provided_path):
return provided_path
# Build candidate paths
candidates: List[str] = []
# If relative path is provided, try relative to new_agent root and current file dir
if provided_path:
candidates.append(os.path.abspath(os.path.join(_current_dir, provided_path)))
new_agent_root = os.path.abspath(os.path.join(_current_dir, "..", "..", ".."))
candidates.append(os.path.abspath(os.path.join(new_agent_root, provided_path)))
# Known default locations
new_agent_root = os.path.abspath(os.path.join(_current_dir, "..", "..", ".."))
candidates.append(os.path.join(new_agent_root, "model_weights", "MedSAM2_US_Heart.pt"))
candidates.append(os.path.join(new_agent_root, "checkpoints", "MedSAM2_US_Heart.pt"))
# Fall back to workspace-level possible locations
workspace_root = os.getenv("ECHO_WORKSPACE_ROOT")
if workspace_root:
candidates.append(os.path.join(workspace_root, "new_agent", "model_weights", "MedSAM2_US_Heart.pt"))
candidates.append(os.path.join(workspace_root, "new_agent", "checkpoints", "MedSAM2_US_Heart.pt"))
for candidate in candidates:
if candidate and os.path.exists(candidate):
return candidate
# If none found, raise clearly with attempted paths for debuggability
raise FileNotFoundError(
f"Model file not found. Tried: {', '.join(candidates)}"
)
def _initialize_predictor(self) -> None:
try:
paths = self._resolve_sam2_paths()
configs_dir = paths["configs"]
base_dir = os.path.dirname(paths["root"]) # parent of sam2
# Verify model
if not os.path.exists(self.model_path):
raise FileNotFoundError(f"Model file not found: {self.model_path}")
# Import after sys.path is prepared
from sam2.build_sam import build_sam2_video_predictor
config_file = "sam2.1_hiera_t512.yaml"
config_yaml = os.path.join(configs_dir, config_file)
if not os.path.exists(config_yaml):
raise FileNotFoundError(f"Missing config: {config_yaml}")
# Use the original build_sam2_video_predictor function but with proper path setup
prev_cwd = os.getcwd()
try:
os.chdir(base_dir)
from hydra.core.global_hydra import GlobalHydra
from hydra import initialize
# Force clear any existing Hydra instance
try:
GlobalHydra.instance().clear()
except:
pass
# Initialize Hydra with the correct config path
rel_config_path = os.path.relpath(configs_dir, base_dir)
with initialize(config_path=rel_config_path, version_base=None):
# Use the original build_sam2_video_predictor function
self.predictor = build_sam2_video_predictor(
config_file=config_file,
ckpt_path=self.model_path,
device="cuda" if torch.cuda.is_available() else "cpu",
)
finally:
os.chdir(prev_cwd)
except Exception as e:
raise RuntimeError(f"Enhanced MedSAM2 initialization failed: {e}")
def segment_video_multi_structure(
self,
frames: List[np.ndarray],
progress_callback=None,
initial_masks: Optional[Dict[str, np.ndarray]] = None,
) -> Dict[str, Any]:
"""Run MedSAM2 once and propagate a fixed set of cardiac structure prompts.
If initial_masks is provided, it should map structure codes (e.g., 'LV','MYO','LA','RV','RA')
to 2D mask arrays (H×W, non-zero foreground) for frame 0. These will seed the predictor; any
missing structures fall back to coarse auto prompts.
"""
try:
structures = list(self.DEFAULT_STRUCTURES)
with tempfile.TemporaryDirectory() as temp_dir:
for i, frame in enumerate(frames):
cv2.imwrite(os.path.join(temp_dir, f"{i:07d}.jpg"), frame)
state = self.predictor.init_state(video_path=temp_dir)
h, w = frames[0].shape[:2]
# Normalize provided initial masks (resize, binarize) if given
provided: Dict[str, np.ndarray] = {}
if isinstance(initial_masks, dict):
for key, arr in initial_masks.items():
s = str(key).upper()
if s in structures and isinstance(arr, np.ndarray):
m = arr
if m.ndim > 2:
m = m.squeeze()
if m.shape != (h, w):
m = cv2.resize(m.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST)
m = (m > 0).astype(bool)
provided[s] = m
for obj_id, structure in enumerate(structures, start=1):
if structure in provided:
mask_bool = provided[structure]
else:
mask_bool = self._initial_prompt_mask(structure, h, w).astype(bool)
self.predictor.add_new_mask(
inference_state=state,
frame_idx=0,
obj_id=obj_id,
mask=mask_bool
)
all_masks: Dict[int, Dict[int, np.ndarray]] = {}
total = len(frames)
processed = 0
for frame_idx, obj_ids, mask_logits in self.predictor.propagate_in_video(state):
processed += 1
if progress_callback:
progress_callback(20 + int((processed / total) * 70), f"Processing frame {processed}/{total}")
frame_masks: Dict[int, np.ndarray] = {}
for i, obj_id in enumerate(obj_ids):
if i < len(mask_logits):
mask_array = (mask_logits[i] > -0.5).cpu().numpy()
if mask_array.ndim == 3 and mask_array.shape[0] == 1:
mask_array = mask_array[0]
if mask_array.shape != (h, w):
mask_array = cv2.resize(mask_array.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST)
frame_masks[obj_id] = (mask_array > 0).astype(np.uint8) * 255
else:
frame_masks[obj_id] = np.zeros((h, w), dtype=np.uint8)
all_masks[frame_idx] = frame_masks
return {
'masks': all_masks,
'structures': structures,
'structure_info': {s: self.CARDIAC_STRUCTURES.get(s, {'name': s, 'color': (0, 255, 0)}) for s in structures},
'total_frames': len(frames),
}
except Exception as e:
raise RuntimeError(f"Enhanced MedSAM2 segmentation failed: {e}")
def _initial_prompt_mask(self, structure: str, height: int, width: int) -> np.ndarray:
"""Create a coarse initial mask for the requested structure on the first frame."""
mask = np.zeros((height, width), dtype=np.uint8)
cx_left, cx_right = int(width * 0.42), int(width * 0.58)
cy_mid = int(height * 0.52)
cy_atria = int(height * 0.35)
if structure == 'LV':
cv2.ellipse(mask, (cx_left, cy_mid), (width // 8, height // 6), 0, 0, 360, 255, -1)
elif structure == 'MYO':
outer = np.zeros_like(mask)
inner = np.zeros_like(mask)
cv2.ellipse(outer, (cx_left, cy_mid), (width // 7, height // 5), 0, 0, 360, 255, -1)
cv2.ellipse(inner, (cx_left, cy_mid), (width // 10, height // 7), 0, 0, 360, 255, -1)
ring = cv2.subtract(outer, inner)
mask[ring > 0] = 255
elif structure == 'LA':
cv2.ellipse(mask, (cx_left, cy_atria), (width // 12, height // 9), 0, 0, 360, 255, -1)
elif structure == 'RV':
cv2.ellipse(mask, (cx_right, cy_mid), (width // 9, height // 6), 0, 0, 360, 255, -1)
elif structure == 'RA':
cv2.ellipse(mask, (cx_right, cy_atria), (width // 12, height // 9), 0, 0, 360, 255, -1)
else:
cv2.circle(mask, (width // 2, height // 2), min(width, height) // 6, 255, -1)
return mask
@staticmethod
def create_combined_overlay(frame: np.ndarray, frame_masks: Dict[int, np.ndarray], structures: List[str]) -> np.ndarray:
overlay = frame.copy()
for obj_id, mask in frame_masks.items():
if 1 <= obj_id <= len(structures):
sid = structures[obj_id - 1]
color = EnhancedMedSAM2VideoSegmenter.CARDIAC_STRUCTURES.get(sid, {}).get('color', (0, 255, 0))
colored = np.zeros_like(frame)
colored[mask > 0] = color
overlay = cv2.addWeighted(overlay, 0.7, colored, 0.3, 0)
return overlay