# -*- coding: utf-8 -*- """ MedSAM2 integration module (consolidated under tools.echo). Provides MedSAM2VideoSegmenter used by echo_tool_managers. """ import os import sys import torch import numpy as np import cv2 import tempfile from pathlib import Path from typing import Dict, Optional, Sequence, Tuple _current_dir = os.path.dirname(os.path.abspath(__file__)) class MedSAM2VideoSegmenter: """Clean MedSAM2 video segmentation class.""" def __init__(self, model_path: str = "checkpoints/MedSAM2_US_Heart.pt"): self.model_path = self._resolve_model_path(model_path) self.predictor = None self._initialize_predictor() def _resolve_sam2_paths(self): candidates = [] local_tool_repos = os.path.abspath(os.path.join(_current_dir, "..", "..", "tool_repos")) if os.path.isdir(local_tool_repos): for repo_name in ("MedSAM2-main", "MedSAM2"): repo_path = os.path.join(local_tool_repos, repo_name) candidates.append(repo_path) workspace_root = os.getenv("ECHO_WORKSPACE_ROOT") if workspace_root: candidates.append(os.path.join(workspace_root, "MedSAM2-main")) for base in candidates: sam2_root = os.path.join(base, "sam2") configs_dir = os.path.join(sam2_root, "configs") if os.path.isdir(configs_dir): if base not in sys.path: sys.path.insert(0, base) return {"root": sam2_root, "configs": configs_dir} raise FileNotFoundError("Could not locate sam2/configs directory. Ensure tool_repos/MedSAM2-main is available.") def _resolve_model_path(self, provided_path: str) -> str: """Resolve model checkpoint absolute path from common locations.""" if provided_path and os.path.isabs(provided_path) and os.path.exists(provided_path): return provided_path candidates = [] # Provided relative if provided_path: candidates.append(os.path.abspath(os.path.join(_current_dir, provided_path))) new_agent_root = os.path.abspath(os.path.join(_current_dir, "..", "..", "..")) candidates.append(os.path.abspath(os.path.join(new_agent_root, provided_path))) # Known defaults new_agent_root = os.path.abspath(os.path.join(_current_dir, "..", "..", "..")) candidates.append(os.path.join(new_agent_root, "model_weights", "MedSAM2_US_Heart.pt")) candidates.append(os.path.join(new_agent_root, "checkpoints", "MedSAM2_US_Heart.pt")) workspace_root = os.getenv("ECHO_WORKSPACE_ROOT") if workspace_root: candidates.append(os.path.join(workspace_root, "new_agent", "model_weights", "MedSAM2_US_Heart.pt")) candidates.append(os.path.join(workspace_root, "new_agent", "checkpoints", "MedSAM2_US_Heart.pt")) for c in candidates: if os.path.exists(c): return c raise FileNotFoundError(f"Model file not found. Tried: {', '.join(candidates)}") def _initialize_predictor(self) -> None: try: paths = self._resolve_sam2_paths() configs_dir = paths["configs"] base_dir = os.path.dirname(paths["root"]) # parent of sam2 if not os.path.exists(self.model_path): raise FileNotFoundError(f"Model file not found: {self.model_path}") from sam2.build_sam import build_sam2_video_predictor config_file = "sam2.1_hiera_t512.yaml" if not os.path.exists(os.path.join(configs_dir, config_file)): raise FileNotFoundError(f"Missing config: {os.path.join(configs_dir, config_file)}") # Use the original build_sam2_video_predictor function but with proper path setup prev_cwd = os.getcwd() try: os.chdir(base_dir) from hydra.core.global_hydra import GlobalHydra from hydra import initialize # Force clear any existing Hydra instance try: GlobalHydra.instance().clear() except: pass # Initialize Hydra with the correct config path rel_config_path = os.path.relpath(configs_dir, base_dir) with initialize(config_path=rel_config_path, version_base=None): # Use the original build_sam2_video_predictor function self.predictor = build_sam2_video_predictor( config_file=config_file, ckpt_path=self.model_path, device="cuda" if torch.cuda.is_available() else "cpu", ) finally: os.chdir(prev_cwd) except Exception as e: raise RuntimeError(f"MedSAM2 initialization failed: {e}") def _load_prompt_masks( self, mask_path: str, frame_shape: Tuple[int, int], label_value: Optional[int] = None, label_map: Optional[Dict[int, int]] = None, frame_index: int = 0, ) -> Dict[int, np.ndarray]: """Load prompt masks from annotation file or directory. Returns mapping of object_id -> boolean mask aligned to the requested frame. """ if not mask_path: raise ValueError("mask_path must be provided when using mask prompts") source = Path(mask_path) if source.is_dir(): # Follow MedSAM2 convention: frame files are zero-padded PNGs candidate = source / f"{frame_index:04d}.png" else: candidate = source if not candidate.exists(): raise FileNotFoundError(f"Prompt mask not found: {candidate}") mask = cv2.imread(str(candidate), cv2.IMREAD_GRAYSCALE) if mask is None: raise RuntimeError(f"Failed to read prompt mask: {candidate}") target_height, target_width = frame_shape if mask.shape != frame_shape: mask = cv2.resize( mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST, ) prompts: Dict[int, np.ndarray] = {} if label_map: for pixel_value, obj_id in label_map.items(): prompts[int(obj_id)] = (mask == pixel_value) elif label_value is not None: prompts[1] = (mask == label_value) else: prompts[1] = mask > 0 for obj_id, obj_mask in prompts.items(): prompts[obj_id] = obj_mask.astype(np.uint8).astype(bool) if not prompts: raise RuntimeError("No prompt objects extracted from mask") return prompts def segment_video( self, frames, target_name: str = "LV", *, prompt_mask_path: Optional[str] = None, prompt_mask_label: Optional[int] = None, prompt_label_map: Optional[Dict[int, int]] = None, prompt_points: Optional[Sequence[Tuple[float, float]]] = None, prompt_box: Optional[Tuple[float, float, float, float]] = None, palette: Optional[Dict[int, Tuple[int, int, int]]] = None, progress_callback=None, ): try: with tempfile.TemporaryDirectory() as temp_dir: for i, frame in enumerate(frames): cv2.imwrite(os.path.join(temp_dir, f"{i:07d}.jpg"), frame) state = self.predictor.init_state(video_path=temp_dir) first_frame = frames[0] h, w = first_frame.shape[:2] if prompt_mask_path: prompt_masks = self._load_prompt_masks( prompt_mask_path, (h, w), label_value=prompt_mask_label, label_map=prompt_label_map, frame_index=0, ) for obj_id, init_mask in prompt_masks.items(): self.predictor.add_new_mask( inference_state=state, frame_idx=0, obj_id=int(obj_id), mask=init_mask, ) elif prompt_points: abs_points = np.array( [[int(px * w), int(py * h)] for px, py in prompt_points], dtype=np.int32, ) point_labels = np.ones(len(abs_points), dtype=np.int32) self.predictor.add_new_points( inference_state=state, frame_idx=0, obj_id=1, points=abs_points, labels=point_labels, ) elif prompt_box: x1, y1, x2, y2 = prompt_box abs_box = np.array( [ int(x1 * w), int(y1 * h), int(x2 * w), int(y2 * h), ], dtype=np.int32, ) self.predictor.add_new_points_or_box( inference_state=state, frame_idx=0, obj_id=1, box=abs_box, ) else: init = np.zeros((h, w), dtype=np.uint8) if target_name == "LV": cx, cy = int(w * 0.4), int(h * 0.5) cv2.ellipse(init, (cx, cy), (w // 8, h // 6), 0, 0, 360, 255, -1) elif target_name == "RV": cx, cy = int(w * 0.6), int(h * 0.5) cv2.ellipse(init, (cx, cy), (w // 10, h // 7), 0, 0, 360, 255, -1) elif target_name == "LA": cx, cy = int(w * 0.4), int(h * 0.3) cv2.ellipse(init, (cx, cy), (w // 12, h // 8), 0, 0, 360, 255, -1) elif target_name == "RA": cx, cy = int(w * 0.6), int(h * 0.3) cv2.ellipse(init, (cx, cy), (w // 12, h // 8), 0, 0, 360, 255, -1) else: cx, cy = w // 2, h // 2 cv2.circle(init, (cx, cy), min(w, h) // 8, 255, -1) init_mask = init.astype(bool) self.predictor.add_new_mask( inference_state=state, frame_idx=0, obj_id=1, mask=init_mask, ) masks = [] total_frames = len(frames) processed = 0 for frame_idx, obj_ids, mask_logits in self.predictor.propagate_in_video(state): processed += 1 if progress_callback: progress_callback(30 + int((processed / total_frames) * 60), f"Processing frame {processed}/{total_frames}") if len(mask_logits) > 0: mask = (mask_logits[0] > 0.0).cpu().numpy() if mask.ndim == 3 and mask.shape[0] == 1: mask = mask[0] if mask.shape != (h, w): mask = cv2.resize(mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool) masks.append(mask.astype(np.uint8) * 255) else: masks.append(np.zeros((h, w), dtype=np.uint8)) while len(masks) < total_frames: masks.append(np.zeros((h, w), dtype=np.uint8)) return masks except Exception as e: raise RuntimeError(f"MedSAM2 segmentation failed: {e}")