neuralese_temp / src /hackable /registry.py
psidharth567's picture
Export neuralese codebase (cache and .env excluded).
dbc69f3
from __future__ import annotations
from collections.abc import Callable
from typing import Any
DATA_PROVIDERS: dict[str, Callable[..., Any]] = {}
REWARDS: dict[str, Callable[..., Any]] = {}
OBJECTIVES: dict[str, Callable[..., Any]] = {}
def register_data_provider(name: str):
def decorator(factory: Callable[..., Any]):
DATA_PROVIDERS[name] = factory
return factory
return decorator
def register_reward(name: str):
def decorator(factory: Callable[..., Any]):
REWARDS[name] = factory
return factory
return decorator
def register_objective(name: str):
def decorator(factory: Callable[..., Any]):
OBJECTIVES[name] = factory
return factory
return decorator
def build_data_provider(name: str, **kwargs: Any):
if name not in DATA_PROVIDERS:
available = ", ".join(sorted(DATA_PROVIDERS))
raise KeyError(f"Unknown data provider '{name}'. Available: {available}")
return DATA_PROVIDERS[name](**kwargs)
def build_reward(name: str, **kwargs: Any):
if name not in REWARDS:
available = ", ".join(sorted(REWARDS))
raise KeyError(f"Unknown reward '{name}'. Available: {available}")
fn = REWARDS[name]
if kwargs:
def _reward_with_kwargs(*args, **call_kwargs):
return fn(*args, **kwargs, **call_kwargs)
return _reward_with_kwargs
return fn
def build_objective(name: str, **kwargs: Any):
if name not in OBJECTIVES:
available = ", ".join(sorted(OBJECTIVES))
raise KeyError(f"Unknown objective '{name}'. Available: {available}")
return OBJECTIVES[name](**kwargs)