tripathysagar commited on
Commit
b5973a5
·
verified ·
1 Parent(s): cdf4ea5

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. client.py +10 -79
  2. server/requirements.txt +4 -6
client.py CHANGED
@@ -1,98 +1,29 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the BSD-style license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """Maze Environment Client."""
8
-
9
  from typing import Dict
10
-
11
  from openenv.core.client_types import StepResult
12
  from openenv.core.env_server.types import State
13
  from openenv.core import EnvClient
 
14
 
15
- from .models import MazeAction, MazeObservation
16
-
17
-
18
- class MazeEnv(
19
- EnvClient[MazeAction, MazeObservation]
20
- ):
21
- """
22
- Client for the Maze Environment.
23
-
24
- This client maintains a persistent WebSocket connection to the environment server,
25
- enabling efficient multi-step interactions with lower latency.
26
- Each client instance has its own dedicated environment session on the server.
27
-
28
- Example:
29
- >>> # Connect to a running server
30
- >>> with MazeEnv(base_url="http://localhost:8000") as client:
31
- ... result = client.reset()
32
- ... print(result.observation.echoed_message)
33
- ...
34
- ... result = client.step(MazeAction(message="Hello!"))
35
- ... print(result.observation.echoed_message)
36
-
37
- Example with Docker:
38
- >>> # Automatically start container and connect
39
- >>> client = MazeEnv.from_docker_image("maze-env:latest")
40
- >>> try:
41
- ... result = client.reset()
42
- ... result = client.step(MazeAction(message="Test"))
43
- ... finally:
44
- ... client.close()
45
- """
46
 
47
  def _step_payload(self, action: MazeAction) -> Dict:
48
- """
49
- Convert MazeAction to JSON payload for step message.
50
-
51
- Args:
52
- action: MazeAction instance
53
-
54
- Returns:
55
- Dictionary representation suitable for JSON encoding
56
- """
57
- return {
58
- "message": action.message,
59
- }
60
 
61
  def _parse_result(self, payload: Dict) -> StepResult[MazeObservation]:
62
- """
63
- Parse server response into StepResult[MazeObservation].
64
-
65
- Args:
66
- payload: JSON response data from server
67
-
68
- Returns:
69
- StepResult with MazeObservation
70
- """
71
  obs_data = payload.get("observation", {})
72
- observation = MazeObservation(
73
- echoed_message=obs_data.get("echoed_message", ""),
74
- message_length=obs_data.get("message_length", 0),
75
- done=payload.get("done", False),
76
- reward=payload.get("reward"),
77
- metadata=obs_data.get("metadata", {}),
78
- )
79
-
80
  return StepResult(
81
- observation=observation,
 
 
 
 
 
82
  reward=payload.get("reward"),
83
  done=payload.get("done", False),
84
  )
85
 
86
  def _parse_state(self, payload: Dict) -> State:
87
- """
88
- Parse server response into State object.
89
-
90
- Args:
91
- payload: JSON response from state request
92
-
93
- Returns:
94
- State object with episode_id and step_count
95
- """
96
  return State(
97
  episode_id=payload.get("episode_id"),
98
  step_count=payload.get("step_count", 0),
 
 
 
 
 
 
 
 
 
1
  from typing import Dict
 
2
  from openenv.core.client_types import StepResult
3
  from openenv.core.env_server.types import State
4
  from openenv.core import EnvClient
5
+ from models import MazeAction, MazeObservation, MazeState
6
 
7
+ class MazeEnv(EnvClient[MazeAction, MazeObservation, MazeState]):
8
+ """Client for the Maze Environment."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  def _step_payload(self, action: MazeAction) -> Dict:
11
+ return {"direction": action.direction}
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  def _parse_result(self, payload: Dict) -> StepResult[MazeObservation]:
 
 
 
 
 
 
 
 
 
14
  obs_data = payload.get("observation", {})
 
 
 
 
 
 
 
 
15
  return StepResult(
16
+ observation=MazeObservation(
17
+ position=obs_data.get("position", []),
18
+ grid_view=obs_data.get("grid_view", ""),
19
+ done=payload.get("done", False),
20
+ reward=payload.get("reward"),
21
+ ),
22
  reward=payload.get("reward"),
23
  done=payload.get("done", False),
24
  )
25
 
26
  def _parse_state(self, payload: Dict) -> State:
 
 
 
 
 
 
 
 
 
27
  return State(
28
  episode_id=payload.get("episode_id"),
29
  step_count=payload.get("step_count", 0),
server/requirements.txt CHANGED
@@ -1,6 +1,4 @@
1
- openenv[core]>=0.2.0
2
- fastapi>=0.115.0
3
- uvicorn>=0.24.0
4
-
5
-
6
-
 
1
+ openenv[core]>=0.2.0
2
+ fastapi>=0.115.0
3
+ uvicorn>=0.24.0
4
+ mazelib