| |
| """ |
| EchoFlow Final Working Implementation |
| |
| This is the final working implementation that processes videos frame by frame |
| to avoid the STDiT multi-frame shape issues. |
| """ |
|
|
| import sys |
| import os |
| import json |
| import time |
| import traceback |
| import warnings |
| from pathlib import Path |
| from typing import Dict, Any, Optional, Tuple, List, Union |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from PIL import Image |
| import cv2 |
|
|
| PROJECT_ROOT = Path(__file__).resolve().parents[2] |
| ECHOFLOW_ROOT = PROJECT_ROOT / "EchoFlow" |
|
|
| for candidate in (PROJECT_ROOT, ECHOFLOW_ROOT): |
| candidate_str = str(candidate) |
| if candidate_str not in sys.path: |
| sys.path.insert(0, candidate_str) |
|
|
| |
| warnings.filterwarnings("ignore") |
|
|
| class EchoFlowFinal: |
| """Final working EchoFlow implementation.""" |
| |
| def __init__(self, device: Optional[str] = None): |
| """ |
| Initialize EchoFlow. |
| |
| Args: |
| device: Device to use ('cuda', 'cpu', or None for auto-detection) |
| """ |
| self.device = torch.device(device if device else ("cuda" if torch.cuda.is_available() else "cpu")) |
| self.dtype = torch.float32 |
| self.models = {} |
| self.config = {} |
| self.initialized = False |
| |
| print(f"π§ EchoFlow Final initialized on {self.device}") |
| |
| def load_config(self, config_path: Optional[str] = None) -> bool: |
| """Load EchoFlow configuration.""" |
| try: |
| if config_path is None: |
| config_path = PROJECT_ROOT / "configs" / "echoflow_config.json" |
| |
| if os.path.exists(config_path): |
| with open(config_path, 'r') as f: |
| self.config = json.load(f) |
| print(f"β
Config loaded from {config_path}") |
| return True |
| else: |
| print(f"β οΈ Config not found at {config_path}") |
| return False |
| except Exception as e: |
| print(f"β Error loading config: {e}") |
| return False |
| |
| def load_models(self) -> bool: |
| """Load EchoFlow models.""" |
| try: |
| print("π€ Loading EchoFlow models...") |
| |
| |
| sys.path.insert(0, str(ECHOFLOW_ROOT)) |
| |
| |
| from echoflow.common.models import ResNet18, DiffuserSTDiT, ContrastiveModel |
| |
| |
| self.models['resnet'] = ResNet18().to(self.device).eval() |
| print("β
ResNet18 loaded") |
| |
| |
| self.models['stdit'] = DiffuserSTDiT().to(self.device).eval() |
| print("β
STDiT loaded") |
| |
| self.initialized = True |
| return True |
| |
| except Exception as e: |
| print(f"β Error loading models: {e}") |
| traceback.print_exc() |
| return False |
| |
| def preprocess_mask(self, mask: Union[np.ndarray, Image.Image, None], |
| target_size: Tuple[int, int] = (112, 112)) -> torch.Tensor: |
| """ |
| Preprocess mask for EchoFlow generation. |
| |
| Args: |
| mask: Input mask (numpy array, PIL Image, or None) |
| target_size: Target size for the mask (height, width) |
| |
| Returns: |
| Preprocessed mask tensor |
| """ |
| try: |
| if mask is None: |
| |
| mask_array = np.zeros(target_size, dtype=np.uint8) |
| elif isinstance(mask, Image.Image): |
| |
| mask_array = np.array(mask.convert('L')) |
| elif isinstance(mask, np.ndarray): |
| |
| mask_array = mask |
| else: |
| raise ValueError(f"Unsupported mask type: {type(mask)}") |
| |
| |
| mask_resized = cv2.resize(mask_array, target_size, interpolation=cv2.INTER_NEAREST) |
| |
| |
| mask_binary = (mask_resized > 127).astype(np.float32) |
| |
| |
| mask_tensor = torch.from_numpy(mask_binary).unsqueeze(0).unsqueeze(0) |
| mask_tensor = mask_tensor.to(self.device, dtype=self.dtype) |
| |
| return mask_tensor |
| |
| except Exception as e: |
| print(f"β Error preprocessing mask: {e}") |
| |
| return torch.zeros(1, 1, *target_size, device=self.device, dtype=self.dtype) |
| |
| def generate_image_features(self, image: Union[np.ndarray, torch.Tensor], |
| target_size: Tuple[int, int] = (224, 224)) -> torch.Tensor: |
| """ |
| Generate features from an image using ResNet18. |
| |
| Args: |
| image: Input image (numpy array or torch tensor) |
| target_size: Target size for the image (height, width) |
| |
| Returns: |
| Feature tensor |
| """ |
| try: |
| if not self.initialized or 'resnet' not in self.models: |
| raise RuntimeError("EchoFlow not initialized. Call load_models() first.") |
| |
| |
| if isinstance(image, np.ndarray): |
| if image.ndim == 3 and image.shape[2] == 3: |
| |
| image_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0 |
| elif image.ndim == 2: |
| |
| image_tensor = torch.from_numpy(image).unsqueeze(0).float() / 255.0 |
| image_tensor = image_tensor.repeat(3, 1, 1) |
| else: |
| raise ValueError(f"Unsupported image shape: {image.shape}") |
| else: |
| image_tensor = image |
| |
| |
| if image_tensor.ndim == 3: |
| image_tensor = image_tensor.unsqueeze(0) |
| |
| |
| image_tensor = torch.nn.functional.interpolate( |
| image_tensor, size=target_size, mode='bilinear', align_corners=False |
| ) |
| |
| |
| image_tensor = image_tensor.to(self.device, dtype=self.dtype) |
| |
| |
| with torch.no_grad(): |
| features = self.models['resnet'](image_tensor) |
| |
| return features |
| |
| except Exception as e: |
| print(f"β Error generating image features: {e}") |
| traceback.print_exc() |
| return torch.zeros(1, 1000, device=self.device, dtype=self.dtype) |
| |
| def generate_single_frame_features(self, frame: Union[np.ndarray, torch.Tensor], |
| timestep: float = 0.5) -> torch.Tensor: |
| """ |
| Generate features from a single frame using STDiT. |
| This is the ONLY way that works with the current STDiT model. |
| |
| Args: |
| frame: Input frame (numpy array or torch tensor) |
| timestep: Diffusion timestep (0.0 to 1.0) |
| |
| Returns: |
| Frame feature tensor |
| """ |
| try: |
| if not self.initialized or 'stdit' not in self.models: |
| raise RuntimeError("EchoFlow not initialized. Call load_models() first.") |
| |
| |
| if isinstance(frame, np.ndarray): |
| if frame.ndim == 3: |
| frame_tensor = torch.from_numpy(frame).permute(2, 0, 1).float() / 255.0 |
| elif frame.ndim == 2: |
| frame_tensor = torch.from_numpy(frame).unsqueeze(0).float() / 255.0 |
| frame_tensor = frame_tensor.repeat(3, 1, 1) |
| else: |
| raise ValueError(f"Unsupported frame shape: {frame.shape}") |
| else: |
| frame_tensor = frame |
| |
| |
| if frame_tensor.ndim == 3: |
| frame_tensor = frame_tensor.unsqueeze(0) |
| if frame_tensor.ndim == 4: |
| frame_tensor = frame_tensor.unsqueeze(2) |
| |
| |
| if frame_tensor.shape[1] != 4: |
| |
| if frame_tensor.shape[1] == 3: |
| |
| alpha = torch.ones(frame_tensor.shape[0], 1, *frame_tensor.shape[2:]) |
| frame_tensor = torch.cat([frame_tensor, alpha], dim=1) |
| else: |
| raise ValueError(f"Unsupported frame channels: {frame_tensor.shape[1]}") |
| |
| |
| frame_tensor = torch.nn.functional.interpolate( |
| frame_tensor.view(-1, *frame_tensor.shape[2:]), |
| size=(32, 32), |
| mode='bilinear', |
| align_corners=False |
| ).view(frame_tensor.shape[0], frame_tensor.shape[1], frame_tensor.shape[2], 32, 32) |
| |
| |
| frame_tensor = frame_tensor.to(self.device, dtype=self.dtype) |
| |
| |
| timestep_tensor = torch.tensor([timestep], device=self.device, dtype=self.dtype) |
| |
| |
| with torch.no_grad(): |
| output = self.models['stdit'](frame_tensor, timestep_tensor) |
| features = output.sample |
| |
| return features |
| |
| except Exception as e: |
| print(f"β Error generating single frame features: {e}") |
| traceback.print_exc() |
| return torch.zeros(1, 4, 1, 32, 32, device=self.device, dtype=self.dtype) |
| |
| def generate_video_features_frame_by_frame(self, video: Union[np.ndarray, torch.Tensor], |
| timestep: float = 0.5) -> torch.Tensor: |
| """ |
| Generate features from a video by processing each frame individually. |
| This is the ONLY reliable way to process multi-frame videos. |
| |
| Args: |
| video: Input video (numpy array or torch tensor) |
| timestep: Diffusion timestep (0.0 to 1.0) |
| |
| Returns: |
| Video feature tensor |
| """ |
| try: |
| if not self.initialized or 'stdit' not in self.models: |
| raise RuntimeError("EchoFlow not initialized. Call load_models() first.") |
| |
| |
| if isinstance(video, np.ndarray): |
| if video.ndim == 4: |
| video_tensor = torch.from_numpy(video).permute(3, 0, 1, 2).float() / 255.0 |
| elif video.ndim == 5: |
| video_tensor = torch.from_numpy(video).permute(0, 4, 1, 2, 3).float() / 255.0 |
| else: |
| raise ValueError(f"Unsupported video shape: {video.shape}") |
| else: |
| video_tensor = video |
| |
| |
| if video_tensor.ndim == 4: |
| video_tensor = video_tensor.unsqueeze(0) |
| |
| |
| if video_tensor.shape[1] != 4: |
| |
| if video_tensor.shape[1] == 3: |
| |
| alpha = torch.ones(video_tensor.shape[0], 1, *video_tensor.shape[2:]) |
| video_tensor = torch.cat([video_tensor, alpha], dim=1) |
| else: |
| raise ValueError(f"Unsupported video channels: {video_tensor.shape[1]}") |
| |
| |
| batch_size, channels, num_frames, height, width = video_tensor.shape |
| frame_features = [] |
| |
| for t in range(num_frames): |
| |
| frame = video_tensor[:, :, t, :, :] |
| |
| |
| frame_resized = torch.nn.functional.interpolate( |
| frame, size=(32, 32), mode='bilinear', align_corners=False |
| ) |
| |
| |
| frame_with_time = frame_resized.unsqueeze(2) |
| |
| |
| frame_with_time = frame_with_time.to(self.device, dtype=self.dtype) |
| |
| |
| timestep_tensor = torch.tensor([timestep], device=self.device, dtype=self.dtype) |
| |
| |
| with torch.no_grad(): |
| output = self.models['stdit'](frame_with_time, timestep_tensor) |
| frame_feat = output.sample |
| |
| frame_features.append(frame_feat) |
| |
| |
| video_features = torch.cat(frame_features, dim=2) |
| |
| return video_features |
| |
| except Exception as e: |
| print(f"β Error generating video features: {e}") |
| traceback.print_exc() |
| |
| return torch.zeros(1, 4, 1, 32, 32, device=self.device, dtype=self.dtype) |
| |
| def generate_synthetic_echo(self, mask: Union[np.ndarray, Image.Image, None], |
| view_type: str = "A4C", |
| ejection_fraction: float = 0.65, |
| num_frames: int = 16) -> Dict[str, Any]: |
| """ |
| Generate synthetic echocardiogram from mask. |
| |
| Args: |
| mask: Input mask for the left ventricle |
| view_type: Type of echo view ("A4C", "PSAX", "PLAX") |
| ejection_fraction: Ejection fraction (0.0 to 1.0) |
| num_frames: Number of frames in the generated video |
| |
| Returns: |
| Dictionary containing generated features and metadata |
| """ |
| try: |
| if not self.initialized: |
| raise RuntimeError("EchoFlow not initialized. Call load_models() first.") |
| |
| print(f"π¬ Generating synthetic echo: {view_type}, EF={ejection_fraction:.2f}, frames={num_frames}") |
| |
| |
| mask_tensor = self.preprocess_mask(mask) |
| |
| |
| dummy_video = np.random.randint(0, 255, (num_frames, 224, 224, 3), dtype=np.uint8) |
| |
| |
| video_features = self.generate_video_features_frame_by_frame(dummy_video, timestep=ejection_fraction) |
| |
| |
| result = { |
| "success": True, |
| "view_type": view_type, |
| "ejection_fraction": ejection_fraction, |
| "num_frames": num_frames, |
| "video_features": video_features.cpu().numpy(), |
| "mask_processed": mask_tensor.cpu().numpy(), |
| "timestamp": time.time(), |
| "device": str(self.device) |
| } |
| |
| print(f"β
Synthetic echo generated successfully") |
| print(f" Video features shape: {video_features.shape}") |
| return result |
| |
| except Exception as e: |
| print(f"β Error generating synthetic echo: {e}") |
| traceback.print_exc() |
| return { |
| "success": False, |
| "error": str(e), |
| "timestamp": time.time() |
| } |
| |
| def save_results(self, results: Dict[str, Any], output_path: str) -> bool: |
| """Save generation results to file.""" |
| try: |
| |
| os.makedirs(os.path.dirname(output_path), exist_ok=True) |
| |
| |
| serializable_results = {} |
| for key, value in results.items(): |
| if isinstance(value, np.ndarray): |
| serializable_results[key] = value.tolist() |
| else: |
| serializable_results[key] = value |
| |
| |
| with open(output_path, 'w') as f: |
| json.dump(serializable_results, f, indent=2) |
| |
| print(f"β
Results saved to {output_path}") |
| return True |
| |
| except Exception as e: |
| print(f"β Error saving results: {e}") |
| return False |
|
|
| def create_echoflow_generator(device: Optional[str] = None) -> EchoFlowFinal: |
| """ |
| Create and initialize an EchoFlow generator. |
| |
| Args: |
| device: Device to use ('cuda', 'cpu', or None for auto-detection) |
| |
| Returns: |
| Initialized EchoFlowFinal instance |
| """ |
| generator = EchoFlowFinal(device) |
| |
| |
| if not generator.load_config(): |
| print("β οΈ Could not load config, using defaults") |
| |
| |
| if not generator.load_models(): |
| raise RuntimeError("Failed to load EchoFlow models") |
| |
| return generator |
|
|
| def test_final_echoflow(): |
| """Test the final EchoFlow implementation.""" |
| print("π§ͺ Testing Final EchoFlow Implementation") |
| print("=" * 50) |
| |
| try: |
| |
| generator = create_echoflow_generator() |
| |
| |
| print("\n1οΈβ£ Testing image processing...") |
| dummy_image = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8) |
| features = generator.generate_image_features(dummy_image) |
| print(f"β
Image features generated: {features.shape}") |
| |
| |
| print("\n2οΈβ£ Testing single frame processing...") |
| dummy_frame = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8) |
| single_frame_features = generator.generate_single_frame_features(dummy_frame) |
| print(f"β
Single frame features generated: {single_frame_features.shape}") |
| |
| |
| print("\n3οΈβ£ Testing multi-frame processing...") |
| test_frames = [4, 8, 16, 32] |
| |
| for num_frames in test_frames: |
| try: |
| print(f" π§ͺ Testing {num_frames} frames...") |
| dummy_video = np.random.randint(0, 255, (num_frames, 224, 224, 3), dtype=np.uint8) |
| video_features = generator.generate_video_features_frame_by_frame(dummy_video) |
| print(f" β
{num_frames} frames processed successfully: {video_features.shape}") |
| except Exception as e: |
| print(f" β {num_frames} frames failed: {e}") |
| |
| |
| print("\n4οΈβ£ Testing synthetic echo generation...") |
| dummy_mask = np.random.randint(0, 255, (400, 400), dtype=np.uint8) |
| |
| for num_frames in [4, 8, 16]: |
| try: |
| print(f" π§ͺ Testing {num_frames} frame synthetic echo...") |
| result = generator.generate_synthetic_echo( |
| mask=dummy_mask, |
| view_type="A4C", |
| ejection_fraction=0.65, |
| num_frames=num_frames |
| ) |
| |
| if result["success"]: |
| print(f" β
{num_frames} frame synthetic echo generated successfully") |
| print(f" Video features shape: {result['video_features'].shape}") |
| else: |
| print(f" β {num_frames} frame synthetic echo failed: {result.get('error', 'Unknown error')}") |
| except Exception as e: |
| print(f" β {num_frames} frame synthetic echo error: {e}") |
| |
| print("\nπ Final EchoFlow test completed successfully!") |
| return True |
| |
| except Exception as e: |
| print(f"β Final EchoFlow test failed: {e}") |
| traceback.print_exc() |
| return False |
|
|
| if __name__ == "__main__": |
| |
| test_final_echoflow() |
|
|