DarshanScripts commited on
Commit
82dd7b3
·
verified ·
1 Parent(s): 5cd2ed8

Upload stratego/env/backup/registration.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. stratego/env/backup/registration.py +95 -0
stratego/env/backup/registration.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re, random, importlib
2
+ from typing import Any, Union, List, Callable, Dict, Tuple, Optional
3
+ from dataclasses import dataclass, field
4
+
5
+ import textarena as ta
6
+
7
+
8
+ # Global environment registry
9
+ ENV_REGISTRY: Dict[str, Callable] = {}
10
+
11
+ @dataclass
12
+ class EnvSpec:
13
+ """A specification for creating environments."""
14
+ id: str
15
+ entry_point: Callable
16
+ default_wrappers: Optional[List[ta.Wrapper]]
17
+ kwargs: Dict[str, Any] = field(default_factory=dict)
18
+
19
+ def make(self, **kwargs) -> Any:
20
+ """Create an environment instance."""
21
+ all_kwargs = {**self.kwargs, **kwargs}
22
+ return self.entry_point(**all_kwargs)
23
+
24
+ def register(id: str, entry_point: Callable, default_wrappers: Optional[List[ta.Wrapper]]=None, **kwargs: Any):
25
+ """Register an environment with a given ID."""
26
+ if id in ENV_REGISTRY:
27
+ raise ValueError(f"Environment {id} already registered.")
28
+ ENV_REGISTRY[id] = EnvSpec(id=id, entry_point=entry_point, default_wrappers=default_wrappers, kwargs=kwargs)
29
+
30
+ def register_with_versions(id: str, entry_point: Callable, wrappers: Optional[Dict[str, List[ta.Wrapper]]]=None, **kwargs: Any):
31
+ """Register an environment with a given ID."""
32
+ if id in ENV_REGISTRY: raise ValueError(f"Environment {id} already registered.")
33
+
34
+ # first register default version
35
+ ENV_REGISTRY[id] = EnvSpec(id=id, entry_point=entry_point, default_wrappers=wrappers.get("default"), kwargs=kwargs)
36
+ for wrapper_version_key in list(wrappers.keys())+["-raw"]:
37
+ if wrapper_version_key=="default": continue
38
+ ENV_REGISTRY[f"{id}{wrapper_version_key}"] = EnvSpec(id=f"{id}{wrapper_version_key}", entry_point=entry_point, default_wrappers=wrappers.get(wrapper_version_key), kwargs=kwargs)
39
+
40
+ def pprint_registry_detailed():
41
+ """Pretty print the registry with additional details like kwargs."""
42
+ if not ENV_REGISTRY:
43
+ print("No environments registered.")
44
+ else:
45
+ print("Detailed Registered Environments:")
46
+ for env_id, env_spec in ENV_REGISTRY.items():
47
+ print(f" - {env_id}:")
48
+ print(f" Entry Point: {env_spec.entry_point}")
49
+ print(f" Kwargs: {env_spec.kwargs}")
50
+ print(f" Wrappers: {env_spec.default_wrappers}")
51
+
52
+ def check_env_exists(env_id: str):
53
+ """Check if an environment exists in the registry."""
54
+ if env_id not in ENV_REGISTRY:
55
+ raise ValueError(f"Environment {env_id} is not registered.")
56
+ else:
57
+ print(f"Environment {env_id} is registered.")
58
+
59
+ def make(env_id: Union[str, List[str]], **kwargs) -> Any:
60
+ """Create an environment instance using the registered ID."""
61
+ # If env_id is a list, randomly select one environment ID
62
+ if isinstance(env_id, list):
63
+ if not env_id:
64
+ raise ValueError("Empty list of environment IDs provided.")
65
+ env_id = random.choice(env_id)
66
+
67
+ # Continue with the existing implementation
68
+ if env_id not in ENV_REGISTRY:
69
+ raise ValueError(f"Environment {env_id} not found in registry.")
70
+
71
+ env_spec = ENV_REGISTRY[env_id]
72
+
73
+ # Resolve the entry point if it's a string
74
+ if isinstance(env_spec.entry_point, str):
75
+ module_path, class_name = env_spec.entry_point.split(":")
76
+ try:
77
+ module = importlib.import_module(module_path)
78
+ env_class = getattr(module, class_name)
79
+ except (ModuleNotFoundError, AttributeError) as e:
80
+ raise ImportError(f"Could not import {module_path}.{class_name}. Error: {e}")
81
+ else:
82
+ env_class = env_spec.entry_point
83
+
84
+ env = env_class(**{**env_spec.kwargs, **kwargs})
85
+
86
+ # Dynamically attach the env_id
87
+ env.env_id = env_id
88
+ env.entry_point = env_spec.entry_point
89
+
90
+ # wrap the environment
91
+ if env_spec.default_wrappers is not None and len(env_spec.default_wrappers) > 0:
92
+ for wrapper in env_spec.default_wrappers:
93
+ env = wrapper(env)
94
+
95
+ return env