File size: 1,640 Bytes
dbc69f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from __future__ import annotations

from collections.abc import Callable
from typing import Any

DATA_PROVIDERS: dict[str, Callable[..., Any]] = {}
REWARDS: dict[str, Callable[..., Any]] = {}
OBJECTIVES: dict[str, Callable[..., Any]] = {}


def register_data_provider(name: str):
    def decorator(factory: Callable[..., Any]):
        DATA_PROVIDERS[name] = factory
        return factory

    return decorator


def register_reward(name: str):
    def decorator(factory: Callable[..., Any]):
        REWARDS[name] = factory
        return factory

    return decorator


def register_objective(name: str):
    def decorator(factory: Callable[..., Any]):
        OBJECTIVES[name] = factory
        return factory

    return decorator


def build_data_provider(name: str, **kwargs: Any):
    if name not in DATA_PROVIDERS:
        available = ", ".join(sorted(DATA_PROVIDERS))
        raise KeyError(f"Unknown data provider '{name}'. Available: {available}")
    return DATA_PROVIDERS[name](**kwargs)


def build_reward(name: str, **kwargs: Any):
    if name not in REWARDS:
        available = ", ".join(sorted(REWARDS))
        raise KeyError(f"Unknown reward '{name}'. Available: {available}")
    fn = REWARDS[name]
    if kwargs:
        def _reward_with_kwargs(*args, **call_kwargs):
            return fn(*args, **kwargs, **call_kwargs)
        return _reward_with_kwargs
    return fn


def build_objective(name: str, **kwargs: Any):
    if name not in OBJECTIVES:
        available = ", ".join(sorted(OBJECTIVES))
        raise KeyError(f"Unknown objective '{name}'. Available: {available}")
    return OBJECTIVES[name](**kwargs)