"""Model architecture registry and factory.""" from typing import Dict, Type, Optional import torch from taoTrain.core import BaseModel from taoTrain.config import ModelConfig # Global registry for model architectures _ARCHITECTURE_REGISTRY: Dict[str, Type[BaseModel]] = {} def register_architecture(name: str): """Decorator to register a custom model architecture.""" def decorator(cls: Type[BaseModel]): if name in _ARCHITECTURE_REGISTRY: raise ValueError(f"Architecture '{name}' is already registered") _ARCHITECTURE_REGISTRY[name] = cls return cls return decorator def get_registered_architectures() -> Dict[str, Type[BaseModel]]: """Get all registered architectures.""" return _ARCHITECTURE_REGISTRY.copy() def get_model( config: ModelConfig, device: Optional[torch.device] = None, ) -> BaseModel: """ Create a model instance from config. Args: config: ModelConfig instance device: Device to create model on (defaults to CPU) Returns: Model instance """ if device is None: device = torch.device('cpu') # Handle both enum and string values arch_type = config.architecture_type if isinstance(arch_type, str): arch_name = arch_type else: arch_name = arch_type.value if arch_name not in _ARCHITECTURE_REGISTRY: raise ValueError( f"Unknown architecture: {arch_name}. " f"Available: {list(_ARCHITECTURE_REGISTRY.keys())}" ) model_class = _ARCHITECTURE_REGISTRY[arch_name] model = model_class(config).to(device) return model def register_builtin_architectures(): """Register all built-in architectures.""" # Import here to register (avoid circular imports) from . import transformer # noqa: F401 from . import taonet # noqa: F401 from . import taonet_ssm # noqa: F401 # Auto-register built-in architectures when module is imported register_builtin_architectures()