|
|
""" |
|
|
Echo Tool Managers |
|
|
|
|
|
This module provides tool manager classes for various echo tools. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import shutil |
|
|
import zipfile |
|
|
import urllib.request |
|
|
from typing import Dict, List, Any, Optional, Type, Tuple |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
import cv2 |
|
|
|
|
|
from utils.video_utils import convert_video_to_h264 |
|
|
|
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) |
|
|
|
|
|
from pydantic import BaseModel, Field |
|
|
from langchain_core.tools import BaseTool |
|
|
from tools.general.base_tool_manager import BaseToolManager, ToolConfig, ToolStatus |
|
|
|
|
|
|
|
|
|
|
|
_model_cache = {} |
|
|
|
|
|
_THIS_FILE = Path(__file__).resolve() |
|
|
_TOOL_REPO_BASES = [ |
|
|
_THIS_FILE.parents[2] / "tool_repos", |
|
|
_THIS_FILE.parents[3] / "tool_repos", |
|
|
] |
|
|
workspace_root_env = os.getenv("ECHO_WORKSPACE_ROOT") |
|
|
if workspace_root_env: |
|
|
_TOOL_REPO_BASES.append(Path(workspace_root_env) / "tool_repos") |
|
|
|
|
|
|
|
|
_unique_tool_repo_bases: List[Path] = [] |
|
|
for base_path in _TOOL_REPO_BASES: |
|
|
if base_path not in _unique_tool_repo_bases: |
|
|
_unique_tool_repo_bases.append(base_path) |
|
|
_TOOL_REPO_BASES = _unique_tool_repo_bases |
|
|
|
|
|
|
|
|
def _resolve_tool_repo(repo_names: List[str]) -> Path: |
|
|
"""Return the first existing path for the given tool repo names.""" |
|
|
for repo_name in repo_names: |
|
|
for base in _TOOL_REPO_BASES: |
|
|
candidate = base / repo_name |
|
|
if candidate.exists(): |
|
|
return candidate |
|
|
primary_base = _TOOL_REPO_BASES[0] if _TOOL_REPO_BASES else Path.cwd() |
|
|
return primary_base / repo_names[0] |
|
|
|
|
|
|
|
|
MEDSAM_REPO_ROOT = _resolve_tool_repo(["MedSAM2-main", "MedSAM2"]) |
|
|
ECHOPRIME_REPO_ROOT = _resolve_tool_repo(["EchoPrime-main", "EchoPrime"]) |
|
|
|
|
|
ECHO_PRIME_RELEASE_BASE = "https://github.com/echonet/EchoPrime/releases/download/v1.0.0" |
|
|
ECHO_PRIME_MODEL_ZIP_URL = f"{ECHO_PRIME_RELEASE_BASE}/model_data.zip" |
|
|
ECHO_PRIME_EMBEDDING_FILES = { |
|
|
"candidate_embeddings_p1.pt": f"{ECHO_PRIME_RELEASE_BASE}/candidate_embeddings_p1.pt", |
|
|
"candidate_embeddings_p2.pt": f"{ECHO_PRIME_RELEASE_BASE}/candidate_embeddings_p2.pt", |
|
|
} |
|
|
|
|
|
DEFAULT_ECHO_SEGMENTATION_MASK = MEDSAM_REPO_ROOT / "0108.png" |
|
|
DEFAULT_ECHO_SEGMENTATION_MASK_DIR = MEDSAM_REPO_ROOT / "default_masks" |
|
|
DEFAULT_ECHO_SEGMENTATION_STRUCTURES = { |
|
|
"LV": "LV.png", |
|
|
"MYO": "MYO.png", |
|
|
"LA": "LA.png", |
|
|
"RV": "RV.png", |
|
|
"RA": "RA.png", |
|
|
} |
|
|
DEFAULT_ECHO_SEGMENTATION_CHECKPOINT = MEDSAM_REPO_ROOT / "checkpoints" / "MedSAM2_US_Heart.pt" |
|
|
|
|
|
|
|
|
def _download_file(url: str, destination: Path) -> bool: |
|
|
"""Download a file from a URL to the destination path.""" |
|
|
try: |
|
|
destination.parent.mkdir(parents=True, exist_ok=True) |
|
|
print(f"⬇️ Downloading {url} -> {destination}") |
|
|
request = urllib.request.Request(url, headers={"User-Agent": "EchoAgent/1.0"}) |
|
|
with urllib.request.urlopen(request) as response, open(destination, "wb") as output_file: |
|
|
shutil.copyfileobj(response, output_file) |
|
|
print(f"✅ Downloaded {destination.name}") |
|
|
return True |
|
|
except Exception as download_error: |
|
|
print(f"❌ Failed to download {url}: {download_error}") |
|
|
if destination.exists(): |
|
|
destination.unlink() |
|
|
return False |
|
|
|
|
|
|
|
|
def ensure_echoprime_assets(echo_prime_path: Path) -> bool: |
|
|
"""Ensure required EchoPrime assets are available, downloading when missing.""" |
|
|
model_data_dir = echo_prime_path / "model_data" |
|
|
weights_dir = model_data_dir / "weights" |
|
|
candidates_dir = model_data_dir / "candidates_data" |
|
|
|
|
|
required_files = [ |
|
|
weights_dir / "echo_prime_encoder.pt", |
|
|
weights_dir / "view_classifier.pt", |
|
|
candidates_dir / "candidate_embeddings_p1.pt", |
|
|
candidates_dir / "candidate_embeddings_p2.pt", |
|
|
] |
|
|
|
|
|
if all(required_path.exists() for required_path in required_files): |
|
|
return True |
|
|
|
|
|
print("⚠️ EchoPrime assets missing; attempting automatic download...") |
|
|
|
|
|
if not all((weights_dir / filename).exists() for filename in ("echo_prime_encoder.pt", "view_classifier.pt")): |
|
|
temp_zip_path = echo_prime_path / "model_data.zip" |
|
|
if _download_file(ECHO_PRIME_MODEL_ZIP_URL, temp_zip_path): |
|
|
try: |
|
|
with zipfile.ZipFile(temp_zip_path, "r") as zip_ref: |
|
|
zip_ref.extractall(echo_prime_path) |
|
|
print("✅ Extracted model_data.zip") |
|
|
except zipfile.BadZipFile as zip_error: |
|
|
print(f"❌ model_data.zip appears corrupted: {zip_error}") |
|
|
finally: |
|
|
temp_zip_path.unlink(missing_ok=True) |
|
|
|
|
|
|
|
|
for filename, url in ECHO_PRIME_EMBEDDING_FILES.items(): |
|
|
destination = candidates_dir / filename |
|
|
if not destination.exists(): |
|
|
_download_file(url, destination) |
|
|
|
|
|
|
|
|
all_present = all(required_path.exists() for required_path in required_files) |
|
|
if not all_present: |
|
|
print("❌ Required EchoPrime assets are still missing after download attempts.") |
|
|
return all_present |
|
|
|
|
|
|
|
|
def load_panecho_model(): |
|
|
"""Load PanEcho model for real predictions with caching.""" |
|
|
if "panecho" in _model_cache: |
|
|
print("✅ Using cached PanEcho model") |
|
|
return _model_cache["panecho"] |
|
|
|
|
|
try: |
|
|
from models.model_factory import get_model |
|
|
print("🔄 Loading PanEcho model...") |
|
|
model = get_model("panecho") |
|
|
if model is None: |
|
|
raise RuntimeError("PanEcho model not available") |
|
|
|
|
|
|
|
|
_model_cache["panecho"] = model |
|
|
print("✅ PanEcho model loaded and cached") |
|
|
return model |
|
|
except Exception as e: |
|
|
print(f"PanEcho loading failed: {e}") |
|
|
raise RuntimeError(f"PanEcho model not available: {e}") |
|
|
|
|
|
def load_medsam2_model(): |
|
|
"""Load MedSAM2 model for segmentation with caching.""" |
|
|
if "medsam2" in _model_cache: |
|
|
print("✅ Using cached MedSAM2 model path") |
|
|
return _model_cache["medsam2"] |
|
|
|
|
|
try: |
|
|
from models.model_factory import get_model |
|
|
print("🔄 Loading MedSAM2 model...") |
|
|
model_path = get_model("medsam2") |
|
|
if model_path is None: |
|
|
raise RuntimeError("MedSAM2 model not available") |
|
|
|
|
|
|
|
|
_model_cache["medsam2"] = model_path |
|
|
print(f"✅ MedSAM2 model loaded and cached: {model_path}") |
|
|
return model_path |
|
|
except Exception as e: |
|
|
print(f"MedSAM2 loading failed: {e}") |
|
|
raise RuntimeError(f"MedSAM2 model not available: {e}") |
|
|
|
|
|
def load_echoflow_model(): |
|
|
"""Load EchoFlow model for generation with caching.""" |
|
|
if "echoflow" in _model_cache: |
|
|
print("✅ Using cached EchoFlow model") |
|
|
return _model_cache["echoflow"] |
|
|
|
|
|
try: |
|
|
from models.model_factory import get_model |
|
|
print("🔄 Loading EchoFlow model...") |
|
|
model = get_model("echoflow") |
|
|
if model is None: |
|
|
raise RuntimeError("EchoFlow model not available") |
|
|
|
|
|
|
|
|
_model_cache["echoflow"] = model |
|
|
print("✅ EchoFlow model loaded and cached") |
|
|
return model |
|
|
except Exception as e: |
|
|
print(f"EchoFlow loading failed: {e}") |
|
|
raise RuntimeError(f"EchoFlow model not available: {e}") |
|
|
|
|
|
def clear_model_cache(): |
|
|
"""Clear the model cache to free memory.""" |
|
|
global _model_cache |
|
|
_model_cache.clear() |
|
|
print("🧹 Model cache cleared") |
|
|
|
|
|
def load_echo_prime_model(): |
|
|
"""Load EchoPrime model for comprehensive analysis with caching.""" |
|
|
if "echo_prime" in _model_cache: |
|
|
print("✅ Using cached EchoPrime model") |
|
|
return _model_cache["echo_prime"] |
|
|
|
|
|
try: |
|
|
|
|
|
echo_prime_path = ECHOPRIME_REPO_ROOT |
|
|
if not echo_prime_path.exists(): |
|
|
print(f"❌ EchoPrime directory not found: {echo_prime_path}") |
|
|
return None |
|
|
|
|
|
if not ensure_echoprime_assets(echo_prime_path): |
|
|
print("❌ EchoPrime assets unavailable; cannot initialize model.") |
|
|
return None |
|
|
|
|
|
|
|
|
if str(echo_prime_path) not in sys.path: |
|
|
sys.path.insert(0, str(echo_prime_path)) |
|
|
|
|
|
|
|
|
from echo_prime.model import EchoPrime |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
echo_prime_model = EchoPrime(device=device) |
|
|
|
|
|
|
|
|
_model_cache["echo_prime"] = echo_prime_model |
|
|
|
|
|
print("✅ EchoPrime model loaded successfully") |
|
|
return echo_prime_model |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Failed to load EchoPrime model: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return None |
|
|
|
|
|
|
|
|
class EchoDiseasePredictionInput(BaseModel): |
|
|
"""Input schema for echo disease prediction.""" |
|
|
input_dir: str = Field(..., description="Directory containing echo videos") |
|
|
max_videos: Optional[int] = Field(None, description="Maximum number of videos to process") |
|
|
save_csv: bool = Field(True, description="Save results to CSV file") |
|
|
include_confidence: bool = Field(True, description="Include confidence scores in output") |
|
|
|
|
|
|
|
|
class EchoImageVideoGenerationInput(BaseModel): |
|
|
"""Input schema for echo image/video generation.""" |
|
|
views: List[str] = Field(..., description="List of echo views to generate") |
|
|
efs: List[float] = Field(..., description="List of ejection fractions") |
|
|
outdir: Optional[str] = Field(None, description="Output directory") |
|
|
num_samples: int = Field(10, description="Number of samples to generate") |
|
|
|
|
|
|
|
|
class EchoMeasurementPredictionInput(BaseModel): |
|
|
"""Input schema for echo measurement prediction.""" |
|
|
input_dir: str = Field(..., description="Directory containing echo videos") |
|
|
max_videos: Optional[int] = Field(None, description="Maximum number of videos to process") |
|
|
include_report: bool = Field(True, description="Include detailed report") |
|
|
save_csv: bool = Field(True, description="Save measurements to CSV") |
|
|
|
|
|
|
|
|
class EchoReportGenerationInput(BaseModel): |
|
|
"""Input schema for echo report generation.""" |
|
|
input_dir: str = Field(..., description="Directory containing echo videos") |
|
|
visualize_views: bool = Field(False, description="Generate view visualizations") |
|
|
max_videos: Optional[int] = Field(None, description="Maximum number of videos to process") |
|
|
include_sections: bool = Field(True, description="Include all report sections") |
|
|
|
|
|
|
|
|
class EchoSegmentationInput(BaseModel): |
|
|
"""Input schema for echo segmentation.""" |
|
|
|
|
|
video_path: str = Field(..., description="Path to echo video file") |
|
|
prompt_mode: str = Field("auto", description="Prompt mode for segmentation (auto, points, box, mask)") |
|
|
target_name: str = Field( |
|
|
"all", |
|
|
description="Target structure name (all, LV, RV, LA, RA, MV, TV, AV, PV, IVS, LVPW, AORoot, PA)", |
|
|
) |
|
|
save_mask_video: bool = Field(True, description="Save mask video") |
|
|
save_overlay_video: bool = Field(True, description="Save overlay video") |
|
|
points: Optional[List[List[float]]] = Field( |
|
|
None, |
|
|
description="List of [x, y, label] triples in normalized coordinates for the first frame (label: 1 foreground, 0 background)", |
|
|
) |
|
|
box: Optional[List[float]] = Field( |
|
|
None, |
|
|
description="Normalized box as [x1, y1, x2, y2] for the first frame", |
|
|
) |
|
|
mask_path: Optional[str] = Field( |
|
|
None, |
|
|
description="Path to an initial segmentation mask for the first frame (for 'mask' mode)", |
|
|
) |
|
|
sample_rate: int = Field(1, description="Process every Nth frame for speed (1 = every frame)") |
|
|
output_fps: Optional[int] = Field(None, description="FPS for output video. Defaults to source FPS") |
|
|
progress_callback: Optional[Any] = Field(None, description="Progress callback function for UI updates") |
|
|
|
|
|
initial_masks_dir: Optional[str] = Field( |
|
|
None, |
|
|
description="Directory containing first-frame masks per structure (e.g., LV.png, RV.png)", |
|
|
) |
|
|
initial_mask_paths: Optional[Dict[str, str]] = Field( |
|
|
None, |
|
|
description="Mapping of structure code (e.g., 'LV') to first-frame mask file path", |
|
|
) |
|
|
initial_mask_frame_idx: int = Field(0, description="Frame index the provided masks correspond to (default 0)") |
|
|
use_auto_masks_if_missing: bool = Field( |
|
|
True, |
|
|
description="If a structure mask is missing, fall back to auto coarse prompt", |
|
|
) |
|
|
|
|
|
|
|
|
class EchoViewClassificationInput(BaseModel): |
|
|
"""Input schema for echo view classification.""" |
|
|
input_dir: str = Field(..., description="Directory containing echo videos") |
|
|
visualize: bool = Field(False, description="Generate visualizations") |
|
|
max_videos: Optional[int] = Field(None, description="Maximum number of videos to process") |
|
|
|
|
|
|
|
|
class EchoDiseasePredictionTool(BaseTool): |
|
|
"""Echo disease prediction tool.""" |
|
|
|
|
|
name: str = "echo_disease_prediction" |
|
|
description: str = "Predict cardiac diseases from echo videos using PanEcho." |
|
|
args_schema: Type[BaseModel] = EchoDiseasePredictionInput |
|
|
|
|
|
def _get_task_units(self, task_name: str) -> str: |
|
|
"""Get units for a specific task.""" |
|
|
units_map = { |
|
|
'EF': '%', |
|
|
'GLS': '%', |
|
|
'LVEDV': 'cm³', |
|
|
'LVESV': 'cm³', |
|
|
'LVSV': 'cm³', |
|
|
'IVSd': 'cm', |
|
|
'LVPWd': 'cm', |
|
|
'LVIDs': 'cm', |
|
|
'LVIDd': 'cm', |
|
|
'LVOTDiam': 'cm', |
|
|
'E|EAvg': 'ratio', |
|
|
'RVSP': 'mmHg', |
|
|
'RVIDd': 'cm', |
|
|
'TAPSE': 'cm', |
|
|
'RVSVel': 'cm/s', |
|
|
'LAIDs2D': 'cm', |
|
|
'LAVol': 'cm³', |
|
|
'RADimensionM-L(cm)': 'cm', |
|
|
'AVPkVel(m/s)': 'm/s', |
|
|
'TVPkGrad': 'mmHg', |
|
|
'AORoot': 'cm' |
|
|
} |
|
|
return units_map.get(task_name, 'N/A') |
|
|
|
|
|
def _get_class_names(self, task_name: str) -> list: |
|
|
"""Get class names for classification tasks.""" |
|
|
class_names_map = { |
|
|
'LVSize': ['Mildly Increased', 'Moderately|Severely Increased', 'Normal'], |
|
|
'LVSystolicFunction': ['Mildly Decreased', 'Moderately|Severely Decreased', 'Normal|Hyperdynamic'], |
|
|
'LVDiastolicFunction': ['Mild|Indeterminate', 'Moderate|Severe', 'Normal'], |
|
|
'RVSize': ['Mildly Increased', 'Moderately|Severely Increased', 'Normal'], |
|
|
'LASize': ['Mildly Dilated', 'Moderately|Severely Dilated', 'Normal'], |
|
|
'AVStenosis': ['Mild|Moderate', 'None', 'Severe'], |
|
|
'AVRegurg': ['Mild', 'Moderate|Severe', 'None|Trace'], |
|
|
'MVRegurgitation': ['Mild', 'Moderate|Severe', 'None|Trace'], |
|
|
'TVRegurgitation': ['Mild', 'Moderate|Severe', 'None|Trace'] |
|
|
} |
|
|
return class_names_map.get(task_name, []) |
|
|
|
|
|
def _run( |
|
|
self, |
|
|
input_dir: str, |
|
|
max_videos: Optional[int] = None, |
|
|
save_csv: bool = True, |
|
|
include_confidence: bool = True, |
|
|
run_manager: Optional[Any] = None, |
|
|
) -> Dict[str, Any]: |
|
|
"""Run echo disease prediction using real PanEcho model.""" |
|
|
try: |
|
|
|
|
|
panecho_model = load_panecho_model() |
|
|
|
|
|
|
|
|
import os |
|
|
import glob |
|
|
video_files = glob.glob(os.path.join(input_dir, "*.mp4")) |
|
|
if max_videos: |
|
|
video_files = video_files[:max_videos] |
|
|
|
|
|
if not video_files: |
|
|
raise RuntimeError(f"No MP4 videos found in {input_dir}") |
|
|
|
|
|
all_predictions = [] |
|
|
|
|
|
for video_path in video_files: |
|
|
try: |
|
|
|
|
|
import torchvision.transforms as transforms |
|
|
|
|
|
cap = cv2.VideoCapture(video_path) |
|
|
frames = [] |
|
|
frame_count = 0 |
|
|
max_frames = 16 |
|
|
|
|
|
|
|
|
normalize = transforms.Normalize( |
|
|
mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225] |
|
|
) |
|
|
|
|
|
while len(frames) < max_frames and cap.isOpened(): |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
|
|
|
|
|
|
frame = cv2.resize(frame, (224, 224)) |
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
frames.append(frame) |
|
|
frame_count += 1 |
|
|
|
|
|
cap.release() |
|
|
|
|
|
if len(frames) < 16: |
|
|
|
|
|
while len(frames) < 16: |
|
|
frames.append(frames[-1]) |
|
|
|
|
|
|
|
|
frames_array = np.array(frames, dtype=np.float32) / 255.0 |
|
|
frames_tensor = torch.tensor(frames_array).permute(0, 3, 1, 2) |
|
|
frames_tensor = frames_tensor.unsqueeze(0) |
|
|
|
|
|
|
|
|
frames_tensor = normalize(frames_tensor.view(-1, 3, 224, 224)).view(1, 16, 3, 224, 224) |
|
|
|
|
|
|
|
|
frames_tensor = frames_tensor.permute(0, 2, 1, 3, 4) |
|
|
|
|
|
|
|
|
device = next(panecho_model.parameters()).device |
|
|
frames_tensor = frames_tensor.to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
predictions = panecho_model(frames_tensor) |
|
|
|
|
|
|
|
|
disease_predictions = {} |
|
|
|
|
|
|
|
|
task_descriptions = { |
|
|
'pericardial-effusion': 'Pericardial Effusion', |
|
|
'EF': 'Ejection Fraction (%)', |
|
|
'GLS': 'Global Longitudinal Strain (%)', |
|
|
'LVEDV': 'LV End-Diastolic Volume (cm³)', |
|
|
'LVESV': 'LV End-Systolic Volume (cm³)', |
|
|
'LVSV': 'LV Stroke Volume (cm³)', |
|
|
'LVSize': 'LV Size', |
|
|
'LVWallThickness-increased-any': 'LV Wall Thickness - Any Increase', |
|
|
'LVWallThickness-increased-modsev': 'LV Wall Thickness - Moderate/Severe Increase', |
|
|
'LVSystolicFunction': 'LV Systolic Function', |
|
|
'LVWallMotionAbnormalities': 'LV Wall Motion Abnormalities', |
|
|
'IVSd': 'Interventricular Septum Diastole (cm)', |
|
|
'LVPWd': 'LV Posterior Wall Diastole (cm)', |
|
|
'LVIDs': 'LV Internal Diameter Systole (cm)', |
|
|
'LVIDd': 'LV Internal Diameter Diastole (cm)', |
|
|
'LVOTDiam': 'LV Outflow Tract Diameter (cm)', |
|
|
'LVDiastolicFunction': 'LV Diastolic Function', |
|
|
'E|EAvg': 'E/e\' Ratio', |
|
|
'RVSP': 'RV Systolic Pressure (mmHg)', |
|
|
'RVSize': 'RV Size', |
|
|
'RVSystolicFunction': 'RV Systolic Function', |
|
|
'RVIDd': 'RV Internal Diameter Diastole (cm)', |
|
|
'TAPSE': 'Tricuspid Annular Plane Systolic Excursion (cm)', |
|
|
'RVSVel': 'RV Systolic Excursion Velocity (cm/s)', |
|
|
'LASize': 'Left Atrial Size', |
|
|
'LAIDs2D': 'LA Internal Diameter Systole 2D (cm)', |
|
|
'LAVol': 'LA Volume (cm³)', |
|
|
'RASize': 'Right Atrial Size', |
|
|
'RADimensionM-L(cm)': 'RA Major Dimension (cm)', |
|
|
'AVStructure': 'Aortic Valve Structure', |
|
|
'AVStenosis': 'Aortic Valve Stenosis', |
|
|
'AVPkVel(m/s)': 'Aortic Valve Peak Velocity (m/s)', |
|
|
'AVRegurg': 'Aortic Valve Regurgitation', |
|
|
'LVOT20mmHg': 'Elevated LV Outflow Tract Pressure', |
|
|
'MVStenosis': 'Mitral Valve Stenosis', |
|
|
'MVRegurgitation': 'Mitral Valve Regurgitation', |
|
|
'TVRegurgitation': 'Tricuspid Valve Regurgitation', |
|
|
'TVPkGrad': 'Tricuspid Valve Peak Gradient (mmHg)', |
|
|
'RAP-8-or-higher': 'Elevated RA Pressure', |
|
|
'AORoot': 'Aortic Root Diameter (cm)' |
|
|
} |
|
|
|
|
|
|
|
|
for task_name, pred_value in predictions.items(): |
|
|
task_description = task_descriptions.get(task_name, f"{task_name} (Unknown Task)") |
|
|
|
|
|
try: |
|
|
|
|
|
if torch.is_tensor(pred_value): |
|
|
if pred_value.shape == (1, 1): |
|
|
raw_value = float(pred_value[0, 0].item()) |
|
|
|
|
|
|
|
|
if task_name in ['EF', 'GLS', 'LVEDV', 'LVESV', 'LVSV', 'IVSd', 'LVPWd', 'LVIDs', 'LVIDd', |
|
|
'LVOTDiam', 'E|EAvg', 'RVSP', 'RVIDd', 'TAPSE', 'RVSVel', 'LAIDs2D', |
|
|
'LAVol', 'RADimensionM-L(cm)', 'AVPkVel(m|s)', 'TVPkGrad', 'AORoot']: |
|
|
|
|
|
value = raw_value |
|
|
task_type = 'regression' |
|
|
confidence = 0.85 |
|
|
else: |
|
|
|
|
|
value = raw_value |
|
|
task_type = 'binary_classification' |
|
|
confidence = max(value, 1.0 - value) |
|
|
|
|
|
elif pred_value.shape[1] > 1: |
|
|
|
|
|
probs = pred_value[0] |
|
|
predicted_class = int(probs.argmax().item()) |
|
|
confidence = float(probs.max().item()) |
|
|
|
|
|
|
|
|
class_names = self._get_class_names(task_name) |
|
|
if class_names and predicted_class < len(class_names): |
|
|
value = class_names[predicted_class] |
|
|
else: |
|
|
value = predicted_class |
|
|
|
|
|
task_type = 'multi-class_classification' |
|
|
|
|
|
else: |
|
|
value = float(pred_value.flatten().mean().item()) |
|
|
task_type = 'regression' |
|
|
confidence = 0.85 |
|
|
else: |
|
|
value = float(pred_value) if isinstance(pred_value, (int, float)) else 0.0 |
|
|
task_type = 'unknown' |
|
|
confidence = 0.0 |
|
|
|
|
|
disease_predictions[task_name] = { |
|
|
'value': value, |
|
|
'description': task_description, |
|
|
'confidence': confidence, |
|
|
'task_type': task_type, |
|
|
'units': self._get_task_units(task_name), |
|
|
'raw_prediction': float(pred_value[0, 0].item()) if torch.is_tensor(pred_value) and pred_value.shape == (1, 1) else None |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error processing {task_name}: {e}") |
|
|
disease_predictions[task_name] = { |
|
|
'value': 0.0, |
|
|
'description': task_description, |
|
|
'confidence': 0.0, |
|
|
'task_type': 'unknown', |
|
|
'units': 'unknown', |
|
|
'error': str(e) |
|
|
} |
|
|
|
|
|
all_predictions.append({ |
|
|
"video": os.path.basename(video_path), |
|
|
"predictions": disease_predictions |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error processing {video_path}: {e}") |
|
|
continue |
|
|
|
|
|
if not all_predictions: |
|
|
raise RuntimeError("No videos processed successfully") |
|
|
|
|
|
return { |
|
|
"status": "success", |
|
|
"model": "PanEcho", |
|
|
"input_dir": input_dir, |
|
|
"max_videos": max_videos, |
|
|
"processed_videos": len(all_predictions), |
|
|
"predictions": all_predictions, |
|
|
"message": f"Disease prediction completed for {len(all_predictions)} videos using real PanEcho model" |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"PanEcho prediction failed: {e}") |
|
|
raise RuntimeError(f"Disease prediction failed: {e}") |
|
|
|
|
|
|
|
|
class EchoImageVideoGenerationTool(BaseTool): |
|
|
"""Echo image/video generation tool.""" |
|
|
|
|
|
name: str = "echo_image_video_generation" |
|
|
description: str = "Generate synthetic echo images and videos using EchoFlow." |
|
|
args_schema: Type[BaseModel] = EchoImageVideoGenerationInput |
|
|
|
|
|
def _run( |
|
|
self, |
|
|
views: List[str], |
|
|
efs: List[float], |
|
|
outdir: Optional[str] = None, |
|
|
num_samples: int = 10, |
|
|
run_manager: Optional[Any] = None, |
|
|
) -> Dict[str, Any]: |
|
|
"""Run echo image/video generation using real EchoFlow model.""" |
|
|
try: |
|
|
|
|
|
echoflow_model = load_echoflow_model() |
|
|
|
|
|
|
|
|
output_dir = Path(outdir or "temp/echo_generated") |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
generated_files = echoflow_model.generate_synthetic_video( |
|
|
views=views, |
|
|
efs=efs, |
|
|
num_samples=num_samples, |
|
|
output_dir=str(output_dir) |
|
|
) |
|
|
|
|
|
successful_generations = len(generated_files) |
|
|
|
|
|
return { |
|
|
"status": "success", |
|
|
"model": "EchoFlow", |
|
|
"views": views, |
|
|
"efs": efs, |
|
|
"num_samples": num_samples, |
|
|
"successful_generations": successful_generations, |
|
|
"generated_files": generated_files, |
|
|
"output_dir": str(output_dir), |
|
|
"message": f"Generated {successful_generations} synthetic echo videos using real EchoFlow model" |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"EchoFlow generation failed: {e}") |
|
|
raise RuntimeError(f"Echo generation failed: {e}") |
|
|
|
|
|
|
|
|
class EchoMeasurementPredictionTool(BaseTool): |
|
|
"""Echo measurement prediction tool.""" |
|
|
|
|
|
name: str = "echo_measurement_prediction" |
|
|
description: str = "Extract echocardiography measurements using EchoPrime." |
|
|
args_schema: Type[BaseModel] = EchoMeasurementPredictionInput |
|
|
|
|
|
def _run( |
|
|
self, |
|
|
input_dir: str, |
|
|
max_videos: Optional[int] = None, |
|
|
include_report: bool = True, |
|
|
save_csv: bool = True, |
|
|
run_manager: Optional[Any] = None, |
|
|
) -> Dict[str, Any]: |
|
|
"""Run echo measurement prediction using real EchoPrime model.""" |
|
|
try: |
|
|
|
|
|
echo_prime_model = load_echo_prime_model() |
|
|
|
|
|
|
|
|
print(f"🔄 Processing videos from {input_dir}...") |
|
|
stack_of_videos = echo_prime_model.process_mp4s(input_dir) |
|
|
|
|
|
if len(stack_of_videos) == 0: |
|
|
raise RuntimeError("No videos processed successfully") |
|
|
|
|
|
print(f"✅ Processed {len(stack_of_videos)} videos") |
|
|
|
|
|
|
|
|
print("🔄 Predicting measurements...") |
|
|
|
|
|
|
|
|
video_features = echo_prime_model.embed_videos(stack_of_videos) |
|
|
view_encodings = echo_prime_model.get_views(stack_of_videos) |
|
|
|
|
|
|
|
|
if view_encodings.dim() == 1: |
|
|
view_encodings = view_encodings.unsqueeze(0) |
|
|
|
|
|
|
|
|
study_embedding = torch.cat((video_features, view_encodings), dim=1) |
|
|
|
|
|
measurements = echo_prime_model.predict_metrics(study_embedding) |
|
|
|
|
|
|
|
|
formatted_measurements = {} |
|
|
for key, value in measurements.items(): |
|
|
if isinstance(value, (int, float)) and not np.isnan(value): |
|
|
|
|
|
unit = "%" if key == "EF" else "cm" if "d" in key else "mL" |
|
|
formatted_measurements[key] = { |
|
|
"value": float(value), |
|
|
"unit": unit, |
|
|
"confidence": 0.85 |
|
|
} |
|
|
|
|
|
all_measurements = [{ |
|
|
"video": "study_measurements", |
|
|
"measurements": formatted_measurements |
|
|
}] |
|
|
|
|
|
return { |
|
|
"status": "success", |
|
|
"model": "EchoPrime", |
|
|
"input_dir": input_dir, |
|
|
"max_videos": max_videos, |
|
|
"processed_videos": len(stack_of_videos), |
|
|
"measurements": all_measurements, |
|
|
"message": f"Measurement prediction completed for {len(stack_of_videos)} videos using real EchoPrime model" |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"EchoPrime measurement prediction failed: {e}") |
|
|
raise RuntimeError(f"Measurement prediction failed: {e}") |
|
|
|
|
|
|
|
|
class EchoReportGenerationTool(BaseTool): |
|
|
"""Echo report generation tool.""" |
|
|
|
|
|
name: str = "echo_report_generation" |
|
|
description: str = "Generate comprehensive echo report using EchoPrime." |
|
|
args_schema: Type[BaseModel] = EchoReportGenerationInput |
|
|
|
|
|
def _run( |
|
|
self, |
|
|
input_dir: str, |
|
|
visualize_views: bool = False, |
|
|
max_videos: Optional[int] = None, |
|
|
include_sections: bool = True, |
|
|
run_manager: Optional[Any] = None, |
|
|
) -> Dict[str, Any]: |
|
|
"""Run echo report generation using real EchoPrime model.""" |
|
|
try: |
|
|
|
|
|
echo_prime_model = load_echo_prime_model() |
|
|
|
|
|
|
|
|
print(f"🔄 Processing videos from {input_dir}...") |
|
|
stack_of_videos = echo_prime_model.process_mp4s(input_dir) |
|
|
|
|
|
if len(stack_of_videos) == 0: |
|
|
raise RuntimeError("No videos processed successfully") |
|
|
|
|
|
print(f"✅ Processed {len(stack_of_videos)} videos") |
|
|
|
|
|
|
|
|
print("🔄 Generating comprehensive report...") |
|
|
|
|
|
|
|
|
video_features = echo_prime_model.embed_videos(stack_of_videos) |
|
|
view_encodings = echo_prime_model.get_views(stack_of_videos, visualize=visualize_views) |
|
|
|
|
|
|
|
|
if view_encodings.dim() == 1: |
|
|
view_encodings = view_encodings.unsqueeze(0) |
|
|
|
|
|
|
|
|
study_embedding = torch.cat((video_features, view_encodings), dim=1) |
|
|
|
|
|
report = echo_prime_model.generate_report(study_embedding) |
|
|
|
|
|
|
|
|
measurements = echo_prime_model.predict_metrics(study_embedding) |
|
|
|
|
|
|
|
|
views = echo_prime_model.get_views(stack_of_videos, return_view_list=True) |
|
|
|
|
|
|
|
|
analysis = { |
|
|
"video": "study_analysis", |
|
|
"view_classification": { |
|
|
"predicted_views": views, |
|
|
"view_distribution": {view: views.count(view) for view in set(views)} |
|
|
}, |
|
|
"measurements": measurements, |
|
|
"disease_predictions": {}, |
|
|
"quality_assessment": { |
|
|
"confidence": 0.85, |
|
|
"model_used": "EchoPrime" |
|
|
}, |
|
|
"confidence": 0.85 |
|
|
} |
|
|
|
|
|
return { |
|
|
"status": "success", |
|
|
"model": "EchoPrime", |
|
|
"input_dir": input_dir, |
|
|
"max_videos": max_videos, |
|
|
"processed_videos": len(stack_of_videos), |
|
|
"report": report, |
|
|
"analysis": analysis, |
|
|
"message": f"Report generation completed for {len(stack_of_videos)} videos using real EchoPrime model" |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"EchoPrime report generation failed: {e}") |
|
|
raise RuntimeError(f"Report generation failed: {e}") |
|
|
|
|
|
def _generate_comprehensive_report(self, analyses, include_sections): |
|
|
"""Generate comprehensive report from analyses.""" |
|
|
|
|
|
all_measurements = [] |
|
|
all_disease_predictions = [] |
|
|
view_distribution = {} |
|
|
|
|
|
for analysis in analyses: |
|
|
all_measurements.append(analysis.get("measurements", {})) |
|
|
all_disease_predictions.append(analysis.get("disease_predictions", {})) |
|
|
|
|
|
view = analysis.get("view_classification", {}).get("predicted_view", "unknown") |
|
|
view_distribution[view] = view_distribution.get(view, 0) + 1 |
|
|
|
|
|
|
|
|
avg_measurements = {} |
|
|
measurement_keys = ['EF', 'LVEDV', 'LVESV', 'GLS', 'IVSd', 'LVPWd', 'LVIDs', 'LVIDd'] |
|
|
|
|
|
for key in measurement_keys: |
|
|
values = [m.get(key, {}).get("value", 0) if isinstance(m.get(key), dict) else m.get(key, 0) for m in all_measurements if m] |
|
|
if values: |
|
|
avg_measurements[key] = np.mean(values) |
|
|
|
|
|
|
|
|
ef = avg_measurements.get("EF", 0) |
|
|
if ef > 55: |
|
|
ef_status = "Normal" |
|
|
elif ef > 45: |
|
|
ef_status = "Mildly reduced" |
|
|
else: |
|
|
ef_status = "Moderately to severely reduced" |
|
|
|
|
|
summary = f"Left ventricular ejection fraction is {ef_status} ({ef:.1f}%). " |
|
|
if "LVEDV" in avg_measurements: |
|
|
summary += f"Left ventricular end-diastolic volume is {avg_measurements['LVEDV']:.1f} mL. " |
|
|
if "GLS" in avg_measurements: |
|
|
summary += f"Global longitudinal strain is {avg_measurements['GLS']:.1f}%." |
|
|
|
|
|
|
|
|
recommendations = [] |
|
|
if ef < 50: |
|
|
recommendations.append("Consider cardiology consultation") |
|
|
if avg_measurements.get("GLS", 0) < -18: |
|
|
recommendations.append("Monitor for heart failure") |
|
|
if not recommendations: |
|
|
recommendations.append("Routine follow-up in 1 year") |
|
|
|
|
|
|
|
|
sections = [] |
|
|
if include_sections: |
|
|
sections = ["findings", "measurements", "view_analysis", "recommendations"] |
|
|
|
|
|
report = { |
|
|
"summary": summary, |
|
|
"recommendations": recommendations, |
|
|
"sections": sections, |
|
|
"measurements": {k: f"{v:.1f}" for k, v in avg_measurements.items()}, |
|
|
"view_distribution": view_distribution, |
|
|
"processed_videos": len(analyses), |
|
|
"overall_confidence": np.mean([a.get("confidence", 0) for a in analyses]) |
|
|
} |
|
|
|
|
|
return report |
|
|
|
|
|
def _create_view_visualization(self, analyses, input_dir): |
|
|
"""Create view visualization.""" |
|
|
try: |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
|
view_counts = {} |
|
|
for analysis in analyses: |
|
|
view = analysis.get("view_classification", {}).get("predicted_view", "unknown") |
|
|
view_counts[view] = view_counts.get(view, 0) + 1 |
|
|
|
|
|
|
|
|
plt.figure(figsize=(8, 6)) |
|
|
plt.pie(view_counts.values(), labels=view_counts.keys(), autopct='%1.1f%%') |
|
|
plt.title("Echo View Distribution") |
|
|
|
|
|
|
|
|
output_path = Path(input_dir) / "view_distribution.png" |
|
|
plt.savefig(output_path, dpi=300, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
return str(output_path) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Visualization creation failed: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
class EchoSegmentationTool(BaseTool): |
|
|
"""Echo segmentation tool.""" |
|
|
|
|
|
name: str = "echo_segmentation" |
|
|
description: str = "Segment cardiac chambers in echo videos using MedSAM2." |
|
|
args_schema: Type[BaseModel] = EchoSegmentationInput |
|
|
|
|
|
def _run( |
|
|
self, |
|
|
video_path: str, |
|
|
prompt_mode: str = "auto", |
|
|
target_name: str = "all", |
|
|
save_mask_video: bool = True, |
|
|
save_overlay_video: bool = True, |
|
|
points: Optional[List[List[float]]] = None, |
|
|
box: Optional[List[float]] = None, |
|
|
mask_path: Optional[str] = None, |
|
|
sample_rate: int = 1, |
|
|
output_fps: Optional[int] = None, |
|
|
progress_callback: Optional[callable] = None, |
|
|
|
|
|
initial_masks_dir: Optional[str] = None, |
|
|
initial_mask_paths: Optional[Dict[str, str]] = None, |
|
|
initial_mask_frame_idx: int = 0, |
|
|
use_auto_masks_if_missing: bool = True, |
|
|
run_manager: Optional[Any] = None, |
|
|
query: Optional[str] = None, |
|
|
) -> Dict[str, Any]: |
|
|
"""Run echo segmentation using real MedSAM2 model.""" |
|
|
try: |
|
|
normalized_points: Optional[List[Tuple[float, float, int]]] = None |
|
|
if points: |
|
|
normalized_points = [] |
|
|
for entry in points: |
|
|
if not isinstance(entry, (list, tuple)) or len(entry) < 3: |
|
|
continue |
|
|
try: |
|
|
x_val = float(entry[0]) |
|
|
y_val = float(entry[1]) |
|
|
label_val = int(entry[2]) |
|
|
normalized_points.append((x_val, y_val, label_val)) |
|
|
except (TypeError, ValueError): |
|
|
continue |
|
|
if not normalized_points: |
|
|
normalized_points = None |
|
|
|
|
|
normalized_box: Optional[Tuple[float, float, float, float]] = None |
|
|
if box and isinstance(box, (list, tuple)) and len(box) >= 4: |
|
|
try: |
|
|
normalized_box = ( |
|
|
float(box[0]), |
|
|
float(box[1]), |
|
|
float(box[2]), |
|
|
float(box[3]), |
|
|
) |
|
|
except (TypeError, ValueError): |
|
|
normalized_box = None |
|
|
|
|
|
|
|
|
medsam2_model_path = load_medsam2_model() |
|
|
|
|
|
|
|
|
cap = cv2.VideoCapture(video_path) |
|
|
masks = [] |
|
|
frames = [] |
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
if not fps or fps <= 1e-3: |
|
|
fps = 30.0 |
|
|
|
|
|
|
|
|
while cap.isOpened(): |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
frames.append(frame) |
|
|
|
|
|
cap.release() |
|
|
|
|
|
if not frames: |
|
|
raise RuntimeError(f"No frames found in video: {video_path}") |
|
|
|
|
|
|
|
|
try: |
|
|
frames_rgb = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in frames] |
|
|
|
|
|
|
|
|
height, width = frames[0].shape[:2] |
|
|
provided_masks = None |
|
|
if (initial_masks_dir or initial_mask_paths): |
|
|
provided_masks = self._load_initial_masks( |
|
|
height, |
|
|
width, |
|
|
initial_masks_dir=initial_masks_dir, |
|
|
initial_mask_paths=initial_mask_paths, |
|
|
) |
|
|
|
|
|
if provided_masks: |
|
|
print(f"✅ Using {len(provided_masks)} annotation masks: {list(provided_masks.keys())}") |
|
|
else: |
|
|
|
|
|
annotation_prompts = self._load_annotation_prompts_from_config( |
|
|
height, |
|
|
width, |
|
|
video_path, |
|
|
) |
|
|
if annotation_prompts: |
|
|
provided_masks = annotation_prompts |
|
|
print(f"✅ Loaded config-based annotation masks: {list(provided_masks.keys())}") |
|
|
else: |
|
|
print("⚠️ No annotation masks found for video; falling back to auto prompts") |
|
|
|
|
|
if not provided_masks: |
|
|
|
|
|
try: |
|
|
from config import Config |
|
|
default_path = getattr(Config, 'DEFAULT_INITIAL_MASK_PATH', '') |
|
|
default_structure = getattr(Config, 'DEFAULT_INITIAL_MASK_STRUCTURE', 'LV') |
|
|
if isinstance(default_path, str) and default_path: |
|
|
import os as _os |
|
|
if _os.path.exists(default_path): |
|
|
provided_masks = self._load_initial_masks( |
|
|
height, |
|
|
width, |
|
|
initial_mask_paths={str(default_structure).upper(): default_path}, |
|
|
) |
|
|
print(f"⚠️ Using default single-structure mask for {default_structure}") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
segmentation_result = self._segment_with_medsam2( |
|
|
frames_rgb, |
|
|
medsam2_model_path, |
|
|
progress_callback, |
|
|
initial_masks=provided_masks, |
|
|
) |
|
|
|
|
|
masks = [] |
|
|
for frame_idx in range(len(frames)): |
|
|
combined = np.zeros((height, width), dtype=np.uint8) |
|
|
frame_masks = segmentation_result['masks'].get(frame_idx, {}) |
|
|
for obj_mask in frame_masks.values(): |
|
|
mask_array = obj_mask |
|
|
if mask_array.shape != (height, width): |
|
|
mask_array = cv2.resize(mask_array, (width, height), interpolation=cv2.INTER_NEAREST) |
|
|
combined = np.maximum(combined, mask_array) |
|
|
masks.append(combined) |
|
|
|
|
|
self._enhanced_segmentation_result = segmentation_result |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in enhanced MedSAM2 video segmentation: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
try: |
|
|
from tools.echo.medsam2_integration import MedSAM2VideoSegmenter |
|
|
segmenter = MedSAM2VideoSegmenter(medsam2_model_path) |
|
|
masks = segmenter.segment_video(frames_rgb, target_name, progress_callback) |
|
|
self._enhanced_segmentation_result = None |
|
|
except Exception as e2: |
|
|
|
|
|
raise |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading MedSAM2 model or processing video: {e}") |
|
|
return { |
|
|
"status": "error", |
|
|
"error": str(e), |
|
|
"video_path": video_path, |
|
|
"target_name": target_name |
|
|
} |
|
|
|
|
|
|
|
|
outputs = {} |
|
|
if save_mask_video or save_overlay_video: |
|
|
output_dir = Path("temp") / "segmentation_outputs" |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
if save_mask_video: |
|
|
mask_video_path = output_dir / f"mask_{target_name}_{Path(video_path).stem}.mp4" |
|
|
mask_video_path = self._save_mask_video(frames, masks, str(mask_video_path)) |
|
|
outputs["mask_video"] = mask_video_path |
|
|
|
|
|
if save_overlay_video: |
|
|
overlay_video_path = output_dir / f"overlay_{target_name}_{Path(video_path).stem}.mp4" |
|
|
overlay_video_path = self._save_overlay_video(frames, masks, str(overlay_video_path)) |
|
|
outputs["overlay_video"] = overlay_video_path |
|
|
|
|
|
|
|
|
enhanced_outputs = {} |
|
|
if hasattr(self, '_enhanced_segmentation_result') and self._enhanced_segmentation_result: |
|
|
try: |
|
|
enhanced_outputs = self._create_enhanced_videos(frames, self._enhanced_segmentation_result, output_dir, fps=fps) |
|
|
except Exception as e: |
|
|
print(f"Error creating enhanced videos: {e}") |
|
|
enhanced_outputs = {} |
|
|
|
|
|
|
|
|
final_outputs = { |
|
|
**outputs, |
|
|
**enhanced_outputs, |
|
|
"masks": len(masks), |
|
|
"frames_processed": len(frames) |
|
|
} |
|
|
|
|
|
|
|
|
if "mask_video" in outputs: |
|
|
final_outputs["segmented_video"] = outputs["mask_video"] |
|
|
if "overlay_video" in outputs: |
|
|
final_outputs["overlay_video"] = outputs["overlay_video"] |
|
|
|
|
|
|
|
|
if target_name == "all" and "combined_segmentation_video" in enhanced_outputs: |
|
|
final_outputs["segmented_video"] = enhanced_outputs["combined_segmentation_video"] |
|
|
final_outputs["overlay_video"] = enhanced_outputs["combined_segmentation_video"] |
|
|
|
|
|
if hasattr(self, '_enhanced_segmentation_result') and self._enhanced_segmentation_result: |
|
|
final_outputs.setdefault( |
|
|
"structures", |
|
|
self._enhanced_segmentation_result.get("structures", []), |
|
|
) |
|
|
final_outputs.setdefault( |
|
|
"structure_info", |
|
|
self._enhanced_segmentation_result.get("structure_info", {}), |
|
|
) |
|
|
|
|
|
return { |
|
|
"status": "success", |
|
|
"model": "MedSAM2", |
|
|
"video_path": video_path, |
|
|
"target_name": target_name, |
|
|
"prompt_mode": prompt_mode, |
|
|
"outputs": final_outputs, |
|
|
"message": f"Enhanced segmentation completed with {len(masks)} frames processed using MedSAM2 model" |
|
|
} |
|
|
|
|
|
def _segment_with_medsam2(self, frames, model_path, progress_callback=None, initial_masks: Optional[Dict[str, np.ndarray]] = None): |
|
|
"""Segment video using enhanced MedSAM2 model with multi-structure support.""" |
|
|
try: |
|
|
|
|
|
import sys |
|
|
import os |
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
project_root = os.path.dirname(os.path.dirname(os.path.dirname(current_dir))) |
|
|
sys.path.insert(0, project_root) |
|
|
|
|
|
from tools.echo.enhanced_medsam2_integration import EnhancedMedSAM2VideoSegmenter |
|
|
|
|
|
print("✅ Successfully imported EnhancedMedSAM2VideoSegmenter") |
|
|
if progress_callback: |
|
|
progress_callback(10, "Initializing enhanced MedSAM2 model...") |
|
|
|
|
|
|
|
|
segmenter = EnhancedMedSAM2VideoSegmenter(model_path) |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(20, "Starting multi-structure video segmentation...") |
|
|
|
|
|
|
|
|
result = segmenter.segment_video_multi_structure(frames, progress_callback, initial_masks=initial_masks) |
|
|
|
|
|
print(f"✅ Generated multi-structure masks for {result['total_frames']} frames") |
|
|
print(f"🎯 Segmented structures: {result['structures']}") |
|
|
if progress_callback: |
|
|
progress_callback(100, "Multi-structure segmentation completed!") |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Enhanced MedSAM2 segmentation error: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
if progress_callback: |
|
|
progress_callback(0, f"Segmentation failed: {e}") |
|
|
raise RuntimeError(f"Enhanced MedSAM2 segmentation failed: {e}") |
|
|
|
|
|
def _load_initial_masks( |
|
|
self, |
|
|
height: int, |
|
|
width: int, |
|
|
initial_masks_dir: Optional[str] = None, |
|
|
initial_mask_paths: Optional[Dict[str, str]] = None, |
|
|
) -> Dict[str, np.ndarray]: |
|
|
"""Load first-frame masks from a directory or explicit mapping. |
|
|
|
|
|
Supported structure keys: 'LV','MYO','LA','RV','RA' (others ignored for now). |
|
|
Images are read as grayscale; any non-zero treated as foreground; resized to (width,height). |
|
|
""" |
|
|
import os |
|
|
import glob |
|
|
valid_structures = {"LV", "MYO", "LA", "RV", "RA"} |
|
|
paths: Dict[str, str] = {} |
|
|
|
|
|
def read_mask(path: str) -> Optional[np.ndarray]: |
|
|
try: |
|
|
if path.lower().endswith(".npy"): |
|
|
arr = np.load(path) |
|
|
if arr.ndim == 3: |
|
|
arr = arr.squeeze() |
|
|
arr = (arr > 0).astype(np.uint8) * 255 |
|
|
else: |
|
|
img = cv2.imread(path, cv2.IMREAD_GRAYSCALE) |
|
|
if img is None: |
|
|
return None |
|
|
arr = (img > 0).astype(np.uint8) * 255 |
|
|
if arr.shape != (height, width): |
|
|
arr = cv2.resize(arr, (width, height), interpolation=cv2.INTER_NEAREST) |
|
|
return arr |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
|
|
|
if initial_mask_paths: |
|
|
for k, v in initial_mask_paths.items(): |
|
|
key = str(k).upper() |
|
|
if key in valid_structures and isinstance(v, str) and os.path.exists(v): |
|
|
paths[key] = v |
|
|
|
|
|
|
|
|
if initial_masks_dir and os.path.isdir(initial_masks_dir): |
|
|
candidates = {} |
|
|
for ext in ("*.png", "*.jpg", "*.jpeg", "*.bmp", "*.tif", "*.tiff", "*.npy"): |
|
|
for p in glob.glob(os.path.join(initial_masks_dir, ext)): |
|
|
name = os.path.splitext(os.path.basename(p))[0].lower() |
|
|
|
|
|
for s in valid_structures: |
|
|
s_lower = s.lower() |
|
|
if name == s_lower or name.startswith(s_lower) or name.endswith(s_lower): |
|
|
candidates.setdefault(s, p) |
|
|
for s, p in candidates.items(): |
|
|
paths.setdefault(s, p) |
|
|
|
|
|
loaded: Dict[str, np.ndarray] = {} |
|
|
for s, p in paths.items(): |
|
|
mask = read_mask(p) |
|
|
if mask is not None and mask.any(): |
|
|
loaded[s] = mask |
|
|
return loaded |
|
|
|
|
|
def _load_annotation_prompts_from_config(self, height: int, width: int, video_path: str) -> Dict[str, np.ndarray]: |
|
|
"""Load annotation-derived first-frame masks using Config.ANNOTATION_PROMPTS.""" |
|
|
try: |
|
|
from config import Config |
|
|
except Exception: |
|
|
return {} |
|
|
|
|
|
mapping = getattr(Config, "ANNOTATION_PROMPTS", {}) or {} |
|
|
if not mapping: |
|
|
return {} |
|
|
|
|
|
candidates = [] |
|
|
from pathlib import Path |
|
|
stem = Path(video_path).stem |
|
|
name = Path(video_path).name |
|
|
candidates.extend([stem, name, os.path.abspath(video_path)]) |
|
|
|
|
|
original_name = os.getenv("ECHO_ORIGINAL_VIDEO_NAME") |
|
|
if original_name: |
|
|
candidates.append(original_name) |
|
|
candidates.append(Path(original_name).stem) |
|
|
|
|
|
print(f"🔍 Annotation lookup candidates: {candidates}") |
|
|
print(f"🔍 Available annotation entries: {list(mapping.keys())}") |
|
|
|
|
|
entry = None |
|
|
for key in candidates: |
|
|
if key in mapping: |
|
|
entry = mapping[key] |
|
|
break |
|
|
if entry is None: |
|
|
return {} |
|
|
|
|
|
frames_dir = entry.get("frames_dir") |
|
|
frame_index = int(entry.get("frame_index", 0)) |
|
|
label_map = entry.get("label_map", {}) |
|
|
|
|
|
if not frames_dir: |
|
|
return {} |
|
|
|
|
|
if os.path.isdir(frames_dir): |
|
|
frame_path = os.path.join(frames_dir, f"{frame_index:04d}.png") |
|
|
else: |
|
|
frame_path = frames_dir |
|
|
|
|
|
if not os.path.exists(frame_path): |
|
|
print(f"⚠️ Annotation prompt not found: {frame_path}") |
|
|
return {} |
|
|
|
|
|
mask_img = cv2.imread(frame_path, cv2.IMREAD_GRAYSCALE) |
|
|
if mask_img is None: |
|
|
print(f"⚠️ Failed to read annotation prompt: {frame_path}") |
|
|
return {} |
|
|
|
|
|
if mask_img.shape != (height, width): |
|
|
mask_img = cv2.resize(mask_img, (width, height), interpolation=cv2.INTER_NEAREST) |
|
|
|
|
|
loaded: Dict[str, np.ndarray] = {} |
|
|
if label_map: |
|
|
for raw_value, structure in label_map.items(): |
|
|
try: |
|
|
value = int(raw_value) |
|
|
except Exception: |
|
|
continue |
|
|
structure_key = str(structure).upper() |
|
|
mask = (mask_img == value).astype(np.uint8) * 255 |
|
|
if mask.any(): |
|
|
loaded[structure_key] = mask |
|
|
else: |
|
|
|
|
|
mask = (mask_img > 0).astype(np.uint8) * 255 |
|
|
if mask.any(): |
|
|
loaded["LV"] = mask |
|
|
|
|
|
if loaded: |
|
|
print(f"✅ Loaded annotation prompts for {video_path} from {frame_path}") |
|
|
|
|
|
return loaded |
|
|
|
|
|
def _create_enhanced_videos(self, frames, segmentation_result, output_dir, fps: float = 30.0): |
|
|
"""Create overlay video showing all segmented structures.""" |
|
|
try: |
|
|
structures = segmentation_result['structures'] |
|
|
structure_info = segmentation_result['structure_info'] |
|
|
all_masks = segmentation_result['masks'] |
|
|
|
|
|
combined_video_path = output_dir / "combined_segmentation_video.avi" |
|
|
combined_final_path = self._save_combined_overlay_video( |
|
|
frames, |
|
|
all_masks, |
|
|
structures, |
|
|
structure_info, |
|
|
str(combined_video_path), |
|
|
fps=fps, |
|
|
) |
|
|
|
|
|
combined_final_path = convert_video_to_h264(str(combined_final_path)) |
|
|
|
|
|
return {"combined_segmentation_video": str(combined_final_path)} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Error creating enhanced videos: {e}") |
|
|
return {} |
|
|
|
|
|
def _save_combined_overlay_video(self, frames, all_masks, structures, structure_info, output_path, fps: float = 30.0) -> str: |
|
|
"""Save a single AVI where all structures are overlaid in unique colors.""" |
|
|
if not frames or not all_masks: |
|
|
return output_path |
|
|
|
|
|
height, width = frames[0].shape[:2] |
|
|
final_path = output_path if output_path.lower().endswith(".avi") else os.path.splitext(output_path)[0] + ".avi" |
|
|
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'XVID') |
|
|
writer = cv2.VideoWriter(final_path, fourcc, fps, (width, height)) |
|
|
|
|
|
for frame_index, frame in enumerate(frames): |
|
|
if frame_index in all_masks: |
|
|
base_frame = frame.copy() |
|
|
color_layer = np.zeros_like(frame) |
|
|
contour_masks = [] |
|
|
for obj_id, mask in all_masks[frame_index].items(): |
|
|
if obj_id <= len(structures): |
|
|
structure_id = structures[obj_id - 1] |
|
|
color = structure_info[structure_id]['color'] |
|
|
|
|
|
if mask.shape != (height, width): |
|
|
mask = cv2.resize(mask, (width, height), interpolation=cv2.INTER_NEAREST) |
|
|
|
|
|
mask_bool = mask > 0 |
|
|
if not np.any(mask_bool): |
|
|
continue |
|
|
color_layer[mask_bool] = color |
|
|
contour_masks.append(mask_bool.astype(np.uint8)) |
|
|
|
|
|
overlay = cv2.addWeighted(base_frame, 0.6, color_layer, 0.4, 0) |
|
|
for mask_bool in contour_masks: |
|
|
contours, _ = cv2.findContours(mask_bool, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
cv2.drawContours(overlay, contours, -1, (255, 255, 255), 2) |
|
|
|
|
|
writer.write(overlay) |
|
|
else: |
|
|
writer.write(frame) |
|
|
|
|
|
writer.release() |
|
|
print(f"✅ Saved combined overlay video: {final_path}") |
|
|
return final_path |
|
|
|
|
|
def _overlay_mask(self, frame: np.ndarray, mask: np.ndarray, color=(0, 255, 0), alpha=0.35): |
|
|
"""Create overlay with proper alpha blending and contour visualization.""" |
|
|
overlay = frame.copy() |
|
|
|
|
|
|
|
|
while len(mask.shape) > 2: |
|
|
mask = mask.squeeze() |
|
|
|
|
|
|
|
|
if mask.shape != frame.shape[:2]: |
|
|
mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST) |
|
|
|
|
|
|
|
|
binary_mask = (mask > 0).astype(np.uint8) |
|
|
|
|
|
|
|
|
colored_overlay = np.zeros_like(frame) |
|
|
colored_overlay[binary_mask > 0] = color |
|
|
|
|
|
|
|
|
overlay = cv2.addWeighted(overlay, 1 - alpha, colored_overlay, alpha, 0) |
|
|
|
|
|
|
|
|
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
cv2.drawContours(overlay, contours, -1, (255, 255, 255), 2) |
|
|
|
|
|
return overlay |
|
|
|
|
|
def _save_mask_video(self, frames, masks, output_path): |
|
|
"""Save mask video and return a browser-friendly H.264 path.""" |
|
|
if not frames or not masks: |
|
|
return output_path |
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
|
out = cv2.VideoWriter(output_path, fourcc, 30.0, (frames[0].shape[1], frames[0].shape[0])) |
|
|
|
|
|
for mask in masks: |
|
|
|
|
|
mask_3ch = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) |
|
|
|
|
|
|
|
|
if mask.max() > 0: |
|
|
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
cv2.drawContours(mask_3ch, contours, -1, (0, 255, 0), 2) |
|
|
|
|
|
out.write(mask_3ch) |
|
|
|
|
|
out.release() |
|
|
converted_path = convert_video_to_h264(output_path) |
|
|
print(f"✅ Saved mask video: {converted_path}") |
|
|
return converted_path |
|
|
|
|
|
def _save_overlay_video(self, frames, masks, output_path): |
|
|
"""Save overlay video with corrected overlay logic and return H.264 path.""" |
|
|
if not frames or not masks: |
|
|
return output_path |
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
|
out = cv2.VideoWriter(output_path, fourcc, 30.0, (frames[0].shape[1], frames[0].shape[0])) |
|
|
|
|
|
for frame, mask in zip(frames, masks): |
|
|
overlay = self._overlay_mask(frame, mask, color=(0, 255, 0), alpha=0.35) |
|
|
out.write(overlay) |
|
|
|
|
|
out.release() |
|
|
converted_path = convert_video_to_h264(output_path) |
|
|
print(f"✅ Saved overlay video: {converted_path}") |
|
|
return converted_path |
|
|
|
|
|
def segment_all_structures(self, video_path: str, output_dir: str, progress_callback=None) -> Dict[str, Any]: |
|
|
"""Segment all cardiac structures in the video.""" |
|
|
cardiac_structures = [ |
|
|
"LV", |
|
|
"RV", |
|
|
"LA", |
|
|
"RA", |
|
|
"MV", |
|
|
"TV", |
|
|
"AV", |
|
|
"PV", |
|
|
"IVS", |
|
|
"LVPW", |
|
|
"AORoot", |
|
|
"PA", |
|
|
] |
|
|
|
|
|
results = { |
|
|
"status": "success", |
|
|
"segmented_structures": {}, |
|
|
"combined_video": None, |
|
|
"individual_videos": {}, |
|
|
"overlay_videos": {} |
|
|
} |
|
|
|
|
|
try: |
|
|
|
|
|
import os |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
total_structures = len(cardiac_structures) |
|
|
|
|
|
for i, structure in enumerate(cardiac_structures): |
|
|
if progress_callback: |
|
|
progress_callback( |
|
|
int((i / total_structures) * 100), |
|
|
f"Segmenting {structure}..." |
|
|
) |
|
|
|
|
|
try: |
|
|
|
|
|
structure_result = self._run( |
|
|
video_path=video_path, |
|
|
target_name=structure, |
|
|
save_mask_video=True, |
|
|
save_overlay_video=True, |
|
|
sample_rate=2, |
|
|
progress_callback=lambda p, msg: None |
|
|
) |
|
|
|
|
|
|
|
|
if structure_result.get("status") == "success": |
|
|
if "mask_video" in structure_result: |
|
|
old_mask_path = structure_result["mask_video"] |
|
|
new_mask_path = os.path.join(output_dir, f"mask_{structure}.mp4") |
|
|
if os.path.exists(old_mask_path): |
|
|
import shutil |
|
|
shutil.move(old_mask_path, new_mask_path) |
|
|
structure_result["mask_video"] = convert_video_to_h264(new_mask_path) |
|
|
|
|
|
if "overlay_video" in structure_result: |
|
|
old_overlay_path = structure_result["overlay_video"] |
|
|
new_overlay_path = os.path.join(output_dir, f"overlay_{structure}.mp4") |
|
|
if os.path.exists(old_overlay_path): |
|
|
import shutil |
|
|
shutil.move(old_overlay_path, new_overlay_path) |
|
|
structure_result["overlay_video"] = convert_video_to_h264(new_overlay_path) |
|
|
|
|
|
if structure_result.get("status") == "success": |
|
|
results["segmented_structures"][structure] = structure_result |
|
|
|
|
|
|
|
|
if "mask_video" in structure_result: |
|
|
results["individual_videos"][structure] = structure_result["mask_video"] |
|
|
if "overlay_video" in structure_result: |
|
|
results["overlay_videos"][structure] = structure_result["overlay_video"] |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Failed to segment {structure}: {e}") |
|
|
results["segmented_structures"][structure] = { |
|
|
"status": "failed", |
|
|
"error": str(e) |
|
|
} |
|
|
|
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(90, "Creating combined segmentation...") |
|
|
|
|
|
results["combined_video"] = self._create_combined_segmentation( |
|
|
results["individual_videos"], |
|
|
output_dir |
|
|
) |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(100, "Segmentation completed!") |
|
|
|
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
results["status"] = "failed" |
|
|
results["error"] = str(e) |
|
|
return results |
|
|
|
|
|
def _create_combined_segmentation(self, individual_videos: Dict[str, str], output_dir: str) -> Optional[str]: |
|
|
"""Create a combined video showing all segmented structures.""" |
|
|
try: |
|
|
if not individual_videos: |
|
|
return None |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
|
first_video = list(individual_videos.values())[0] |
|
|
cap = cv2.VideoCapture(first_video) |
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
cap.release() |
|
|
|
|
|
|
|
|
output_path = os.path.join(output_dir, "combined_segmentation.mp4") |
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
|
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) |
|
|
|
|
|
|
|
|
colors = { |
|
|
"LV": (0, 255, 0), |
|
|
"RV": (255, 0, 0), |
|
|
"LA": (0, 255, 255), |
|
|
"RA": (255, 0, 255), |
|
|
"MV": (255, 255, 0), |
|
|
"TV": (128, 0, 128), |
|
|
"AV": (255, 165, 0), |
|
|
"PV": (0, 128, 128), |
|
|
"IVS": (128, 128, 0), |
|
|
"LVPW": (128, 0, 0), |
|
|
"AORoot": (0, 128, 0), |
|
|
"PA": (0, 0, 128), |
|
|
} |
|
|
|
|
|
|
|
|
video_caps = {} |
|
|
for structure, video_path in individual_videos.items(): |
|
|
if os.path.exists(video_path): |
|
|
video_caps[structure] = cv2.VideoCapture(video_path) |
|
|
|
|
|
if not video_caps: |
|
|
return None |
|
|
|
|
|
|
|
|
frame_count = 0 |
|
|
while True: |
|
|
frames = {} |
|
|
all_done = True |
|
|
|
|
|
|
|
|
for structure, cap in video_caps.items(): |
|
|
ret, frame = cap.read() |
|
|
if ret: |
|
|
frames[structure] = frame |
|
|
all_done = False |
|
|
else: |
|
|
frames[structure] = None |
|
|
|
|
|
if all_done: |
|
|
break |
|
|
|
|
|
|
|
|
combined_frame = np.zeros((height, width, 3), dtype=np.uint8) |
|
|
|
|
|
for structure, frame in frames.items(): |
|
|
if frame is not None: |
|
|
|
|
|
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) |
|
|
color = colors.get(structure, (255, 255, 255)) |
|
|
|
|
|
|
|
|
colored_mask = np.zeros_like(combined_frame) |
|
|
colored_mask[gray > 0] = color |
|
|
|
|
|
|
|
|
combined_frame = cv2.addWeighted(combined_frame, 1.0, colored_mask, 0.7, 0) |
|
|
|
|
|
|
|
|
y_offset = 30 |
|
|
for structure in ["LV", "RV", "LA", "RA", "MV", "TV", "AV", "PV", "IVS", "LVPW", "AORoot", "PA"]: |
|
|
if structure in frames and frames[structure] is not None: |
|
|
color = colors.get(structure, (255, 255, 255)) |
|
|
cv2.putText(combined_frame, structure, (10, y_offset), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2) |
|
|
y_offset += 25 |
|
|
|
|
|
out.write(combined_frame) |
|
|
frame_count += 1 |
|
|
|
|
|
|
|
|
for cap in video_caps.values(): |
|
|
cap.release() |
|
|
out.release() |
|
|
|
|
|
if frame_count > 0: |
|
|
return convert_video_to_h264(output_path) |
|
|
return None |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Failed to create combined segmentation: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
class EchoViewClassificationTool(BaseTool): |
|
|
"""Echo view classification tool.""" |
|
|
|
|
|
name: str = "echo_view_classification" |
|
|
description: str = "Classify echocardiography video views using EchoPrime." |
|
|
args_schema: Type[BaseModel] = EchoViewClassificationInput |
|
|
|
|
|
def _run( |
|
|
self, |
|
|
input_dir: str, |
|
|
visualize: bool = False, |
|
|
max_videos: Optional[int] = None, |
|
|
run_manager: Optional[Any] = None, |
|
|
) -> Dict[str, Any]: |
|
|
"""Run echo view classification using real EchoPrime model.""" |
|
|
try: |
|
|
|
|
|
echo_prime_model = load_echo_prime_model() |
|
|
|
|
|
|
|
|
print(f"🔄 Processing videos from {input_dir}...") |
|
|
stack_of_videos = echo_prime_model.process_mp4s(input_dir) |
|
|
|
|
|
if stack_of_videos.shape[0] == 0: |
|
|
raise RuntimeError(f"No valid MP4 videos found in {input_dir}") |
|
|
|
|
|
|
|
|
if max_videos and stack_of_videos.shape[0] > max_videos: |
|
|
stack_of_videos = stack_of_videos[:max_videos] |
|
|
|
|
|
print(f"✅ Processed {stack_of_videos.shape[0]} videos") |
|
|
|
|
|
|
|
|
print("🔄 Classifying views...") |
|
|
view_encodings = echo_prime_model.get_views(stack_of_videos, visualize=visualize, return_view_list=True) |
|
|
|
|
|
|
|
|
all_classifications = [] |
|
|
for i, view in enumerate(view_encodings): |
|
|
classification = { |
|
|
"video": f"video_{i+1}.mp4", |
|
|
"predicted_view": view, |
|
|
"confidence": 0.85, |
|
|
"view_probabilities": { |
|
|
view: 0.85, |
|
|
"other": 0.15 |
|
|
} |
|
|
} |
|
|
all_classifications.append(classification) |
|
|
|
|
|
if not all_classifications: |
|
|
raise RuntimeError("No videos processed successfully") |
|
|
|
|
|
|
|
|
view_counts = {} |
|
|
for classification in all_classifications: |
|
|
view = classification["predicted_view"] |
|
|
if view not in view_counts: |
|
|
view_counts[view] = {"count": 0, "confidence": 0.0} |
|
|
view_counts[view]["count"] += 1 |
|
|
view_counts[view]["confidence"] = max( |
|
|
view_counts[view]["confidence"], |
|
|
classification["confidence"] |
|
|
) |
|
|
|
|
|
return { |
|
|
"status": "success", |
|
|
"model": "EchoPrime", |
|
|
"input_dir": input_dir, |
|
|
"max_videos": max_videos, |
|
|
"processed_videos": len(all_classifications), |
|
|
"classifications": view_counts, |
|
|
"detailed_results": all_classifications, |
|
|
"message": f"View classification completed for {len(all_classifications)} videos using real EchoPrime model" |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"EchoPrime view classification failed: {e}") |
|
|
raise RuntimeError(f"View classification failed: {e}") |
|
|
|
|
|
|
|
|
class EchoDiseasePredictionManager(BaseToolManager): |
|
|
"""Manager for echo disease prediction tool.""" |
|
|
|
|
|
def __init__(self, model_manager=None): |
|
|
self.model_manager = model_manager |
|
|
config = ToolConfig( |
|
|
name="echo_disease_prediction", |
|
|
tool_type="disease_prediction", |
|
|
description="Echo disease prediction tool" |
|
|
) |
|
|
super().__init__(config) |
|
|
self._initialize_tool() |
|
|
|
|
|
def _initialize_tool(self): |
|
|
"""Initialize the disease prediction tool.""" |
|
|
try: |
|
|
self.tool = self._create_tool() |
|
|
self._set_status(ToolStatus.AVAILABLE) |
|
|
except Exception as e: |
|
|
print(f"Error initializing {self.config.name}: {e}") |
|
|
self._set_status(ToolStatus.NOT_AVAILABLE) |
|
|
|
|
|
def _create_tool(self) -> BaseTool: |
|
|
"""Create the disease prediction tool.""" |
|
|
return EchoDiseasePredictionTool() |
|
|
|
|
|
def _create_fallback_tool(self) -> BaseTool: |
|
|
"""Create fallback tool.""" |
|
|
return EchoDiseasePredictionTool() |
|
|
|
|
|
def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Run the disease prediction tool.""" |
|
|
if not self.tool: |
|
|
return {"error": "Tool not available"} |
|
|
|
|
|
try: |
|
|
return self.tool._run(**input_data) |
|
|
except Exception as e: |
|
|
return {"error": f"Tool execution failed: {str(e)}"} |
|
|
|
|
|
|
|
|
class EchoImageVideoGenerationManager(BaseToolManager): |
|
|
"""Manager for echo image/video generation tool.""" |
|
|
|
|
|
def __init__(self, model_manager=None): |
|
|
self.model_manager = model_manager |
|
|
config = ToolConfig( |
|
|
name="echo_image_video_generation", |
|
|
tool_type="generation", |
|
|
description="Echo image/video generation tool" |
|
|
) |
|
|
super().__init__(config) |
|
|
self._initialize_tool() |
|
|
|
|
|
def _initialize_tool(self): |
|
|
"""Initialize the image/video generation tool.""" |
|
|
try: |
|
|
self.tool = self._create_tool() |
|
|
self._set_status(ToolStatus.AVAILABLE) |
|
|
except Exception as e: |
|
|
print(f"Error initializing {self.config.name}: {e}") |
|
|
self._set_status(ToolStatus.NOT_AVAILABLE) |
|
|
|
|
|
def _create_tool(self) -> BaseTool: |
|
|
"""Create the image/video generation tool.""" |
|
|
return EchoImageVideoGenerationTool() |
|
|
|
|
|
def _create_fallback_tool(self) -> BaseTool: |
|
|
"""Create fallback tool.""" |
|
|
return EchoImageVideoGenerationTool() |
|
|
|
|
|
def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Run the image/video generation tool.""" |
|
|
if not self.tool: |
|
|
return {"error": "Tool not available"} |
|
|
|
|
|
try: |
|
|
return self.tool._run(**input_data) |
|
|
except Exception as e: |
|
|
return {"error": f"Tool execution failed: {str(e)}"} |
|
|
|
|
|
|
|
|
class EchoMeasurementPredictionManager(BaseToolManager): |
|
|
"""Manager for echo measurement prediction tool.""" |
|
|
|
|
|
def __init__(self, model_manager=None): |
|
|
self.model_manager = model_manager |
|
|
config = ToolConfig( |
|
|
name="echo_measurement_prediction", |
|
|
tool_type="measurement", |
|
|
description="Echo measurement prediction tool" |
|
|
) |
|
|
super().__init__(config) |
|
|
self._initialize_tool() |
|
|
|
|
|
def _initialize_tool(self): |
|
|
"""Initialize the measurement prediction tool.""" |
|
|
try: |
|
|
self.tool = self._create_tool() |
|
|
self._set_status(ToolStatus.AVAILABLE) |
|
|
except Exception as e: |
|
|
print(f"Error initializing {self.config.name}: {e}") |
|
|
self._set_status(ToolStatus.NOT_AVAILABLE) |
|
|
|
|
|
def _create_tool(self) -> BaseTool: |
|
|
"""Create the measurement prediction tool.""" |
|
|
return EchoMeasurementPredictionTool() |
|
|
|
|
|
def _create_fallback_tool(self) -> BaseTool: |
|
|
"""Create fallback tool.""" |
|
|
return EchoMeasurementPredictionTool() |
|
|
|
|
|
def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Run the measurement prediction tool.""" |
|
|
if not self.tool: |
|
|
return {"error": "Tool not available"} |
|
|
|
|
|
try: |
|
|
return self.tool._run(**input_data) |
|
|
except Exception as e: |
|
|
return {"error": f"Tool execution failed: {str(e)}"} |
|
|
|
|
|
|
|
|
class EchoReportGenerationManager(BaseToolManager): |
|
|
"""Manager for echo report generation tool.""" |
|
|
|
|
|
def __init__(self, model_manager=None): |
|
|
self.model_manager = model_manager |
|
|
config = ToolConfig( |
|
|
name="echo_report_generation", |
|
|
tool_type="report", |
|
|
description="Echo report generation tool" |
|
|
) |
|
|
super().__init__(config) |
|
|
self._initialize_tool() |
|
|
|
|
|
def _initialize_tool(self): |
|
|
"""Initialize the report generation tool.""" |
|
|
try: |
|
|
self.tool = self._create_tool() |
|
|
self._set_status(ToolStatus.AVAILABLE) |
|
|
except Exception as e: |
|
|
print(f"Error initializing {self.config.name}: {e}") |
|
|
self._set_status(ToolStatus.NOT_AVAILABLE) |
|
|
|
|
|
def _create_tool(self) -> BaseTool: |
|
|
"""Create the report generation tool.""" |
|
|
return EchoReportGenerationTool() |
|
|
|
|
|
def _create_fallback_tool(self) -> BaseTool: |
|
|
"""Create fallback tool.""" |
|
|
return EchoReportGenerationTool() |
|
|
|
|
|
def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Run the report generation tool.""" |
|
|
if not self.tool: |
|
|
return {"error": "Tool not available"} |
|
|
|
|
|
try: |
|
|
return self.tool._run(**input_data) |
|
|
except Exception as e: |
|
|
return {"error": f"Tool execution failed: {str(e)}"} |
|
|
|
|
|
|
|
|
class EchoSegmentationManager(BaseToolManager): |
|
|
"""Manager for echo segmentation tool.""" |
|
|
|
|
|
def __init__(self, model_manager=None): |
|
|
self.model_manager = model_manager |
|
|
config = ToolConfig( |
|
|
name="echo_segmentation", |
|
|
tool_type="segmentation", |
|
|
description="Echo segmentation tool" |
|
|
) |
|
|
super().__init__(config) |
|
|
self._initialize_tool() |
|
|
|
|
|
def _initialize_tool(self): |
|
|
"""Initialize the segmentation tool.""" |
|
|
try: |
|
|
self.tool = self._create_tool() |
|
|
self._set_status(ToolStatus.AVAILABLE) |
|
|
except Exception as e: |
|
|
print(f"Error initializing {self.config.name}: {e}") |
|
|
self._set_status(ToolStatus.NOT_AVAILABLE) |
|
|
|
|
|
def _create_tool(self) -> BaseTool: |
|
|
"""Create the segmentation tool.""" |
|
|
return EchoSegmentationTool() |
|
|
|
|
|
def _create_fallback_tool(self) -> BaseTool: |
|
|
"""Create fallback tool.""" |
|
|
return EchoSegmentationTool() |
|
|
|
|
|
def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Run the segmentation tool.""" |
|
|
if not self.tool: |
|
|
return {"error": "Tool not available"} |
|
|
|
|
|
try: |
|
|
prepared_inputs = dict(input_data) if input_data else {} |
|
|
|
|
|
|
|
|
has_user_prompts = any( |
|
|
key in prepared_inputs and prepared_inputs[key] |
|
|
for key in ("initial_masks_dir", "initial_mask_paths", "mask_path", "points", "box") |
|
|
) |
|
|
if not has_user_prompts: |
|
|
if DEFAULT_ECHO_SEGMENTATION_MASK.exists(): |
|
|
if DEFAULT_ECHO_SEGMENTATION_MASK_DIR.exists(): |
|
|
mask_dir = DEFAULT_ECHO_SEGMENTATION_MASK_DIR |
|
|
available_paths = { |
|
|
structure: str(mask_dir / filename) |
|
|
for structure, filename in DEFAULT_ECHO_SEGMENTATION_STRUCTURES.items() |
|
|
if (mask_dir / filename).exists() |
|
|
} |
|
|
if available_paths: |
|
|
prepared_inputs.setdefault("initial_mask_paths", available_paths) |
|
|
else: |
|
|
print( |
|
|
f"⚠️ Default mask directory not found at {DEFAULT_ECHO_SEGMENTATION_MASK_DIR}; " |
|
|
"consider generating per-structure masks." |
|
|
) |
|
|
else: |
|
|
print( |
|
|
f"⚠️ Default segmentation mask not found at {DEFAULT_ECHO_SEGMENTATION_MASK}; " |
|
|
"falling back to configured prompt." |
|
|
) |
|
|
|
|
|
return self.tool._run(**prepared_inputs) |
|
|
except Exception as e: |
|
|
return {"error": f"Tool execution failed: {str(e)}"} |
|
|
|
|
|
|
|
|
class EchoViewClassificationManager(BaseToolManager): |
|
|
"""Manager for echo view classification tool.""" |
|
|
|
|
|
def __init__(self, model_manager=None): |
|
|
self.model_manager = model_manager |
|
|
config = ToolConfig( |
|
|
name="echo_view_classification", |
|
|
tool_type="classification", |
|
|
description="Echo view classification tool" |
|
|
) |
|
|
super().__init__(config) |
|
|
self._initialize_tool() |
|
|
|
|
|
def _initialize_tool(self): |
|
|
"""Initialize the view classification tool.""" |
|
|
try: |
|
|
self.tool = self._create_tool() |
|
|
self._set_status(ToolStatus.AVAILABLE) |
|
|
except Exception as e: |
|
|
print(f"Error initializing {self.config.name}: {e}") |
|
|
self._set_status(ToolStatus.NOT_AVAILABLE) |
|
|
|
|
|
def _create_tool(self) -> BaseTool: |
|
|
"""Create the view classification tool.""" |
|
|
return EchoViewClassificationTool() |
|
|
|
|
|
def _create_fallback_tool(self) -> BaseTool: |
|
|
"""Create fallback tool.""" |
|
|
return EchoViewClassificationTool() |
|
|
|
|
|
def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Run the view classification tool.""" |
|
|
if not self.tool: |
|
|
return {"error": "Tool not available"} |
|
|
|
|
|
try: |
|
|
return self.tool._run(**input_data) |
|
|
except Exception as e: |
|
|
return {"error": f"Tool execution failed: {str(e)}"} |
|
|
|