| from __future__ import annotations |
|
|
| from collections.abc import Callable |
| from typing import Any |
|
|
| DATA_PROVIDERS: dict[str, Callable[..., Any]] = {} |
| REWARDS: dict[str, Callable[..., Any]] = {} |
| OBJECTIVES: dict[str, Callable[..., Any]] = {} |
|
|
|
|
| def register_data_provider(name: str): |
| def decorator(factory: Callable[..., Any]): |
| DATA_PROVIDERS[name] = factory |
| return factory |
|
|
| return decorator |
|
|
|
|
| def register_reward(name: str): |
| def decorator(factory: Callable[..., Any]): |
| REWARDS[name] = factory |
| return factory |
|
|
| return decorator |
|
|
|
|
| def register_objective(name: str): |
| def decorator(factory: Callable[..., Any]): |
| OBJECTIVES[name] = factory |
| return factory |
|
|
| return decorator |
|
|
|
|
| def build_data_provider(name: str, **kwargs: Any): |
| if name not in DATA_PROVIDERS: |
| available = ", ".join(sorted(DATA_PROVIDERS)) |
| raise KeyError(f"Unknown data provider '{name}'. Available: {available}") |
| return DATA_PROVIDERS[name](**kwargs) |
|
|
|
|
| def build_reward(name: str, **kwargs: Any): |
| if name not in REWARDS: |
| available = ", ".join(sorted(REWARDS)) |
| raise KeyError(f"Unknown reward '{name}'. Available: {available}") |
| fn = REWARDS[name] |
| if kwargs: |
| def _reward_with_kwargs(*args, **call_kwargs): |
| return fn(*args, **kwargs, **call_kwargs) |
| return _reward_with_kwargs |
| return fn |
|
|
|
|
| def build_objective(name: str, **kwargs: Any): |
| if name not in OBJECTIVES: |
| available = ", ".join(sorted(OBJECTIVES)) |
| raise KeyError(f"Unknown objective '{name}'. Available: {available}") |
| return OBJECTIVES[name](**kwargs) |
|
|