AIMLxDIV commited on
Commit
8ae0c62
·
unverified ·
2 Parent(s): e8fc831c2a73f0

Merge pull request #26 from ArshVermaGit/main

Browse files

feat: add API hardening, rate limiting, and resource management

Files changed (4) hide show
  1. app.py +135 -31
  2. codereview_env/config.py +23 -0
  3. requirements.txt +3 -0
  4. tests/test_api.py +4 -2
app.py CHANGED
@@ -1,57 +1,105 @@
1
  import uuid
2
- from typing import Dict
3
-
4
- from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
 
 
 
 
 
 
5
  from pydantic import BaseModel
 
 
 
6
 
7
  from codereview_env.models import (
8
  TaskId, Action, ResetResult, StepResult, EpisodeResult
9
  )
10
  from codereview_env.env import CodeReviewEnv
 
 
 
 
 
 
 
 
 
11
 
 
12
  app = FastAPI(
13
  title="AgentOrg CodeReview OpenEnv API",
14
  description=(
15
  "AI Senior Code Reviewer evaluation environment. "
16
  "Trains agents to detect bugs, security vulnerabilities, and architectural issues "
17
- "in realistic Python PRs grounded in real-world incident patterns."
18
  ),
19
  version="1.0.0",
20
  )
21
 
22
- # Simple in-memory storage for active episodes
23
- episodes: Dict[str, CodeReviewEnv] = {}
 
 
 
 
 
24
 
 
 
 
 
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  class ResetRequest(BaseModel):
27
  task_id: TaskId
28
  seed: int = 42
29
 
30
-
31
  class ResetResponse(BaseModel):
32
  episode_id: str
33
  result: ResetResult
34
 
35
-
36
- # In-memory leaderboard
37
- leaderboard: Dict[TaskId, list] = {
38
  TaskId.BUG_DETECTION: [],
39
  TaskId.SECURITY_AUDIT: [],
40
  TaskId.ARCHITECTURAL_REVIEW: []
41
  }
42
 
43
-
44
  class SubmitScore(BaseModel):
45
  agent_name: str
46
  task_id: TaskId
47
  score: float
48
  seed: int
49
 
50
-
51
  # ── WebSocket clients ─────────────────────────────────────────────────────────
52
  clients = set()
53
 
54
-
55
  async def broadcast_event(data: dict):
56
  from fastapi.encoders import jsonable_encoder
57
  import json
@@ -64,30 +112,55 @@ async def broadcast_event(data: dict):
64
  dead.add(client)
65
  clients.difference_update(dead)
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  # ── Endpoints ─────────────────────────────────────────────────────────────────
69
 
70
  @app.get("/health")
71
  def health_check():
72
  return {
73
- "status": "ok",
74
- "version": "1.0.0",
75
  "env_ready": True,
 
76
  "active_episodes": len(episodes),
 
77
  }
78
 
79
-
80
  @app.post("/reset", response_model=ResetResponse)
81
- def reset_env(req: ResetRequest):
 
82
  episode_id = str(uuid.uuid4())
83
  env = CodeReviewEnv()
84
  result = env.reset(req.task_id, req.seed)
85
  episodes[episode_id] = env
 
86
  return ResetResponse(episode_id=episode_id, result=result)
87
 
88
-
89
  @app.post("/step/{episode_id}", response_model=StepResult)
90
- async def step_env(episode_id: str, action: Action):
 
91
  if episode_id not in episodes:
92
  raise HTTPException(status_code=404, detail="Episode not found")
93
 
@@ -99,29 +172,61 @@ async def step_env(episode_id: str, action: Action):
99
  except RuntimeError as e:
100
  raise HTTPException(status_code=400, detail=str(e))
101
 
102
-
103
  @app.get("/result/{episode_id}", response_model=EpisodeResult)
104
- def get_result(episode_id: str):
105
  if episode_id not in episodes:
106
  raise HTTPException(status_code=404, detail="Episode not found")
107
  return episodes[episode_id].get_final_result()
108
 
109
-
110
  @app.get("/leaderboard")
111
- def get_leaderboard():
112
- return leaderboard
113
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  @app.post("/submit")
116
- def submit_to_leaderboard(submission: SubmitScore):
 
