| | |
| | """ |
| | MedSAM2 integration module (consolidated under tools.echo). |
| | Provides MedSAM2VideoSegmenter used by echo_tool_managers. |
| | """ |
| |
|
| | import os |
| | import sys |
| | import torch |
| | import numpy as np |
| | import cv2 |
| | import tempfile |
| | from pathlib import Path |
| | from typing import Dict, Optional, Sequence, Tuple |
| |
|
| | _current_dir = os.path.dirname(os.path.abspath(__file__)) |
| |
|
| | class MedSAM2VideoSegmenter: |
| | """Clean MedSAM2 video segmentation class.""" |
| |
|
| | def __init__(self, model_path: str = "checkpoints/MedSAM2_US_Heart.pt"): |
| | self.model_path = self._resolve_model_path(model_path) |
| | self.predictor = None |
| | self._initialize_predictor() |
| |
|
| | def _resolve_sam2_paths(self): |
| | candidates = [] |
| | local_tool_repos = os.path.abspath(os.path.join(_current_dir, "..", "..", "tool_repos")) |
| | if os.path.isdir(local_tool_repos): |
| | for repo_name in ("MedSAM2-main", "MedSAM2"): |
| | repo_path = os.path.join(local_tool_repos, repo_name) |
| | candidates.append(repo_path) |
| | workspace_root = os.getenv("ECHO_WORKSPACE_ROOT") |
| | if workspace_root: |
| | candidates.append(os.path.join(workspace_root, "MedSAM2-main")) |
| | for base in candidates: |
| | sam2_root = os.path.join(base, "sam2") |
| | configs_dir = os.path.join(sam2_root, "configs") |
| | if os.path.isdir(configs_dir): |
| | if base not in sys.path: |
| | sys.path.insert(0, base) |
| | return {"root": sam2_root, "configs": configs_dir} |
| | raise FileNotFoundError("Could not locate sam2/configs directory. Ensure tool_repos/MedSAM2-main is available.") |
| |
|
| | def _resolve_model_path(self, provided_path: str) -> str: |
| | """Resolve model checkpoint absolute path from common locations.""" |
| | if provided_path and os.path.isabs(provided_path) and os.path.exists(provided_path): |
| | return provided_path |
| |
|
| | candidates = [] |
| | |
| | if provided_path: |
| | candidates.append(os.path.abspath(os.path.join(_current_dir, provided_path))) |
| | new_agent_root = os.path.abspath(os.path.join(_current_dir, "..", "..", "..")) |
| | candidates.append(os.path.abspath(os.path.join(new_agent_root, provided_path))) |
| |
|
| | |
| | new_agent_root = os.path.abspath(os.path.join(_current_dir, "..", "..", "..")) |
| | candidates.append(os.path.join(new_agent_root, "model_weights", "MedSAM2_US_Heart.pt")) |
| | candidates.append(os.path.join(new_agent_root, "checkpoints", "MedSAM2_US_Heart.pt")) |
| |
|
| | workspace_root = os.getenv("ECHO_WORKSPACE_ROOT") |
| | if workspace_root: |
| | candidates.append(os.path.join(workspace_root, "new_agent", "model_weights", "MedSAM2_US_Heart.pt")) |
| | candidates.append(os.path.join(workspace_root, "new_agent", "checkpoints", "MedSAM2_US_Heart.pt")) |
| |
|
| | for c in candidates: |
| | if os.path.exists(c): |
| | return c |
| |
|
| | raise FileNotFoundError(f"Model file not found. Tried: {', '.join(candidates)}") |
| |
|
| | def _initialize_predictor(self) -> None: |
| | try: |
| | paths = self._resolve_sam2_paths() |
| | configs_dir = paths["configs"] |
| | base_dir = os.path.dirname(paths["root"]) |
| | if not os.path.exists(self.model_path): |
| | raise FileNotFoundError(f"Model file not found: {self.model_path}") |
| | from sam2.build_sam import build_sam2_video_predictor |
| | config_file = "sam2.1_hiera_t512.yaml" |
| | if not os.path.exists(os.path.join(configs_dir, config_file)): |
| | raise FileNotFoundError(f"Missing config: {os.path.join(configs_dir, config_file)}") |
| | |
| | |
| | prev_cwd = os.getcwd() |
| | try: |
| | os.chdir(base_dir) |
| | from hydra.core.global_hydra import GlobalHydra |
| | from hydra import initialize |
| | |
| | |
| | try: |
| | GlobalHydra.instance().clear() |
| | except: |
| | pass |
| | |
| | |
| | rel_config_path = os.path.relpath(configs_dir, base_dir) |
| | with initialize(config_path=rel_config_path, version_base=None): |
| | |
| | self.predictor = build_sam2_video_predictor( |
| | config_file=config_file, |
| | ckpt_path=self.model_path, |
| | device="cuda" if torch.cuda.is_available() else "cpu", |
| | ) |
| | finally: |
| | os.chdir(prev_cwd) |
| | except Exception as e: |
| | raise RuntimeError(f"MedSAM2 initialization failed: {e}") |
| |
|
| | def _load_prompt_masks( |
| | self, |
| | mask_path: str, |
| | frame_shape: Tuple[int, int], |
| | label_value: Optional[int] = None, |
| | label_map: Optional[Dict[int, int]] = None, |
| | frame_index: int = 0, |
| | ) -> Dict[int, np.ndarray]: |
| | """Load prompt masks from annotation file or directory. |
| | |
| | Returns mapping of object_id -> boolean mask aligned to the requested frame. |
| | """ |
| |
|
| | if not mask_path: |
| | raise ValueError("mask_path must be provided when using mask prompts") |
| |
|
| | source = Path(mask_path) |
| | if source.is_dir(): |
| | |
| | candidate = source / f"{frame_index:04d}.png" |
| | else: |
| | candidate = source |
| |
|
| | if not candidate.exists(): |
| | raise FileNotFoundError(f"Prompt mask not found: {candidate}") |
| |
|
| | mask = cv2.imread(str(candidate), cv2.IMREAD_GRAYSCALE) |
| | if mask is None: |
| | raise RuntimeError(f"Failed to read prompt mask: {candidate}") |
| |
|
| | target_height, target_width = frame_shape |
| | if mask.shape != frame_shape: |
| | mask = cv2.resize( |
| | mask, |
| | (target_width, target_height), |
| | interpolation=cv2.INTER_NEAREST, |
| | ) |
| |
|
| | prompts: Dict[int, np.ndarray] = {} |
| |
|
| | if label_map: |
| | for pixel_value, obj_id in label_map.items(): |
| | prompts[int(obj_id)] = (mask == pixel_value) |
| | elif label_value is not None: |
| | prompts[1] = (mask == label_value) |
| | else: |
| | prompts[1] = mask > 0 |
| |
|
| | for obj_id, obj_mask in prompts.items(): |
| | prompts[obj_id] = obj_mask.astype(np.uint8).astype(bool) |
| |
|
| | if not prompts: |
| | raise RuntimeError("No prompt objects extracted from mask") |
| |
|
| | return prompts |
| |
|
| | def segment_video( |
| | self, |
| | frames, |
| | target_name: str = "LV", |
| | *, |
| | prompt_mask_path: Optional[str] = None, |
| | prompt_mask_label: Optional[int] = None, |
| | prompt_label_map: Optional[Dict[int, int]] = None, |
| | prompt_points: Optional[Sequence[Tuple[float, float]]] = None, |
| | prompt_box: Optional[Tuple[float, float, float, float]] = None, |
| | palette: Optional[Dict[int, Tuple[int, int, int]]] = None, |
| | progress_callback=None, |
| | ): |
| | try: |
| | with tempfile.TemporaryDirectory() as temp_dir: |
| | for i, frame in enumerate(frames): |
| | cv2.imwrite(os.path.join(temp_dir, f"{i:07d}.jpg"), frame) |
| | state = self.predictor.init_state(video_path=temp_dir) |
| | first_frame = frames[0] |
| | h, w = first_frame.shape[:2] |
| |
|
| | if prompt_mask_path: |
| | prompt_masks = self._load_prompt_masks( |
| | prompt_mask_path, |
| | (h, w), |
| | label_value=prompt_mask_label, |
| | label_map=prompt_label_map, |
| | frame_index=0, |
| | ) |
| | for obj_id, init_mask in prompt_masks.items(): |
| | self.predictor.add_new_mask( |
| | inference_state=state, |
| | frame_idx=0, |
| | obj_id=int(obj_id), |
| | mask=init_mask, |
| | ) |
| | elif prompt_points: |
| | abs_points = np.array( |
| | [[int(px * w), int(py * h)] for px, py in prompt_points], |
| | dtype=np.int32, |
| | ) |
| | point_labels = np.ones(len(abs_points), dtype=np.int32) |
| | self.predictor.add_new_points( |
| | inference_state=state, |
| | frame_idx=0, |
| | obj_id=1, |
| | points=abs_points, |
| | labels=point_labels, |
| | ) |
| | elif prompt_box: |
| | x1, y1, x2, y2 = prompt_box |
| | abs_box = np.array( |
| | [ |
| | int(x1 * w), |
| | int(y1 * h), |
| | int(x2 * w), |
| | int(y2 * h), |
| | ], |
| | dtype=np.int32, |
| | ) |
| | self.predictor.add_new_points_or_box( |
| | inference_state=state, |
| | frame_idx=0, |
| | obj_id=1, |
| | box=abs_box, |
| | ) |
| | else: |
| | init = np.zeros((h, w), dtype=np.uint8) |
| | if target_name == "LV": |
| | cx, cy = int(w * 0.4), int(h * 0.5) |
| | cv2.ellipse(init, (cx, cy), (w // 8, h // 6), 0, 0, 360, 255, -1) |
| | elif target_name == "RV": |
| | cx, cy = int(w * 0.6), int(h * 0.5) |
| | cv2.ellipse(init, (cx, cy), (w // 10, h // 7), 0, 0, 360, 255, -1) |
| | elif target_name == "LA": |
| | cx, cy = int(w * 0.4), int(h * 0.3) |
| | cv2.ellipse(init, (cx, cy), (w // 12, h // 8), 0, 0, 360, 255, -1) |
| | elif target_name == "RA": |
| | cx, cy = int(w * 0.6), int(h * 0.3) |
| | cv2.ellipse(init, (cx, cy), (w // 12, h // 8), 0, 0, 360, 255, -1) |
| | else: |
| | cx, cy = w // 2, h // 2 |
| | cv2.circle(init, (cx, cy), min(w, h) // 8, 255, -1) |
| | init_mask = init.astype(bool) |
| | self.predictor.add_new_mask( |
| | inference_state=state, |
| | frame_idx=0, |
| | obj_id=1, |
| | mask=init_mask, |
| | ) |
| | masks = [] |
| | total_frames = len(frames) |
| | processed = 0 |
| | for frame_idx, obj_ids, mask_logits in self.predictor.propagate_in_video(state): |
| | processed += 1 |
| | if progress_callback: |
| | progress_callback(30 + int((processed / total_frames) * 60), f"Processing frame {processed}/{total_frames}") |
| | if len(mask_logits) > 0: |
| | mask = (mask_logits[0] > 0.0).cpu().numpy() |
| | if mask.ndim == 3 and mask.shape[0] == 1: |
| | mask = mask[0] |
| | if mask.shape != (h, w): |
| | mask = cv2.resize(mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool) |
| | masks.append(mask.astype(np.uint8) * 255) |
| | else: |
| | masks.append(np.zeros((h, w), dtype=np.uint8)) |
| | while len(masks) < total_frames: |
| | masks.append(np.zeros((h, w), dtype=np.uint8)) |
| | return masks |
| | except Exception as e: |
| | raise RuntimeError(f"MedSAM2 segmentation failed: {e}") |
| |
|