albertvillanova HF Staff commited on
Commit
afd99d9
·
verified ·
1 Parent(s): 1600e60

Create environments.py

Browse files
Files changed (1) hide show
  1. environments.py +121 -0
environments.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import uuid
3
+ from abc import ABC, abstractmethod
4
+ from functools import wraps
5
+
6
+ import gradio as gr
7
+ from gradio_client import Client
8
+
9
+
10
+ __version__ = "0.1.0"
11
+ __all__ = ["Environment", "load", "register_env"]
12
+
13
+
14
+ class Environment(ABC):
15
+ @abstractmethod
16
+ def reset(self, *args, **kwargs):
17
+ pass
18
+
19
+ @abstractmethod
20
+ def step(self, *args, **kwargs):
21
+ pass
22
+
23
+
24
+ class _RemoteEnvironment(Environment):
25
+ def __init__(self, env_id: str):
26
+ username, repo = env_id.split("/")
27
+ self.client = Client(f"https://{username}-{repo}.hf.space/")
28
+ self.session_id = self.client.predict(api_name="/init")
29
+
30
+ def reset(self, *args, **kwargs):
31
+ return self.client.predict(self.session_id, api_name="/reset", *args, **kwargs)
32
+
33
+ def step(self, *args, **kwargs):
34
+ return self.client.predict(self.session_id, api_name="/step", *args, **kwargs)
35
+
36
+
37
+ def load(env_id: str) -> _RemoteEnvironment:
38
+ return _RemoteEnvironment(env_id)
39
+
40
+
41
+ def bind_method_to_session(method, registry: dict):
42
+ sig = inspect.signature(method)
43
+ params = list(sig.parameters.values())
44
+
45
+ @wraps(method)
46
+ def wrapper(session_id: str, *args, **kwargs):
47
+ instance = registry.get(session_id)
48
+ if instance is None:
49
+ raise ValueError(f"Invalid session_id: {session_id}")
50
+ m = getattr(instance, method.__func__.__name__)
51
+ return m(*args, **kwargs)
52
+
53
+ # --- update __annotations__ ---
54
+ wrapper.__annotations__ = method.__annotations__.copy()
55
+ wrapper.__annotations__["session_id"] = str
56
+
57
+ # --- build signature ---
58
+ new_params = (
59
+ inspect.Parameter(
60
+ "session_id",
61
+ kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
62
+ annotation=str,
63
+ ),
64
+ *params,
65
+ )
66
+ wrapper.__signature__ = inspect.Signature(
67
+ parameters=new_params,
68
+ return_annotation=sig.return_annotation,
69
+ )
70
+
71
+ return wrapper
72
+
73
+
74
+ def register_env(env_cls: type[Environment]) -> gr.Blocks:
75
+ """
76
+ Register an environment class with Gradio APIs.
77
+
78
+ Example:
79
+
80
+ ```python
81
+ from environments import register_env, Environment
82
+ import gradio as gr
83
+
84
+ class MyEnvironmentClass(Environment):
85
+ def reset(self) -> str:
86
+ return "Reset called!"
87
+
88
+ def step(self, action: str) -> str:
89
+ return f"Step called with action: {action}!"
90
+
91
+ with gr.Blocks() as demo:
92
+ register_env(MyEnvironmentClass)
93
+
94
+ demo.launch(mcp_server_name=True)
95
+ ```
96
+ """
97
+ sessions = {}
98
+
99
+ def init_env() -> str:
100
+ """
101
+ Initialize a new environment instance and return a session ID.
102
+ Returns:
103
+ A unique session ID for the new environment instance.
104
+ """
105
+ session_id = str(uuid.uuid4())
106
+ env = env_cls()
107
+ sessions[session_id] = env
108
+ return session_id
109
+
110
+ # Bind methods to session dict
111
+ reset_api = bind_method_to_session(env_cls().reset, sessions)
112
+ step_api = bind_method_to_session(env_cls().step, sessions)
113
+
114
+ # Create Gradio APIs
115
+ gr.api(
116
+ init_env,
117
+ api_name="init",
118
+ api_description="Initialize a new environment session",
119
+ )
120
+ gr.api(reset_api, api_name="reset")
121
+ gr.api(step_api, api_name="step")