117
  entries = leaderboard.get(submission.task_id, [])
118
  new_entry = submission.model_dump()
119
  entries.append(new_entry)
120
  entries.sort(key=lambda x: x["score"], reverse=True)
121
  rank = entries.index(new_entry) + 1 # capture rank before slicing
122
- leaderboard[submission.task_id] = entries[:5]
123
- return {"status": "submitted", "rank": rank if rank <= 5 else None}
124
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  @app.websocket("/ws/events")
127
  async def websocket_endpoint(websocket: WebSocket):
@@ -135,7 +240,6 @@ async def websocket_endpoint(websocket: WebSocket):
135
  finally:
136
  clients.discard(websocket)
137
 
138
-
139
  if __name__ == "__main__":
140
  import uvicorn
141
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import uuid
2
+ import logging
3
+ import asyncio
4
+ from typing import Dict, List, Optional
5
+ from datetime import datetime, timezone
6
+
7
+ from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Depends, Security, Query, BackgroundTasks, Request
8
+ from fastapi.responses import JSONResponse
9
+ from fastapi.exceptions import RequestValidationError
10
+ from fastapi.security.api_key import APIKeyHeader
11
  from pydantic import BaseModel
12
+ from slowapi import Limiter, _rate_limit_exceeded_handler
13
+ from slowapi.util import get_remote_address
14
+ from slowapi.errors import RateLimitExceeded
15
 
16
  from codereview_env.models import (
17
  TaskId, Action, ResetResult, StepResult, EpisodeResult
18
  )
19
  from codereview_env.env import CodeReviewEnv
20
+ from codereview_env.config import get_settings
21
+
22
+ # ── Logging ───────────────────────────────────────────────────────────────────
23
+ settings = get_settings()
24
+ logging.basicConfig(
25
+ level=getattr(logging, settings.log_level),
26
+ format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
27
+ )
28
+ logger = logging.getLogger("codereview_env")
29
 
30
+ # ── App Initialization ────────────────────────────────────────────────────────
31
  app = FastAPI(
32
  title="AgentOrg CodeReview OpenEnv API",
33
  description=(
34
  "AI Senior Code Reviewer evaluation environment. "
35
  "Trains agents to detect bugs, security vulnerabilities, and architectural issues "
36
+ "in realistic Python PRs."
37
  ),
38
  version="1.0.0",
39
  )
40
 
41
+ # ── Rate Limiting ─────────────────────────────────────────────────────────────
42
+ limiter = Limiter(key_func=get_remote_address, default_limits=[f"{settings.rate_limit_per_minute}/minute"])
43
+ app.state.limiter = limiter
44
+ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
45
+
46
+ # ── API Key Authentication ────────────────────────────────────────────────────
47
+ API_KEY_HEADER = APIKeyHeader(name="X-API-Key", auto_error=False)
48
 
49
+ async def verify_api_key(api_key: str = Security(API_KEY_HEADER)):
50
+ if not settings.api_key_enabled:
51
+ return # Auth disabled in development
52
+ if api_key != settings.api_key:
53
+ raise HTTPException(status_code=403, detail="Invalid or missing API key")
54
 
55
+ # ── Storage & TTL ─────────────────────────────────────────────────────────────
56
+ episodes: Dict[str, CodeReviewEnv] = {}
57
+ episode_timestamps: Dict[str, datetime] = {}
58
+
59
+ async def cleanup_expired_episodes():
60
+ """Remove episodes older than TTL."""
61
+ while True:
62
+ await asyncio.sleep(300) # run every 5 minutes
63
+ cutoff = datetime.now(timezone.utc).timestamp() - settings.episode_ttl_seconds
64
+ expired = [
65
+ eid for eid, ts in episode_timestamps.items()
66
+ if ts.timestamp() < cutoff
67
+ ]
68
+ for eid in expired:
69
+ episodes.pop(eid, None)
70
+ episode_timestamps.pop(eid, None)
71
+ if expired:
72
+ logger.info(f"Cleaned up {len(expired)} expired episodes")
73
+
74
+ @app.on_event("startup")
75
+ async def startup_event():
76
+ asyncio.create_task(cleanup_expired_episodes())
77
+ logger.info(f"CodeReview API started on port {settings.app_port}")
78
+
79
+ # ── Models ────────────────────────────────────────────────────────────────────
80
  class ResetRequest(BaseModel):
