KrisKeshav commited on
Commit
542893e
·
unverified ·
1 Parent(s): 81e328b

Enhance API endpoints with async locking and error handling

Browse files

Refactor /reset, /step, and /state endpoints to use async locks for thread safety. Update ResetRequest to handle optional seed.

Files changed (1) hide show
  1. src/api.py +25 -21
src/api.py CHANGED
@@ -8,32 +8,31 @@ Exposes HTTP endpoints for environment interaction:
8
  GET /health — Health check (returns 200)
9
  """
10
 
11
- from fastapi import FastAPI
 
 
 
12
  from pydantic import BaseModel
13
- from typing import Any, Dict
14
 
15
  from src.models import Observation, Action, Reward, State
16
  from src.env import PLLAttackEnv
17
 
18
-
19
  app = FastAPI(
20
  title="PLL Cyberattack Detection OpenEnv",
21
  description="OpenEnv for AI-driven cyberattack detection on SRF-PLLs",
22
  version="1.0.0",
23
  )
24
 
25
- # Global environment instance
26
  env = PLLAttackEnv()
 
27
 
28
 
29
  class ResetRequest(BaseModel):
30
- """Request body for /reset endpoint."""
31
  task_id: int = 0
32
- seed: int = None
33
 
34
 
35
  class StepResponse(BaseModel):
36
- """Response body for /step endpoint."""
37
  observation: Observation
38
  reward: Reward
39
  done: bool
@@ -41,31 +40,36 @@ class StepResponse(BaseModel):
41
 
42
 
43
  @app.post("/reset", response_model=Observation)
44
- async def reset(request: ResetRequest):
45
  """Reset the environment and return initial observation."""
46
- obs = env.reset(task_id=request.task_id, seed=request.seed)
47
- return obs
 
 
 
 
 
 
 
 
 
48
 
49
 
50
  @app.post("/step", response_model=StepResponse)
51
  async def step(action: Action):
52
- """Submit an action and advance the environment one step."""
53
- obs, reward, done, info = env.step(action)
54
- return StepResponse(
55
- observation=obs,
56
- reward=reward,
57
- done=done,
58
- info=info,
59
- )
60
 
61
 
62
  @app.get("/state", response_model=State)
63
  async def get_state():
64
- """Return the current internal state."""
65
- return env.get_state()
66
 
67
 
68
  @app.get("/health")
69
  async def health():
70
- """Health check endpoint."""
71
  return {"status": "ok"}
 
8
  GET /health — Health check (returns 200)
9
  """
10
 
11
+ import asyncio
12
+ from typing import Any, Dict, Optional
13
+
14
+ from fastapi import FastAPI, HTTPException, Request
15
  from pydantic import BaseModel
 
16
 
17
  from src.models import Observation, Action, Reward, State
18
  from src.env import PLLAttackEnv
19
 
 
20
  app = FastAPI(
21
  title="PLL Cyberattack Detection OpenEnv",
22
  description="OpenEnv for AI-driven cyberattack detection on SRF-PLLs",
23
  version="1.0.0",
24
  )
25
 
 
26
  env = PLLAttackEnv()
27
+ env_lock = asyncio.Lock()
28
 
29
 
30
  class ResetRequest(BaseModel):
 
31
  task_id: int = 0
32
+ seed: Optional[int] = None
33
 
34
 
35
  class StepResponse(BaseModel):
 
36
  observation: Observation
37
  reward: Reward
38
  done: bool
 
40
 
41
 
42
  @app.post("/reset", response_model=Observation)
43
+ async def reset(req: Request):
44
  """Reset the environment and return initial observation."""
45
+ async with env_lock:
46
+ try:
47
+ body = await req.body()
48
+ if body:
49
+ data = await req.json()
50
+ request = ResetRequest(**data)
51
+ else:
52
+ request = ResetRequest()
53
+ except Exception:
54
+ request = ResetRequest()
55
+ return env.reset(task_id=request.task_id, seed=request.seed)
56
 
57
 
58
  @app.post("/step", response_model=StepResponse)
59
  async def step(action: Action):
60
+ async with env_lock:
61
+ if env.attack_generator is None:
62
+ raise HTTPException(status_code=400, detail="Call /reset before /step")
63
+ obs, reward, done, info = env.step(action)
64
+ return StepResponse(observation=obs, reward=reward, done=done, info=info)
 
 
 
65
 
66
 
67
  @app.get("/state", response_model=State)
68
  async def get_state():
69
+ async with env_lock:
70
+ return env.get_state()
71
 
72
 
73
  @app.get("/health")
74
  async def health():
 
75
  return {"status": "ok"}