""" Echo Tool Managers This module provides tool manager classes for various echo tools. """ import os import sys import shutil import zipfile import urllib.request from typing import Dict, List, Any, Optional, Type, Tuple from pathlib import Path import torch import numpy as np import cv2 from utils.video_utils import convert_video_to_h264 # Add parent directory to path for imports sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) from pydantic import BaseModel, Field from langchain_core.tools import BaseTool from tools.general.base_tool_manager import BaseToolManager, ToolConfig, ToolStatus # Model caching to prevent multiple loads _model_cache = {} _THIS_FILE = Path(__file__).resolve() _TOOL_REPO_BASES = [ _THIS_FILE.parents[2] / "tool_repos", # echo-agent/tool_repos _THIS_FILE.parents[3] / "tool_repos", # workspace-level tool_repos ] workspace_root_env = os.getenv("ECHO_WORKSPACE_ROOT") if workspace_root_env: _TOOL_REPO_BASES.append(Path(workspace_root_env) / "tool_repos") # Deduplicate while preserving order _unique_tool_repo_bases: List[Path] = [] for base_path in _TOOL_REPO_BASES: if base_path not in _unique_tool_repo_bases: _unique_tool_repo_bases.append(base_path) _TOOL_REPO_BASES = _unique_tool_repo_bases def _resolve_tool_repo(repo_names: List[str]) -> Path: """Return the first existing path for the given tool repo names.""" for repo_name in repo_names: for base in _TOOL_REPO_BASES: candidate = base / repo_name if candidate.exists(): return candidate primary_base = _TOOL_REPO_BASES[0] if _TOOL_REPO_BASES else Path.cwd() return primary_base / repo_names[0] MEDSAM_REPO_ROOT = _resolve_tool_repo(["MedSAM2-main", "MedSAM2"]) ECHOPRIME_REPO_ROOT = _resolve_tool_repo(["EchoPrime-main", "EchoPrime"]) ECHO_PRIME_RELEASE_BASE = "https://github.com/echonet/EchoPrime/releases/download/v1.0.0" ECHO_PRIME_MODEL_ZIP_URL = f"{ECHO_PRIME_RELEASE_BASE}/model_data.zip" ECHO_PRIME_EMBEDDING_FILES = { "candidate_embeddings_p1.pt": f"{ECHO_PRIME_RELEASE_BASE}/candidate_embeddings_p1.pt", "candidate_embeddings_p2.pt": f"{ECHO_PRIME_RELEASE_BASE}/candidate_embeddings_p2.pt", } DEFAULT_ECHO_SEGMENTATION_MASK = MEDSAM_REPO_ROOT / "0108.png" DEFAULT_ECHO_SEGMENTATION_MASK_DIR = MEDSAM_REPO_ROOT / "default_masks" DEFAULT_ECHO_SEGMENTATION_STRUCTURES = { "LV": "LV.png", "MYO": "MYO.png", "LA": "LA.png", "RV": "RV.png", "RA": "RA.png", } DEFAULT_ECHO_SEGMENTATION_CHECKPOINT = MEDSAM_REPO_ROOT / "checkpoints" / "MedSAM2_US_Heart.pt" def _download_file(url: str, destination: Path) -> bool: """Download a file from a URL to the destination path.""" try: destination.parent.mkdir(parents=True, exist_ok=True) print(f"⬇️ Downloading {url} -> {destination}") request = urllib.request.Request(url, headers={"User-Agent": "EchoAgent/1.0"}) with urllib.request.urlopen(request) as response, open(destination, "wb") as output_file: shutil.copyfileobj(response, output_file) print(f"✅ Downloaded {destination.name}") return True except Exception as download_error: print(f"❌ Failed to download {url}: {download_error}") if destination.exists(): destination.unlink() return False def ensure_echoprime_assets(echo_prime_path: Path) -> bool: """Ensure required EchoPrime assets are available, downloading when missing.""" model_data_dir = echo_prime_path / "model_data" weights_dir = model_data_dir / "weights" candidates_dir = model_data_dir / "candidates_data" required_files = [ weights_dir / "echo_prime_encoder.pt", weights_dir / "view_classifier.pt", candidates_dir / "candidate_embeddings_p1.pt", candidates_dir / "candidate_embeddings_p2.pt", ] if all(required_path.exists() for required_path in required_files): return True print("⚠️ EchoPrime assets missing; attempting automatic download...") # Attempt to download the zipped model data if weights are missing. if not all((weights_dir / filename).exists() for filename in ("echo_prime_encoder.pt", "view_classifier.pt")): temp_zip_path = echo_prime_path / "model_data.zip" if _download_file(ECHO_PRIME_MODEL_ZIP_URL, temp_zip_path): try: with zipfile.ZipFile(temp_zip_path, "r") as zip_ref: zip_ref.extractall(echo_prime_path) print("✅ Extracted model_data.zip") except zipfile.BadZipFile as zip_error: print(f"❌ model_data.zip appears corrupted: {zip_error}") finally: temp_zip_path.unlink(missing_ok=True) # Download candidate embedding shards if still missing. for filename, url in ECHO_PRIME_EMBEDDING_FILES.items(): destination = candidates_dir / filename if not destination.exists(): _download_file(url, destination) # Final verification step after download attempts. all_present = all(required_path.exists() for required_path in required_files) if not all_present: print("❌ Required EchoPrime assets are still missing after download attempts.") return all_present def load_panecho_model(): """Load PanEcho model for real predictions with caching.""" if "panecho" in _model_cache: print("✅ Using cached PanEcho model") return _model_cache["panecho"] try: from models.model_factory import get_model print("🔄 Loading PanEcho model...") model = get_model("panecho") if model is None: raise RuntimeError("PanEcho model not available") # Cache the model _model_cache["panecho"] = model print("✅ PanEcho model loaded and cached") return model except Exception as e: print(f"PanEcho loading failed: {e}") raise RuntimeError(f"PanEcho model not available: {e}") def load_medsam2_model(): """Load MedSAM2 model for segmentation with caching.""" if "medsam2" in _model_cache: print("✅ Using cached MedSAM2 model path") return _model_cache["medsam2"] try: from models.model_factory import get_model print("🔄 Loading MedSAM2 model...") model_path = get_model("medsam2") if model_path is None: raise RuntimeError("MedSAM2 model not available") # Cache the model path _model_cache["medsam2"] = model_path print(f"✅ MedSAM2 model loaded and cached: {model_path}") return model_path except Exception as e: print(f"MedSAM2 loading failed: {e}") raise RuntimeError(f"MedSAM2 model not available: {e}") def load_echoflow_model(): """Load EchoFlow model for generation with caching.""" if "echoflow" in _model_cache: print("✅ Using cached EchoFlow model") return _model_cache["echoflow"] try: from models.model_factory import get_model print("🔄 Loading EchoFlow model...") model = get_model("echoflow") if model is None: raise RuntimeError("EchoFlow model not available") # Cache the model _model_cache["echoflow"] = model print("✅ EchoFlow model loaded and cached") return model except Exception as e: print(f"EchoFlow loading failed: {e}") raise RuntimeError(f"EchoFlow model not available: {e}") def clear_model_cache(): """Clear the model cache to free memory.""" global _model_cache _model_cache.clear() print("🧹 Model cache cleared") def load_echo_prime_model(): """Load EchoPrime model for comprehensive analysis with caching.""" if "echo_prime" in _model_cache: print("✅ Using cached EchoPrime model") return _model_cache["echo_prime"] try: # Use the EchoPrime directory in tool_repos echo_prime_path = ECHOPRIME_REPO_ROOT if not echo_prime_path.exists(): print(f"❌ EchoPrime directory not found: {echo_prime_path}") return None if not ensure_echoprime_assets(echo_prime_path): print("❌ EchoPrime assets unavailable; cannot initialize model.") return None # Add EchoPrime to path if str(echo_prime_path) not in sys.path: sys.path.insert(0, str(echo_prime_path)) # Import EchoPrime from echo_prime.model import EchoPrime # Initialize EchoPrime model with correct device device = "cuda" if torch.cuda.is_available() else "cpu" echo_prime_model = EchoPrime(device=device) # Cache the model _model_cache["echo_prime"] = echo_prime_model print("✅ EchoPrime model loaded successfully") return echo_prime_model except Exception as e: print(f"❌ Failed to load EchoPrime model: {e}") import traceback traceback.print_exc() return None class EchoDiseasePredictionInput(BaseModel): """Input schema for echo disease prediction.""" input_dir: str = Field(..., description="Directory containing echo videos") max_videos: Optional[int] = Field(None, description="Maximum number of videos to process") save_csv: bool = Field(True, description="Save results to CSV file") include_confidence: bool = Field(True, description="Include confidence scores in output") class EchoImageVideoGenerationInput(BaseModel): """Input schema for echo image/video generation.""" views: List[str] = Field(..., description="List of echo views to generate") efs: List[float] = Field(..., description="List of ejection fractions") outdir: Optional[str] = Field(None, description="Output directory") num_samples: int = Field(10, description="Number of samples to generate") class EchoMeasurementPredictionInput(BaseModel): """Input schema for echo measurement prediction.""" input_dir: str = Field(..., description="Directory containing echo videos") max_videos: Optional[int] = Field(None, description="Maximum number of videos to process") include_report: bool = Field(True, description="Include detailed report") save_csv: bool = Field(True, description="Save measurements to CSV") class EchoReportGenerationInput(BaseModel): """Input schema for echo report generation.""" input_dir: str = Field(..., description="Directory containing echo videos") visualize_views: bool = Field(False, description="Generate view visualizations") max_videos: Optional[int] = Field(None, description="Maximum number of videos to process") include_sections: bool = Field(True, description="Include all report sections") class EchoSegmentationInput(BaseModel): """Input schema for echo segmentation.""" video_path: str = Field(..., description="Path to echo video file") prompt_mode: str = Field("auto", description="Prompt mode for segmentation (auto, points, box, mask)") target_name: str = Field( "all", description="Target structure name (all, LV, RV, LA, RA, MV, TV, AV, PV, IVS, LVPW, AORoot, PA)", ) save_mask_video: bool = Field(True, description="Save mask video") save_overlay_video: bool = Field(True, description="Save overlay video") points: Optional[List[List[float]]] = Field( None, description="List of [x, y, label] triples in normalized coordinates for the first frame (label: 1 foreground, 0 background)", ) box: Optional[List[float]] = Field( None, description="Normalized box as [x1, y1, x2, y2] for the first frame", ) mask_path: Optional[str] = Field( None, description="Path to an initial segmentation mask for the first frame (for 'mask' mode)", ) sample_rate: int = Field(1, description="Process every Nth frame for speed (1 = every frame)") output_fps: Optional[int] = Field(None, description="FPS for output video. Defaults to source FPS") progress_callback: Optional[Any] = Field(None, description="Progress callback function for UI updates") # New: support dataset-provided initial prompt masks (first frame) initial_masks_dir: Optional[str] = Field( None, description="Directory containing first-frame masks per structure (e.g., LV.png, RV.png)", ) initial_mask_paths: Optional[Dict[str, str]] = Field( None, description="Mapping of structure code (e.g., 'LV') to first-frame mask file path", ) initial_mask_frame_idx: int = Field(0, description="Frame index the provided masks correspond to (default 0)") use_auto_masks_if_missing: bool = Field( True, description="If a structure mask is missing, fall back to auto coarse prompt", ) class EchoViewClassificationInput(BaseModel): """Input schema for echo view classification.""" input_dir: str = Field(..., description="Directory containing echo videos") visualize: bool = Field(False, description="Generate visualizations") max_videos: Optional[int] = Field(None, description="Maximum number of videos to process") class EchoDiseasePredictionTool(BaseTool): """Echo disease prediction tool.""" name: str = "echo_disease_prediction" description: str = "Predict cardiac diseases from echo videos using PanEcho." args_schema: Type[BaseModel] = EchoDiseasePredictionInput def _get_task_units(self, task_name: str) -> str: """Get units for a specific task.""" units_map = { 'EF': '%', 'GLS': '%', 'LVEDV': 'cm³', 'LVESV': 'cm³', 'LVSV': 'cm³', 'IVSd': 'cm', 'LVPWd': 'cm', 'LVIDs': 'cm', 'LVIDd': 'cm', 'LVOTDiam': 'cm', 'E|EAvg': 'ratio', 'RVSP': 'mmHg', 'RVIDd': 'cm', 'TAPSE': 'cm', 'RVSVel': 'cm/s', 'LAIDs2D': 'cm', 'LAVol': 'cm³', 'RADimensionM-L(cm)': 'cm', 'AVPkVel(m/s)': 'm/s', 'TVPkGrad': 'mmHg', 'AORoot': 'cm' } return units_map.get(task_name, 'N/A') def _get_class_names(self, task_name: str) -> list: """Get class names for classification tasks.""" class_names_map = { 'LVSize': ['Mildly Increased', 'Moderately|Severely Increased', 'Normal'], 'LVSystolicFunction': ['Mildly Decreased', 'Moderately|Severely Decreased', 'Normal|Hyperdynamic'], 'LVDiastolicFunction': ['Mild|Indeterminate', 'Moderate|Severe', 'Normal'], 'RVSize': ['Mildly Increased', 'Moderately|Severely Increased', 'Normal'], 'LASize': ['Mildly Dilated', 'Moderately|Severely Dilated', 'Normal'], 'AVStenosis': ['Mild|Moderate', 'None', 'Severe'], 'AVRegurg': ['Mild', 'Moderate|Severe', 'None|Trace'], 'MVRegurgitation': ['Mild', 'Moderate|Severe', 'None|Trace'], 'TVRegurgitation': ['Mild', 'Moderate|Severe', 'None|Trace'] } return class_names_map.get(task_name, []) def _run( self, input_dir: str, max_videos: Optional[int] = None, save_csv: bool = True, include_confidence: bool = True, run_manager: Optional[Any] = None, ) -> Dict[str, Any]: """Run echo disease prediction using real PanEcho model.""" try: # Load PanEcho model - will raise exception if not available panecho_model = load_panecho_model() # Process videos in input directory import os import glob video_files = glob.glob(os.path.join(input_dir, "*.mp4")) if max_videos: video_files = video_files[:max_videos] if not video_files: raise RuntimeError(f"No MP4 videos found in {input_dir}") all_predictions = [] for video_path in video_files: try: # Load and process video with proper preprocessing import torchvision.transforms as transforms cap = cv2.VideoCapture(video_path) frames = [] frame_count = 0 max_frames = 16 # PanEcho expects 16 frames # Define ImageNet normalization (as per PanEcho documentation) normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) while len(frames) < max_frames and cap.isOpened(): ret, frame = cap.read() if not ret: break # Resize to 224x224 and convert BGR to RGB frame = cv2.resize(frame, (224, 224)) frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(frame) frame_count += 1 cap.release() if len(frames) < 16: # Pad with last frame if needed while len(frames) < 16: frames.append(frames[-1]) # Convert to tensor with proper preprocessing frames_array = np.array(frames, dtype=np.float32) / 255.0 # Normalize to [0,1] frames_tensor = torch.tensor(frames_array).permute(0, 3, 1, 2) # (16, 3, 224, 224) frames_tensor = frames_tensor.unsqueeze(0) # Add batch dimension: (1, 16, 3, 224, 224) # Apply ImageNet normalization frames_tensor = normalize(frames_tensor.view(-1, 3, 224, 224)).view(1, 16, 3, 224, 224) # Reshape to correct format for PanEcho: (batch, channels, frames, height, width) frames_tensor = frames_tensor.permute(0, 2, 1, 3, 4) # (1, 3, 16, 224, 224) # Move to same device as model device = next(panecho_model.parameters()).device frames_tensor = frames_tensor.to(device) # Run inference with torch.no_grad(): predictions = panecho_model(frames_tensor) # Process predictions according to PanEcho output format disease_predictions = {} # Define comprehensive task descriptions for all PanEcho tasks task_descriptions = { 'pericardial-effusion': 'Pericardial Effusion', 'EF': 'Ejection Fraction (%)', 'GLS': 'Global Longitudinal Strain (%)', 'LVEDV': 'LV End-Diastolic Volume (cm³)', 'LVESV': 'LV End-Systolic Volume (cm³)', 'LVSV': 'LV Stroke Volume (cm³)', 'LVSize': 'LV Size', 'LVWallThickness-increased-any': 'LV Wall Thickness - Any Increase', 'LVWallThickness-increased-modsev': 'LV Wall Thickness - Moderate/Severe Increase', 'LVSystolicFunction': 'LV Systolic Function', 'LVWallMotionAbnormalities': 'LV Wall Motion Abnormalities', 'IVSd': 'Interventricular Septum Diastole (cm)', 'LVPWd': 'LV Posterior Wall Diastole (cm)', 'LVIDs': 'LV Internal Diameter Systole (cm)', 'LVIDd': 'LV Internal Diameter Diastole (cm)', 'LVOTDiam': 'LV Outflow Tract Diameter (cm)', 'LVDiastolicFunction': 'LV Diastolic Function', 'E|EAvg': 'E/e\' Ratio', 'RVSP': 'RV Systolic Pressure (mmHg)', 'RVSize': 'RV Size', 'RVSystolicFunction': 'RV Systolic Function', 'RVIDd': 'RV Internal Diameter Diastole (cm)', 'TAPSE': 'Tricuspid Annular Plane Systolic Excursion (cm)', 'RVSVel': 'RV Systolic Excursion Velocity (cm/s)', 'LASize': 'Left Atrial Size', 'LAIDs2D': 'LA Internal Diameter Systole 2D (cm)', 'LAVol': 'LA Volume (cm³)', 'RASize': 'Right Atrial Size', 'RADimensionM-L(cm)': 'RA Major Dimension (cm)', 'AVStructure': 'Aortic Valve Structure', 'AVStenosis': 'Aortic Valve Stenosis', 'AVPkVel(m/s)': 'Aortic Valve Peak Velocity (m/s)', 'AVRegurg': 'Aortic Valve Regurgitation', 'LVOT20mmHg': 'Elevated LV Outflow Tract Pressure', 'MVStenosis': 'Mitral Valve Stenosis', 'MVRegurgitation': 'Mitral Valve Regurgitation', 'TVRegurgitation': 'Tricuspid Valve Regurgitation', 'TVPkGrad': 'Tricuspid Valve Peak Gradient (mmHg)', 'RAP-8-or-higher': 'Elevated RA Pressure', 'AORoot': 'Aortic Root Diameter (cm)' } # Process all available predictions based on PanEcho output format for task_name, pred_value in predictions.items(): task_description = task_descriptions.get(task_name, f"{task_name} (Unknown Task)") try: # Determine task type based on tensor shape and PanEcho model structure if torch.is_tensor(pred_value): if pred_value.shape == (1, 1): # Regression or binary classification raw_value = float(pred_value[0, 0].item()) # Determine if it's regression or binary classification based on task name if task_name in ['EF', 'GLS', 'LVEDV', 'LVESV', 'LVSV', 'IVSd', 'LVPWd', 'LVIDs', 'LVIDd', 'LVOTDiam', 'E|EAvg', 'RVSP', 'RVIDd', 'TAPSE', 'RVSVel', 'LAIDs2D', 'LAVol', 'RADimensionM-L(cm)', 'AVPkVel(m|s)', 'TVPkGrad', 'AORoot']: # Regression task - use raw value value = raw_value task_type = 'regression' confidence = 0.85 # Default confidence for regression else: # Binary classification - sigmoid already applied, value is probability value = raw_value task_type = 'binary_classification' confidence = max(value, 1.0 - value) # Confidence based on how close to 0 or 1 elif pred_value.shape[1] > 1: # Multi-class classification # Softmax already applied, get class probabilities probs = pred_value[0] # Remove batch dimension predicted_class = int(probs.argmax().item()) confidence = float(probs.max().item()) # Get class names for this task class_names = self._get_class_names(task_name) if class_names and predicted_class < len(class_names): value = class_names[predicted_class] else: value = predicted_class task_type = 'multi-class_classification' else: # Fallback for other shapes value = float(pred_value.flatten().mean().item()) task_type = 'regression' confidence = 0.85 else: value = float(pred_value) if isinstance(pred_value, (int, float)) else 0.0 task_type = 'unknown' confidence = 0.0 disease_predictions[task_name] = { 'value': value, 'description': task_description, 'confidence': confidence, 'task_type': task_type, 'units': self._get_task_units(task_name), 'raw_prediction': float(pred_value[0, 0].item()) if torch.is_tensor(pred_value) and pred_value.shape == (1, 1) else None } except Exception as e: print(f"Error processing {task_name}: {e}") disease_predictions[task_name] = { 'value': 0.0, 'description': task_description, 'confidence': 0.0, 'task_type': 'unknown', 'units': 'unknown', 'error': str(e) } all_predictions.append({ "video": os.path.basename(video_path), "predictions": disease_predictions }) except Exception as e: print(f"Error processing {video_path}: {e}") continue if not all_predictions: raise RuntimeError("No videos processed successfully") return { "status": "success", "model": "PanEcho", "input_dir": input_dir, "max_videos": max_videos, "processed_videos": len(all_predictions), "predictions": all_predictions, "message": f"Disease prediction completed for {len(all_predictions)} videos using real PanEcho model" } except Exception as e: print(f"PanEcho prediction failed: {e}") raise RuntimeError(f"Disease prediction failed: {e}") class EchoImageVideoGenerationTool(BaseTool): """Echo image/video generation tool.""" name: str = "echo_image_video_generation" description: str = "Generate synthetic echo images and videos using EchoFlow." args_schema: Type[BaseModel] = EchoImageVideoGenerationInput def _run( self, views: List[str], efs: List[float], outdir: Optional[str] = None, num_samples: int = 10, run_manager: Optional[Any] = None, ) -> Dict[str, Any]: """Run echo image/video generation using real EchoFlow model.""" try: # Load EchoFlow model - will raise exception if not available echoflow_model = load_echoflow_model() # Create output directory output_dir = Path(outdir or "temp/echo_generated") output_dir.mkdir(parents=True, exist_ok=True) # Generate synthetic videos using real EchoFlow model generated_files = echoflow_model.generate_synthetic_video( views=views, efs=efs, num_samples=num_samples, output_dir=str(output_dir) ) successful_generations = len(generated_files) return { "status": "success", "model": "EchoFlow", "views": views, "efs": efs, "num_samples": num_samples, "successful_generations": successful_generations, "generated_files": generated_files, "output_dir": str(output_dir), "message": f"Generated {successful_generations} synthetic echo videos using real EchoFlow model" } except Exception as e: print(f"EchoFlow generation failed: {e}") raise RuntimeError(f"Echo generation failed: {e}") class EchoMeasurementPredictionTool(BaseTool): """Echo measurement prediction tool.""" name: str = "echo_measurement_prediction" description: str = "Extract echocardiography measurements using EchoPrime." args_schema: Type[BaseModel] = EchoMeasurementPredictionInput def _run( self, input_dir: str, max_videos: Optional[int] = None, include_report: bool = True, save_csv: bool = True, run_manager: Optional[Any] = None, ) -> Dict[str, Any]: """Run echo measurement prediction using real EchoPrime model.""" try: # Load EchoPrime model for measurements echo_prime_model = load_echo_prime_model() # Process videos using EchoPrime's process_mp4s method print(f"🔄 Processing videos from {input_dir}...") stack_of_videos = echo_prime_model.process_mp4s(input_dir) if len(stack_of_videos) == 0: raise RuntimeError("No videos processed successfully") print(f"✅ Processed {len(stack_of_videos)} videos") # Get study embedding and predict metrics print("🔄 Predicting measurements...") # Get video features and view encodings separately to handle dimension issues video_features = echo_prime_model.embed_videos(stack_of_videos) view_encodings = echo_prime_model.get_views(stack_of_videos) # Fix tensor dimension issue for single video if view_encodings.dim() == 1: view_encodings = view_encodings.unsqueeze(0) # Concatenate features and view encodings study_embedding = torch.cat((video_features, view_encodings), dim=1) measurements = echo_prime_model.predict_metrics(study_embedding) # Format measurements with proper units and confidence formatted_measurements = {} for key, value in measurements.items(): if isinstance(value, (int, float)) and not np.isnan(value): # Determine unit based on measurement type unit = "%" if key == "EF" else "cm" if "d" in key else "mL" formatted_measurements[key] = { "value": float(value), "unit": unit, "confidence": 0.85 } all_measurements = [{ "video": "study_measurements", "measurements": formatted_measurements }] return { "status": "success", "model": "EchoPrime", "input_dir": input_dir, "max_videos": max_videos, "processed_videos": len(stack_of_videos), "measurements": all_measurements, "message": f"Measurement prediction completed for {len(stack_of_videos)} videos using real EchoPrime model" } except Exception as e: print(f"EchoPrime measurement prediction failed: {e}") raise RuntimeError(f"Measurement prediction failed: {e}") class EchoReportGenerationTool(BaseTool): """Echo report generation tool.""" name: str = "echo_report_generation" description: str = "Generate comprehensive echo report using EchoPrime." args_schema: Type[BaseModel] = EchoReportGenerationInput def _run( self, input_dir: str, visualize_views: bool = False, max_videos: Optional[int] = None, include_sections: bool = True, run_manager: Optional[Any] = None, ) -> Dict[str, Any]: """Run echo report generation using real EchoPrime model.""" try: # Load EchoPrime model for comprehensive analysis echo_prime_model = load_echo_prime_model() # Process videos using EchoPrime's process_mp4s method print(f"🔄 Processing videos from {input_dir}...") stack_of_videos = echo_prime_model.process_mp4s(input_dir) if len(stack_of_videos) == 0: raise RuntimeError("No videos processed successfully") print(f"✅ Processed {len(stack_of_videos)} videos") # Get study embedding and generate report print("🔄 Generating comprehensive report...") # Get video features and view encodings separately to handle dimension issues video_features = echo_prime_model.embed_videos(stack_of_videos) view_encodings = echo_prime_model.get_views(stack_of_videos, visualize=visualize_views) # Fix tensor dimension issue for single video if view_encodings.dim() == 1: view_encodings = view_encodings.unsqueeze(0) # Concatenate features and view encodings study_embedding = torch.cat((video_features, view_encodings), dim=1) report = echo_prime_model.generate_report(study_embedding) # Get measurements for additional context measurements = echo_prime_model.predict_metrics(study_embedding) # Get view classifications views = echo_prime_model.get_views(stack_of_videos, return_view_list=True) # Create comprehensive analysis analysis = { "video": "study_analysis", "view_classification": { "predicted_views": views, "view_distribution": {view: views.count(view) for view in set(views)} }, "measurements": measurements, "disease_predictions": {}, # EchoPrime doesn't have disease predictions "quality_assessment": { "confidence": 0.85, "model_used": "EchoPrime" }, "confidence": 0.85 } return { "status": "success", "model": "EchoPrime", "input_dir": input_dir, "max_videos": max_videos, "processed_videos": len(stack_of_videos), "report": report, "analysis": analysis, "message": f"Report generation completed for {len(stack_of_videos)} videos using real EchoPrime model" } except Exception as e: print(f"EchoPrime report generation failed: {e}") raise RuntimeError(f"Report generation failed: {e}") def _generate_comprehensive_report(self, analyses, include_sections): """Generate comprehensive report from analyses.""" # Aggregate measurements all_measurements = [] all_disease_predictions = [] view_distribution = {} for analysis in analyses: all_measurements.append(analysis.get("measurements", {})) all_disease_predictions.append(analysis.get("disease_predictions", {})) view = analysis.get("view_classification", {}).get("predicted_view", "unknown") view_distribution[view] = view_distribution.get(view, 0) + 1 # Calculate average measurements avg_measurements = {} measurement_keys = ['EF', 'LVEDV', 'LVESV', 'GLS', 'IVSd', 'LVPWd', 'LVIDs', 'LVIDd'] for key in measurement_keys: values = [m.get(key, {}).get("value", 0) if isinstance(m.get(key), dict) else m.get(key, 0) for m in all_measurements if m] if values: avg_measurements[key] = np.mean(values) # Generate clinical summary ef = avg_measurements.get("EF", 0) if ef > 55: ef_status = "Normal" elif ef > 45: ef_status = "Mildly reduced" else: ef_status = "Moderately to severely reduced" summary = f"Left ventricular ejection fraction is {ef_status} ({ef:.1f}%). " if "LVEDV" in avg_measurements: summary += f"Left ventricular end-diastolic volume is {avg_measurements['LVEDV']:.1f} mL. " if "GLS" in avg_measurements: summary += f"Global longitudinal strain is {avg_measurements['GLS']:.1f}%." # Generate recommendations recommendations = [] if ef < 50: recommendations.append("Consider cardiology consultation") if avg_measurements.get("GLS", 0) < -18: recommendations.append("Monitor for heart failure") if not recommendations: recommendations.append("Routine follow-up in 1 year") # Generate report sections sections = [] if include_sections: sections = ["findings", "measurements", "view_analysis", "recommendations"] report = { "summary": summary, "recommendations": recommendations, "sections": sections, "measurements": {k: f"{v:.1f}" for k, v in avg_measurements.items()}, "view_distribution": view_distribution, "processed_videos": len(analyses), "overall_confidence": np.mean([a.get("confidence", 0) for a in analyses]) } return report def _create_view_visualization(self, analyses, input_dir): """Create view visualization.""" try: import matplotlib.pyplot as plt # Count views view_counts = {} for analysis in analyses: view = analysis.get("view_classification", {}).get("predicted_view", "unknown") view_counts[view] = view_counts.get(view, 0) + 1 # Create pie chart plt.figure(figsize=(8, 6)) plt.pie(view_counts.values(), labels=view_counts.keys(), autopct='%1.1f%%') plt.title("Echo View Distribution") # Save visualization output_path = Path(input_dir) / "view_distribution.png" plt.savefig(output_path, dpi=300, bbox_inches='tight') plt.close() return str(output_path) except Exception as e: print(f"Visualization creation failed: {e}") return None class EchoSegmentationTool(BaseTool): """Echo segmentation tool.""" name: str = "echo_segmentation" description: str = "Segment cardiac chambers in echo videos using MedSAM2." args_schema: Type[BaseModel] = EchoSegmentationInput def _run( self, video_path: str, prompt_mode: str = "auto", target_name: str = "all", save_mask_video: bool = True, save_overlay_video: bool = True, points: Optional[List[List[float]]] = None, box: Optional[List[float]] = None, mask_path: Optional[str] = None, sample_rate: int = 1, output_fps: Optional[int] = None, progress_callback: Optional[callable] = None, # New initial-mask plumbing initial_masks_dir: Optional[str] = None, initial_mask_paths: Optional[Dict[str, str]] = None, initial_mask_frame_idx: int = 0, use_auto_masks_if_missing: bool = True, run_manager: Optional[Any] = None, query: Optional[str] = None, ) -> Dict[str, Any]: """Run echo segmentation using real MedSAM2 model.""" try: normalized_points: Optional[List[Tuple[float, float, int]]] = None if points: normalized_points = [] for entry in points: if not isinstance(entry, (list, tuple)) or len(entry) < 3: continue try: x_val = float(entry[0]) y_val = float(entry[1]) label_val = int(entry[2]) normalized_points.append((x_val, y_val, label_val)) except (TypeError, ValueError): continue if not normalized_points: normalized_points = None normalized_box: Optional[Tuple[float, float, float, float]] = None if box and isinstance(box, (list, tuple)) and len(box) >= 4: try: normalized_box = ( float(box[0]), float(box[1]), float(box[2]), float(box[3]), ) except (TypeError, ValueError): normalized_box = None # Load MedSAM2 model medsam2_model_path = load_medsam2_model() # Process video with MedSAM2 cap = cv2.VideoCapture(video_path) masks = [] frames = [] fps = cap.get(cv2.CAP_PROP_FPS) if not fps or fps <= 1e-3: fps = 30.0 # Read all frames first while cap.isOpened(): ret, frame = cap.read() if not ret: break frames.append(frame) cap.release() if not frames: raise RuntimeError(f"No frames found in video: {video_path}") # Use enhanced MedSAM2 for video segmentation try: frames_rgb = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in frames] # Prepare optional initial masks for first frame height, width = frames[0].shape[:2] provided_masks = None if (initial_masks_dir or initial_mask_paths): provided_masks = self._load_initial_masks( height, width, initial_masks_dir=initial_masks_dir, initial_mask_paths=initial_mask_paths, ) if provided_masks: print(f"✅ Using {len(provided_masks)} annotation masks: {list(provided_masks.keys())}") else: # Attempt to load annotation-based prompts from Config annotation_prompts = self._load_annotation_prompts_from_config( height, width, video_path, ) if annotation_prompts: provided_masks = annotation_prompts print(f"✅ Loaded config-based annotation masks: {list(provided_masks.keys())}") else: print("⚠️ No annotation masks found for video; falling back to auto prompts") if not provided_masks: # Fallback to global default from Config if set (single-structure) try: from config import Config default_path = getattr(Config, 'DEFAULT_INITIAL_MASK_PATH', '') default_structure = getattr(Config, 'DEFAULT_INITIAL_MASK_STRUCTURE', 'LV') if isinstance(default_path, str) and default_path: import os as _os if _os.path.exists(default_path): provided_masks = self._load_initial_masks( height, width, initial_mask_paths={str(default_structure).upper(): default_path}, ) print(f"⚠️ Using default single-structure mask for {default_structure}") except Exception: pass segmentation_result = self._segment_with_medsam2( frames_rgb, medsam2_model_path, progress_callback, initial_masks=provided_masks, ) masks = [] for frame_idx in range(len(frames)): combined = np.zeros((height, width), dtype=np.uint8) frame_masks = segmentation_result['masks'].get(frame_idx, {}) for obj_mask in frame_masks.values(): mask_array = obj_mask if mask_array.shape != (height, width): mask_array = cv2.resize(mask_array, (width, height), interpolation=cv2.INTER_NEAREST) combined = np.maximum(combined, mask_array) masks.append(combined) self._enhanced_segmentation_result = segmentation_result except Exception as e: print(f"Error in enhanced MedSAM2 video segmentation: {e}") import traceback traceback.print_exc() # Fallback to basic segmentation try: from tools.echo.medsam2_integration import MedSAM2VideoSegmenter segmenter = MedSAM2VideoSegmenter(medsam2_model_path) masks = segmenter.segment_video(frames_rgb, target_name, progress_callback) self._enhanced_segmentation_result = None except Exception as e2: # Strict mode: do not fallback silently; raise the original error raise except Exception as e: print(f"Error loading MedSAM2 model or processing video: {e}") return { "status": "error", "error": str(e), "video_path": video_path, "target_name": target_name } # Save outputs if requested outputs = {} if save_mask_video or save_overlay_video: output_dir = Path("temp") / "segmentation_outputs" output_dir.mkdir(parents=True, exist_ok=True) if save_mask_video: mask_video_path = output_dir / f"mask_{target_name}_{Path(video_path).stem}.mp4" mask_video_path = self._save_mask_video(frames, masks, str(mask_video_path)) outputs["mask_video"] = mask_video_path if save_overlay_video: overlay_video_path = output_dir / f"overlay_{target_name}_{Path(video_path).stem}.mp4" overlay_video_path = self._save_overlay_video(frames, masks, str(overlay_video_path)) outputs["overlay_video"] = overlay_video_path # Add enhanced outputs if available enhanced_outputs = {} if hasattr(self, '_enhanced_segmentation_result') and self._enhanced_segmentation_result: try: enhanced_outputs = self._create_enhanced_videos(frames, self._enhanced_segmentation_result, output_dir, fps=fps) except Exception as e: print(f"Error creating enhanced videos: {e}") enhanced_outputs = {} # Ensure basic outputs are always included final_outputs = { **outputs, **enhanced_outputs, "masks": len(masks), "frames_processed": len(frames) } # Add basic video outputs if they exist if "mask_video" in outputs: final_outputs["segmented_video"] = outputs["mask_video"] if "overlay_video" in outputs: final_outputs["overlay_video"] = outputs["overlay_video"] # Prefer combined multi-structure overlay as the primary segmented video when targeting 'all' if target_name == "all" and "combined_segmentation_video" in enhanced_outputs: final_outputs["segmented_video"] = enhanced_outputs["combined_segmentation_video"] final_outputs["overlay_video"] = enhanced_outputs["combined_segmentation_video"] if hasattr(self, '_enhanced_segmentation_result') and self._enhanced_segmentation_result: final_outputs.setdefault( "structures", self._enhanced_segmentation_result.get("structures", []), ) final_outputs.setdefault( "structure_info", self._enhanced_segmentation_result.get("structure_info", {}), ) return { "status": "success", "model": "MedSAM2", "video_path": video_path, "target_name": target_name, "prompt_mode": prompt_mode, "outputs": final_outputs, "message": f"Enhanced segmentation completed with {len(masks)} frames processed using MedSAM2 model" } def _segment_with_medsam2(self, frames, model_path, progress_callback=None, initial_masks: Optional[Dict[str, np.ndarray]] = None): """Segment video using enhanced MedSAM2 model with multi-structure support.""" try: # Import the enhanced MedSAM2 integration import sys import os current_dir = os.path.dirname(os.path.abspath(__file__)) project_root = os.path.dirname(os.path.dirname(os.path.dirname(current_dir))) sys.path.insert(0, project_root) from tools.echo.enhanced_medsam2_integration import EnhancedMedSAM2VideoSegmenter print("✅ Successfully imported EnhancedMedSAM2VideoSegmenter") if progress_callback: progress_callback(10, "Initializing enhanced MedSAM2 model...") # Create segmenter instance segmenter = EnhancedMedSAM2VideoSegmenter(model_path) if progress_callback: progress_callback(20, "Starting multi-structure video segmentation...") # Segment the video with progress updates result = segmenter.segment_video_multi_structure(frames, progress_callback, initial_masks=initial_masks) print(f"✅ Generated multi-structure masks for {result['total_frames']} frames") print(f"🎯 Segmented structures: {result['structures']}") if progress_callback: progress_callback(100, "Multi-structure segmentation completed!") return result except Exception as e: print(f"❌ Enhanced MedSAM2 segmentation error: {e}") import traceback traceback.print_exc() if progress_callback: progress_callback(0, f"Segmentation failed: {e}") raise RuntimeError(f"Enhanced MedSAM2 segmentation failed: {e}") def _load_initial_masks( self, height: int, width: int, initial_masks_dir: Optional[str] = None, initial_mask_paths: Optional[Dict[str, str]] = None, ) -> Dict[str, np.ndarray]: """Load first-frame masks from a directory or explicit mapping. Supported structure keys: 'LV','MYO','LA','RV','RA' (others ignored for now). Images are read as grayscale; any non-zero treated as foreground; resized to (width,height). """ import os import glob valid_structures = {"LV", "MYO", "LA", "RV", "RA"} paths: Dict[str, str] = {} def read_mask(path: str) -> Optional[np.ndarray]: try: if path.lower().endswith(".npy"): arr = np.load(path) if arr.ndim == 3: arr = arr.squeeze() arr = (arr > 0).astype(np.uint8) * 255 else: img = cv2.imread(path, cv2.IMREAD_GRAYSCALE) if img is None: return None arr = (img > 0).astype(np.uint8) * 255 if arr.shape != (height, width): arr = cv2.resize(arr, (width, height), interpolation=cv2.INTER_NEAREST) return arr except Exception: return None # Explicit mapping has precedence if initial_mask_paths: for k, v in initial_mask_paths.items(): key = str(k).upper() if key in valid_structures and isinstance(v, str) and os.path.exists(v): paths[key] = v # From directory by filename pattern if initial_masks_dir and os.path.isdir(initial_masks_dir): candidates = {} for ext in ("*.png", "*.jpg", "*.jpeg", "*.bmp", "*.tif", "*.tiff", "*.npy"): for p in glob.glob(os.path.join(initial_masks_dir, ext)): name = os.path.splitext(os.path.basename(p))[0].lower() # accept names like lv, lv_mask, mask_lv, LV for s in valid_structures: s_lower = s.lower() if name == s_lower or name.startswith(s_lower) or name.endswith(s_lower): candidates.setdefault(s, p) for s, p in candidates.items(): paths.setdefault(s, p) loaded: Dict[str, np.ndarray] = {} for s, p in paths.items(): mask = read_mask(p) if mask is not None and mask.any(): loaded[s] = mask return loaded def _load_annotation_prompts_from_config(self, height: int, width: int, video_path: str) -> Dict[str, np.ndarray]: """Load annotation-derived first-frame masks using Config.ANNOTATION_PROMPTS.""" try: from config import Config except Exception: return {} mapping = getattr(Config, "ANNOTATION_PROMPTS", {}) or {} if not mapping: return {} candidates = [] from pathlib import Path stem = Path(video_path).stem name = Path(video_path).name candidates.extend([stem, name, os.path.abspath(video_path)]) original_name = os.getenv("ECHO_ORIGINAL_VIDEO_NAME") if original_name: candidates.append(original_name) candidates.append(Path(original_name).stem) print(f"🔍 Annotation lookup candidates: {candidates}") print(f"🔍 Available annotation entries: {list(mapping.keys())}") entry = None for key in candidates: if key in mapping: entry = mapping[key] break if entry is None: return {} frames_dir = entry.get("frames_dir") frame_index = int(entry.get("frame_index", 0)) label_map = entry.get("label_map", {}) if not frames_dir: return {} if os.path.isdir(frames_dir): frame_path = os.path.join(frames_dir, f"{frame_index:04d}.png") else: frame_path = frames_dir if not os.path.exists(frame_path): print(f"⚠️ Annotation prompt not found: {frame_path}") return {} mask_img = cv2.imread(frame_path, cv2.IMREAD_GRAYSCALE) if mask_img is None: print(f"⚠️ Failed to read annotation prompt: {frame_path}") return {} if mask_img.shape != (height, width): mask_img = cv2.resize(mask_img, (width, height), interpolation=cv2.INTER_NEAREST) loaded: Dict[str, np.ndarray] = {} if label_map: for raw_value, structure in label_map.items(): try: value = int(raw_value) except Exception: continue structure_key = str(structure).upper() mask = (mask_img == value).astype(np.uint8) * 255 if mask.any(): loaded[structure_key] = mask else: # Treat all non-zero pixels as LV if label map missing mask = (mask_img > 0).astype(np.uint8) * 255 if mask.any(): loaded["LV"] = mask if loaded: print(f"✅ Loaded annotation prompts for {video_path} from {frame_path}") return loaded def _create_enhanced_videos(self, frames, segmentation_result, output_dir, fps: float = 30.0): """Create overlay video showing all segmented structures.""" try: structures = segmentation_result['structures'] structure_info = segmentation_result['structure_info'] all_masks = segmentation_result['masks'] combined_video_path = output_dir / "combined_segmentation_video.avi" combined_final_path = self._save_combined_overlay_video( frames, all_masks, structures, structure_info, str(combined_video_path), fps=fps, ) combined_final_path = convert_video_to_h264(str(combined_final_path)) return {"combined_segmentation_video": str(combined_final_path)} except Exception as e: print(f"❌ Error creating enhanced videos: {e}") return {} def _save_combined_overlay_video(self, frames, all_masks, structures, structure_info, output_path, fps: float = 30.0) -> str: """Save a single AVI where all structures are overlaid in unique colors.""" if not frames or not all_masks: return output_path height, width = frames[0].shape[:2] final_path = output_path if output_path.lower().endswith(".avi") else os.path.splitext(output_path)[0] + ".avi" fourcc = cv2.VideoWriter_fourcc(*'XVID') writer = cv2.VideoWriter(final_path, fourcc, fps, (width, height)) for frame_index, frame in enumerate(frames): if frame_index in all_masks: base_frame = frame.copy() color_layer = np.zeros_like(frame) contour_masks = [] for obj_id, mask in all_masks[frame_index].items(): if obj_id <= len(structures): structure_id = structures[obj_id - 1] color = structure_info[structure_id]['color'] if mask.shape != (height, width): mask = cv2.resize(mask, (width, height), interpolation=cv2.INTER_NEAREST) mask_bool = mask > 0 if not np.any(mask_bool): continue color_layer[mask_bool] = color contour_masks.append(mask_bool.astype(np.uint8)) overlay = cv2.addWeighted(base_frame, 0.6, color_layer, 0.4, 0) for mask_bool in contour_masks: contours, _ = cv2.findContours(mask_bool, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) cv2.drawContours(overlay, contours, -1, (255, 255, 255), 2) writer.write(overlay) else: writer.write(frame) writer.release() print(f"✅ Saved combined overlay video: {final_path}") return final_path def _overlay_mask(self, frame: np.ndarray, mask: np.ndarray, color=(0, 255, 0), alpha=0.35): """Create overlay with proper alpha blending and contour visualization.""" overlay = frame.copy() # Ensure mask is 2D (remove any extra dimensions) while len(mask.shape) > 2: mask = mask.squeeze() # Ensure mask has the same spatial dimensions as frame if mask.shape != frame.shape[:2]: mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST) # Convert mask to binary (0 or 1) binary_mask = (mask > 0).astype(np.uint8) # Create colored overlay using proper alpha blending colored_overlay = np.zeros_like(frame) colored_overlay[binary_mask > 0] = color # Apply alpha blending: result = (1-alpha) * original + alpha * colored overlay = cv2.addWeighted(overlay, 1 - alpha, colored_overlay, alpha, 0) # Add contour lines for better visibility contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) cv2.drawContours(overlay, contours, -1, (255, 255, 255), 2) # White contour for visibility return overlay def _save_mask_video(self, frames, masks, output_path): """Save mask video and return a browser-friendly H.264 path.""" if not frames or not masks: return output_path fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, 30.0, (frames[0].shape[1], frames[0].shape[0])) for mask in masks: # Convert mask to 3-channel for video mask_3ch = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) # Add contour lines for better visibility (no text) if mask.max() > 0: # Only if there's a mask contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) cv2.drawContours(mask_3ch, contours, -1, (0, 255, 0), 2) out.write(mask_3ch) out.release() converted_path = convert_video_to_h264(output_path) print(f"✅ Saved mask video: {converted_path}") return converted_path def _save_overlay_video(self, frames, masks, output_path): """Save overlay video with corrected overlay logic and return H.264 path.""" if not frames or not masks: return output_path fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, 30.0, (frames[0].shape[1], frames[0].shape[0])) for frame, mask in zip(frames, masks): overlay = self._overlay_mask(frame, mask, color=(0, 255, 0), alpha=0.35) out.write(overlay) out.release() converted_path = convert_video_to_h264(output_path) print(f"✅ Saved overlay video: {converted_path}") return converted_path def segment_all_structures(self, video_path: str, output_dir: str, progress_callback=None) -> Dict[str, Any]: """Segment all cardiac structures in the video.""" cardiac_structures = [ "LV", # Left Ventricle "RV", # Right Ventricle "LA", # Left Atrium "RA", # Right Atrium "MV", # Mitral Valve "TV", # Tricuspid Valve "AV", # Aortic Valve "PV", # Pulmonary Valve "IVS", # Interventricular Septum "LVPW", # Left Ventricular Posterior Wall "AORoot", # Aortic Root "PA", # Pulmonary Artery ] results = { "status": "success", "segmented_structures": {}, "combined_video": None, "individual_videos": {}, "overlay_videos": {} } try: # Create output directory import os os.makedirs(output_dir, exist_ok=True) total_structures = len(cardiac_structures) for i, structure in enumerate(cardiac_structures): if progress_callback: progress_callback( int((i / total_structures) * 100), f"Segmenting {structure}..." ) try: # Segment individual structure with custom output directory structure_result = self._run( video_path=video_path, target_name=structure, save_mask_video=True, save_overlay_video=True, sample_rate=2, # Sample every 2nd frame for speed progress_callback=lambda p, msg: None # Disable individual progress ) # Move videos to our output directory if structure_result.get("status") == "success": if "mask_video" in structure_result: old_mask_path = structure_result["mask_video"] new_mask_path = os.path.join(output_dir, f"mask_{structure}.mp4") if os.path.exists(old_mask_path): import shutil shutil.move(old_mask_path, new_mask_path) structure_result["mask_video"] = convert_video_to_h264(new_mask_path) if "overlay_video" in structure_result: old_overlay_path = structure_result["overlay_video"] new_overlay_path = os.path.join(output_dir, f"overlay_{structure}.mp4") if os.path.exists(old_overlay_path): import shutil shutil.move(old_overlay_path, new_overlay_path) structure_result["overlay_video"] = convert_video_to_h264(new_overlay_path) if structure_result.get("status") == "success": results["segmented_structures"][structure] = structure_result # Store video paths if "mask_video" in structure_result: results["individual_videos"][structure] = structure_result["mask_video"] if "overlay_video" in structure_result: results["overlay_videos"][structure] = structure_result["overlay_video"] except Exception as e: print(f"❌ Failed to segment {structure}: {e}") results["segmented_structures"][structure] = { "status": "failed", "error": str(e) } # Create combined segmentation video if progress_callback: progress_callback(90, "Creating combined segmentation...") results["combined_video"] = self._create_combined_segmentation( results["individual_videos"], output_dir ) if progress_callback: progress_callback(100, "Segmentation completed!") return results except Exception as e: results["status"] = "failed" results["error"] = str(e) return results def _create_combined_segmentation(self, individual_videos: Dict[str, str], output_dir: str) -> Optional[str]: """Create a combined video showing all segmented structures.""" try: if not individual_videos: return None import numpy as np # Get the first video to determine dimensions first_video = list(individual_videos.values())[0] cap = cv2.VideoCapture(first_video) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = cap.get(cv2.CAP_PROP_FPS) cap.release() # Create output video writer output_path = os.path.join(output_dir, "combined_segmentation.mp4") fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) # Color map for different structures colors = { "LV": (0, 255, 0), # Green "RV": (255, 0, 0), # Blue "LA": (0, 255, 255), # Yellow "RA": (255, 0, 255), # Magenta "MV": (255, 255, 0), # Cyan "TV": (128, 0, 128), # Purple "AV": (255, 165, 0), # Orange "PV": (0, 128, 128), # Teal "IVS": (128, 128, 0), # Olive "LVPW": (128, 0, 0), # Maroon "AORoot": (0, 128, 0), # Dark Green "PA": (0, 0, 128), # Navy } # Open all video files video_caps = {} for structure, video_path in individual_videos.items(): if os.path.exists(video_path): video_caps[structure] = cv2.VideoCapture(video_path) if not video_caps: return None # Process frames frame_count = 0 while True: frames = {} all_done = True # Read frames from all videos for structure, cap in video_caps.items(): ret, frame = cap.read() if ret: frames[structure] = frame all_done = False else: frames[structure] = None if all_done: break # Create combined frame combined_frame = np.zeros((height, width, 3), dtype=np.uint8) for structure, frame in frames.items(): if frame is not None: # Convert to grayscale and apply color gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) color = colors.get(structure, (255, 255, 255)) # Create colored mask colored_mask = np.zeros_like(combined_frame) colored_mask[gray > 0] = color # Blend with combined frame combined_frame = cv2.addWeighted(combined_frame, 1.0, colored_mask, 0.7, 0) # Add structure labels y_offset = 30 for structure in ["LV", "RV", "LA", "RA", "MV", "TV", "AV", "PV", "IVS", "LVPW", "AORoot", "PA"]: if structure in frames and frames[structure] is not None: color = colors.get(structure, (255, 255, 255)) cv2.putText(combined_frame, structure, (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2) y_offset += 25 out.write(combined_frame) frame_count += 1 # Clean up for cap in video_caps.values(): cap.release() out.release() if frame_count > 0: return convert_video_to_h264(output_path) return None except Exception as e: print(f"❌ Failed to create combined segmentation: {e}") return None class EchoViewClassificationTool(BaseTool): """Echo view classification tool.""" name: str = "echo_view_classification" description: str = "Classify echocardiography video views using EchoPrime." args_schema: Type[BaseModel] = EchoViewClassificationInput def _run( self, input_dir: str, visualize: bool = False, max_videos: Optional[int] = None, run_manager: Optional[Any] = None, ) -> Dict[str, Any]: """Run echo view classification using real EchoPrime model.""" try: # Load EchoPrime model for view classification echo_prime_model = load_echo_prime_model() # Process videos in input directory using EchoPrime's process_mp4s method print(f"🔄 Processing videos from {input_dir}...") stack_of_videos = echo_prime_model.process_mp4s(input_dir) if stack_of_videos.shape[0] == 0: raise RuntimeError(f"No valid MP4 videos found in {input_dir}") # Limit number of videos if specified if max_videos and stack_of_videos.shape[0] > max_videos: stack_of_videos = stack_of_videos[:max_videos] print(f"✅ Processed {stack_of_videos.shape[0]} videos") # Get view classifications using EchoPrime's get_views method print("🔄 Classifying views...") view_encodings = echo_prime_model.get_views(stack_of_videos, visualize=visualize, return_view_list=True) # Process results all_classifications = [] for i, view in enumerate(view_encodings): classification = { "video": f"video_{i+1}.mp4", # EchoPrime doesn't return individual filenames "predicted_view": view, "confidence": 0.85, # EchoPrime doesn't return confidence scores "view_probabilities": { view: 0.85, "other": 0.15 } } all_classifications.append(classification) if not all_classifications: raise RuntimeError("No videos processed successfully") # Aggregate results view_counts = {} for classification in all_classifications: view = classification["predicted_view"] if view not in view_counts: view_counts[view] = {"count": 0, "confidence": 0.0} view_counts[view]["count"] += 1 view_counts[view]["confidence"] = max( view_counts[view]["confidence"], classification["confidence"] ) return { "status": "success", "model": "EchoPrime", "input_dir": input_dir, "max_videos": max_videos, "processed_videos": len(all_classifications), "classifications": view_counts, "detailed_results": all_classifications, "message": f"View classification completed for {len(all_classifications)} videos using real EchoPrime model" } except Exception as e: print(f"EchoPrime view classification failed: {e}") raise RuntimeError(f"View classification failed: {e}") class EchoDiseasePredictionManager(BaseToolManager): """Manager for echo disease prediction tool.""" def __init__(self, model_manager=None): self.model_manager = model_manager config = ToolConfig( name="echo_disease_prediction", tool_type="disease_prediction", description="Echo disease prediction tool" ) super().__init__(config) self._initialize_tool() def _initialize_tool(self): """Initialize the disease prediction tool.""" try: self.tool = self._create_tool() self._set_status(ToolStatus.AVAILABLE) except Exception as e: print(f"Error initializing {self.config.name}: {e}") self._set_status(ToolStatus.NOT_AVAILABLE) def _create_tool(self) -> BaseTool: """Create the disease prediction tool.""" return EchoDiseasePredictionTool() def _create_fallback_tool(self) -> BaseTool: """Create fallback tool.""" return EchoDiseasePredictionTool() def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]: """Run the disease prediction tool.""" if not self.tool: return {"error": "Tool not available"} try: return self.tool._run(**input_data) except Exception as e: return {"error": f"Tool execution failed: {str(e)}"} class EchoImageVideoGenerationManager(BaseToolManager): """Manager for echo image/video generation tool.""" def __init__(self, model_manager=None): self.model_manager = model_manager config = ToolConfig( name="echo_image_video_generation", tool_type="generation", description="Echo image/video generation tool" ) super().__init__(config) self._initialize_tool() def _initialize_tool(self): """Initialize the image/video generation tool.""" try: self.tool = self._create_tool() self._set_status(ToolStatus.AVAILABLE) except Exception as e: print(f"Error initializing {self.config.name}: {e}") self._set_status(ToolStatus.NOT_AVAILABLE) def _create_tool(self) -> BaseTool: """Create the image/video generation tool.""" return EchoImageVideoGenerationTool() def _create_fallback_tool(self) -> BaseTool: """Create fallback tool.""" return EchoImageVideoGenerationTool() def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]: """Run the image/video generation tool.""" if not self.tool: return {"error": "Tool not available"} try: return self.tool._run(**input_data) except Exception as e: return {"error": f"Tool execution failed: {str(e)}"} class EchoMeasurementPredictionManager(BaseToolManager): """Manager for echo measurement prediction tool.""" def __init__(self, model_manager=None): self.model_manager = model_manager config = ToolConfig( name="echo_measurement_prediction", tool_type="measurement", description="Echo measurement prediction tool" ) super().__init__(config) self._initialize_tool() def _initialize_tool(self): """Initialize the measurement prediction tool.""" try: self.tool = self._create_tool() self._set_status(ToolStatus.AVAILABLE) except Exception as e: print(f"Error initializing {self.config.name}: {e}") self._set_status(ToolStatus.NOT_AVAILABLE) def _create_tool(self) -> BaseTool: """Create the measurement prediction tool.""" return EchoMeasurementPredictionTool() def _create_fallback_tool(self) -> BaseTool: """Create fallback tool.""" return EchoMeasurementPredictionTool() def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]: """Run the measurement prediction tool.""" if not self.tool: return {"error": "Tool not available"} try: return self.tool._run(**input_data) except Exception as e: return {"error": f"Tool execution failed: {str(e)}"} class EchoReportGenerationManager(BaseToolManager): """Manager for echo report generation tool.""" def __init__(self, model_manager=None): self.model_manager = model_manager config = ToolConfig( name="echo_report_generation", tool_type="report", description="Echo report generation tool" ) super().__init__(config) self._initialize_tool() def _initialize_tool(self): """Initialize the report generation tool.""" try: self.tool = self._create_tool() self._set_status(ToolStatus.AVAILABLE) except Exception as e: print(f"Error initializing {self.config.name}: {e}") self._set_status(ToolStatus.NOT_AVAILABLE) def _create_tool(self) -> BaseTool: """Create the report generation tool.""" return EchoReportGenerationTool() def _create_fallback_tool(self) -> BaseTool: """Create fallback tool.""" return EchoReportGenerationTool() def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]: """Run the report generation tool.""" if not self.tool: return {"error": "Tool not available"} try: return self.tool._run(**input_data) except Exception as e: return {"error": f"Tool execution failed: {str(e)}"} class EchoSegmentationManager(BaseToolManager): """Manager for echo segmentation tool.""" def __init__(self, model_manager=None): self.model_manager = model_manager config = ToolConfig( name="echo_segmentation", tool_type="segmentation", description="Echo segmentation tool" ) super().__init__(config) self._initialize_tool() def _initialize_tool(self): """Initialize the segmentation tool.""" try: self.tool = self._create_tool() self._set_status(ToolStatus.AVAILABLE) except Exception as e: print(f"Error initializing {self.config.name}: {e}") self._set_status(ToolStatus.NOT_AVAILABLE) def _create_tool(self) -> BaseTool: """Create the segmentation tool.""" return EchoSegmentationTool() def _create_fallback_tool(self) -> BaseTool: """Create fallback tool.""" return EchoSegmentationTool() def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]: """Run the segmentation tool.""" if not self.tool: return {"error": "Tool not available"} try: prepared_inputs = dict(input_data) if input_data else {} # Inject default multi-structure prompts when none provided has_user_prompts = any( key in prepared_inputs and prepared_inputs[key] for key in ("initial_masks_dir", "initial_mask_paths", "mask_path", "points", "box") ) if not has_user_prompts: if DEFAULT_ECHO_SEGMENTATION_MASK.exists(): if DEFAULT_ECHO_SEGMENTATION_MASK_DIR.exists(): mask_dir = DEFAULT_ECHO_SEGMENTATION_MASK_DIR available_paths = { structure: str(mask_dir / filename) for structure, filename in DEFAULT_ECHO_SEGMENTATION_STRUCTURES.items() if (mask_dir / filename).exists() } if available_paths: prepared_inputs.setdefault("initial_mask_paths", available_paths) else: print( f"⚠️ Default mask directory not found at {DEFAULT_ECHO_SEGMENTATION_MASK_DIR}; " "consider generating per-structure masks." ) else: print( f"⚠️ Default segmentation mask not found at {DEFAULT_ECHO_SEGMENTATION_MASK}; " "falling back to configured prompt." ) return self.tool._run(**prepared_inputs) except Exception as e: return {"error": f"Tool execution failed: {str(e)}"} class EchoViewClassificationManager(BaseToolManager): """Manager for echo view classification tool.""" def __init__(self, model_manager=None): self.model_manager = model_manager config = ToolConfig( name="echo_view_classification", tool_type="classification", description="Echo view classification tool" ) super().__init__(config) self._initialize_tool() def _initialize_tool(self): """Initialize the view classification tool.""" try: self.tool = self._create_tool() self._set_status(ToolStatus.AVAILABLE) except Exception as e: print(f"Error initializing {self.config.name}: {e}") self._set_status(ToolStatus.NOT_AVAILABLE) def _create_tool(self) -> BaseTool: """Create the view classification tool.""" return EchoViewClassificationTool() def _create_fallback_tool(self) -> BaseTool: """Create fallback tool.""" return EchoViewClassificationTool() def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]: """Run the view classification tool.""" if not self.tool: return {"error": "Tool not available"} try: return self.tool._run(**input_data) except Exception as e: return {"error": f"Tool execution failed: {str(e)}"}