File size: 1,640 Bytes
dbc69f3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 | from __future__ import annotations
from collections.abc import Callable
from typing import Any
DATA_PROVIDERS: dict[str, Callable[..., Any]] = {}
REWARDS: dict[str, Callable[..., Any]] = {}
OBJECTIVES: dict[str, Callable[..., Any]] = {}
def register_data_provider(name: str):
def decorator(factory: Callable[..., Any]):
DATA_PROVIDERS[name] = factory
return factory
return decorator
def register_reward(name: str):
def decorator(factory: Callable[..., Any]):
REWARDS[name] = factory
return factory
return decorator
def register_objective(name: str):
def decorator(factory: Callable[..., Any]):
OBJECTIVES[name] = factory
return factory
return decorator
def build_data_provider(name: str, **kwargs: Any):
if name not in DATA_PROVIDERS:
available = ", ".join(sorted(DATA_PROVIDERS))
raise KeyError(f"Unknown data provider '{name}'. Available: {available}")
return DATA_PROVIDERS[name](**kwargs)
def build_reward(name: str, **kwargs: Any):
if name not in REWARDS:
available = ", ".join(sorted(REWARDS))
raise KeyError(f"Unknown reward '{name}'. Available: {available}")
fn = REWARDS[name]
if kwargs:
def _reward_with_kwargs(*args, **call_kwargs):
return fn(*args, **kwargs, **call_kwargs)
return _reward_with_kwargs
return fn
def build_objective(name: str, **kwargs: Any):
if name not in OBJECTIVES:
available = ", ".join(sorted(OBJECTIVES))
raise KeyError(f"Unknown objective '{name}'. Available: {available}")
return OBJECTIVES[name](**kwargs)
|