velai / services /registry.py
cansik's picture
Upload folder via script
691f45a verified
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Iterable, List, MutableMapping, Type
from services.services import TaskType, GenerationService, TService
@dataclass
class ServiceRegistry:
"""Registry for generation services, grouped by task type."""
_services: MutableMapping[TaskType, MutableMapping[str, Type[GenerationService]]] = field(
default_factory=dict
)
def register(self, service_cls: Type[GenerationService]) -> None:
task_type = service_cls.task_type
service_id = getattr(service_cls, "service_id", service_cls.__name__)
bucket = self._services.setdefault(task_type, {})
if service_id in bucket:
raise ValueError(
f"Service id {service_id!r} already registered for task type {task_type}."
)
bucket[service_id] = service_cls
def get(self, task_type: TaskType, service_id: str) -> Type[GenerationService]:
bucket = self._services.get(task_type)
if bucket is None or service_id not in bucket:
raise KeyError(
f"No service registered for type {task_type} with id {service_id!r}."
)
return bucket[service_id]
def create(self, task_type: TaskType, service_id: str, **kwargs) -> GenerationService:
service_cls = self.get(task_type, service_id)
return service_cls(**kwargs)
def list_ids(self, task_type: TaskType) -> List[str]:
bucket = self._services.get(task_type, {})
return sorted(bucket.keys())
def iter_services(self, task_type: TaskType) -> Iterable[Type[GenerationService]]:
bucket = self._services.get(task_type, {})
return bucket.values()
_default_registry = ServiceRegistry()
def register_service(service_cls: Type[TService]) -> Type[TService]:
"""Class decorator to register a service in the default registry."""
_default_registry.register(service_cls)
return service_cls
def get_registry() -> ServiceRegistry:
return _default_registry