model111 / larm /data /envs /base_env.py
LCZZZZ's picture
Upload MemGen code and data
e34b94f verified
from abc import ABC, abstractmethod
from typing import Literal, Dict, Tuple
class BaseEnv(ABC):
ENV_CARD: Literal["STATIC", "DYNAMIC"] = None
def __init__(self, config):
self.config = config
@classmethod
@abstractmethod
def _accuracy_reward(cls, **kwargs):
...
@classmethod
@abstractmethod
def _format_reward(cls, **kwargs):
...
@classmethod
def get_reward_func(cls, func_name: Literal["accuracy", "format"]):
if func_name == "accuracy":
return cls._accuracy_reward
elif func_name == "format":
return cls._format_reward
else:
raise ValueError(f"Unsupported reward func: {func_name}")
class StaticEnv(BaseEnv):
ENV_CARD = "STATIC"
class DynamicEnv(BaseEnv):
ENV_CARD = "DYNAMIC"
@abstractmethod
def set_env(self, task_config: Dict) -> Tuple[str, str]:
...
@classmethod
@abstractmethod
def preprocess_action(self, action: str) -> str:
...
@abstractmethod
def step(self, action: str) -> Tuple[str, bool]:
...
@abstractmethod
def feedback(self) -> Tuple[float, bool]:
...