| """
|
| Minimal service registry for dependency injection
|
| """
|
| import logging
|
| import traceback
|
| from typing import Any, Dict, Optional
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
| MODEL = "model"
|
| PRETRAINED_MODEL = "pretrained_model"
|
| TOKENIZER = "tokenizer"
|
| MODEL_MANAGER = "model_manager"
|
| COMMUNICATOR = "communicator"
|
| PIPELINE = "pipeline"
|
| TRANSFORMER = "transformer"
|
|
|
| class ServiceRegistry:
|
| """A minimal service registry that avoids loading heavy models"""
|
|
|
| def __init__(self):
|
| self._services = {}
|
|
|
| def register(self, key: str, service: Any, overwrite: bool = False) -> None:
|
| """Register a service with the given key"""
|
| if key in self._services:
|
| if not overwrite:
|
|
|
| if key == 'model_class_custom':
|
| logger.debug(f"Service with key '{key}' already registered")
|
| else:
|
| logger.warning(f"Service with key '{key}' already registered")
|
| return
|
| else:
|
| logger.debug(f"Overwriting service with key: {key}")
|
|
|
| self._services[key] = service
|
| logger.debug(f"Registered service with key: {key}")
|
|
|
| def get(self, key: str) -> Optional[Any]:
|
| """Get a service by its key"""
|
| if key not in self._services:
|
|
|
| return None
|
|
|
| return self._services[key]
|
|
|
| def has(self, key: str) -> bool:
|
| """Check if a service with the given key exists"""
|
| return key in self._services
|
|
|
| def clear(self) -> None:
|
| """Clear all registered services"""
|
| self._services.clear()
|
|
|
|
|
| registry = ServiceRegistry()
|
|
|
| def ensure_models_registered():
|
| """Placeholder function - don't actually register models at startup"""
|
| return True
|
|
|