81
  task_id: TaskId
82
  seed: int = 42
83
 
 
84
  class ResetResponse(BaseModel):
85
  episode_id: str
86
  result: ResetResult
87
 
88
+ leaderboard: Dict[TaskId, List[dict]] = {
 
 
89
  TaskId.BUG_DETECTION: [],
90
  TaskId.SECURITY_AUDIT: [],
91
  TaskId.ARCHITECTURAL_REVIEW: []
92
  }
93
 
 
94
  class SubmitScore(BaseModel):
95
  agent_name: str
96
  task_id: TaskId
97
  score: float
98
  seed: int
99
 
 
100
  # ── WebSocket clients ─────────────────────────────────────────────────────────
101
  clients = set()
102
 
 
103
  async def broadcast_event(data: dict):
104
  from fastapi.encoders import jsonable_encoder
105
  import json
 
112
  dead.add(client)
113
  clients.difference_update(dead)
114
 
115
+ # ── Error Handlers ────────────────────────────────────────────────────────────
116
+ @app.exception_handler(RequestValidationError)
117
+ async def validation_exception_handler(request, exc):
118
+ return JSONResponse(
119
+ status_code=422,
120
+ content={
121
+ "error": "validation_error",
122
+ "detail": str(exc),
123
+ "status_code": 422
124
+ }
125
+ )
126
+
127
+ @app.exception_handler(HTTPException)
128
+ async def http_exception_handler(request, exc):
129
+ logger.warning(f"HTTP {exc.status_code}: {exc.detail} \u2014 {request.url}")
130
+ return JSONResponse(
131
+ status_code=exc.status_code,
132
+ content={
133
+ "error": exc.detail,
134
+ "status_code": exc.status_code
135
+ }
136
+ )
137
 
138
  # ── Endpoints ─────────────────────────────────────────────────────────────────
139
 
140
  @app.get("/health")
141
  def health_check():
142
  return {
143
+ "status": "ok",
144
+ "version": "1.0.0",
145
  "env_ready": True,
146
+ "env": settings.app_env,
147
  "active_episodes": len(episodes),
148
+ "auth_enabled": settings.api_key_enabled
149
  }
150
 
 
151
  @app.post("/reset", response_model=ResetResponse)
152
+ @limiter.limit(f"{settings.rate_limit_per_minute}/minute")
153
+ def reset_env(request: Request, req: ResetRequest, _: None = Depends(verify_api_key)):
154
  episode_id = str(uuid.uuid4())
155
  env = CodeReviewEnv()
156
  result = env.reset(req.task_id, req.seed)
157
  episodes[episode_id] = env
158
+ episode_timestamps[episode_id] = datetime.now(timezone.utc)
159
  return ResetResponse(episode_id=episode_id, result=result)
160
 
 
161
  @app.post("/step/{episode_id}", response_model=StepResult)
162
+ @limiter.limit(f"{settings.rate_limit_per_minute}/minute")
163
+ async def step_env(request: Request, episode_id: str, action: Action, _: None = Depends(verify_api_key)):
164
  if episode_id not in episodes:
165
  raise HTTPException(status_code=404, detail="Episode not found")
166
 
 
172
  except RuntimeError as e:
173
  raise HTTPException(status_code=400, detail=str(e))
174
 
 
175
  @app.get("/result/{episode_id}", response_model=EpisodeResult)
176
+ def get_result(episode_id: str, _: None = Depends(verify_api_key)):
177
  if episode_id not in episodes:
178
  raise HTTPException(status_code=404, detail="Episode not found")
179
  return episodes[episode_id].get_final_result()
180
 
 
181
  @app.get("/leaderboard")
182
+ def get_leaderboard(
183
+ task_id: Optional[TaskId] = None,
184
+ limit: int = Query(default=10, ge=1, le=50),
185
+ offset: int = Query(default=0, ge=0)
186
+ ):
187
+ if task_id:
188
+ entries = leaderboard.get(task_id, [])
189
+ return {
190
+ "task_id": task_id,
191
+ "entries": entries[offset:offset+limit],
192
+ "total": len(entries)
193
+ }
194
+ return {
195
+ task: {
196
+ "entries": entries[offset:offset+limit],
197
+ "total": len(entries)
198
+ }
199
+ for task, entries in leaderboard.items()
200
+ }
201
 
