Upload folder using huggingface_hub
Browse files- server/app.py +13 -59
server/app.py
CHANGED
|
@@ -9,7 +9,7 @@ from pathlib import Path
|
|
| 9 |
# Add parent directory to path
|
| 10 |
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 11 |
|
| 12 |
-
from fastapi import FastAPI, HTTPException
|
| 13 |
from fastapi.responses import JSONResponse
|
| 14 |
from pydantic import BaseModel
|
| 15 |
from typing import Dict, Any, Optional
|
|
@@ -26,37 +26,14 @@ app = FastAPI(
|
|
| 26 |
version="1.0.0"
|
| 27 |
)
|
| 28 |
|
| 29 |
-
# Global environment instance
|
| 30 |
env_instance: Optional[MyEnvEnvironment] = None
|
| 31 |
|
| 32 |
|
| 33 |
-
# Request
|
| 34 |
-
class ResetRequest(BaseModel):
|
| 35 |
-
task_id: int = 1
|
| 36 |
-
|
| 37 |
-
class Config:
|
| 38 |
-
json_schema_extra = {
|
| 39 |
-
"example": {
|
| 40 |
-
"task_id": 1
|
| 41 |
-
}
|
| 42 |
-
}
|
| 43 |
-
|
| 44 |
-
|
| 45 |
class StepRequest(BaseModel):
|
| 46 |
action_id: int
|
| 47 |
-
task_id: int = 1
|
| 48 |
-
|
| 49 |
-
class Config:
|
| 50 |
-
json_schema_extra = {
|
| 51 |
-
"example": {
|
| 52 |
-
"action_id": 0,
|
| 53 |
-
"task_id": 1
|
| 54 |
-
}
|
| 55 |
-
}
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
class StateResponse(BaseModel):
|
| 59 |
-
state: Dict[str, Any]
|
| 60 |
|
| 61 |
|
| 62 |
# Health check endpoint
|
|
@@ -87,12 +64,13 @@ async def health():
|
|
| 87 |
|
| 88 |
|
| 89 |
@app.post("/reset")
|
| 90 |
-
async def reset(
|
| 91 |
"""
|
| 92 |
Reset the environment for a specific task.
|
|
|
|
| 93 |
|
| 94 |
Args:
|
| 95 |
-
|
| 96 |
|
| 97 |
Returns:
|
| 98 |
Initial observation after reset
|
|
@@ -100,23 +78,12 @@ async def reset(request: ResetRequest) -> Dict[str, Any]:
|
|
| 100 |
global env_instance
|
| 101 |
|
| 102 |
try:
|
| 103 |
-
# Validate task_id
|
| 104 |
-
if request.task_id not in [1, 2, 3]:
|
| 105 |
-
raise HTTPException(
|
| 106 |
-
status_code=400,
|
| 107 |
-
detail=f"Invalid task_id: {request.task_id}. Must be 1, 2, or 3."
|
| 108 |
-
)
|
| 109 |
-
|
| 110 |
# Create new environment instance
|
| 111 |
env_instance = MyEnvEnvironment()
|
| 112 |
-
obs = env_instance.reset(task_id=
|
| 113 |
|
| 114 |
-
# Return observation
|
| 115 |
-
return
|
| 116 |
-
"observation": obs.model_dump(),
|
| 117 |
-
"task_id": request.task_id,
|
| 118 |
-
"message": "Environment reset successfully"
|
| 119 |
-
}
|
| 120 |
|
| 121 |
except Exception as e:
|
| 122 |
raise HTTPException(
|
|
@@ -156,26 +123,14 @@ async def step(request: StepRequest) -> Dict[str, Any]:
|
|
| 156 |
# Create action
|
| 157 |
action = IncidentAction(
|
| 158 |
action_id=request.action_id,
|
| 159 |
-
task_id=request.task_id
|
| 160 |
)
|
| 161 |
|
| 162 |
# Execute step
|
| 163 |
obs = env_instance.step(action)
|
| 164 |
|
| 165 |
-
#
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
return {
|
| 169 |
-
"observation": obs.model_dump(),
|
| 170 |
-
"action_taken": {
|
| 171 |
-
"action_id": request.action_id,
|
| 172 |
-
"action_name": action_name
|
| 173 |
-
},
|
| 174 |
-
"reward": obs.reward,
|
| 175 |
-
"done": obs.done,
|
| 176 |
-
"total_reward": env_instance.total_reward,
|
| 177 |
-
"incident_resolved": env_instance.incident_resolved
|
| 178 |
-
}
|
| 179 |
|
| 180 |
except Exception as e:
|
| 181 |
raise HTTPException(
|
|
@@ -227,7 +182,6 @@ async def get_actions() -> Dict[str, Any]:
|
|
| 227 |
Dictionary of action IDs and names
|
| 228 |
"""
|
| 229 |
try:
|
| 230 |
-
# Create temporary instance to get action names
|
| 231 |
temp_env = MyEnvEnvironment()
|
| 232 |
|
| 233 |
return {
|
|
|
|
| 9 |
# Add parent directory to path
|
| 10 |
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 11 |
|
| 12 |
+
from fastapi import FastAPI, HTTPException, Query
|
| 13 |
from fastapi.responses import JSONResponse
|
| 14 |
from pydantic import BaseModel
|
| 15 |
from typing import Dict, Any, Optional
|
|
|
|
| 26 |
version="1.0.0"
|
| 27 |
)
|
| 28 |
|
| 29 |
+
# Global environment instance
|
| 30 |
env_instance: Optional[MyEnvEnvironment] = None
|
| 31 |
|
| 32 |
|
| 33 |
+
# Request Models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
class StepRequest(BaseModel):
|
| 35 |
action_id: int
|
| 36 |
+
task_id: Optional[int] = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
# Health check endpoint
|
|
|
|
| 64 |
|
| 65 |
|
| 66 |
@app.post("/reset")
|
| 67 |
+
async def reset(task_id: int = Query(default=1, ge=1, le=3)) -> Dict[str, Any]:
|
| 68 |
"""
|
| 69 |
Reset the environment for a specific task.
|
| 70 |
+
OpenEnv standard endpoint.
|
| 71 |
|
| 72 |
Args:
|
| 73 |
+
task_id: Task difficulty (1=easy, 2=medium, 3=hard)
|
| 74 |
|
| 75 |
Returns:
|
| 76 |
Initial observation after reset
|
|
|
|
| 78 |
global env_instance
|
| 79 |
|
| 80 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
# Create new environment instance
|
| 82 |
env_instance = MyEnvEnvironment()
|
| 83 |
+
obs = env_instance.reset(task_id=task_id)
|
| 84 |
|
| 85 |
+
# Return observation in OpenEnv format
|
| 86 |
+
return obs.model_dump()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
except Exception as e:
|
| 89 |
raise HTTPException(
|
|
|
|
| 123 |
# Create action
|
| 124 |
action = IncidentAction(
|
| 125 |
action_id=request.action_id,
|
| 126 |
+
task_id=request.task_id if request.task_id else env_instance.task_id
|
| 127 |
)
|
| 128 |
|
| 129 |
# Execute step
|
| 130 |
obs = env_instance.step(action)
|
| 131 |
|
| 132 |
+
# Return observation in OpenEnv format
|
| 133 |
+
return obs.model_dump()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
except Exception as e:
|
| 136 |
raise HTTPException(
|
|
|
|
| 182 |
Dictionary of action IDs and names
|
| 183 |
"""
|
| 184 |
try:
|
|
|
|
| 185 |
temp_env = MyEnvEnvironment()
|
| 186 |
|
| 187 |
return {
|