Echo / tools /echo /echo_segmentation_corrected.py
moein99's picture
Initial Echo Space
8f51ef2
from __future__ import annotations
from typing import Any, Dict, List, Literal, Optional, Tuple, Type
from pathlib import Path
import tempfile
import uuid
import json
import numpy as np
from huggingface_hub import hf_hub_download
try:
import cv2 # type: ignore
except Exception as e: # pragma: no cover
cv2 = None # lazy import error handled in _ensure_dependencies
import torch
from pydantic import BaseModel, Field, validator
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.tools import BaseTool
class EchoSegmentationInput(BaseModel):
"""Input schema for the Echo (ultrasound) segmentation tool.
Supports MP4/AVI/GIF and single image (PNG/JPG). For DICOM cine, please
convert to a standard video first or extend this tool to read DICOM directly.
"""
video_path: str = Field(
..., description="Path to echo video (mp4/avi/gif) or single image (png/jpg)"
)
prompt_mode: Literal["auto", "points", "box", "mask"] = Field(
"auto", description="Segmentation prompt mode: auto, points, box, or mask"
)
# Normalized coordinates in [0,1], labels: 1=foreground, 0=background
points: Optional[List[Tuple[float, float, int]]] = Field(
None, description="List of (x,y,label) in normalized coords for the first frame"
)
# Normalized box [x1,y1,x2,y2] in [0,1]
box: Optional[Tuple[float, float, float, float]] = Field(
None, description="Normalized box (x1,y1,x2,y2) for the first frame"
)
mask_path: Optional[str] = Field(
None, description="Path to an initial segmentation mask for the first frame (for 'mask' mode)"
)
mask_label: Optional[int] = Field(
None,
description="Palette label to extract from the provided mask when using dataset annotations",
)
mask_frame_index: Optional[int] = Field(
0,
ge=0,
description="Frame index to pick when mask_path points to an annotation directory",
)
mask_label_map: Optional[Dict[int, int]] = Field(
None,
description="Mapping from palette pixel values to object IDs (e.g. {1:1,2:2,3:3,4:4})",
)
mask_palette: Optional[Dict[int, List[int]]] = Field(
None,
description="Mapping from object IDs to RGB colors for overlays (each value is [R,G,B])",
)
target_name: Optional[str] = Field(
"LV", description="Optional target label used in metadata/filenames"
)
sample_rate: int = Field(
1,
description="Process every Nth frame for speed (1 = every frame)",
ge=1,
)
output_fps: Optional[int] = Field(
None, description="FPS for output video. Defaults to source FPS"
)
save_mask_video: bool = Field(True, description="Save binary mask-only video")
save_overlay_video: bool = Field(True, description="Save overlay video")
@validator("points")
def _validate_points(cls, v):
if v is not None:
for p in v:
if len(p) != 3:
raise ValueError("Each point must be (x,y,label)")
return v
@validator("box")
def _validate_box(cls, v):
if v is not None and len(v) != 4:
raise ValueError("box must be (x1,y1,x2,y2)")
return v
@validator("mask_palette")
def _validate_palette(cls, v):
if v is not None:
for obj_id, color in v.items():
if len(color) != 3:
raise ValueError("mask_palette colors must be RGB triplets")
if any(c < 0 or c > 255 for c in color):
raise ValueError("mask_palette colors must be 0-255 integers")
return v
class EchoSegmentationTool(BaseTool):
"""Segments cardiac chambers in echocardiography videos using MedSAM2 (HF) with SAM2 video predictor.
- Downloads MedSAM2 checkpoint from Hugging Face by default (wanglab/MedSAM2) and builds a SAM2 video predictor.
- Supports auto or prompted segmentation (points/box on first frame) with propagation.
- Returns paths to generated videos (overlay and/or mask) and basic per-frame metrics.
Note: You must supply a valid SAM2 model config YAML via `model_cfg` (from the SAM2 repo). The tool will
auto-download the MedSAM2 checkpoint unless you provide a local `checkpoint` path. Pass the CONFIG NAME
(e.g., 'sam2.1_hiera_t.yaml'), not a filesystem path.
"""
name: str = "echo_segmentation"
description: str = (
"Segments echocardiography videos/images with MedSAM2 (SAM2-based). "
"Downloads MedSAM2 weights from Hugging Face if needed. "
"Input: video_path and optional prompt (points/box). "
"Output: paths to generated videos and per-frame metrics."
)
args_schema: Type[BaseModel] = EchoSegmentationInput
# Runtime
device: Optional[str] = "cuda"
temp_dir: Path = Path("temp")
# Model config
model_cfg: Optional[str] = None
checkpoint: Optional[str] = None
cache_dir: Optional[str] = None
# Hugging Face model info (used if checkpoint not provided)
model_id: Optional[str] = "wanglab/MedSAM2"
model_filename: Optional[str] = "MedSAM2_US_Heart.pt"
# Internal predictor (SAM2/MedSAM2 video predictor)
_predictor: Any = None
def __init__(
self,
device: Optional[str] = "cuda",
temp_dir: Optional[str] = "temp",
model_cfg: Optional[str] = None,
checkpoint: Optional[str] = None,
cache_dir: Optional[str] = None,
model_id: Optional[str] = "wanglab/MedSAM2",
model_filename: Optional[str] = "MedSAM2_US_Heart.pt",
):
super().__init__()
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.temp_dir = Path(temp_dir or tempfile.mkdtemp())
self.temp_dir.mkdir(exist_ok=True, parents=True)
self.model_cfg = model_cfg
self.checkpoint = checkpoint
self.cache_dir = cache_dir
self.model_id = model_id
self.model_filename = model_filename
# Lazy-load predictor on first run to avoid heavy startup if unused
self._predictor = None
# ------------- SAM2/MedSAM2 predictor helpers -------------
def _ensure_dependencies(self):
if cv2 is None:
raise RuntimeError(
"OpenCV (cv2) is required. Install with: pip install opencv-python"
)
# Torch is imported already; SAM2/MedSAM2 imports happen in _load_predictor
def _resolve_default_model_cfg(self) -> Optional[str]:
"""Resolve a default SAM2 YAML CONFIG NAME if none provided.
We rely on the configs packaged inside the installed `sam2` module.
Returns a config NAME like 'sam2.1_hiera_t' if found, else None.
"""
if self.model_cfg:
return self.model_cfg
try:
import importlib.resources as pkg_resources
import sam2 # type: ignore
candidates = [
"sam2.1_hiera_t512.yaml",
"sam2.1_hiera_t.yaml",
"sam2_hiera_s.yaml",
]
for name in candidates:
try:
cfg_path = pkg_resources.files(sam2) / "configs" / name
if cfg_path and cfg_path.is_file():
# Return the NAME without .yaml extension for Hydra
return name[:-5] if name.endswith('.yaml') else name
except Exception:
continue
except Exception:
pass
# If not found, return None and let caller raise a clear error.
return None
def _normalize_model_cfg_name(self, cfg: str) -> str:
"""Normalize user-provided model_cfg to a config NAME for Hydra.
- If a filesystem path is provided, reduce to basename.
- Fix common typos: 'sam2.1.hiera' -> 'sam2.1_hiera'.
- Remove .yaml extension as Hydra expects just the config name.
"""
try:
p = Path(cfg)
if p.exists():
cfg = p.name
except Exception:
pass
if "sam2.1.hiera" in cfg:
cfg = cfg.replace("sam2.1.hiera", "sam2.1_hiera")
# Remove .yaml extension - Hydra expects just the config name
if cfg.endswith('.yaml'):
cfg = cfg[:-5]
return cfg
def _load_predictor(self):
"""Load the SAM2 video predictor with MedSAM2 weights.
If `checkpoint` is not provided, attempt to download from Hugging Face using
`model_id` and `model_filename` (defaults target the ultrasound heart model).
A valid SAM2 YAML config NAME is required; if not provided, we try to resolve a default.
"""
if self._predictor is not None:
return
# Ensure checkpoint (local or download)
if not self.checkpoint:
if not self.model_id or not self.model_filename:
raise RuntimeError(
"Either provide `checkpoint` or set (`model_id`, `model_filename`) to download MedSAM2."
)
try:
ckpt_path = hf_hub_download(
repo_id=self.model_id,
filename=self.model_filename,
local_dir=self.cache_dir or str(self.temp_dir / "hf_cache"),
local_dir_use_symlinks=False,
)
self.checkpoint = ckpt_path
except Exception as e:
raise RuntimeError(
f"Failed to download MedSAM2 checkpoint from Hugging Face ({self.model_id}/{self.model_filename}): {e}"
)
# Ensure a model config NAME
if not self.model_cfg:
self.model_cfg = self._resolve_default_model_cfg()
if not self.model_cfg:
raise RuntimeError(
"Could not resolve a SAM2 config automatically. Install `sam2` and pass a config NAME, e.g., --model-cfg sam2.1_hiera_t.yaml"
)
cfg_name = self._normalize_model_cfg_name(self.model_cfg)
try:
# Build SAM2 video predictor with MedSAM2 weights
from sam2.build_sam import build_sam2_video_predictor # type: ignore
from hydra.core.global_hydra import GlobalHydra
from hydra import initialize_config_dir
import os
# Clear any existing Hydra configuration to avoid conflicts
GlobalHydra.instance().clear()
# Get SAM2 configs directory path
import sam2
sam2_configs_dir = os.path.join(os.path.dirname(sam2.__file__), "configs")
# Initialize Hydra with SAM2 configs directory
with initialize_config_dir(config_dir=sam2_configs_dir, version_base=None):
predictor = build_sam2_video_predictor(
cfg_name, self.checkpoint, device=self.device
)
except Exception as e:
raise RuntimeError(
f"Failed to build predictor with MedSAM2 weights. Config: '{cfg_name}', "
f"Checkpoint: '{self.checkpoint}'. Error: {e}"
)
self._predictor = predictor
# ------------- Video IO helpers -------------
def _read_video(self, path: str) -> Tuple[List[np.ndarray], float]:
"""Read video into list of RGB frames and return frames + fps.
If it's an image, return single frame and default fps=25.
"""
p = Path(path)
if not p.exists():
raise FileNotFoundError(f"Video/image not found: {path}")
if p.suffix.lower() in {".png", ".jpg", ".jpeg", ".bmp"}:
img = cv2.imread(str(p), cv2.IMREAD_COLOR)
if img is None:
raise RuntimeError("Failed to read image.")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return [img], 25.0
cap = cv2.VideoCapture(str(p))
if not cap.isOpened():
raise RuntimeError("Failed to open video.")
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
frames: List[np.ndarray] = []
while True:
ret, frame = cap.read()
if not ret:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame)
cap.release()
if not frames:
raise RuntimeError("No frames read from video.")
return frames, float(fps)
def _write_video(self, frames: List[np.ndarray], fps: float, out_path: Path):
out_path.parent.mkdir(exist_ok=True, parents=True)
h, w = frames[0].shape[:2]
fourcc = cv2.VideoWriter_fourcc(*"H264")
writer = cv2.VideoWriter(str(out_path), fourcc, fps, (w, h))
for fr in frames:
bgr = cv2.cvtColor(fr, cv2.COLOR_RGB2BGR)
writer.write(bgr)
writer.release()
# ------------- Segmentation core -------------
def _normalized_to_abs_points(self, points: List[Tuple[float, float, int]], w: int, h: int):
coords = np.array([[int(x * w), int(y * h)] for x, y, _ in points], dtype=np.int32)
labels = np.array([int(lbl) for _, _, lbl in points], dtype=np.int32)
return coords, labels
def _normalized_to_abs_box(self, box: Tuple[float, float, float, float], w: int, h: int):
x1, y1, x2, y2 = box
return np.array([int(x1 * w), int(y1 * h), int(x2 * w), int(y2 * h)], dtype=np.int32)
def _compose_color_layer(
self,
masks: Dict[int, np.ndarray],
palette: Dict[int, Tuple[int, int, int]],
fallback: Tuple[Tuple[int, int, int], ...],
) -> np.ndarray:
"""Create an RGB layer where each object mask is painted with its palette color."""
# Determine output spatial size from first mask
sample_mask = next(iter(masks.values()))
height, width = sample_mask.shape[:2]
color_layer = np.zeros((height, width, 3), dtype=np.uint8)
for obj_id, mask in masks.items():
if mask.shape != (height, width):
mask = cv2.resize(mask.astype(np.uint8), (width, height), interpolation=cv2.INTER_NEAREST)
mask_bool = mask.astype(bool)
color = palette.get(obj_id)
if color is None:
color = fallback[obj_id % len(fallback)]
color_layer[mask_bool] = color
return color_layer
def _render_overlay(
self,
frame: np.ndarray,
masks: Dict[int, np.ndarray],
palette: Dict[int, Tuple[int, int, int]],
fallback: Tuple[Tuple[int, int, int], ...],
alpha: float = 0.5,
) -> np.ndarray:
"""Alpha blend colorized masks onto the frame."""
color_layer = self._compose_color_layer(masks, palette, fallback)
overlay = cv2.addWeighted(frame, 1 - alpha, color_layer, alpha, 0)
return overlay
def _load_mask_prompt(
self,
mask_path: str,
frame_shape: Tuple[int, int],
mask_label: Optional[int] = None,
mask_frame_index: Optional[int] = 0,
mask_label_map: Optional[Dict[int, int]] = None,
) -> Dict[int, np.ndarray]:
"""Load prompt masks (object_id -> binary mask) from annotation."""
if mask_path is None:
raise ValueError("mask_path must be provided for mask prompts")
candidate = Path(mask_path)
if candidate.is_dir():
if mask_frame_index is None:
mask_frame_index = 0
frame_name = f"{int(mask_frame_index):04d}.png"
candidate = candidate / frame_name
if not candidate.exists():
raise FileNotFoundError(f"Mask prompt not found at {candidate}")
mask = cv2.imread(str(candidate), cv2.IMREAD_GRAYSCALE)
if mask is None:
raise RuntimeError(f"Failed to read mask prompt: {candidate}")
height, width = frame_shape
if mask.shape != frame_shape:
mask = cv2.resize(mask, (width, height), interpolation=cv2.INTER_NEAREST)
prompt_masks: Dict[int, np.ndarray] = {}
if mask_label_map:
for pixel_value, obj_id in mask_label_map.items():
prompt_masks[int(obj_id)] = (mask == pixel_value).astype(np.uint8)
elif mask_label is not None:
prompt_masks[1] = (mask == mask_label).astype(np.uint8)
else:
prompt_masks[1] = (mask > 0).astype(np.uint8)
if not prompt_masks:
raise RuntimeError("No foreground objects extracted from mask prompt")
return prompt_masks
def _run(
self,
video_path: str,
prompt_mode: Literal["auto", "points", "box", "mask"] = "auto",
points: Optional[List[Tuple[float, float, int]]] = None,
box: Optional[Tuple[float, float, float, float]] = None,
mask_path: Optional[str] = None,
mask_label: Optional[int] = None,
mask_frame_index: Optional[int] = 0,
mask_label_map: Optional[Dict[int, int]] = None,
mask_palette: Optional[Dict[int, List[int]]] = None,
target_name: Optional[str] = "LV",
sample_rate: int = 1,
output_fps: Optional[int] = None,
save_mask_video: bool = True,
save_overlay_video: bool = True,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Tuple[Dict[str, Any], Dict]:
"""Run MedSAM2/SAM2 video segmentation on an echo video or image.
Returns (output, metadata), where output contains file paths;
metadata contains additional info and basic per-frame metrics.
"""
self._ensure_dependencies()
# Load predictor lazily
self._load_predictor()
predictor = self._predictor
# Get video info for output formatting
frames, src_fps = self._read_video(video_path)
fps = float(output_fps) if output_fps else src_fps
h, w = frames[0].shape[:2]
default_palette = {
1: (0, 255, 0), # LV - green
2: (255, 0, 0), # RV - red
3: (255, 255, 0), # LA - yellow
4: (0, 0, 255), # RA - blue
5: (255, 0, 255), # myocardium/other
}
palette_rgb: Dict[int, Tuple[int, int, int]] = dict(default_palette)
if mask_palette:
palette_rgb.update(
{
int(obj_id): tuple(int(c) for c in color)
for obj_id, color in mask_palette.items()
}
)
fallback_colors: Tuple[Tuple[int, int, int], ...] = (
(0, 255, 0),
(255, 0, 0),
(0, 0, 255),
(255, 255, 0),
(255, 0, 255),
(0, 255, 255),
)
active_object_ids: List[int] = []
# Initialize video state (SAM2 expects video path directly)
try:
# SAM2 wants the video file path, not processed frames
state = predictor.init_state(video_path)
except Exception as e:
raise RuntimeError(
f"Failed to initialize SAM2 state with video: {video_path}. "
f"SAM2 may only support MP4 videos and JPEG folders. Error: {e}"
)
# Feed prompt to predictor on first frame
try:
if state is None:
raise RuntimeError("SAM2 state initialization failed")
if prompt_mode == "mask" and mask_path:
prompt_masks = self._load_mask_prompt(
mask_path,
(h, w),
mask_label=mask_label,
mask_frame_index=mask_frame_index,
mask_label_map=mask_label_map,
)
for obj_id, obj_mask in prompt_masks.items():
predictor.add_new_mask(
state,
frame_idx=0,
obj_id=int(obj_id),
mask=obj_mask.astype(bool),
)
active_object_ids.append(int(obj_id))
elif prompt_mode == "points" and points:
abs_points, point_labels = self._normalized_to_abs_points(points, w, h)
predictor.add_new_points(
state,
frame_idx=0,
obj_id=1,
points=abs_points,
labels=point_labels,
)
active_object_ids.append(1)
elif prompt_mode == "box" and box:
abs_box = self._normalized_to_abs_box(box, w, h)
predictor.add_new_points_or_box(
state,
frame_idx=0,
obj_id=1,
box=abs_box,
)
active_object_ids.append(1)
else:
# Default: use center point as prompt
center_x, center_y = w // 2, h // 2
center_points = np.array([[center_x, center_y]])
center_labels = np.array([1])
predictor.add_new_points(
state,
frame_idx=0,
obj_id=1,
points=center_points,
labels=center_labels,
)
active_object_ids.append(1)
except Exception as e:
raise RuntimeError(
f"Prompting API mismatch. Please adapt the add_new_points calls "
f"to your installed SAM2/MedSAM2 version. Error: {e}"
)
# Propagate segmentation across frames
mask_frames: List[Dict[int, np.ndarray]] = []
overlay_frames: List[np.ndarray] = []
per_frame_metrics: List[Dict[str, Any]] = []
try:
for out in predictor.propagate_in_video(state):
if not (isinstance(out, tuple) and len(out) == 3):
continue
frame_idx, obj_ids, mask_logits = out
if len(mask_logits) == 0:
continue
frame_masks: Dict[int, np.ndarray] = {}
for idx, obj_id in enumerate(obj_ids):
logits = mask_logits[idx]
mask = (torch.sigmoid(logits).cpu().numpy() > 0.5).astype(np.uint8)
frame_masks[int(obj_id)] = mask
if not frame_masks:
continue
mask_frames.append(frame_masks)
if save_overlay_video and frame_idx < len(frames):
overlay = self._render_overlay(
frames[frame_idx], frame_masks, palette_rgb, fallback_colors
)
overlay_frames.append(overlay)
per_frame_metrics.append(
{
"frame_index": int(frame_idx),
"object_areas": {
int(obj_id): int(mask.sum()) for obj_id, mask in frame_masks.items()
},
}
)
except Exception as e:
raise RuntimeError(f"Error during propagation: {e}")
# Write outputs
out_base = f"echo_seg_{target_name}_{uuid.uuid4().hex[:8]}"
outputs: Dict[str, Any] = {}
if save_overlay_video and overlay_frames:
overlay_path = self.temp_dir / f"{out_base}_overlay.mp4"
self._write_video(overlay_frames, fps, overlay_path)
outputs["overlay_video_path"] = str(overlay_path)
if save_mask_video and mask_frames:
mask_rgb_frames: List[np.ndarray] = []
for frame_masks in mask_frames:
color_layer = self._compose_color_layer(frame_masks, palette_rgb, fallback_colors)
mask_rgb_frames.append(color_layer)
mask_path = self.temp_dir / f"{out_base}_mask.mp4"
self._write_video(mask_rgb_frames, fps, mask_path)
outputs["mask_video_path"] = str(mask_path)
metadata: Dict[str, Any] = {
"video_path": video_path,
"frames_processed": len(mask_frames),
"source_frames": len(frames),
"sample_rate": sample_rate,
"fps_out": fps,
"resolution": [h, w],
"target_name": target_name,
"active_object_ids": sorted(set(active_object_ids) or {1}),
"per_frame_metrics": per_frame_metrics,
"analysis_status": "completed",
}
return outputs, metadata
async def _arun(
self,
video_path: str,
prompt_mode: Literal["auto", "points", "box", "mask"] = "auto",
points: Optional[List[Tuple[float, float, int]]] = None,
box: Optional[Tuple[float, float, float, float]] = None,
mask_path: Optional[str] = None,
mask_label: Optional[int] = None,
mask_frame_index: Optional[int] = 0,
mask_label_map: Optional[Dict[int, int]] = None,
mask_palette: Optional[Dict[int, List[int]]] = None,
target_name: Optional[str] = "LV",
sample_rate: int = 1,
output_fps: Optional[int] = None,
save_mask_video: bool = True,
save_overlay_video: bool = True,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> Tuple[Dict[str, Any], Dict]:
return self._run(
video_path,
prompt_mode,
points,
box,
mask_path,
mask_label,
mask_frame_index,
mask_label_map,
mask_palette,
target_name,
sample_rate,
output_fps,
save_mask_video,
save_overlay_video,
)