202
  @app.post("/submit")
203
+ @limiter.limit(f"{settings.rate_limit_per_minute}/minute")
204
+ def submit_to_leaderboard(request: Request, submission: SubmitScore, _: None = Depends(verify_api_key)):
205
  entries = leaderboard.get(submission.task_id, [])
206
  new_entry = submission.model_dump()
207
  entries.append(new_entry)
208
  entries.sort(key=lambda x: x["score"], reverse=True)
209
  rank = entries.index(new_entry) + 1 # capture rank before slicing
210
+ leaderboard[submission.task_id] = entries[:settings.leaderboard_max_entries]
211
+ in_top_n = rank <= settings.leaderboard_max_entries
212
+ return {"status": "submitted", "rank": rank if in_top_n else None}
213
+
214
+ @app.get("/episodes")
215
+ def list_episodes(
216
+ _: None = Depends(verify_api_key),
217
+ limit: int = Query(default=20, ge=1, le=100)
218
+ ):
219
+ episode_list = [
220
+ {
221
+ "episode_id": eid,
222
+ "task_id": env.task_id,
223
+ "step_count": env.observation.step_count,
224
+ "done": env.done,
225
+ "created_at": episode_timestamps.get(eid, "").isoformat() if episode_timestamps.get(eid) else ""
226
+ }
227
+ for eid, env in list(episodes.items())[:limit]
228
+ ]
229
+ return {"episodes": episode_list, "total": len(episodes)}
230
 
231
  @app.websocket("/ws/events")
232
  async def websocket_endpoint(websocket: WebSocket):
 
240
  finally:
241
  clients.discard(websocket)
242
 
 
243
  if __name__ == "__main__":
244
  import uvicorn
245
+ uvicorn.run(app, host=settings.app_host, port=settings.app_port)
codereview_env/config.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+ from pydantic_settings import BaseSettings, SettingsConfigDict
3
+
4
+ class Settings(BaseSettings):
5
+ model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
6
+
7
+ app_host: str = "0.0.0.0"
8
+ app_port: int = 7860
9
+ app_env: str = "development"
10
+
11
+ api_key: str = "changeme"
12
+ api_key_enabled: bool = False
13
+
14
+ leaderboard_max_entries: int = 10
15
+
16
+ log_level: str = "INFO"
17
+
18
+ episode_ttl_seconds: int = 3600 # episodes expire after 1 hour
19
+ rate_limit_per_minute: int = 60 # requests per minute per IP
20
+
21
+ @lru_cache
22
+ def get_settings() -> Settings:
23
+ return Settings()
requirements.txt CHANGED
@@ -6,3 +6,6 @@ requests>=2.31.0
6
  websockets>=12.0
7
  httpx<0.28.0
8
  openai>=1.0.0
 
 
 
 
6
  websockets>=12.0
7
  httpx<0.28.0
8
  openai>=1.0.0
9
+ pydantic-settings==2.2.1
10
+ slowapi==0.1.9
11
+ python-dotenv==1.0.1
tests/test_api.py CHANGED
@@ -50,8 +50,10 @@ def test_api_leaderboard():
50
  # Check leaderboard
51
  lb_resp = client.get("/leaderboard")
52
  assert lb_resp.status_code == 200
53
- assert len(lb_resp.json()["bug_detection"]) > 0
54
- assert lb_resp.json()["bug_detection"][0]["agent_name"] == "test_agent"
 
 
55
 
56
  def test_api_invalid_episode():
57
  client = TestClient(app)
 
50
  # Check leaderboard
51
  lb_resp = client.get("/leaderboard")
52
  assert lb_resp.status_code == 200
53
+ lb_data = lb_resp.json()
54
+ bug_entries = lb_data["bug_detection"]["entries"]
55
+ assert len(bug_entries) > 0
56
+ assert bug_entries[0]["agent_name"] == "test_agent"
57
 
58
  def test_api_invalid_episode():
59
  client = TestClient(app)