| | from __future__ import annotations |
| |
|
| | from typing import Any, Dict, List, Optional, Type |
| | from pathlib import Path |
| | import tempfile |
| | import shutil |
| | import datetime |
| |
|
| | import numpy as np |
| | import cv2 |
| | import torch |
| | from pydantic import BaseModel, Field, field_validator |
| | from langchain_core.tools import BaseTool |
| | from langchain_core.callbacks import ( |
| | CallbackManagerForToolRun, |
| | AsyncCallbackManagerForToolRun, |
| | ) |
| |
|
| | |
| | try: |
| | import sys |
| | import os |
| | from pathlib import Path |
| | |
| | current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| | echoflow_path = os.path.join(current_dir, "model_weights", "EchoFlow") |
| | if echoflow_path not in sys.path: |
| | sys.path.insert(0, echoflow_path) |
| | |
| | from demo import ( |
| | load_view_mask, |
| | generate_latent_image, |
| | convert_latent_to_display, |
| | decode_latent_to_pixel, |
| | check_privacy, |
| | generate_animation, |
| | latent_animation_to_grayscale, |
| | decode_animation, |
| | ) |
| | except Exception: |
| | load_view_mask = None |
| | generate_latent_image = None |
| | convert_latent_to_display = None |
| | decode_latent_to_pixel = None |
| | check_privacy = None |
| | generate_animation = None |
| | latent_animation_to_grayscale = None |
| | decode_animation = None |
| |
|
| |
|
| | |
| |
|
| | class EchoSynthesisInput(BaseModel): |
| | """Generate synthetic echo images and EF-conditioned videos via EchoFlow demo.""" |
| |
|
| | views: List[str] = Field( |
| | default_factory=lambda: ["A4C", "PLAX", "PSAX"], |
| | description="Cardiac echo views to synthesize (e.g., A4C, PLAX, PSAX).", |
| | ) |
| | efs: List[int] = Field( |
| | default_factory=lambda: [35, 55, 70], |
| | description="Ejection fraction percentages used to condition the animation.", |
| | ) |
| | img_steps: int = Field(150, ge=1, le=2000, description="Sampling steps for latent image generation.") |
| | vid_steps: int = Field(150, ge=1, le=2000, description="Sampling steps for latent video.") |
| | cfg_scale: float = Field(1.0, ge=0.0, le=20.0, description="CFG scale for animation generation.") |
| | max_privacy_retries: int = Field(3, ge=0, le=20, description="Max retries if privacy filter fails.") |
| |
|
| | outdir: Optional[str] = Field( |
| | None, |
| | description="Root output dir. If omitted, a timestamped folder is created under the tool temp dir.", |
| | ) |
| | save_decoded_image: bool = Field(True, description="Save decoded RGB PNG per view.") |
| | save_latent_preview: bool = Field(True, description="Save latent grayscale PNG per view.") |
| | keep_failed_privacy_preview: bool = Field( |
| | True, |
| | description="If privacy fails after retries, save the last latent preview for diagnostics.", |
| | ) |
| |
|
| | @field_validator("views") |
| | @classmethod |
| | def _nonempty_views(cls, v: List[str]) -> List[str]: |
| | if not v: |
| | raise ValueError("At least one view must be provided.") |
| | return v |
| |
|
| | @field_validator("efs") |
| | @classmethod |
| | def _valid_efs(cls, v: List[int]) -> List[int]: |
| | if not v: |
| | raise ValueError("At least one EF must be provided.") |
| | for x in v: |
| | if x < 0 or x > 100: |
| | raise ValueError(f"EF {x} out of range [0, 100].") |
| | return v |
| |
|
| |
|
| | |
| |
|
| | class EchoSynthesisTool(BaseTool): |
| | """EchoFlow synthesis tool consistent with your EchoPrime tool suite.""" |
| |
|
| | name: str = "echo_synthesis" |
| | description: str = ( |
| | "Synthesize echocardiography images and EF-conditioned videos using EchoFlow demo primitives. " |
| | "For each view: generate latent, save latent preview/decoded image, pass privacy filter, then render EF videos " |
| | "(latent grayscale MP4 and decoded RGB MP4). Returns artifact paths and metadata." |
| | ) |
| | args_schema: Type[BaseModel] = EchoSynthesisInput |
| |
|
| | device: Optional[str] = "cuda" |
| | temp_dir: Path = Path("temp") |
| |
|
| | def __init__(self, device: Optional[str] = None, temp_dir: Optional[str] = None): |
| | super().__init__() |
| | |
| | self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
| | self.temp_dir = Path(temp_dir or tempfile.mkdtemp()) |
| | self.temp_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| |
|
| | def _ensure_demo(self): |
| | if any(x is None for x in [ |
| | load_view_mask, generate_latent_image, convert_latent_to_display, |
| | decode_latent_to_pixel, check_privacy, generate_animation, |
| | latent_animation_to_grayscale, decode_animation, |
| | ]): |
| | raise RuntimeError( |
| | "EchoFlow demo functions not importable. Ensure the 'demo' module and assets are in PYTHONPATH / working directory." |
| | ) |
| |
|
| | @staticmethod |
| | def _ensure_dirs(root: Path) -> Dict[str, Path]: |
| | d = { |
| | "grayscale_frames": root / "grayscale_frames", |
| | "decoded_images": root / "decoded_images", |
| | "latent_videos": root / "latent_videos", |
| | "decoded_videos": root / "decoded_videos", |
| | "meta": root / "meta", |
| | } |
| | for p in d.values(): |
| | p.mkdir(parents=True, exist_ok=True) |
| | return d |
| |
|
| | @staticmethod |
| | def _save_png(path: Path, arr: np.ndarray) -> str: |
| | if arr.dtype != np.uint8: |
| | arr = np.clip(arr, 0, 255).astype(np.uint8) |
| | if arr.ndim == 3 and arr.shape[2] == 3: |
| | arr = arr[:, :, ::-1] |
| | if not cv2.imwrite(str(path), arr): |
| | raise IOError(f"Failed to write image: {path}") |
| | return str(path) |
| |
|
| | |
| |
|
| | def _run( |
| | self, |
| | views: List[str], |
| | efs: List[int], |
| | img_steps: int = 150, |
| | vid_steps: int = 150, |
| | cfg_scale: float = 1.0, |
| | max_privacy_retries: int = 3, |
| | outdir: Optional[str] = None, |
| | save_decoded_image: bool = True, |
| | save_latent_preview: bool = True, |
| | keep_failed_privacy_preview: bool = True, |
| | run_manager: Optional[CallbackManagerForToolRun] = None, |
| | ) -> Dict[str, Any]: |
| | self._ensure_demo() |
| |
|
| | stamp = datetime.datetime.utcnow().strftime("%Y%m%d_%H%M%S") |
| | root = Path(outdir) if outdir else (self.temp_dir / f"echoflow_run_{stamp}") |
| | root.mkdir(parents=True, exist_ok=True) |
| | paths = self._ensure_dirs(root) |
| |
|
| | run_meta = { |
| | "timestamp_utc": stamp, |
| | "device": self.device, |
| | "views": views, |
| | "efs": efs, |
| | "img_steps": img_steps, |
| | "vid_steps": vid_steps, |
| | "cfg_scale": cfg_scale, |
| | "max_privacy_retries": max_privacy_retries, |
| | } |
| |
|
| | results: Dict[str, Any] = {"outdir": str(root), "meta": run_meta, "views": {}} |
| |
|
| | for view in views: |
| | view_rec: Dict[str, Any] = { |
| | "view": view, |
| | "latent_preview_png": None, |
| | "decoded_image_png": None, |
| | "privacy_passed": False, |
| | "privacy_message": None, |
| | "videos": [], |
| | } |
| | results["views"][view] = view_rec |
| |
|
| | |
| | mask = load_view_mask(view) |
| |
|
| | |
| | latent = generate_latent_image(mask, view, sampling_steps=img_steps) |
| |
|
| | |
| | if save_latent_preview: |
| | preview = convert_latent_to_display(latent) |
| | view_rec["latent_preview_png"] = self._save_png( |
| | paths["grayscale_frames"] / f"{view}_latent.png", preview |
| | ) |
| | if save_decoded_image: |
| | decoded = decode_latent_to_pixel(latent) |
| | view_rec["decoded_image_png"] = self._save_png( |
| | paths["decoded_images"] / f"{view}_decoded.png", decoded |
| | ) |
| |
|
| | |
| | filtered_latent, msg = check_privacy(latent, view) |
| | tries = 0 |
| | while filtered_latent is None and tries < max_privacy_retries: |
| | latent = generate_latent_image(mask, view, sampling_steps=img_steps) |
| | filtered_latent, msg = check_privacy(latent, view) |
| | tries += 1 |
| |
|
| | view_rec["privacy_message"] = msg |
| | if filtered_latent is None: |
| | if keep_failed_privacy_preview: |
| | rejected_preview = convert_latent_to_display(latent) |
| | self._save_png(paths["grayscale_frames"] / f"{view}_privacy_reject.png", rejected_preview) |
| | continue |
| | view_rec["privacy_passed"] = True |
| |
|
| | |
| | for ef in efs: |
| | lat_vid = generate_animation(filtered_latent, int(ef), sampling_steps=vid_steps, cfg_scale=cfg_scale) |
| |
|
| | gray_tmp = latent_animation_to_grayscale(lat_vid) |
| | gray_target = paths["latent_videos"] / f"{view}_EF{ef}.mp4" |
| | shutil.move(gray_tmp, gray_target) |
| |
|
| | dec_tmp = decode_animation(lat_vid) |
| | dec_target = paths["decoded_videos"] / f"{view}_EF{ef}.mp4" |
| | shutil.move(dec_tmp, dec_target) |
| |
|
| | view_rec["videos"].append({ |
| | "ef": int(ef), |
| | "latent_grayscale_mp4": str(gray_target), |
| | "decoded_rgb_mp4": str(dec_target), |
| | }) |
| |
|
| | return results |
| |
|
| | async def _arun( |
| | self, |
| | views: List[str], |
| | efs: List[int], |
| | img_steps: int = 150, |
| | vid_steps: int = 150, |
| | cfg_scale: float = 1.0, |
| | max_privacy_retries: int = 3, |
| | outdir: Optional[str] = None, |
| | save_decoded_image: bool = True, |
| | save_latent_preview: bool = True, |
| | keep_failed_privacy_preview: bool = True, |
| | run_manager: Optional[AsyncCallbackManagerForToolRun] = None, |
| | ) -> Dict[str, Any]: |
| | return self._run( |
| | views=views, |
| | efs=efs, |
| | img_steps=img_steps, |
| | vid_steps=vid_steps, |
| | cfg_scale=cfg_scale, |
| | max_privacy_retries=max_privacy_retries, |
| | outdir=outdir, |
| | save_decoded_image=save_decoded_image, |
| | save_latent_preview=save_latent_preview, |
| | keep_failed_privacy_preview=keep_failed_privacy_preview, |
| | ) |
| |
|
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |