|
|
from __future__ import annotations |
|
|
|
|
|
from typing import Any, Dict, List, Literal, Optional, Tuple, Type |
|
|
from pathlib import Path |
|
|
import tempfile |
|
|
import uuid |
|
|
import json |
|
|
|
|
|
import numpy as np |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
try: |
|
|
import cv2 |
|
|
except Exception as e: |
|
|
cv2 = None |
|
|
|
|
|
import torch |
|
|
from pydantic import BaseModel, Field, validator |
|
|
|
|
|
from langchain_core.callbacks import ( |
|
|
AsyncCallbackManagerForToolRun, |
|
|
CallbackManagerForToolRun, |
|
|
) |
|
|
from langchain_core.tools import BaseTool |
|
|
|
|
|
|
|
|
class EchoSegmentationInput(BaseModel): |
|
|
"""Input schema for the Echo (ultrasound) segmentation tool. |
|
|
|
|
|
Supports MP4/AVI/GIF and single image (PNG/JPG). For DICOM cine, please |
|
|
convert to a standard video first or extend this tool to read DICOM directly. |
|
|
""" |
|
|
|
|
|
video_path: str = Field( |
|
|
..., description="Path to echo video (mp4/avi/gif) or single image (png/jpg)" |
|
|
) |
|
|
prompt_mode: Literal["auto", "points", "box", "mask"] = Field( |
|
|
"auto", description="Segmentation prompt mode: auto, points, box, or mask" |
|
|
) |
|
|
|
|
|
points: Optional[List[Tuple[float, float, int]]] = Field( |
|
|
None, description="List of (x,y,label) in normalized coords for the first frame" |
|
|
) |
|
|
|
|
|
box: Optional[Tuple[float, float, float, float]] = Field( |
|
|
None, description="Normalized box (x1,y1,x2,y2) for the first frame" |
|
|
) |
|
|
mask_path: Optional[str] = Field( |
|
|
None, description="Path to an initial segmentation mask for the first frame (for 'mask' mode)" |
|
|
) |
|
|
mask_label: Optional[int] = Field( |
|
|
None, |
|
|
description="Palette label to extract from the provided mask when using dataset annotations", |
|
|
) |
|
|
mask_frame_index: Optional[int] = Field( |
|
|
0, |
|
|
ge=0, |
|
|
description="Frame index to pick when mask_path points to an annotation directory", |
|
|
) |
|
|
mask_label_map: Optional[Dict[int, int]] = Field( |
|
|
None, |
|
|
description="Mapping from palette pixel values to object IDs (e.g. {1:1,2:2,3:3,4:4})", |
|
|
) |
|
|
mask_palette: Optional[Dict[int, List[int]]] = Field( |
|
|
None, |
|
|
description="Mapping from object IDs to RGB colors for overlays (each value is [R,G,B])", |
|
|
) |
|
|
target_name: Optional[str] = Field( |
|
|
"LV", description="Optional target label used in metadata/filenames" |
|
|
) |
|
|
sample_rate: int = Field( |
|
|
1, |
|
|
description="Process every Nth frame for speed (1 = every frame)", |
|
|
ge=1, |
|
|
) |
|
|
output_fps: Optional[int] = Field( |
|
|
None, description="FPS for output video. Defaults to source FPS" |
|
|
) |
|
|
save_mask_video: bool = Field(True, description="Save binary mask-only video") |
|
|
save_overlay_video: bool = Field(True, description="Save overlay video") |
|
|
|
|
|
@validator("points") |
|
|
def _validate_points(cls, v): |
|
|
if v is not None: |
|
|
for p in v: |
|
|
if len(p) != 3: |
|
|
raise ValueError("Each point must be (x,y,label)") |
|
|
return v |
|
|
|
|
|
@validator("box") |
|
|
def _validate_box(cls, v): |
|
|
if v is not None and len(v) != 4: |
|
|
raise ValueError("box must be (x1,y1,x2,y2)") |
|
|
return v |
|
|
|
|
|
@validator("mask_palette") |
|
|
def _validate_palette(cls, v): |
|
|
if v is not None: |
|
|
for obj_id, color in v.items(): |
|
|
if len(color) != 3: |
|
|
raise ValueError("mask_palette colors must be RGB triplets") |
|
|
if any(c < 0 or c > 255 for c in color): |
|
|
raise ValueError("mask_palette colors must be 0-255 integers") |
|
|
return v |
|
|
|
|
|
|
|
|
class EchoSegmentationTool(BaseTool): |
|
|
"""Segments cardiac chambers in echocardiography videos using MedSAM2 (HF) with SAM2 video predictor. |
|
|
|
|
|
- Downloads MedSAM2 checkpoint from Hugging Face by default (wanglab/MedSAM2) and builds a SAM2 video predictor. |
|
|
- Supports auto or prompted segmentation (points/box on first frame) with propagation. |
|
|
- Returns paths to generated videos (overlay and/or mask) and basic per-frame metrics. |
|
|
|
|
|
Note: You must supply a valid SAM2 model config YAML via `model_cfg` (from the SAM2 repo). The tool will |
|
|
auto-download the MedSAM2 checkpoint unless you provide a local `checkpoint` path. Pass the CONFIG NAME |
|
|
(e.g., 'sam2.1_hiera_t.yaml'), not a filesystem path. |
|
|
""" |
|
|
|
|
|
name: str = "echo_segmentation" |
|
|
description: str = ( |
|
|
"Segments echocardiography videos/images with MedSAM2 (SAM2-based). " |
|
|
"Downloads MedSAM2 weights from Hugging Face if needed. " |
|
|
"Input: video_path and optional prompt (points/box). " |
|
|
"Output: paths to generated videos and per-frame metrics." |
|
|
) |
|
|
args_schema: Type[BaseModel] = EchoSegmentationInput |
|
|
|
|
|
|
|
|
device: Optional[str] = "cuda" |
|
|
temp_dir: Path = Path("temp") |
|
|
|
|
|
|
|
|
model_cfg: Optional[str] = None |
|
|
checkpoint: Optional[str] = None |
|
|
cache_dir: Optional[str] = None |
|
|
|
|
|
model_id: Optional[str] = "wanglab/MedSAM2" |
|
|
model_filename: Optional[str] = "MedSAM2_US_Heart.pt" |
|
|
|
|
|
|
|
|
_predictor: Any = None |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
device: Optional[str] = "cuda", |
|
|
temp_dir: Optional[str] = "temp", |
|
|
model_cfg: Optional[str] = None, |
|
|
checkpoint: Optional[str] = None, |
|
|
cache_dir: Optional[str] = None, |
|
|
model_id: Optional[str] = "wanglab/MedSAM2", |
|
|
model_filename: Optional[str] = "MedSAM2_US_Heart.pt", |
|
|
): |
|
|
super().__init__() |
|
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.temp_dir = Path(temp_dir or tempfile.mkdtemp()) |
|
|
self.temp_dir.mkdir(exist_ok=True, parents=True) |
|
|
self.model_cfg = model_cfg |
|
|
self.checkpoint = checkpoint |
|
|
self.cache_dir = cache_dir |
|
|
self.model_id = model_id |
|
|
self.model_filename = model_filename |
|
|
|
|
|
|
|
|
self._predictor = None |
|
|
|
|
|
|
|
|
def _ensure_dependencies(self): |
|
|
if cv2 is None: |
|
|
raise RuntimeError( |
|
|
"OpenCV (cv2) is required. Install with: pip install opencv-python" |
|
|
) |
|
|
|
|
|
|
|
|
def _resolve_default_model_cfg(self) -> Optional[str]: |
|
|
"""Resolve a default SAM2 YAML CONFIG NAME if none provided. |
|
|
|
|
|
We rely on the configs packaged inside the installed `sam2` module. |
|
|
Returns a config NAME like 'sam2.1_hiera_t' if found, else None. |
|
|
""" |
|
|
if self.model_cfg: |
|
|
return self.model_cfg |
|
|
|
|
|
try: |
|
|
import importlib.resources as pkg_resources |
|
|
import sam2 |
|
|
|
|
|
candidates = [ |
|
|
"sam2.1_hiera_t512.yaml", |
|
|
"sam2.1_hiera_t.yaml", |
|
|
"sam2_hiera_s.yaml", |
|
|
] |
|
|
for name in candidates: |
|
|
try: |
|
|
cfg_path = pkg_resources.files(sam2) / "configs" / name |
|
|
if cfg_path and cfg_path.is_file(): |
|
|
|
|
|
return name[:-5] if name.endswith('.yaml') else name |
|
|
except Exception: |
|
|
continue |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
return None |
|
|
|
|
|
def _normalize_model_cfg_name(self, cfg: str) -> str: |
|
|
"""Normalize user-provided model_cfg to a config NAME for Hydra. |
|
|
|
|
|
- If a filesystem path is provided, reduce to basename. |
|
|
- Fix common typos: 'sam2.1.hiera' -> 'sam2.1_hiera'. |
|
|
- Remove .yaml extension as Hydra expects just the config name. |
|
|
""" |
|
|
try: |
|
|
p = Path(cfg) |
|
|
if p.exists(): |
|
|
cfg = p.name |
|
|
except Exception: |
|
|
pass |
|
|
if "sam2.1.hiera" in cfg: |
|
|
cfg = cfg.replace("sam2.1.hiera", "sam2.1_hiera") |
|
|
|
|
|
|
|
|
if cfg.endswith('.yaml'): |
|
|
cfg = cfg[:-5] |
|
|
|
|
|
return cfg |
|
|
|
|
|
def _load_predictor(self): |
|
|
"""Load the SAM2 video predictor with MedSAM2 weights. |
|
|
|
|
|
If `checkpoint` is not provided, attempt to download from Hugging Face using |
|
|
`model_id` and `model_filename` (defaults target the ultrasound heart model). |
|
|
A valid SAM2 YAML config NAME is required; if not provided, we try to resolve a default. |
|
|
""" |
|
|
if self._predictor is not None: |
|
|
return |
|
|
|
|
|
|
|
|
if not self.checkpoint: |
|
|
if not self.model_id or not self.model_filename: |
|
|
raise RuntimeError( |
|
|
"Either provide `checkpoint` or set (`model_id`, `model_filename`) to download MedSAM2." |
|
|
) |
|
|
try: |
|
|
ckpt_path = hf_hub_download( |
|
|
repo_id=self.model_id, |
|
|
filename=self.model_filename, |
|
|
local_dir=self.cache_dir or str(self.temp_dir / "hf_cache"), |
|
|
local_dir_use_symlinks=False, |
|
|
) |
|
|
self.checkpoint = ckpt_path |
|
|
except Exception as e: |
|
|
raise RuntimeError( |
|
|
f"Failed to download MedSAM2 checkpoint from Hugging Face ({self.model_id}/{self.model_filename}): {e}" |
|
|
) |
|
|
|
|
|
|
|
|
if not self.model_cfg: |
|
|
self.model_cfg = self._resolve_default_model_cfg() |
|
|
if not self.model_cfg: |
|
|
raise RuntimeError( |
|
|
"Could not resolve a SAM2 config automatically. Install `sam2` and pass a config NAME, e.g., --model-cfg sam2.1_hiera_t.yaml" |
|
|
) |
|
|
|
|
|
cfg_name = self._normalize_model_cfg_name(self.model_cfg) |
|
|
|
|
|
try: |
|
|
|
|
|
from sam2.build_sam import build_sam2_video_predictor |
|
|
from hydra.core.global_hydra import GlobalHydra |
|
|
from hydra import initialize_config_dir |
|
|
import os |
|
|
|
|
|
|
|
|
GlobalHydra.instance().clear() |
|
|
|
|
|
|
|
|
import sam2 |
|
|
sam2_configs_dir = os.path.join(os.path.dirname(sam2.__file__), "configs") |
|
|
|
|
|
|
|
|
with initialize_config_dir(config_dir=sam2_configs_dir, version_base=None): |
|
|
predictor = build_sam2_video_predictor( |
|
|
cfg_name, self.checkpoint, device=self.device |
|
|
) |
|
|
except Exception as e: |
|
|
raise RuntimeError( |
|
|
f"Failed to build predictor with MedSAM2 weights. Config: '{cfg_name}', " |
|
|
f"Checkpoint: '{self.checkpoint}'. Error: {e}" |
|
|
) |
|
|
|
|
|
self._predictor = predictor |
|
|
|
|
|
|
|
|
def _read_video(self, path: str) -> Tuple[List[np.ndarray], float]: |
|
|
"""Read video into list of RGB frames and return frames + fps. |
|
|
If it's an image, return single frame and default fps=25. |
|
|
""" |
|
|
p = Path(path) |
|
|
if not p.exists(): |
|
|
raise FileNotFoundError(f"Video/image not found: {path}") |
|
|
|
|
|
if p.suffix.lower() in {".png", ".jpg", ".jpeg", ".bmp"}: |
|
|
img = cv2.imread(str(p), cv2.IMREAD_COLOR) |
|
|
if img is None: |
|
|
raise RuntimeError("Failed to read image.") |
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
return [img], 25.0 |
|
|
|
|
|
cap = cv2.VideoCapture(str(p)) |
|
|
if not cap.isOpened(): |
|
|
raise RuntimeError("Failed to open video.") |
|
|
|
|
|
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0 |
|
|
frames: List[np.ndarray] = [] |
|
|
while True: |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
frames.append(frame) |
|
|
cap.release() |
|
|
if not frames: |
|
|
raise RuntimeError("No frames read from video.") |
|
|
return frames, float(fps) |
|
|
|
|
|
def _write_video(self, frames: List[np.ndarray], fps: float, out_path: Path): |
|
|
out_path.parent.mkdir(exist_ok=True, parents=True) |
|
|
h, w = frames[0].shape[:2] |
|
|
fourcc = cv2.VideoWriter_fourcc(*"H264") |
|
|
writer = cv2.VideoWriter(str(out_path), fourcc, fps, (w, h)) |
|
|
for fr in frames: |
|
|
bgr = cv2.cvtColor(fr, cv2.COLOR_RGB2BGR) |
|
|
writer.write(bgr) |
|
|
writer.release() |
|
|
|
|
|
|
|
|
def _normalized_to_abs_points(self, points: List[Tuple[float, float, int]], w: int, h: int): |
|
|
coords = np.array([[int(x * w), int(y * h)] for x, y, _ in points], dtype=np.int32) |
|
|
labels = np.array([int(lbl) for _, _, lbl in points], dtype=np.int32) |
|
|
return coords, labels |
|
|
|
|
|
def _normalized_to_abs_box(self, box: Tuple[float, float, float, float], w: int, h: int): |
|
|
x1, y1, x2, y2 = box |
|
|
return np.array([int(x1 * w), int(y1 * h), int(x2 * w), int(y2 * h)], dtype=np.int32) |
|
|
|
|
|
def _compose_color_layer( |
|
|
self, |
|
|
masks: Dict[int, np.ndarray], |
|
|
palette: Dict[int, Tuple[int, int, int]], |
|
|
fallback: Tuple[Tuple[int, int, int], ...], |
|
|
) -> np.ndarray: |
|
|
"""Create an RGB layer where each object mask is painted with its palette color.""" |
|
|
|
|
|
|
|
|
sample_mask = next(iter(masks.values())) |
|
|
height, width = sample_mask.shape[:2] |
|
|
color_layer = np.zeros((height, width, 3), dtype=np.uint8) |
|
|
|
|
|
for obj_id, mask in masks.items(): |
|
|
if mask.shape != (height, width): |
|
|
mask = cv2.resize(mask.astype(np.uint8), (width, height), interpolation=cv2.INTER_NEAREST) |
|
|
mask_bool = mask.astype(bool) |
|
|
color = palette.get(obj_id) |
|
|
if color is None: |
|
|
color = fallback[obj_id % len(fallback)] |
|
|
color_layer[mask_bool] = color |
|
|
|
|
|
return color_layer |
|
|
|
|
|
def _render_overlay( |
|
|
self, |
|
|
frame: np.ndarray, |
|
|
masks: Dict[int, np.ndarray], |
|
|
palette: Dict[int, Tuple[int, int, int]], |
|
|
fallback: Tuple[Tuple[int, int, int], ...], |
|
|
alpha: float = 0.5, |
|
|
) -> np.ndarray: |
|
|
"""Alpha blend colorized masks onto the frame.""" |
|
|
|
|
|
color_layer = self._compose_color_layer(masks, palette, fallback) |
|
|
overlay = cv2.addWeighted(frame, 1 - alpha, color_layer, alpha, 0) |
|
|
return overlay |
|
|
|
|
|
def _load_mask_prompt( |
|
|
self, |
|
|
mask_path: str, |
|
|
frame_shape: Tuple[int, int], |
|
|
mask_label: Optional[int] = None, |
|
|
mask_frame_index: Optional[int] = 0, |
|
|
mask_label_map: Optional[Dict[int, int]] = None, |
|
|
) -> Dict[int, np.ndarray]: |
|
|
"""Load prompt masks (object_id -> binary mask) from annotation.""" |
|
|
|
|
|
if mask_path is None: |
|
|
raise ValueError("mask_path must be provided for mask prompts") |
|
|
|
|
|
candidate = Path(mask_path) |
|
|
if candidate.is_dir(): |
|
|
if mask_frame_index is None: |
|
|
mask_frame_index = 0 |
|
|
frame_name = f"{int(mask_frame_index):04d}.png" |
|
|
candidate = candidate / frame_name |
|
|
|
|
|
if not candidate.exists(): |
|
|
raise FileNotFoundError(f"Mask prompt not found at {candidate}") |
|
|
|
|
|
mask = cv2.imread(str(candidate), cv2.IMREAD_GRAYSCALE) |
|
|
if mask is None: |
|
|
raise RuntimeError(f"Failed to read mask prompt: {candidate}") |
|
|
|
|
|
height, width = frame_shape |
|
|
if mask.shape != frame_shape: |
|
|
mask = cv2.resize(mask, (width, height), interpolation=cv2.INTER_NEAREST) |
|
|
|
|
|
prompt_masks: Dict[int, np.ndarray] = {} |
|
|
|
|
|
if mask_label_map: |
|
|
for pixel_value, obj_id in mask_label_map.items(): |
|
|
prompt_masks[int(obj_id)] = (mask == pixel_value).astype(np.uint8) |
|
|
elif mask_label is not None: |
|
|
prompt_masks[1] = (mask == mask_label).astype(np.uint8) |
|
|
else: |
|
|
prompt_masks[1] = (mask > 0).astype(np.uint8) |
|
|
|
|
|
if not prompt_masks: |
|
|
raise RuntimeError("No foreground objects extracted from mask prompt") |
|
|
|
|
|
return prompt_masks |
|
|
|
|
|
def _run( |
|
|
self, |
|
|
video_path: str, |
|
|
prompt_mode: Literal["auto", "points", "box", "mask"] = "auto", |
|
|
points: Optional[List[Tuple[float, float, int]]] = None, |
|
|
box: Optional[Tuple[float, float, float, float]] = None, |
|
|
mask_path: Optional[str] = None, |
|
|
mask_label: Optional[int] = None, |
|
|
mask_frame_index: Optional[int] = 0, |
|
|
mask_label_map: Optional[Dict[int, int]] = None, |
|
|
mask_palette: Optional[Dict[int, List[int]]] = None, |
|
|
target_name: Optional[str] = "LV", |
|
|
sample_rate: int = 1, |
|
|
output_fps: Optional[int] = None, |
|
|
save_mask_video: bool = True, |
|
|
save_overlay_video: bool = True, |
|
|
run_manager: Optional[CallbackManagerForToolRun] = None, |
|
|
) -> Tuple[Dict[str, Any], Dict]: |
|
|
"""Run MedSAM2/SAM2 video segmentation on an echo video or image. |
|
|
|
|
|
Returns (output, metadata), where output contains file paths; |
|
|
metadata contains additional info and basic per-frame metrics. |
|
|
""" |
|
|
self._ensure_dependencies() |
|
|
|
|
|
|
|
|
self._load_predictor() |
|
|
predictor = self._predictor |
|
|
|
|
|
|
|
|
frames, src_fps = self._read_video(video_path) |
|
|
fps = float(output_fps) if output_fps else src_fps |
|
|
h, w = frames[0].shape[:2] |
|
|
|
|
|
default_palette = { |
|
|
1: (0, 255, 0), |
|
|
2: (255, 0, 0), |
|
|
3: (255, 255, 0), |
|
|
4: (0, 0, 255), |
|
|
5: (255, 0, 255), |
|
|
} |
|
|
palette_rgb: Dict[int, Tuple[int, int, int]] = dict(default_palette) |
|
|
if mask_palette: |
|
|
palette_rgb.update( |
|
|
{ |
|
|
int(obj_id): tuple(int(c) for c in color) |
|
|
for obj_id, color in mask_palette.items() |
|
|
} |
|
|
) |
|
|
fallback_colors: Tuple[Tuple[int, int, int], ...] = ( |
|
|
(0, 255, 0), |
|
|
(255, 0, 0), |
|
|
(0, 0, 255), |
|
|
(255, 255, 0), |
|
|
(255, 0, 255), |
|
|
(0, 255, 255), |
|
|
) |
|
|
active_object_ids: List[int] = [] |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
state = predictor.init_state(video_path) |
|
|
except Exception as e: |
|
|
raise RuntimeError( |
|
|
f"Failed to initialize SAM2 state with video: {video_path}. " |
|
|
f"SAM2 may only support MP4 videos and JPEG folders. Error: {e}" |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
if state is None: |
|
|
raise RuntimeError("SAM2 state initialization failed") |
|
|
|
|
|
if prompt_mode == "mask" and mask_path: |
|
|
prompt_masks = self._load_mask_prompt( |
|
|
mask_path, |
|
|
(h, w), |
|
|
mask_label=mask_label, |
|
|
mask_frame_index=mask_frame_index, |
|
|
mask_label_map=mask_label_map, |
|
|
) |
|
|
for obj_id, obj_mask in prompt_masks.items(): |
|
|
predictor.add_new_mask( |
|
|
state, |
|
|
frame_idx=0, |
|
|
obj_id=int(obj_id), |
|
|
mask=obj_mask.astype(bool), |
|
|
) |
|
|
active_object_ids.append(int(obj_id)) |
|
|
elif prompt_mode == "points" and points: |
|
|
abs_points, point_labels = self._normalized_to_abs_points(points, w, h) |
|
|
predictor.add_new_points( |
|
|
state, |
|
|
frame_idx=0, |
|
|
obj_id=1, |
|
|
points=abs_points, |
|
|
labels=point_labels, |
|
|
) |
|
|
active_object_ids.append(1) |
|
|
elif prompt_mode == "box" and box: |
|
|
abs_box = self._normalized_to_abs_box(box, w, h) |
|
|
predictor.add_new_points_or_box( |
|
|
state, |
|
|
frame_idx=0, |
|
|
obj_id=1, |
|
|
box=abs_box, |
|
|
) |
|
|
active_object_ids.append(1) |
|
|
else: |
|
|
|
|
|
center_x, center_y = w // 2, h // 2 |
|
|
center_points = np.array([[center_x, center_y]]) |
|
|
center_labels = np.array([1]) |
|
|
predictor.add_new_points( |
|
|
state, |
|
|
frame_idx=0, |
|
|
obj_id=1, |
|
|
points=center_points, |
|
|
labels=center_labels, |
|
|
) |
|
|
active_object_ids.append(1) |
|
|
except Exception as e: |
|
|
raise RuntimeError( |
|
|
f"Prompting API mismatch. Please adapt the add_new_points calls " |
|
|
f"to your installed SAM2/MedSAM2 version. Error: {e}" |
|
|
) |
|
|
|
|
|
|
|
|
mask_frames: List[Dict[int, np.ndarray]] = [] |
|
|
overlay_frames: List[np.ndarray] = [] |
|
|
per_frame_metrics: List[Dict[str, Any]] = [] |
|
|
|
|
|
try: |
|
|
for out in predictor.propagate_in_video(state): |
|
|
if not (isinstance(out, tuple) and len(out) == 3): |
|
|
continue |
|
|
|
|
|
frame_idx, obj_ids, mask_logits = out |
|
|
if len(mask_logits) == 0: |
|
|
continue |
|
|
|
|
|
frame_masks: Dict[int, np.ndarray] = {} |
|
|
for idx, obj_id in enumerate(obj_ids): |
|
|
logits = mask_logits[idx] |
|
|
mask = (torch.sigmoid(logits).cpu().numpy() > 0.5).astype(np.uint8) |
|
|
frame_masks[int(obj_id)] = mask |
|
|
|
|
|
if not frame_masks: |
|
|
continue |
|
|
|
|
|
mask_frames.append(frame_masks) |
|
|
|
|
|
if save_overlay_video and frame_idx < len(frames): |
|
|
overlay = self._render_overlay( |
|
|
frames[frame_idx], frame_masks, palette_rgb, fallback_colors |
|
|
) |
|
|
overlay_frames.append(overlay) |
|
|
|
|
|
per_frame_metrics.append( |
|
|
{ |
|
|
"frame_index": int(frame_idx), |
|
|
"object_areas": { |
|
|
int(obj_id): int(mask.sum()) for obj_id, mask in frame_masks.items() |
|
|
}, |
|
|
} |
|
|
) |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Error during propagation: {e}") |
|
|
|
|
|
|
|
|
out_base = f"echo_seg_{target_name}_{uuid.uuid4().hex[:8]}" |
|
|
outputs: Dict[str, Any] = {} |
|
|
|
|
|
if save_overlay_video and overlay_frames: |
|
|
overlay_path = self.temp_dir / f"{out_base}_overlay.mp4" |
|
|
self._write_video(overlay_frames, fps, overlay_path) |
|
|
outputs["overlay_video_path"] = str(overlay_path) |
|
|
|
|
|
if save_mask_video and mask_frames: |
|
|
mask_rgb_frames: List[np.ndarray] = [] |
|
|
for frame_masks in mask_frames: |
|
|
color_layer = self._compose_color_layer(frame_masks, palette_rgb, fallback_colors) |
|
|
mask_rgb_frames.append(color_layer) |
|
|
mask_path = self.temp_dir / f"{out_base}_mask.mp4" |
|
|
self._write_video(mask_rgb_frames, fps, mask_path) |
|
|
outputs["mask_video_path"] = str(mask_path) |
|
|
|
|
|
metadata: Dict[str, Any] = { |
|
|
"video_path": video_path, |
|
|
"frames_processed": len(mask_frames), |
|
|
"source_frames": len(frames), |
|
|
"sample_rate": sample_rate, |
|
|
"fps_out": fps, |
|
|
"resolution": [h, w], |
|
|
"target_name": target_name, |
|
|
"active_object_ids": sorted(set(active_object_ids) or {1}), |
|
|
"per_frame_metrics": per_frame_metrics, |
|
|
"analysis_status": "completed", |
|
|
} |
|
|
|
|
|
return outputs, metadata |
|
|
|
|
|
async def _arun( |
|
|
self, |
|
|
video_path: str, |
|
|
prompt_mode: Literal["auto", "points", "box", "mask"] = "auto", |
|
|
points: Optional[List[Tuple[float, float, int]]] = None, |
|
|
box: Optional[Tuple[float, float, float, float]] = None, |
|
|
mask_path: Optional[str] = None, |
|
|
mask_label: Optional[int] = None, |
|
|
mask_frame_index: Optional[int] = 0, |
|
|
mask_label_map: Optional[Dict[int, int]] = None, |
|
|
mask_palette: Optional[Dict[int, List[int]]] = None, |
|
|
target_name: Optional[str] = "LV", |
|
|
sample_rate: int = 1, |
|
|
output_fps: Optional[int] = None, |
|
|
save_mask_video: bool = True, |
|
|
save_overlay_video: bool = True, |
|
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, |
|
|
) -> Tuple[Dict[str, Any], Dict]: |
|
|
return self._run( |
|
|
video_path, |
|
|
prompt_mode, |
|
|
points, |
|
|
box, |
|
|
mask_path, |
|
|
mask_label, |
|
|
mask_frame_index, |
|
|
mask_label_map, |
|
|
mask_palette, |
|
|
target_name, |
|
|
sample_rate, |
|
|
output_fps, |
|
|
save_mask_video, |
|
|
save_overlay_video, |
|
|
) |
|
|
|