arya89 commited on
Commit
cc49071
·
verified ·
1 Parent(s): d02897f

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. server/app.py +13 -59
server/app.py CHANGED
@@ -9,7 +9,7 @@ from pathlib import Path
9
  # Add parent directory to path
10
  sys.path.insert(0, str(Path(__file__).parent.parent))
11
 
12
- from fastapi import FastAPI, HTTPException
13
  from fastapi.responses import JSONResponse
14
  from pydantic import BaseModel
15
  from typing import Dict, Any, Optional
@@ -26,37 +26,14 @@ app = FastAPI(
26
  version="1.0.0"
27
  )
28
 
29
- # Global environment instance (stateful for demo purposes)
30
  env_instance: Optional[MyEnvEnvironment] = None
31
 
32
 
33
- # Request/Response Models
34
- class ResetRequest(BaseModel):
35
- task_id: int = 1
36
-
37
- class Config:
38
- json_schema_extra = {
39
- "example": {
40
- "task_id": 1
41
- }
42
- }
43
-
44
-
45
  class StepRequest(BaseModel):
46
  action_id: int
47
- task_id: int = 1
48
-
49
- class Config:
50
- json_schema_extra = {
51
- "example": {
52
- "action_id": 0,
53
- "task_id": 1
54
- }
55
- }
56
-
57
-
58
- class StateResponse(BaseModel):
59
- state: Dict[str, Any]
60
 
61
 
62
  # Health check endpoint
@@ -87,12 +64,13 @@ async def health():
87
 
88
 
89
  @app.post("/reset")
90
- async def reset(request: ResetRequest) -> Dict[str, Any]:
91
  """
92
  Reset the environment for a specific task.
 
93
 
94
  Args:
95
- request: ResetRequest with task_id (1=easy, 2=medium, 3=hard)
96
 
97
  Returns:
98
  Initial observation after reset
@@ -100,23 +78,12 @@ async def reset(request: ResetRequest) -> Dict[str, Any]:
100
  global env_instance
101
 
102
  try:
103
- # Validate task_id
104
- if request.task_id not in [1, 2, 3]:
105
- raise HTTPException(
106
- status_code=400,
107
- detail=f"Invalid task_id: {request.task_id}. Must be 1, 2, or 3."
108
- )
109
-
110
  # Create new environment instance
111
  env_instance = MyEnvEnvironment()
112
- obs = env_instance.reset(task_id=request.task_id)
113
 
114
- # Return observation as dict
115
- return {
116
- "observation": obs.model_dump(),
117
- "task_id": request.task_id,
118
- "message": "Environment reset successfully"
119
- }
120
 
121
  except Exception as e:
122
  raise HTTPException(
@@ -156,26 +123,14 @@ async def step(request: StepRequest) -> Dict[str, Any]:
156
  # Create action
157
  action = IncidentAction(
158
  action_id=request.action_id,
159
- task_id=request.task_id
160
  )
161
 
162
  # Execute step
163
  obs = env_instance.step(action)
164
 
165
- # Get action name
166
- action_name = env_instance.ACTION_NAMES.get(request.action_id, "unknown")
167
-
168
- return {
169
- "observation": obs.model_dump(),
170
- "action_taken": {
171
- "action_id": request.action_id,
172
- "action_name": action_name
173
- },
174
- "reward": obs.reward,
175
- "done": obs.done,
176
- "total_reward": env_instance.total_reward,
177
- "incident_resolved": env_instance.incident_resolved
178
- }
179
 
180
  except Exception as e:
181
  raise HTTPException(
@@ -227,7 +182,6 @@ async def get_actions() -> Dict[str, Any]:
227
  Dictionary of action IDs and names
228
  """
229
  try:
230
- # Create temporary instance to get action names
231
  temp_env = MyEnvEnvironment()
232
 
233
  return {
 
9
  # Add parent directory to path
10
  sys.path.insert(0, str(Path(__file__).parent.parent))
11
 
12
+ from fastapi import FastAPI, HTTPException, Query
13
  from fastapi.responses import JSONResponse
14
  from pydantic import BaseModel
15
  from typing import Dict, Any, Optional
 
26
  version="1.0.0"
27
  )
28
 
29
+ # Global environment instance
30
  env_instance: Optional[MyEnvEnvironment] = None
31
 
32
 
33
+ # Request Models
 
 
 
 
 
 
 
 
 
 
 
34
  class StepRequest(BaseModel):
35
  action_id: int
36
+ task_id: Optional[int] = 1
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
 
39
  # Health check endpoint
 
64
 
65
 
66
  @app.post("/reset")
67
+ async def reset(task_id: int = Query(default=1, ge=1, le=3)) -> Dict[str, Any]:
68
  """
69
  Reset the environment for a specific task.
70
+ OpenEnv standard endpoint.
71
 
72
  Args:
73
+ task_id: Task difficulty (1=easy, 2=medium, 3=hard)
74
 
75
  Returns:
76
  Initial observation after reset
 
78
  global env_instance
79
 
80
  try:
 
 
 
 
 
 
 
81
  # Create new environment instance
82
  env_instance = MyEnvEnvironment()
83
+ obs = env_instance.reset(task_id=task_id)
84
 
85
+ # Return observation in OpenEnv format
86
+ return obs.model_dump()
 
 
 
 
87
 
88
  except Exception as e:
89
  raise HTTPException(
 
123
  # Create action
124
  action = IncidentAction(
125
  action_id=request.action_id,
126
+ task_id=request.task_id if request.task_id else env_instance.task_id
127
  )
128
 
129
  # Execute step
130
  obs = env_instance.step(action)
131
 
132
+ # Return observation in OpenEnv format
133
+ return obs.model_dump()
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  except Exception as e:
136
  raise HTTPException(
 
182
  Dictionary of action IDs and names
183
  """
184
  try:
 
185
  temp_env = MyEnvEnvironment()
186
 
187
  return {