Arijit-07 commited on
Commit
fca2aa4
·
1 Parent(s): d50c3f9

Add WebSocket support at /ws endpoint

Browse files
Files changed (2) hide show
  1. openenv.yaml +5 -0
  2. server/app.py +85 -1
openenv.yaml CHANGED
@@ -124,6 +124,11 @@ reward:
124
  healthy services), excessive noops, and treating symptoms instead
125
  of root causes. Efficiency bonus for fast resolution.
126
 
 
 
 
 
 
127
  docker:
128
  base_image: python:3.11-slim
129
  port: 7860
 
124
  healthy services), excessive noops, and treating symptoms instead
125
  of root causes. Efficiency bonus for fast resolution.
126
 
127
+ websocket:
128
+ endpoint: /ws
129
+ protocol: json
130
+ commands: [reset, step, state]
131
+
132
  docker:
133
  base_image: python:3.11-slim
134
  port: 7860
server/app.py CHANGED
@@ -1,5 +1,5 @@
1
  from __future__ import annotations
2
- from fastapi import FastAPI, HTTPException, Request
3
  from fastapi.responses import HTMLResponse
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from pydantic import BaseModel
@@ -297,3 +297,87 @@ def validate():
297
  _env._logic = old_logic
298
  all_ok = all(r.get("status") == "ok" and r.get("in_range") for r in results)
299
  return {"validation": "passed" if all_ok else "failed", "tasks": results}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
+ from fastapi import FastAPI, HTTPException, Request, WebSocket, WebSocketDisconnect
3
  from fastapi.responses import HTMLResponse
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from pydantic import BaseModel
 
297
  _env._logic = old_logic
298
  all_ok = all(r.get("status") == "ok" and r.get("in_range") for r in results)
299
  return {"validation": "passed" if all_ok else "failed", "tasks": results}
300
+
301
+
302
+ @app.websocket("/ws")
303
+ async def websocket_endpoint(websocket: WebSocket):
304
+ await websocket.accept()
305
+ # Independent environment instance for this connection
306
+ ws_env = DevOpsEnvironment()
307
+
308
+ try:
309
+ while True:
310
+ data = await websocket.receive_json()
311
+ command = data.get("command")
312
+
313
+ print(f"WebSocket received: {data}")
314
+
315
+ if command == "reset":
316
+ task_id = data.get("task_id", "easy")
317
+ seed = data.get("seed")
318
+ obs = await ws_env.reset(seed=seed, task_id=task_id)
319
+ await websocket.send_json({
320
+ "type": "observation",
321
+ "data": obs.model_dump() if hasattr(obs, "model_dump") else obs.dict()
322
+ })
323
+
324
+ elif command == "step":
325
+ if ws_env._logic is None:
326
+ await websocket.send_json({
327
+ "type": "error",
328
+ "message": "Call reset before step"
329
+ })
330
+ continue
331
+
332
+ action_data = data.get("action", {})
333
+ try:
334
+ action = Action(**action_data)
335
+ step_result = await ws_env.step(action)
336
+ await websocket.send_json({
337
+ "type": "step_result",
338
+ "data": {
339
+ "observation": step_result.observation.model_dump() if hasattr(step_result.observation, "model_dump") else step_result.observation.dict(),
340
+ "reward": step_result.reward,
341
+ "done": step_result.done,
342
+ "info": step_result.info
343
+ }
344
+ })
345
+ except Exception as e:
346
+ await websocket.send_json({
347
+ "type": "error",
348
+ "message": str(e)
349
+ })
350
+
351
+ elif command == "state":
352
+ if ws_env._logic is None:
353
+ await websocket.send_json({
354
+ "type": "error",
355
+ "message": "Call reset before state"
356
+ })
357
+ continue
358
+
359
+ state = ws_env.state
360
+ await websocket.send_json({
361
+ "type": "state",
362
+ "data": state.model_dump() if hasattr(state, "model_dump") else state.dict()
363
+ })
364
+
365
+ else:
366
+ await websocket.send_json({
367
+ "type": "error",
368
+ "message": f"Unrecognized command: {command}"
369
+ })
370
+
371
+ except WebSocketDisconnect:
372
+ print("WebSocket client disconnected")
373
+ except Exception as e:
374
+ print(f"WebSocket error: {e}")
375
+ try:
376
+ await websocket.send_json({
377
+ "type": "error",
378
+ "message": str(e)
379
+ })
380
+ except:
381
+ pass
382
+ await websocket.close()
383
+