| | """ |
| | Model Factory |
| | |
| | This module provides a factory for loading and managing different types of models |
| | using real weights from the model_weights directory. |
| | """ |
| |
|
| | import os |
| | import sys |
| | import torch |
| | import json |
| | from typing import Dict, Any, Optional |
| | from pathlib import Path |
| |
|
| | |
| | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| |
|
| | class ModelFactory: |
| | """Factory for loading and managing different types of models.""" |
| | |
| | def __init__(self): |
| | self._models: Dict[str, Any] = {} |
| | self.model_weights_dir = Path("model_weights") |
| | self.checkpoints_dir = Path("checkpoints") |
| | |
| | def load_echo_prime(self) -> Optional[Any]: |
| | """Load EchoPrime model with real weights.""" |
| | try: |
| | |
| | echo_prime_path = self.model_weights_dir / "echo_prime" |
| | if echo_prime_path.exists(): |
| | |
| | sys.path.insert(0, str(echo_prime_path)) |
| | |
| | |
| | from model import EchoPrime |
| | |
| | |
| | model = EchoPrime(device="cuda" if torch.cuda.is_available() else "cpu") |
| | |
| | print("β
EchoPrime model loaded successfully") |
| | return model |
| | else: |
| | print(f"β EchoPrime weights directory not found: {echo_prime_path}") |
| | return None |
| | |
| | except Exception as e: |
| | print(f"β Failed to load EchoPrime model: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | return None |
| | |
| | def load_panecho(self) -> Optional[Any]: |
| | """Load PanEcho model with real weights and all available tasks.""" |
| | try: |
| | |
| | model = torch.hub.load( |
| | 'CarDS-Yale/PanEcho', |
| | 'PanEcho', |
| | force_reload=False, |
| | tasks='all', |
| | clip_len=16 |
| | ) |
| | model.eval() |
| | |
| | |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | model = model.to(device) |
| | |
| | |
| | available_tasks = list(model.tasks) if hasattr(model, 'tasks') else [] |
| | task_names = [task.task_name for task in available_tasks] if available_tasks else [] |
| | |
| | print(f"β
PanEcho model loaded successfully with {len(task_names)} tasks") |
| | print(f" Total tasks available: {len(task_names)}") |
| | return model |
| | |
| | except Exception as e: |
| | print(f"β Failed to load PanEcho model: {e}") |
| | return None |
| | |
| | def load_medsam2(self) -> Optional[Any]: |
| | """Load MedSAM2 model with real weights.""" |
| | try: |
| | |
| | checkpoint_path = self.checkpoints_dir / "MedSAM2_US_Heart.pt" |
| | if checkpoint_path.exists(): |
| | print(f"β
Using local MedSAM2 checkpoint: {checkpoint_path}") |
| | return str(checkpoint_path) |
| | |
| | |
| | from huggingface_hub import hf_hub_download |
| | model_path = hf_hub_download(repo_id="wanglab/MedSAM2", filename="MedSAM2_US_Heart.pt") |
| | print(f"β
Downloaded MedSAM2 model: {model_path}") |
| | return model_path |
| | |
| | except Exception as e: |
| | print(f"β Failed to load MedSAM2 model: {e}") |
| | return None |
| | |
| | def load_echoflow(self) -> Optional[Any]: |
| | """Load EchoFlow model with real weights.""" |
| | try: |
| | root = Path(__file__).resolve().parents[1] |
| | candidates = [ |
| | root / "tool_repos" / "EchoFlow", |
| | root / "tool_repos" / "EchoFlow-main", |
| | ] |
| | workspace_root = os.getenv("ECHO_WORKSPACE_ROOT") |
| | if workspace_root: |
| | candidates.append(Path(workspace_root) / "EchoFlow") |
| | candidates.append(Path(workspace_root) / "tool_repos" / "EchoFlow") |
| |
|
| | echoflow_path = next((path for path in candidates if path.exists()), None) |
| | if echoflow_path is None: |
| | print("β EchoFlow directory not found in tool_repos. Please clone EchoFlow into tool_repos/EchoFlow") |
| | return None |
| | |
| | |
| | sys.path.insert(0, str(echoflow_path)) |
| | |
| | |
| | from echoflow.common.echoflow_model import EchoFlowModel |
| | |
| | |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | model = EchoFlowModel(device=device, model_dir=echoflow_path) |
| | |
| | |
| | if model.load_components(): |
| | print("β
EchoFlow model loaded successfully") |
| | return model |
| | else: |
| | print("β Failed to load EchoFlow components") |
| | return None |
| | |
| | except Exception as e: |
| | print(f"β Failed to load EchoFlow model: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | return None |
| | |
| | def get_model(self, model_name: str) -> Optional[Any]: |
| | """Get a model by name.""" |
| | if model_name in self._models: |
| | return self._models[model_name] |
| | |
| | |
| | if model_name == "echo_prime": |
| | model = self.load_echo_prime() |
| | elif model_name == "panecho": |
| | model = self.load_panecho() |
| | elif model_name == "medsam2": |
| | model = self.load_medsam2() |
| | elif model_name == "echoflow": |
| | model = self.load_echoflow() |
| | else: |
| | print(f"β Unknown model: {model_name}") |
| | return None |
| | |
| | if model is not None: |
| | self._models[model_name] = model |
| | |
| | return model |
| | |
| | def get_available_models(self) -> list: |
| | """Get list of available models.""" |
| | return ["echo_prime", "panecho", "medsam2", "echoflow"] |
| | |
| | def cleanup(self): |
| | """Clean up all loaded models.""" |
| | for model_name, model in self._models.items(): |
| | if hasattr(model, 'cpu'): |
| | model.cpu() |
| | del model |
| | self._models.clear() |
| | print("β
All models cleaned up") |
| |
|
| |
|
| | |
| | model_factory = ModelFactory() |
| |
|
| | def get_model(model_name: str) -> Optional[Any]: |
| | """Get a model using the global factory.""" |
| | return model_factory.get_model(model_name) |
| |
|
| | def get_available_models() -> list: |
| | """Get available models using the global factory.""" |
| | return model_factory.get_available_models() |
| |
|
| | def cleanup_all_models(): |
| | """Clean up all models using the global factory.""" |
| | model_factory.cleanup() |
| |
|