Spaces:
Configuration error
Configuration error
Upload stratego/env/backup/registration.py with huggingface_hub
Browse files
stratego/env/backup/registration.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re, random, importlib
|
| 2 |
+
from typing import Any, Union, List, Callable, Dict, Tuple, Optional
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
|
| 5 |
+
import textarena as ta
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# Global environment registry
|
| 9 |
+
ENV_REGISTRY: Dict[str, Callable] = {}
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class EnvSpec:
|
| 13 |
+
"""A specification for creating environments."""
|
| 14 |
+
id: str
|
| 15 |
+
entry_point: Callable
|
| 16 |
+
default_wrappers: Optional[List[ta.Wrapper]]
|
| 17 |
+
kwargs: Dict[str, Any] = field(default_factory=dict)
|
| 18 |
+
|
| 19 |
+
def make(self, **kwargs) -> Any:
|
| 20 |
+
"""Create an environment instance."""
|
| 21 |
+
all_kwargs = {**self.kwargs, **kwargs}
|
| 22 |
+
return self.entry_point(**all_kwargs)
|
| 23 |
+
|
| 24 |
+
def register(id: str, entry_point: Callable, default_wrappers: Optional[List[ta.Wrapper]]=None, **kwargs: Any):
|
| 25 |
+
"""Register an environment with a given ID."""
|
| 26 |
+
if id in ENV_REGISTRY:
|
| 27 |
+
raise ValueError(f"Environment {id} already registered.")
|
| 28 |
+
ENV_REGISTRY[id] = EnvSpec(id=id, entry_point=entry_point, default_wrappers=default_wrappers, kwargs=kwargs)
|
| 29 |
+
|
| 30 |
+
def register_with_versions(id: str, entry_point: Callable, wrappers: Optional[Dict[str, List[ta.Wrapper]]]=None, **kwargs: Any):
|
| 31 |
+
"""Register an environment with a given ID."""
|
| 32 |
+
if id in ENV_REGISTRY: raise ValueError(f"Environment {id} already registered.")
|
| 33 |
+
|
| 34 |
+
# first register default version
|
| 35 |
+
ENV_REGISTRY[id] = EnvSpec(id=id, entry_point=entry_point, default_wrappers=wrappers.get("default"), kwargs=kwargs)
|
| 36 |
+
for wrapper_version_key in list(wrappers.keys())+["-raw"]:
|
| 37 |
+
if wrapper_version_key=="default": continue
|
| 38 |
+
ENV_REGISTRY[f"{id}{wrapper_version_key}"] = EnvSpec(id=f"{id}{wrapper_version_key}", entry_point=entry_point, default_wrappers=wrappers.get(wrapper_version_key), kwargs=kwargs)
|
| 39 |
+
|
| 40 |
+
def pprint_registry_detailed():
|
| 41 |
+
"""Pretty print the registry with additional details like kwargs."""
|
| 42 |
+
if not ENV_REGISTRY:
|
| 43 |
+
print("No environments registered.")
|
| 44 |
+
else:
|
| 45 |
+
print("Detailed Registered Environments:")
|
| 46 |
+
for env_id, env_spec in ENV_REGISTRY.items():
|
| 47 |
+
print(f" - {env_id}:")
|
| 48 |
+
print(f" Entry Point: {env_spec.entry_point}")
|
| 49 |
+
print(f" Kwargs: {env_spec.kwargs}")
|
| 50 |
+
print(f" Wrappers: {env_spec.default_wrappers}")
|
| 51 |
+
|
| 52 |
+
def check_env_exists(env_id: str):
|
| 53 |
+
"""Check if an environment exists in the registry."""
|
| 54 |
+
if env_id not in ENV_REGISTRY:
|
| 55 |
+
raise ValueError(f"Environment {env_id} is not registered.")
|
| 56 |
+
else:
|
| 57 |
+
print(f"Environment {env_id} is registered.")
|
| 58 |
+
|
| 59 |
+
def make(env_id: Union[str, List[str]], **kwargs) -> Any:
|
| 60 |
+
"""Create an environment instance using the registered ID."""
|
| 61 |
+
# If env_id is a list, randomly select one environment ID
|
| 62 |
+
if isinstance(env_id, list):
|
| 63 |
+
if not env_id:
|
| 64 |
+
raise ValueError("Empty list of environment IDs provided.")
|
| 65 |
+
env_id = random.choice(env_id)
|
| 66 |
+
|
| 67 |
+
# Continue with the existing implementation
|
| 68 |
+
if env_id not in ENV_REGISTRY:
|
| 69 |
+
raise ValueError(f"Environment {env_id} not found in registry.")
|
| 70 |
+
|
| 71 |
+
env_spec = ENV_REGISTRY[env_id]
|
| 72 |
+
|
| 73 |
+
# Resolve the entry point if it's a string
|
| 74 |
+
if isinstance(env_spec.entry_point, str):
|
| 75 |
+
module_path, class_name = env_spec.entry_point.split(":")
|
| 76 |
+
try:
|
| 77 |
+
module = importlib.import_module(module_path)
|
| 78 |
+
env_class = getattr(module, class_name)
|
| 79 |
+
except (ModuleNotFoundError, AttributeError) as e:
|
| 80 |
+
raise ImportError(f"Could not import {module_path}.{class_name}. Error: {e}")
|
| 81 |
+
else:
|
| 82 |
+
env_class = env_spec.entry_point
|
| 83 |
+
|
| 84 |
+
env = env_class(**{**env_spec.kwargs, **kwargs})
|
| 85 |
+
|
| 86 |
+
# Dynamically attach the env_id
|
| 87 |
+
env.env_id = env_id
|
| 88 |
+
env.entry_point = env_spec.entry_point
|
| 89 |
+
|
| 90 |
+
# wrap the environment
|
| 91 |
+
if env_spec.default_wrappers is not None and len(env_spec.default_wrappers) > 0:
|
| 92 |
+
for wrapper in env_spec.default_wrappers:
|
| 93 |
+
env = wrapper(env)
|
| 94 |
+
|
| 95 |
+
return env
|