| from __future__ import annotations |
|
|
| from dataclasses import dataclass, field |
| from typing import Iterable, List, MutableMapping, Type |
|
|
| from services.services import TaskType, GenerationService, TService |
|
|
|
|
| @dataclass |
| class ServiceRegistry: |
| """Registry for generation services, grouped by task type.""" |
| _services: MutableMapping[TaskType, MutableMapping[str, Type[GenerationService]]] = field( |
| default_factory=dict |
| ) |
|
|
| def register(self, service_cls: Type[GenerationService]) -> None: |
| task_type = service_cls.task_type |
| service_id = getattr(service_cls, "service_id", service_cls.__name__) |
| bucket = self._services.setdefault(task_type, {}) |
| if service_id in bucket: |
| raise ValueError( |
| f"Service id {service_id!r} already registered for task type {task_type}." |
| ) |
| bucket[service_id] = service_cls |
|
|
| def get(self, task_type: TaskType, service_id: str) -> Type[GenerationService]: |
| bucket = self._services.get(task_type) |
| if bucket is None or service_id not in bucket: |
| raise KeyError( |
| f"No service registered for type {task_type} with id {service_id!r}." |
| ) |
| return bucket[service_id] |
|
|
| def create(self, task_type: TaskType, service_id: str, **kwargs) -> GenerationService: |
| service_cls = self.get(task_type, service_id) |
| return service_cls(**kwargs) |
|
|
| def list_ids(self, task_type: TaskType) -> List[str]: |
| bucket = self._services.get(task_type, {}) |
| return sorted(bucket.keys()) |
|
|
| def iter_services(self, task_type: TaskType) -> Iterable[Type[GenerationService]]: |
| bucket = self._services.get(task_type, {}) |
| return bucket.values() |
|
|
|
|
| _default_registry = ServiceRegistry() |
|
|
|
|
| def register_service(service_cls: Type[TService]) -> Type[TService]: |
| """Class decorator to register a service in the default registry.""" |
| _default_registry.register(service_cls) |
| return service_cls |
|
|
|
|
| def get_registry() -> ServiceRegistry: |
| return _default_registry |
|
|