Spaces:
Configuration error
Configuration error
| import re, random, importlib | |
| from typing import Any, Union, List, Callable, Dict, Tuple, Optional | |
| from dataclasses import dataclass, field | |
| import textarena as ta | |
| # Global environment registry | |
| ENV_REGISTRY: Dict[str, Callable] = {} | |
| class EnvSpec: | |
| """A specification for creating environments.""" | |
| id: str | |
| entry_point: Callable | |
| default_wrappers: Optional[List[ta.Wrapper]] | |
| kwargs: Dict[str, Any] = field(default_factory=dict) | |
| def make(self, **kwargs) -> Any: | |
| """Create an environment instance.""" | |
| all_kwargs = {**self.kwargs, **kwargs} | |
| return self.entry_point(**all_kwargs) | |
| def register(id: str, entry_point: Callable, default_wrappers: Optional[List[ta.Wrapper]]=None, **kwargs: Any): | |
| """Register an environment with a given ID.""" | |
| if id in ENV_REGISTRY: | |
| raise ValueError(f"Environment {id} already registered.") | |
| ENV_REGISTRY[id] = EnvSpec(id=id, entry_point=entry_point, default_wrappers=default_wrappers, kwargs=kwargs) | |
| def register_with_versions(id: str, entry_point: Callable, wrappers: Optional[Dict[str, List[ta.Wrapper]]]=None, **kwargs: Any): | |
| """Register an environment with a given ID.""" | |
| if id in ENV_REGISTRY: raise ValueError(f"Environment {id} already registered.") | |
| # first register default version | |
| ENV_REGISTRY[id] = EnvSpec(id=id, entry_point=entry_point, default_wrappers=wrappers.get("default"), kwargs=kwargs) | |
| for wrapper_version_key in list(wrappers.keys())+["-raw"]: | |
| if wrapper_version_key=="default": continue | |
| ENV_REGISTRY[f"{id}{wrapper_version_key}"] = EnvSpec(id=f"{id}{wrapper_version_key}", entry_point=entry_point, default_wrappers=wrappers.get(wrapper_version_key), kwargs=kwargs) | |
| def pprint_registry_detailed(): | |
| """Pretty print the registry with additional details like kwargs.""" | |
| if not ENV_REGISTRY: | |
| print("No environments registered.") | |
| else: | |
| print("Detailed Registered Environments:") | |
| for env_id, env_spec in ENV_REGISTRY.items(): | |
| print(f" - {env_id}:") | |
| print(f" Entry Point: {env_spec.entry_point}") | |
| print(f" Kwargs: {env_spec.kwargs}") | |
| print(f" Wrappers: {env_spec.default_wrappers}") | |
| def check_env_exists(env_id: str): | |
| """Check if an environment exists in the registry.""" | |
| if env_id not in ENV_REGISTRY: | |
| raise ValueError(f"Environment {env_id} is not registered.") | |
| else: | |
| print(f"Environment {env_id} is registered.") | |
| def make(env_id: Union[str, List[str]], **kwargs) -> Any: | |
| """Create an environment instance using the registered ID.""" | |
| # If env_id is a list, randomly select one environment ID | |
| if isinstance(env_id, list): | |
| if not env_id: | |
| raise ValueError("Empty list of environment IDs provided.") | |
| env_id = random.choice(env_id) | |
| # Continue with the existing implementation | |
| if env_id not in ENV_REGISTRY: | |
| raise ValueError(f"Environment {env_id} not found in registry.") | |
| env_spec = ENV_REGISTRY[env_id] | |
| # Resolve the entry point if it's a string | |
| if isinstance(env_spec.entry_point, str): | |
| module_path, class_name = env_spec.entry_point.split(":") | |
| try: | |
| module = importlib.import_module(module_path) | |
| env_class = getattr(module, class_name) | |
| except (ModuleNotFoundError, AttributeError) as e: | |
| raise ImportError(f"Could not import {module_path}.{class_name}. Error: {e}") | |
| else: | |
| env_class = env_spec.entry_point | |
| env = env_class(**{**env_spec.kwargs, **kwargs}) | |
| # Dynamically attach the env_id | |
| env.env_id = env_id | |
| env.entry_point = env_spec.entry_point | |
| # wrap the environment | |
| if env_spec.default_wrappers is not None and len(env_spec.default_wrappers) > 0: | |
| for wrapper in env_spec.default_wrappers: | |
| env = wrapper(env) | |
| return env | |