StarMist0012's picture
Add files using upload-large-folder tool
3270dae verified
"""Model architecture registry and factory."""
from typing import Dict, Type, Optional
import torch
from taoTrain.core import BaseModel
from taoTrain.config import ModelConfig
# Global registry for model architectures
_ARCHITECTURE_REGISTRY: Dict[str, Type[BaseModel]] = {}
def register_architecture(name: str):
"""Decorator to register a custom model architecture."""
def decorator(cls: Type[BaseModel]):
if name in _ARCHITECTURE_REGISTRY:
raise ValueError(f"Architecture '{name}' is already registered")
_ARCHITECTURE_REGISTRY[name] = cls
return cls
return decorator
def get_registered_architectures() -> Dict[str, Type[BaseModel]]:
"""Get all registered architectures."""
return _ARCHITECTURE_REGISTRY.copy()
def get_model(
config: ModelConfig,
device: Optional[torch.device] = None,
) -> BaseModel:
"""
Create a model instance from config.
Args:
config: ModelConfig instance
device: Device to create model on (defaults to CPU)
Returns:
Model instance
"""
if device is None:
device = torch.device('cpu')
# Handle both enum and string values
arch_type = config.architecture_type
if isinstance(arch_type, str):
arch_name = arch_type
else:
arch_name = arch_type.value
if arch_name not in _ARCHITECTURE_REGISTRY:
raise ValueError(
f"Unknown architecture: {arch_name}. "
f"Available: {list(_ARCHITECTURE_REGISTRY.keys())}"
)
model_class = _ARCHITECTURE_REGISTRY[arch_name]
model = model_class(config).to(device)
return model
def register_builtin_architectures():
"""Register all built-in architectures."""
# Import here to register (avoid circular imports)
from . import transformer # noqa: F401
from . import taonet # noqa: F401
from . import taonet_ssm # noqa: F401
# Auto-register built-in architectures when module is imported
register_builtin_architectures()