stratego-human-vs-ai / stratego /env /backup /registration.py
DarshanScripts's picture
Upload stratego/env/backup/registration.py with huggingface_hub
82dd7b3 verified
import re, random, importlib
from typing import Any, Union, List, Callable, Dict, Tuple, Optional
from dataclasses import dataclass, field
import textarena as ta
# Global environment registry
ENV_REGISTRY: Dict[str, Callable] = {}
@dataclass
class EnvSpec:
"""A specification for creating environments."""
id: str
entry_point: Callable
default_wrappers: Optional[List[ta.Wrapper]]
kwargs: Dict[str, Any] = field(default_factory=dict)
def make(self, **kwargs) -> Any:
"""Create an environment instance."""
all_kwargs = {**self.kwargs, **kwargs}
return self.entry_point(**all_kwargs)
def register(id: str, entry_point: Callable, default_wrappers: Optional[List[ta.Wrapper]]=None, **kwargs: Any):
"""Register an environment with a given ID."""
if id in ENV_REGISTRY:
raise ValueError(f"Environment {id} already registered.")
ENV_REGISTRY[id] = EnvSpec(id=id, entry_point=entry_point, default_wrappers=default_wrappers, kwargs=kwargs)
def register_with_versions(id: str, entry_point: Callable, wrappers: Optional[Dict[str, List[ta.Wrapper]]]=None, **kwargs: Any):
"""Register an environment with a given ID."""
if id in ENV_REGISTRY: raise ValueError(f"Environment {id} already registered.")
# first register default version
ENV_REGISTRY[id] = EnvSpec(id=id, entry_point=entry_point, default_wrappers=wrappers.get("default"), kwargs=kwargs)
for wrapper_version_key in list(wrappers.keys())+["-raw"]:
if wrapper_version_key=="default": continue
ENV_REGISTRY[f"{id}{wrapper_version_key}"] = EnvSpec(id=f"{id}{wrapper_version_key}", entry_point=entry_point, default_wrappers=wrappers.get(wrapper_version_key), kwargs=kwargs)
def pprint_registry_detailed():
"""Pretty print the registry with additional details like kwargs."""
if not ENV_REGISTRY:
print("No environments registered.")
else:
print("Detailed Registered Environments:")
for env_id, env_spec in ENV_REGISTRY.items():
print(f" - {env_id}:")
print(f" Entry Point: {env_spec.entry_point}")
print(f" Kwargs: {env_spec.kwargs}")
print(f" Wrappers: {env_spec.default_wrappers}")
def check_env_exists(env_id: str):
"""Check if an environment exists in the registry."""
if env_id not in ENV_REGISTRY:
raise ValueError(f"Environment {env_id} is not registered.")
else:
print(f"Environment {env_id} is registered.")
def make(env_id: Union[str, List[str]], **kwargs) -> Any:
"""Create an environment instance using the registered ID."""
# If env_id is a list, randomly select one environment ID
if isinstance(env_id, list):
if not env_id:
raise ValueError("Empty list of environment IDs provided.")
env_id = random.choice(env_id)
# Continue with the existing implementation
if env_id not in ENV_REGISTRY:
raise ValueError(f"Environment {env_id} not found in registry.")
env_spec = ENV_REGISTRY[env_id]
# Resolve the entry point if it's a string
if isinstance(env_spec.entry_point, str):
module_path, class_name = env_spec.entry_point.split(":")
try:
module = importlib.import_module(module_path)
env_class = getattr(module, class_name)
except (ModuleNotFoundError, AttributeError) as e:
raise ImportError(f"Could not import {module_path}.{class_name}. Error: {e}")
else:
env_class = env_spec.entry_point
env = env_class(**{**env_spec.kwargs, **kwargs})
# Dynamically attach the env_id
env.env_id = env_id
env.entry_point = env_spec.entry_point
# wrap the environment
if env_spec.default_wrappers is not None and len(env_spec.default_wrappers) > 0:
for wrapper in env_spec.default_wrappers:
env = wrapper(env)
return env