| |
| """ |
| EchoFlow Integrated Tool |
| |
| This tool integrates EchoFlow into the main echo analysis system. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Any, Dict, List, Optional, Type |
| from pathlib import Path |
| import tempfile |
| import shutil |
| import datetime |
| import os |
| import sys |
|
|
| import numpy as np |
| import cv2 |
| import torch |
| from pydantic import BaseModel, Field, field_validator |
| from langchain_core.tools import BaseTool |
| from langchain_core.callbacks import ( |
| CallbackManagerForToolRun, |
| AsyncCallbackManagerForToolRun, |
| ) |
|
|
| |
| from .echoflow_final_working import EchoFlowFinal |
|
|
| |
|
|
| class EchoFlowGenerationInput(BaseModel): |
| """Generate synthetic echo images and videos using EchoFlow.""" |
|
|
| views: List[str] = Field( |
| default_factory=lambda: ["A4C", "PLAX", "PSAX"], |
| description="Cardiac echo views to synthesize (e.g., A4C, PLAX, PSAX).", |
| ) |
| ejection_fractions: List[float] = Field( |
| default_factory=lambda: [0.35, 0.55, 0.70], |
| description="Ejection fraction values (0.0 to 1.0) used to condition the generation.", |
| ) |
| num_frames: int = Field(16, ge=1, le=64, description="Number of frames in generated videos.") |
| timestep: float = Field(0.5, ge=0.0, le=1.0, description="Diffusion timestep for generation.") |
| |
| outdir: Optional[str] = Field( |
| None, |
| description="Root output dir. If omitted, a timestamped folder is created under the tool temp dir.", |
| ) |
| save_features: bool = Field(True, description="Save generated features as numpy arrays.") |
| save_metadata: bool = Field(True, description="Save generation metadata.") |
|
|
| @field_validator("views") |
| @classmethod |
| def _nonempty_views(cls, v: List[str]) -> List[str]: |
| if not v: |
| raise ValueError("At least one view must be provided.") |
| return v |
|
|
| @field_validator("ejection_fractions") |
| @classmethod |
| def _valid_efs(cls, v: List[float]) -> List[float]: |
| if not v: |
| raise ValueError("At least one ejection fraction must be provided.") |
| for x in v: |
| if x < 0.0 or x > 1.0: |
| raise ValueError(f"Ejection fraction {x} out of range [0.0, 1.0].") |
| return v |
|
|
| |
|
|
| class EchoFlowGenerationTool(BaseTool): |
| """EchoFlow generation tool integrated with the main echo analysis system.""" |
|
|
| name: str = "echoflow_generation" |
| description: str = ( |
| "Generate synthetic echocardiography images and videos using EchoFlow. " |
| "Creates realistic echo data for training, testing, and augmentation purposes. " |
| "Supports multiple views (A4C, PLAX, PSAX) and ejection fraction conditioning." |
| ) |
| args_schema: Type[BaseModel] = EchoFlowGenerationInput |
|
|
| device: Optional[str] = "cuda" |
| temp_dir: Path = Path("temp") |
| echoflow_generator: Optional[EchoFlowFinal] = None |
|
|
| def __init__(self, device: Optional[str] = None, temp_dir: Optional[str] = None): |
| super().__init__() |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
| self.temp_dir = Path(temp_dir or tempfile.mkdtemp()) |
| self.temp_dir.mkdir(parents=True, exist_ok=True) |
| |
| |
| try: |
| self.echoflow_generator = EchoFlowFinal(device=self.device) |
| if not self.echoflow_generator.load_config(): |
| print("⚠️ Could not load EchoFlow config, using defaults") |
| if not self.echoflow_generator.load_models(): |
| raise RuntimeError("Failed to load EchoFlow models") |
| print("✅ EchoFlow generator initialized successfully") |
| except Exception as e: |
| print(f"❌ Failed to initialize EchoFlow generator: {e}") |
| self.echoflow_generator = None |
|
|
| |
|
|
| def _ensure_echoflow(self): |
| if self.echoflow_generator is None: |
| raise RuntimeError( |
| "EchoFlow generator not initialized. Check model loading and dependencies." |
| ) |
|
|
| @staticmethod |
| def _ensure_dirs(root: Path) -> Dict[str, Path]: |
| d = { |
| "features": root / "features", |
| "metadata": root / "metadata", |
| "masks": root / "masks", |
| "videos": root / "videos", |
| } |
| for p in d.values(): |
| p.mkdir(parents=True, exist_ok=True) |
| return d |
|
|
| @staticmethod |
| def _save_numpy(path: Path, arr: np.ndarray) -> str: |
| np.save(str(path), arr) |
| return str(path) |
|
|
| @staticmethod |
| def _save_json(path: Path, data: Dict[str, Any]) -> str: |
| import json |
| with open(path, 'w') as f: |
| json.dump(data, f, indent=2, default=str) |
| return str(path) |
|
|
| |
|
|
| def _run( |
| self, |
| views: List[str], |
| ejection_fractions: List[float], |
| num_frames: int = 16, |
| timestep: float = 0.5, |
| outdir: Optional[str] = None, |
| save_features: bool = True, |
| save_metadata: bool = True, |
| run_manager: Optional[CallbackManagerForToolRun] = None, |
| ) -> Dict[str, Any]: |
| self._ensure_echoflow() |
|
|
| stamp = datetime.datetime.utcnow().strftime("%Y%m%d_%H%M%S") |
| root = Path(outdir) if outdir else (self.temp_dir / f"echoflow_generation_{stamp}") |
| root.mkdir(parents=True, exist_ok=True) |
| paths = self._ensure_dirs(root) |
|
|
| run_meta = { |
| "timestamp_utc": stamp, |
| "device": self.device, |
| "views": views, |
| "ejection_fractions": ejection_fractions, |
| "num_frames": num_frames, |
| "timestep": timestep, |
| } |
|
|
| results: Dict[str, Any] = { |
| "outdir": str(root), |
| "meta": run_meta, |
| "views": {}, |
| "success": True, |
| "total_generations": 0, |
| "successful_generations": 0 |
| } |
|
|
| for view in views: |
| view_rec: Dict[str, Any] = { |
| "view": view, |
| "ejection_fractions": {}, |
| "features_saved": [], |
| "metadata_saved": [], |
| } |
| results["views"][view] = view_rec |
|
|
| |
| for ef in ejection_fractions: |
| try: |
| print(f"🎬 Generating {view} view with EF={ef:.2f}") |
| |
| |
| dummy_mask = np.random.randint(0, 255, (400, 400), dtype=np.uint8) |
| |
| |
| result = self.echoflow_generator.generate_synthetic_echo( |
| mask=dummy_mask, |
| view_type=view, |
| ejection_fraction=ef, |
| num_frames=num_frames |
| ) |
| |
| if result["success"]: |
| results["total_generations"] += 1 |
| results["successful_generations"] += 1 |
| |
| |
| if save_features: |
| features_path = paths["features"] / f"{view}_EF{ef:.2f}_features.npy" |
| self._save_numpy(features_path, result["video_features"]) |
| view_rec["features_saved"].append(str(features_path)) |
| |
| |
| if save_metadata: |
| metadata = { |
| "view": view, |
| "ejection_fraction": ef, |
| "num_frames": num_frames, |
| "timestep": timestep, |
| "video_features_shape": result["video_features"].shape, |
| "mask_processed_shape": result["mask_processed"].shape, |
| "timestamp": result["timestamp"], |
| "device": result["device"] |
| } |
| metadata_path = paths["metadata"] / f"{view}_EF{ef:.2f}_metadata.json" |
| self._save_json(metadata_path, metadata) |
| view_rec["metadata_saved"].append(str(metadata_path)) |
| |
| |
| mask_path = paths["masks"] / f"{view}_EF{ef:.2f}_mask.npy" |
| self._save_numpy(mask_path, result["mask_processed"]) |
| |
| view_rec["ejection_fractions"][f"EF_{ef:.2f}"] = { |
| "success": True, |
| "video_features_shape": result["video_features"].shape, |
| "features_path": str(features_path) if save_features else None, |
| "metadata_path": str(metadata_path) if save_metadata else None, |
| "mask_path": str(mask_path) |
| } |
| |
| print(f"✅ {view} EF={ef:.2f} generated successfully") |
| |
| else: |
| results["total_generations"] += 1 |
| view_rec["ejection_fractions"][f"EF_{ef:.2f}"] = { |
| "success": False, |
| "error": result.get("error", "Unknown error") |
| } |
| print(f"❌ {view} EF={ef:.2f} generation failed: {result.get('error', 'Unknown error')}") |
| |
| except Exception as e: |
| results["total_generations"] += 1 |
| view_rec["ejection_fractions"][f"EF_{ef:.2f}"] = { |
| "success": False, |
| "error": str(e) |
| } |
| print(f"❌ {view} EF={ef:.2f} generation error: {e}") |
|
|
| |
| if results["total_generations"] > 0: |
| results["success_rate"] = results["successful_generations"] / results["total_generations"] |
| else: |
| results["success_rate"] = 0.0 |
|
|
| print(f"\n📊 Generation Summary:") |
| print(f" Total generations: {results['total_generations']}") |
| print(f" Successful: {results['successful_generations']}") |
| print(f" Success rate: {results['success_rate']:.2%}") |
|
|
| return results |
|
|
| async def _arun( |
| self, |
| views: List[str], |
| ejection_fractions: List[float], |
| num_frames: int = 16, |
| timestep: float = 0.5, |
| outdir: Optional[str] = None, |
| save_features: bool = True, |
| save_metadata: bool = True, |
| run_manager: Optional[AsyncCallbackManagerForToolRun] = None, |
| ) -> Dict[str, Any]: |
| return self._run( |
| views=views, |
| ejection_fractions=ejection_fractions, |
| num_frames=num_frames, |
| timestep=timestep, |
| outdir=outdir, |
| save_features=save_features, |
| save_metadata=save_metadata, |
| ) |
|
|
| |
|
|
| def create_echoflow_tool(device: Optional[str] = None, temp_dir: Optional[str] = None) -> EchoFlowGenerationTool: |
| """ |
| Create an EchoFlow generation tool. |
| |
| Args: |
| device: Device to use ('cuda', 'cpu', or None for auto-detection) |
| temp_dir: Temporary directory for outputs |
| |
| Returns: |
| Initialized EchoFlowGenerationTool instance |
| """ |
| return EchoFlowGenerationTool(device=device, temp_dir=temp_dir) |
|
|
| def test_echoflow_tool(): |
| """Test the EchoFlow tool.""" |
| print("🧪 Testing EchoFlow Tool") |
| print("=" * 40) |
| |
| try: |
| |
| tool = create_echoflow_tool() |
| |
| |
| output_dir = Path(__file__).resolve().parents[2] / "temp" / "echoflow_test_output" |
| result = tool.run({ |
| "views": ["A4C", "PLAX"], |
| "ejection_fractions": [0.35, 0.65], |
| "num_frames": 8, |
| "timestep": 0.5, |
| "outdir": str(output_dir), |
| "save_features": True, |
| "save_metadata": True |
| }) |
| |
| print(f"\n📊 Test Results:") |
| print(f" Success: {result['success']}") |
| print(f" Success rate: {result['success_rate']:.2%}") |
| print(f" Output directory: {result['outdir']}") |
| |
| return result |
| |
| except Exception as e: |
| print(f"❌ EchoFlow tool test failed: {e}") |
| import traceback |
| traceback.print_exc() |
| return None |
|
|
| if __name__ == "__main__": |
| |
| test_echoflow_tool() |
|
|