Echo / tools /echo /medsam2_integration.py
moein99's picture
Initial Echo Space
8f51ef2
# -*- coding: utf-8 -*-
"""
MedSAM2 integration module (consolidated under tools.echo).
Provides MedSAM2VideoSegmenter used by echo_tool_managers.
"""
import os
import sys
import torch
import numpy as np
import cv2
import tempfile
from pathlib import Path
from typing import Dict, Optional, Sequence, Tuple
_current_dir = os.path.dirname(os.path.abspath(__file__))
class MedSAM2VideoSegmenter:
"""Clean MedSAM2 video segmentation class."""
def __init__(self, model_path: str = "checkpoints/MedSAM2_US_Heart.pt"):
self.model_path = self._resolve_model_path(model_path)
self.predictor = None
self._initialize_predictor()
def _resolve_sam2_paths(self):
candidates = []
local_tool_repos = os.path.abspath(os.path.join(_current_dir, "..", "..", "tool_repos"))
if os.path.isdir(local_tool_repos):
for repo_name in ("MedSAM2-main", "MedSAM2"):
repo_path = os.path.join(local_tool_repos, repo_name)
candidates.append(repo_path)
workspace_root = os.getenv("ECHO_WORKSPACE_ROOT")
if workspace_root:
candidates.append(os.path.join(workspace_root, "MedSAM2-main"))
for base in candidates:
sam2_root = os.path.join(base, "sam2")
configs_dir = os.path.join(sam2_root, "configs")
if os.path.isdir(configs_dir):
if base not in sys.path:
sys.path.insert(0, base)
return {"root": sam2_root, "configs": configs_dir}
raise FileNotFoundError("Could not locate sam2/configs directory. Ensure tool_repos/MedSAM2-main is available.")
def _resolve_model_path(self, provided_path: str) -> str:
"""Resolve model checkpoint absolute path from common locations."""
if provided_path and os.path.isabs(provided_path) and os.path.exists(provided_path):
return provided_path
candidates = []
# Provided relative
if provided_path:
candidates.append(os.path.abspath(os.path.join(_current_dir, provided_path)))
new_agent_root = os.path.abspath(os.path.join(_current_dir, "..", "..", ".."))
candidates.append(os.path.abspath(os.path.join(new_agent_root, provided_path)))
# Known defaults
new_agent_root = os.path.abspath(os.path.join(_current_dir, "..", "..", ".."))
candidates.append(os.path.join(new_agent_root, "model_weights", "MedSAM2_US_Heart.pt"))
candidates.append(os.path.join(new_agent_root, "checkpoints", "MedSAM2_US_Heart.pt"))
workspace_root = os.getenv("ECHO_WORKSPACE_ROOT")
if workspace_root:
candidates.append(os.path.join(workspace_root, "new_agent", "model_weights", "MedSAM2_US_Heart.pt"))
candidates.append(os.path.join(workspace_root, "new_agent", "checkpoints", "MedSAM2_US_Heart.pt"))
for c in candidates:
if os.path.exists(c):
return c
raise FileNotFoundError(f"Model file not found. Tried: {', '.join(candidates)}")
def _initialize_predictor(self) -> None:
try:
paths = self._resolve_sam2_paths()
configs_dir = paths["configs"]
base_dir = os.path.dirname(paths["root"]) # parent of sam2
if not os.path.exists(self.model_path):
raise FileNotFoundError(f"Model file not found: {self.model_path}")
from sam2.build_sam import build_sam2_video_predictor
config_file = "sam2.1_hiera_t512.yaml"
if not os.path.exists(os.path.join(configs_dir, config_file)):
raise FileNotFoundError(f"Missing config: {os.path.join(configs_dir, config_file)}")
# Use the original build_sam2_video_predictor function but with proper path setup
prev_cwd = os.getcwd()
try:
os.chdir(base_dir)
from hydra.core.global_hydra import GlobalHydra
from hydra import initialize
# Force clear any existing Hydra instance
try:
GlobalHydra.instance().clear()
except:
pass
# Initialize Hydra with the correct config path
rel_config_path = os.path.relpath(configs_dir, base_dir)
with initialize(config_path=rel_config_path, version_base=None):
# Use the original build_sam2_video_predictor function
self.predictor = build_sam2_video_predictor(
config_file=config_file,
ckpt_path=self.model_path,
device="cuda" if torch.cuda.is_available() else "cpu",
)
finally:
os.chdir(prev_cwd)
except Exception as e:
raise RuntimeError(f"MedSAM2 initialization failed: {e}")
def _load_prompt_masks(
self,
mask_path: str,
frame_shape: Tuple[int, int],
label_value: Optional[int] = None,
label_map: Optional[Dict[int, int]] = None,
frame_index: int = 0,
) -> Dict[int, np.ndarray]:
"""Load prompt masks from annotation file or directory.
Returns mapping of object_id -> boolean mask aligned to the requested frame.
"""
if not mask_path:
raise ValueError("mask_path must be provided when using mask prompts")
source = Path(mask_path)
if source.is_dir():
# Follow MedSAM2 convention: frame files are zero-padded PNGs
candidate = source / f"{frame_index:04d}.png"
else:
candidate = source
if not candidate.exists():
raise FileNotFoundError(f"Prompt mask not found: {candidate}")
mask = cv2.imread(str(candidate), cv2.IMREAD_GRAYSCALE)
if mask is None:
raise RuntimeError(f"Failed to read prompt mask: {candidate}")
target_height, target_width = frame_shape
if mask.shape != frame_shape:
mask = cv2.resize(
mask,
(target_width, target_height),
interpolation=cv2.INTER_NEAREST,
)
prompts: Dict[int, np.ndarray] = {}
if label_map:
for pixel_value, obj_id in label_map.items():
prompts[int(obj_id)] = (mask == pixel_value)
elif label_value is not None:
prompts[1] = (mask == label_value)
else:
prompts[1] = mask > 0
for obj_id, obj_mask in prompts.items():
prompts[obj_id] = obj_mask.astype(np.uint8).astype(bool)
if not prompts:
raise RuntimeError("No prompt objects extracted from mask")
return prompts
def segment_video(
self,
frames,
target_name: str = "LV",
*,
prompt_mask_path: Optional[str] = None,
prompt_mask_label: Optional[int] = None,
prompt_label_map: Optional[Dict[int, int]] = None,
prompt_points: Optional[Sequence[Tuple[float, float]]] = None,
prompt_box: Optional[Tuple[float, float, float, float]] = None,
palette: Optional[Dict[int, Tuple[int, int, int]]] = None,
progress_callback=None,
):
try:
with tempfile.TemporaryDirectory() as temp_dir:
for i, frame in enumerate(frames):
cv2.imwrite(os.path.join(temp_dir, f"{i:07d}.jpg"), frame)
state = self.predictor.init_state(video_path=temp_dir)
first_frame = frames[0]
h, w = first_frame.shape[:2]
if prompt_mask_path:
prompt_masks = self._load_prompt_masks(
prompt_mask_path,
(h, w),
label_value=prompt_mask_label,
label_map=prompt_label_map,
frame_index=0,
)
for obj_id, init_mask in prompt_masks.items():
self.predictor.add_new_mask(
inference_state=state,
frame_idx=0,
obj_id=int(obj_id),
mask=init_mask,
)
elif prompt_points:
abs_points = np.array(
[[int(px * w), int(py * h)] for px, py in prompt_points],
dtype=np.int32,
)
point_labels = np.ones(len(abs_points), dtype=np.int32)
self.predictor.add_new_points(
inference_state=state,
frame_idx=0,
obj_id=1,
points=abs_points,
labels=point_labels,
)
elif prompt_box:
x1, y1, x2, y2 = prompt_box
abs_box = np.array(
[
int(x1 * w),
int(y1 * h),
int(x2 * w),
int(y2 * h),
],
dtype=np.int32,
)
self.predictor.add_new_points_or_box(
inference_state=state,
frame_idx=0,
obj_id=1,
box=abs_box,
)
else:
init = np.zeros((h, w), dtype=np.uint8)
if target_name == "LV":
cx, cy = int(w * 0.4), int(h * 0.5)
cv2.ellipse(init, (cx, cy), (w // 8, h // 6), 0, 0, 360, 255, -1)
elif target_name == "RV":
cx, cy = int(w * 0.6), int(h * 0.5)
cv2.ellipse(init, (cx, cy), (w // 10, h // 7), 0, 0, 360, 255, -1)
elif target_name == "LA":
cx, cy = int(w * 0.4), int(h * 0.3)
cv2.ellipse(init, (cx, cy), (w // 12, h // 8), 0, 0, 360, 255, -1)
elif target_name == "RA":
cx, cy = int(w * 0.6), int(h * 0.3)
cv2.ellipse(init, (cx, cy), (w // 12, h // 8), 0, 0, 360, 255, -1)
else:
cx, cy = w // 2, h // 2
cv2.circle(init, (cx, cy), min(w, h) // 8, 255, -1)
init_mask = init.astype(bool)
self.predictor.add_new_mask(
inference_state=state,
frame_idx=0,
obj_id=1,
mask=init_mask,
)
masks = []
total_frames = len(frames)
processed = 0
for frame_idx, obj_ids, mask_logits in self.predictor.propagate_in_video(state):
processed += 1
if progress_callback:
progress_callback(30 + int((processed / total_frames) * 60), f"Processing frame {processed}/{total_frames}")
if len(mask_logits) > 0:
mask = (mask_logits[0] > 0.0).cpu().numpy()
if mask.ndim == 3 and mask.shape[0] == 1:
mask = mask[0]
if mask.shape != (h, w):
mask = cv2.resize(mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool)
masks.append(mask.astype(np.uint8) * 255)
else:
masks.append(np.zeros((h, w), dtype=np.uint8))
while len(masks) < total_frames:
masks.append(np.zeros((h, w), dtype=np.uint8))
return masks
except Exception as e:
raise RuntimeError(f"MedSAM2 segmentation failed: {e}")