Commit Β·
4c8efe2
0
Parent(s):
first commit
Browse files- .gitignore +1 -0
- Dockerfile +19 -0
- README.md +202 -0
- __pycache__/environment.cpython-310.pyc +0 -0
- __pycache__/graders.cpython-310.pyc +0 -0
- __pycache__/models.cpython-310.pyc +0 -0
- __pycache__/simulator.cpython-310.pyc +0 -0
- __pycache__/tasks.cpython-310.pyc +0 -0
- environment.py +219 -0
- graders.py +125 -0
- inference.py +198 -0
- models.py +77 -0
- openenv.yaml +118 -0
- requirements.txt +7 -0
- simulator.py +78 -0
- tasks.py +94 -0
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
.venv
|
Dockerfile
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install dependencies first (layer cache)
|
| 6 |
+
COPY requirements.txt .
|
| 7 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 8 |
+
|
| 9 |
+
# Copy application code
|
| 10 |
+
COPY . .
|
| 11 |
+
|
| 12 |
+
# HuggingFace Spaces requires port 7860
|
| 13 |
+
EXPOSE 7860
|
| 14 |
+
|
| 15 |
+
# Healthcheck so orchestrators know when the app is ready
|
| 16 |
+
HEALTHCHECK --interval=10s --timeout=5s --start-period=15s --retries=3 \
|
| 17 |
+
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:7860/health')"
|
| 18 |
+
|
| 19 |
+
CMD ["uvicorn", "environment:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Adaptive Traffic Controller
|
| 3 |
+
emoji: π¦
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: red
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
tags:
|
| 9 |
+
- openenv
|
| 10 |
+
- reinforcement-learning
|
| 11 |
+
- traffic-control
|
| 12 |
+
- llm-agent
|
| 13 |
+
license: mit
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
# Adaptive Backend Traffic Controller
|
| 17 |
+
|
| 18 |
+
An **OpenEnv**-compatible reinforcement learning environment where an LLM agent learns to prevent backend server crashes by intelligently throttling incoming traffic in real-time.
|
| 19 |
+
|
| 20 |
+
Built for the **Scaler Γ Meta PyTorch Hackathon**.
|
| 21 |
+
|
| 22 |
+
---
|
| 23 |
+
|
| 24 |
+
## Overview
|
| 25 |
+
|
| 26 |
+
The environment simulates a backend server receiving variable traffic. The agent observes system metrics every time step and chooses a throttling action to keep the server healthy. The server's physics are modelled realistically: CPU and memory track load linearly, latency spikes superlinearly, and sustained overload causes crashes.
|
| 27 |
+
|
| 28 |
+
---
|
| 29 |
+
|
| 30 |
+
## Observation Space
|
| 31 |
+
|
| 32 |
+
| Field | Type | Range | Description |
|
| 33 |
+
|-------|------|--------|-------------|
|
| 34 |
+
| `cpu_usage` | float | 0.0 β 1.0 | CPU utilization fraction |
|
| 35 |
+
| `memory_usage` | float | 0.0 β 1.0 | Memory utilization fraction |
|
| 36 |
+
| `request_rate` | float | β₯ 0 | Incoming requests per second |
|
| 37 |
+
| `queue_length` | int | 0 β 500 | Pending requests in backlog |
|
| 38 |
+
| `avg_latency` | float | β₯ 0 | Average response latency (ms) |
|
| 39 |
+
| `step` | int | β₯ 0 | Current episode step |
|
| 40 |
+
| `crashed` | bool | β | Whether the server crashed this step |
|
| 41 |
+
|
| 42 |
+
---
|
| 43 |
+
|
| 44 |
+
## Action Space
|
| 45 |
+
|
| 46 |
+
| Action | Accept Rate | Description |
|
| 47 |
+
|--------|------------|-------------|
|
| 48 |
+
| `allow_all` | 100% | Safe load β accept all requests |
|
| 49 |
+
| `throttle_70` | 70% | Moderate load β drop 30% |
|
| 50 |
+
| `throttle_40` | 40% | High load β drop 60% |
|
| 51 |
+
| `drop_aggressive` | 20% | Imminent crash β drop 80% |
|
| 52 |
+
|
| 53 |
+
---
|
| 54 |
+
|
| 55 |
+
## Tasks
|
| 56 |
+
|
| 57 |
+
### Task Easy β Single Spike
|
| 58 |
+
- Traffic: 40 req/s baseline β 160 req/s spike at step 10 for 5 steps β back to 40
|
| 59 |
+
- Episode length: 30 steps
|
| 60 |
+
- Scoring:
|
| 61 |
+
- `1.0` β no crash AND avg latency < 300 ms
|
| 62 |
+
- `0.5` β no crash, but avg latency β₯ 300 ms
|
| 63 |
+
- `0.0` β any crash
|
| 64 |
+
|
| 65 |
+
### Task Medium β Multiple Spikes
|
| 66 |
+
- Traffic: 50 req/s baseline with 3 spikes of 150 req/s at steps 5, 15, 25 (3 steps each)
|
| 67 |
+
- Episode length: 40 steps
|
| 68 |
+
- Scoring: `(steps_without_crash / total_steps) Γ latency_factor`
|
| 69 |
+
- `latency_factor` = 1.0 at β€ 200 ms, 0.5 at β₯ 600 ms, linear between
|
| 70 |
+
|
| 71 |
+
### Task Hard β Sustained Overload
|
| 72 |
+
- Traffic: ramps 60 β 200 req/s over 20 steps, stays at 200 for 20 steps, drops to 80
|
| 73 |
+
- Episode length: 50 steps
|
| 74 |
+
- Scoring: `throughput_ratio Γ 0.7 + queue_factor Γ 0.3`
|
| 75 |
+
- `throughput_ratio` = total allowed / total incoming
|
| 76 |
+
- `queue_factor` = fraction of steps with queue < 100
|
| 77 |
+
|
| 78 |
+
---
|
| 79 |
+
|
| 80 |
+
## API Endpoints
|
| 81 |
+
|
| 82 |
+
| Method | Path | Description |
|
| 83 |
+
|--------|------|-------------|
|
| 84 |
+
| `POST` | `/reset` | Reset environment, returns initial state |
|
| 85 |
+
| `POST` | `/step` | Execute action, returns state/reward/done/info |
|
| 86 |
+
| `GET` | `/state` | Current server state |
|
| 87 |
+
| `GET` | `/tasks` | List all 3 tasks |
|
| 88 |
+
| `GET` | `/openenv.yaml` | OpenEnv specification |
|
| 89 |
+
| `GET` | `/health` | Liveness probe |
|
| 90 |
+
|
| 91 |
+
---
|
| 92 |
+
|
| 93 |
+
## Setup
|
| 94 |
+
|
| 95 |
+
### Local (Python)
|
| 96 |
+
|
| 97 |
+
```bash
|
| 98 |
+
pip install -r requirements.txt
|
| 99 |
+
|
| 100 |
+
# Start the environment server
|
| 101 |
+
uvicorn environment:app --host 0.0.0.0 --port 7860
|
| 102 |
+
|
| 103 |
+
# In another terminal, run a quick smoke test
|
| 104 |
+
curl -s localhost:7860/health
|
| 105 |
+
curl -s -X POST localhost:7860/reset -H "Content-Type: application/json" \
|
| 106 |
+
-d '{"task_id": "task_easy"}' | python -m json.tool
|
| 107 |
+
curl -s -X POST localhost:7860/step -H "Content-Type: application/json" \
|
| 108 |
+
-d '{"action": "throttle_70"}' | python -m json.tool
|
| 109 |
+
curl -s localhost:7860/tasks | python -m json.tool
|
| 110 |
+
curl -s localhost:7860/openenv.yaml
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
### Docker
|
| 114 |
+
|
| 115 |
+
```bash
|
| 116 |
+
docker build -t traffic-controller .
|
| 117 |
+
docker run -p 7860:7860 traffic-controller
|
| 118 |
+
|
| 119 |
+
# Same smoke tests work on localhost:7860
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
---
|
| 123 |
+
|
| 124 |
+
## Running Inference
|
| 125 |
+
|
| 126 |
+
Set the three required environment variables then run `inference.py`:
|
| 127 |
+
|
| 128 |
+
```bash
|
| 129 |
+
export API_BASE_URL="https://api-inference.huggingface.co/models/<your-model>/v1"
|
| 130 |
+
export MODEL_NAME="meta-llama/Llama-3.1-8B-Instruct"
|
| 131 |
+
export HF_TOKEN="hf_..."
|
| 132 |
+
export ENV_URL="http://localhost:7860" # optional, defaults to this
|
| 133 |
+
|
| 134 |
+
python inference.py
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
Expected output:
|
| 138 |
+
|
| 139 |
+
```
|
| 140 |
+
Environment URL : http://localhost:7860
|
| 141 |
+
Model : meta-llama/Llama-3.1-8B-Instruct
|
| 142 |
+
API base : https://api-inference.huggingface.co/...
|
| 143 |
+
|
| 144 |
+
Health check OK
|
| 145 |
+
|
| 146 |
+
=== TASK_EASY ===
|
| 147 |
+
Starting task_easy (max_steps=30)
|
| 148 |
+
step= 1 action=allow_all reward=+0.950 latency= 56.5ms queue= 0 cpu=0.54
|
| 149 |
+
...
|
| 150 |
+
task_easy done β total_reward=27.3, score=1.000
|
| 151 |
+
|
| 152 |
+
=== RESULTS ===
|
| 153 |
+
task_easy : 1.000
|
| 154 |
+
task_medium : 0.875
|
| 155 |
+
task_hard : 0.623
|
| 156 |
+
Overall : 0.833
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
---
|
| 160 |
+
|
| 161 |
+
## Baseline Scores
|
| 162 |
+
|
| 163 |
+
Measured on the deterministic simulator. Scores are in **0.0 β 1.0**.
|
| 164 |
+
|
| 165 |
+
| Agent | task_easy | task_medium | task_hard | Overall |
|
| 166 |
+
|-------|-----------|-------------|-----------|---------|
|
| 167 |
+
| **Always allow_all** (naive) | 0.000 π₯ | 0.833 | 0.300 π₯ | 0.378 |
|
| 168 |
+
| **Always drop_aggressive** (conservative) | 1.000 | 1.000 | 0.440 | 0.813 |
|
| 169 |
+
| **Rule-based heuristic** | 1.000 | 1.000 | 0.500 | 0.833 |
|
| 170 |
+
| **LLM agent** (target) | β₯ 0.9 | β₯ 0.9 | β₯ 0.6 | β₯ 0.8 |
|
| 171 |
+
|
| 172 |
+
π₯ = server crash occurred during episode
|
| 173 |
+
|
| 174 |
+
**Key insight:** The hard task is the differentiator β naive and conservative agents score β€ 0.44 because sustained 200 req/s overload requires balancing throughput (don't drop too much) against stability (don't let load crash the server). A smart LLM agent should outperform all rule-based baselines here.
|
| 175 |
+
|
| 176 |
+
---
|
| 177 |
+
|
| 178 |
+
## Infrastructure
|
| 179 |
+
|
| 180 |
+
- Port: **7860** (HuggingFace Spaces)
|
| 181 |
+
- CPU: 2 vCPU
|
| 182 |
+
- Memory: 8 GB
|
| 183 |
+
- GPU: not required
|
| 184 |
+
- Inference timeout: < 20 minutes total
|
| 185 |
+
|
| 186 |
+
---
|
| 187 |
+
|
| 188 |
+
## Project Structure
|
| 189 |
+
|
| 190 |
+
```
|
| 191 |
+
.
|
| 192 |
+
βββ environment.py # FastAPI app + episode logic
|
| 193 |
+
βββ tasks.py # Traffic patterns + task metadata
|
| 194 |
+
βββ graders.py # Per-task scoring functions
|
| 195 |
+
βββ simulator.py # Backend physics (latency, CPU, memory, crash)
|
| 196 |
+
βββ models.py # Pydantic models (state, action, request/response)
|
| 197 |
+
βββ inference.py # LLM agent runner
|
| 198 |
+
βββ openenv.yaml # OpenEnv spec
|
| 199 |
+
βββ Dockerfile
|
| 200 |
+
βββ requirements.txt
|
| 201 |
+
βββ README.md
|
| 202 |
+
```
|
__pycache__/environment.cpython-310.pyc
ADDED
|
Binary file (5.5 kB). View file
|
|
|
__pycache__/graders.cpython-310.pyc
ADDED
|
Binary file (3.59 kB). View file
|
|
|
__pycache__/models.cpython-310.pyc
ADDED
|
Binary file (3.19 kB). View file
|
|
|
__pycache__/simulator.cpython-310.pyc
ADDED
|
Binary file (2.05 kB). View file
|
|
|
__pycache__/tasks.cpython-310.pyc
ADDED
|
Binary file (2.4 kB). View file
|
|
|
environment.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adaptive Traffic Controller β OpenEnv-compatible FastAPI environment.
|
| 3 |
+
|
| 4 |
+
Endpoints
|
| 5 |
+
---------
|
| 6 |
+
POST /reset reset env, return initial state
|
| 7 |
+
POST /step take action, return (state, reward, done, info)
|
| 8 |
+
GET /state current state
|
| 9 |
+
GET /tasks list all tasks
|
| 10 |
+
GET /openenv.yaml OpenEnv spec
|
| 11 |
+
GET /health liveness probe
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
from contextlib import asynccontextmanager
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Any
|
| 20 |
+
|
| 21 |
+
import yaml
|
| 22 |
+
from fastapi import FastAPI, HTTPException
|
| 23 |
+
from fastapi.responses import PlainTextResponse
|
| 24 |
+
|
| 25 |
+
from graders import grade
|
| 26 |
+
from models import (
|
| 27 |
+
Action,
|
| 28 |
+
ACTION_ACCEPT_RATE,
|
| 29 |
+
EpisodeStep,
|
| 30 |
+
HealthResponse,
|
| 31 |
+
ResetRequest,
|
| 32 |
+
ResetResponse,
|
| 33 |
+
ServerState,
|
| 34 |
+
StepRequest,
|
| 35 |
+
StepResponse,
|
| 36 |
+
TaskListResponse,
|
| 37 |
+
)
|
| 38 |
+
from simulator import compute_next_state, initial_state
|
| 39 |
+
from tasks import EPISODE_LENGTHS, TASK_METADATA, TRAFFIC_PATTERNS
|
| 40 |
+
|
| 41 |
+
# ---------------------------------------------------------------------------
|
| 42 |
+
# In-memory session state
|
| 43 |
+
# ---------------------------------------------------------------------------
|
| 44 |
+
|
| 45 |
+
class EnvSession:
|
| 46 |
+
def __init__(self) -> None:
|
| 47 |
+
self.task_id: str = "task_easy"
|
| 48 |
+
self.state: ServerState = initial_state()
|
| 49 |
+
self.step: int = 0
|
| 50 |
+
self.done: bool = False
|
| 51 |
+
self.history: list[EpisodeStep] = []
|
| 52 |
+
|
| 53 |
+
def reset(self, task_id: str) -> ServerState:
|
| 54 |
+
traffic_fn = TRAFFIC_PATTERNS[task_id]
|
| 55 |
+
first_incoming = traffic_fn(0)
|
| 56 |
+
self.task_id = task_id
|
| 57 |
+
self.state = initial_state(first_incoming)
|
| 58 |
+
self.step = 0
|
| 59 |
+
self.done = False
|
| 60 |
+
self.history = []
|
| 61 |
+
return self.state
|
| 62 |
+
|
| 63 |
+
def step_env(self, action: Action) -> tuple[ServerState, float, bool, dict[str, Any]]:
|
| 64 |
+
if self.done:
|
| 65 |
+
raise ValueError("Episode is done. Call /reset to start a new episode.")
|
| 66 |
+
|
| 67 |
+
task_id = self.task_id
|
| 68 |
+
traffic_fn = TRAFFIC_PATTERNS[task_id]
|
| 69 |
+
max_steps = EPISODE_LENGTHS[task_id]
|
| 70 |
+
|
| 71 |
+
incoming = traffic_fn(self.step)
|
| 72 |
+
accept_rate = ACTION_ACCEPT_RATE[action]
|
| 73 |
+
allowed = incoming * accept_rate
|
| 74 |
+
|
| 75 |
+
next_state, crashed = compute_next_state(self.state, allowed, incoming)
|
| 76 |
+
next_state.step = self.step + 1
|
| 77 |
+
|
| 78 |
+
# --- Reward shaping ---
|
| 79 |
+
reward = _compute_reward(
|
| 80 |
+
incoming=incoming,
|
| 81 |
+
allowed=allowed,
|
| 82 |
+
latency=next_state.avg_latency,
|
| 83 |
+
crashed=crashed,
|
| 84 |
+
queue=next_state.queue_length,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
ep_step = EpisodeStep(
|
| 88 |
+
step=self.step,
|
| 89 |
+
state=next_state,
|
| 90 |
+
action=action,
|
| 91 |
+
reward=reward,
|
| 92 |
+
incoming_requests=incoming,
|
| 93 |
+
allowed_requests=allowed,
|
| 94 |
+
crashed=crashed,
|
| 95 |
+
)
|
| 96 |
+
self.history.append(ep_step)
|
| 97 |
+
|
| 98 |
+
self.step += 1
|
| 99 |
+
self.state = next_state
|
| 100 |
+
self.done = crashed or (self.step >= max_steps)
|
| 101 |
+
|
| 102 |
+
# Expose the *upcoming* incoming rate so the agent can react proactively.
|
| 103 |
+
# This mirrors real monitoring: you see current traffic flow before deciding
|
| 104 |
+
# the next throttle level.
|
| 105 |
+
if not self.done:
|
| 106 |
+
upcoming = traffic_fn(self.step)
|
| 107 |
+
self.state.request_rate = round(upcoming, 2)
|
| 108 |
+
|
| 109 |
+
info: dict[str, Any] = {
|
| 110 |
+
"incoming_requests": incoming,
|
| 111 |
+
"allowed_requests": allowed,
|
| 112 |
+
"accept_rate": accept_rate,
|
| 113 |
+
"crashed": crashed,
|
| 114 |
+
"episode_step": self.step,
|
| 115 |
+
"max_steps": max_steps,
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
if self.done:
|
| 119 |
+
final_score = grade(task_id, self.history)
|
| 120 |
+
info["final_score"] = final_score
|
| 121 |
+
info["episode_done"] = True
|
| 122 |
+
|
| 123 |
+
return next_state, reward, self.done, info
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def _compute_reward(
|
| 127 |
+
incoming: float,
|
| 128 |
+
allowed: float,
|
| 129 |
+
latency: float,
|
| 130 |
+
crashed: bool,
|
| 131 |
+
queue: int,
|
| 132 |
+
) -> float:
|
| 133 |
+
if crashed:
|
| 134 |
+
return -10.0
|
| 135 |
+
|
| 136 |
+
# Throughput reward: prefer allowing more traffic (normalised to [0, 1])
|
| 137 |
+
throughput_reward = allowed / max(incoming, 1.0)
|
| 138 |
+
|
| 139 |
+
# Latency penalty: smooth penalty starting at 200 ms
|
| 140 |
+
latency_penalty = max(0.0, (latency - 200.0) / 800.0) # 0 at 200ms, 1 at 1000ms
|
| 141 |
+
|
| 142 |
+
# Queue penalty
|
| 143 |
+
queue_penalty = min(1.0, queue / 500.0)
|
| 144 |
+
|
| 145 |
+
reward = throughput_reward - latency_penalty * 0.5 - queue_penalty * 0.3
|
| 146 |
+
return round(reward, 4)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# ---------------------------------------------------------------------------
|
| 150 |
+
# App lifecycle
|
| 151 |
+
# ---------------------------------------------------------------------------
|
| 152 |
+
|
| 153 |
+
SESSION = EnvSession()
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
@asynccontextmanager
|
| 157 |
+
async def lifespan(app: FastAPI):
|
| 158 |
+
SESSION.reset("task_easy")
|
| 159 |
+
yield
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
app = FastAPI(
|
| 163 |
+
title="Adaptive Traffic Controller",
|
| 164 |
+
description="OpenEnv environment for LLM-based backend traffic control",
|
| 165 |
+
version="1.0.0",
|
| 166 |
+
lifespan=lifespan,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# ---------------------------------------------------------------------------
|
| 171 |
+
# Endpoints
|
| 172 |
+
# ---------------------------------------------------------------------------
|
| 173 |
+
|
| 174 |
+
@app.get("/health", response_model=HealthResponse)
|
| 175 |
+
async def health() -> HealthResponse:
|
| 176 |
+
return HealthResponse(status="ok")
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
@app.post("/reset", response_model=ResetResponse)
|
| 180 |
+
async def reset(body: ResetRequest = ResetRequest()) -> ResetResponse:
|
| 181 |
+
if body.task_id not in TRAFFIC_PATTERNS:
|
| 182 |
+
raise HTTPException(
|
| 183 |
+
status_code=400,
|
| 184 |
+
detail=f"Unknown task_id {body.task_id!r}. "
|
| 185 |
+
f"Valid: {list(TRAFFIC_PATTERNS.keys())}",
|
| 186 |
+
)
|
| 187 |
+
state = SESSION.reset(body.task_id)
|
| 188 |
+
return ResetResponse(
|
| 189 |
+
state=state,
|
| 190 |
+
task_id=body.task_id,
|
| 191 |
+
max_steps=EPISODE_LENGTHS[body.task_id],
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
@app.post("/step", response_model=StepResponse)
|
| 196 |
+
async def step(body: StepRequest) -> StepResponse:
|
| 197 |
+
try:
|
| 198 |
+
state, reward, done, info = SESSION.step_env(body.action)
|
| 199 |
+
except ValueError as exc:
|
| 200 |
+
raise HTTPException(status_code=400, detail=str(exc))
|
| 201 |
+
return StepResponse(state=state, reward=reward, done=done, info=info)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
@app.get("/state", response_model=ServerState)
|
| 205 |
+
async def get_state() -> ServerState:
|
| 206 |
+
return SESSION.state
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
@app.get("/tasks", response_model=TaskListResponse)
|
| 210 |
+
async def list_tasks() -> TaskListResponse:
|
| 211 |
+
return TaskListResponse(tasks=TASK_METADATA)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
@app.get("/openenv.yaml", response_class=PlainTextResponse)
|
| 215 |
+
async def get_openenv_yaml() -> str:
|
| 216 |
+
yaml_path = Path(__file__).parent / "openenv.yaml"
|
| 217 |
+
if not yaml_path.exists():
|
| 218 |
+
raise HTTPException(status_code=404, detail="openenv.yaml not found")
|
| 219 |
+
return yaml_path.read_text()
|
graders.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Graders β deterministic scoring for each task.
|
| 3 |
+
|
| 4 |
+
Each grader receives the full episode history and returns a float in [0.0, 1.0].
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from models import EpisodeStep
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# ---------------------------------------------------------------------------
|
| 13 |
+
# Task Easy β Single Spike
|
| 14 |
+
# ---------------------------------------------------------------------------
|
| 15 |
+
|
| 16 |
+
def grade_task_easy(history: list[EpisodeStep]) -> float:
|
| 17 |
+
"""
|
| 18 |
+
Score:
|
| 19 |
+
1.0 β no crash AND avg latency across all steps < 300 ms
|
| 20 |
+
0.5 β no crash but avg latency >= 300 ms
|
| 21 |
+
0.0 β any crash occurred
|
| 22 |
+
"""
|
| 23 |
+
if not history:
|
| 24 |
+
return 0.0
|
| 25 |
+
|
| 26 |
+
crashed = any(s.crashed for s in history)
|
| 27 |
+
if crashed:
|
| 28 |
+
return 0.0
|
| 29 |
+
|
| 30 |
+
avg_latency = sum(s.state.avg_latency for s in history) / len(history)
|
| 31 |
+
if avg_latency < 300.0:
|
| 32 |
+
return 1.0
|
| 33 |
+
return 0.5
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# ---------------------------------------------------------------------------
|
| 37 |
+
# Task Medium β Multiple Spikes
|
| 38 |
+
# ---------------------------------------------------------------------------
|
| 39 |
+
|
| 40 |
+
def grade_task_medium(history: list[EpisodeStep]) -> float:
|
| 41 |
+
"""
|
| 42 |
+
Score:
|
| 43 |
+
base = steps_without_crash / total_steps
|
| 44 |
+
penalty factor for high latency: multiplied by latency_factor in [0.5, 1.0]
|
| 45 |
+
latency_factor = 1.0 if avg_latency <= 200 ms
|
| 46 |
+
latency_factor = 0.5 if avg_latency >= 600 ms
|
| 47 |
+
linear interpolation in between
|
| 48 |
+
"""
|
| 49 |
+
if not history:
|
| 50 |
+
return 0.0
|
| 51 |
+
|
| 52 |
+
total = len(history)
|
| 53 |
+
crash_steps = sum(1 for s in history if s.crashed)
|
| 54 |
+
base = (total - crash_steps) / total
|
| 55 |
+
|
| 56 |
+
avg_latency = sum(s.state.avg_latency for s in history) / total
|
| 57 |
+
|
| 58 |
+
low, high = 200.0, 600.0
|
| 59 |
+
if avg_latency <= low:
|
| 60 |
+
latency_factor = 1.0
|
| 61 |
+
elif avg_latency >= high:
|
| 62 |
+
latency_factor = 0.5
|
| 63 |
+
else:
|
| 64 |
+
latency_factor = 1.0 - 0.5 * (avg_latency - low) / (high - low)
|
| 65 |
+
|
| 66 |
+
return round(base * latency_factor, 4)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# ---------------------------------------------------------------------------
|
| 70 |
+
# Task Hard β Sustained Overload
|
| 71 |
+
# ---------------------------------------------------------------------------
|
| 72 |
+
|
| 73 |
+
def grade_task_hard(history: list[EpisodeStep]) -> float:
|
| 74 |
+
"""
|
| 75 |
+
Score = throughput_ratio * stability_bonus * queue_factor
|
| 76 |
+
|
| 77 |
+
throughput_ratio = sum(allowed) / sum(incoming) β maximize allowed traffic
|
| 78 |
+
stability_bonus = 1.0 if no crash, 0.0 if any crash
|
| 79 |
+
queue_factor = fraction of steps where queue_length < 100
|
| 80 |
+
"""
|
| 81 |
+
if not history:
|
| 82 |
+
return 0.0
|
| 83 |
+
|
| 84 |
+
total_incoming = sum(s.incoming_requests for s in history)
|
| 85 |
+
total_allowed = sum(s.allowed_requests for s in history)
|
| 86 |
+
|
| 87 |
+
if total_incoming == 0:
|
| 88 |
+
throughput_ratio = 0.0
|
| 89 |
+
else:
|
| 90 |
+
throughput_ratio = min(1.0, total_allowed / total_incoming)
|
| 91 |
+
|
| 92 |
+
crashed = any(s.crashed for s in history)
|
| 93 |
+
stability_bonus = 0.0 if crashed else 1.0
|
| 94 |
+
|
| 95 |
+
# Partial credit for keeping queue under control
|
| 96 |
+
low_queue_steps = sum(1 for s in history if s.state.queue_length < 100)
|
| 97 |
+
queue_factor = low_queue_steps / len(history)
|
| 98 |
+
|
| 99 |
+
# Combine: throughput matters most, stability is binary gate,
|
| 100 |
+
# queue is a tie-breaker bonus
|
| 101 |
+
if stability_bonus == 0.0:
|
| 102 |
+
# Still give partial credit for throughput management even with a crash
|
| 103 |
+
score = throughput_ratio * 0.3 * queue_factor
|
| 104 |
+
else:
|
| 105 |
+
score = throughput_ratio * 0.7 + queue_factor * 0.3
|
| 106 |
+
|
| 107 |
+
return round(min(1.0, max(0.0, score)), 4)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# ---------------------------------------------------------------------------
|
| 111 |
+
# Dispatcher
|
| 112 |
+
# ---------------------------------------------------------------------------
|
| 113 |
+
|
| 114 |
+
GRADERS = {
|
| 115 |
+
"task_easy": grade_task_easy,
|
| 116 |
+
"task_medium": grade_task_medium,
|
| 117 |
+
"task_hard": grade_task_hard,
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def grade(task_id: str, history: list[EpisodeStep]) -> float:
|
| 122 |
+
grader = GRADERS.get(task_id)
|
| 123 |
+
if grader is None:
|
| 124 |
+
raise ValueError(f"Unknown task_id: {task_id!r}")
|
| 125 |
+
return grader(history)
|
inference.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLM agent runner for the Adaptive Traffic Controller environment.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
API_BASE_URL=<url> MODEL_NAME=<model> HF_TOKEN=<token> python inference.py
|
| 6 |
+
|
| 7 |
+
Environment variables:
|
| 8 |
+
API_BASE_URL β OpenAI-compatible base URL (e.g. HuggingFace TGI endpoint)
|
| 9 |
+
MODEL_NAME β Model identifier
|
| 10 |
+
HF_TOKEN β API key / HuggingFace token
|
| 11 |
+
ENV_URL β (optional) Traffic controller environment URL, default http://localhost:7860
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import sys
|
| 18 |
+
import time
|
| 19 |
+
|
| 20 |
+
import httpx
|
| 21 |
+
from openai import OpenAI
|
| 22 |
+
|
| 23 |
+
# ---------------------------------------------------------------------------
|
| 24 |
+
# Configuration
|
| 25 |
+
# ---------------------------------------------------------------------------
|
| 26 |
+
|
| 27 |
+
API_BASE_URL: str = os.environ["API_BASE_URL"]
|
| 28 |
+
MODEL_NAME: str = os.environ["MODEL_NAME"]
|
| 29 |
+
HF_TOKEN: str = os.environ["HF_TOKEN"]
|
| 30 |
+
ENV_URL: str = os.environ.get("ENV_URL", "http://localhost:7860")
|
| 31 |
+
|
| 32 |
+
VALID_ACTIONS = {"allow_all", "throttle_70", "throttle_40", "drop_aggressive"}
|
| 33 |
+
DEFAULT_ACTION = "throttle_70"
|
| 34 |
+
MAX_RETRIES = 3
|
| 35 |
+
|
| 36 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
|
| 37 |
+
|
| 38 |
+
# ---------------------------------------------------------------------------
|
| 39 |
+
# Prompts
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
|
| 42 |
+
SYSTEM_PROMPT = """You are a backend traffic controller agent.
|
| 43 |
+
Your goal: prevent server crashes while maximizing throughput.
|
| 44 |
+
|
| 45 |
+
Server state fields:
|
| 46 |
+
cpu_usage β fraction 0.0β1.0 (danger above 0.8)
|
| 47 |
+
memory_usage β fraction 0.0β1.0 (danger above 0.8)
|
| 48 |
+
request_rate β incoming requests per second
|
| 49 |
+
queue_length β pending requests (danger above 200)
|
| 50 |
+
avg_latency β milliseconds (danger above 400ms)
|
| 51 |
+
|
| 52 |
+
Available actions (choose exactly one):
|
| 53 |
+
allow_all β accept 100% of requests (use when load is safe)
|
| 54 |
+
throttle_70 β accept 70%, drop 30% (use when load is moderate)
|
| 55 |
+
throttle_40 β accept 40%, drop 60% (use when load is high)
|
| 56 |
+
drop_aggressive β accept 20%, drop 80% (use when crash is imminent)
|
| 57 |
+
|
| 58 |
+
Decision heuristics:
|
| 59 |
+
- cpu < 0.6 AND latency < 200ms AND queue < 50 β allow_all
|
| 60 |
+
- cpu < 0.75 OR latency < 300ms β throttle_70
|
| 61 |
+
- cpu < 0.9 OR latency < 500ms OR queue < 150 β throttle_40
|
| 62 |
+
- otherwise β drop_aggressive
|
| 63 |
+
|
| 64 |
+
Respond with ONLY the action name, nothing else. No punctuation, no explanation."""
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _format_state(state: dict) -> str:
|
| 68 |
+
return (
|
| 69 |
+
f"cpu_usage={state['cpu_usage']:.3f} "
|
| 70 |
+
f"memory_usage={state['memory_usage']:.3f} "
|
| 71 |
+
f"request_rate={state['request_rate']:.1f} req/s "
|
| 72 |
+
f"queue_length={state['queue_length']} "
|
| 73 |
+
f"avg_latency={state['avg_latency']:.1f}ms "
|
| 74 |
+
f"step={state.get('step', '?')}"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# ---------------------------------------------------------------------------
|
| 79 |
+
# LLM interaction
|
| 80 |
+
# ---------------------------------------------------------------------------
|
| 81 |
+
|
| 82 |
+
def get_action(state: dict) -> str:
|
| 83 |
+
"""Query the LLM for a throttling action given the current server state."""
|
| 84 |
+
user_msg = f"Current server state: {_format_state(state)}\nChoose action:"
|
| 85 |
+
|
| 86 |
+
for attempt in range(1, MAX_RETRIES + 1):
|
| 87 |
+
try:
|
| 88 |
+
response = client.chat.completions.create(
|
| 89 |
+
model=MODEL_NAME,
|
| 90 |
+
messages=[
|
| 91 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 92 |
+
{"role": "user", "content": user_msg},
|
| 93 |
+
],
|
| 94 |
+
max_tokens=20,
|
| 95 |
+
temperature=0.0,
|
| 96 |
+
)
|
| 97 |
+
raw = response.choices[0].message.content.strip().lower()
|
| 98 |
+
# Normalise: strip punctuation, take first token
|
| 99 |
+
action = raw.split()[0].rstrip(".,;:!") if raw.split() else ""
|
| 100 |
+
if action in VALID_ACTIONS:
|
| 101 |
+
return action
|
| 102 |
+
print(f" [warn] LLM returned invalid action {raw!r}, attempt {attempt}/{MAX_RETRIES}")
|
| 103 |
+
except Exception as exc:
|
| 104 |
+
print(f" [warn] LLM call failed ({exc}), attempt {attempt}/{MAX_RETRIES}")
|
| 105 |
+
time.sleep(1)
|
| 106 |
+
|
| 107 |
+
print(f" [warn] falling back to default action: {DEFAULT_ACTION}")
|
| 108 |
+
return DEFAULT_ACTION
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# ---------------------------------------------------------------------------
|
| 112 |
+
# Episode runner
|
| 113 |
+
# ---------------------------------------------------------------------------
|
| 114 |
+
|
| 115 |
+
def run_task(task_id: str, env_url: str) -> float:
|
| 116 |
+
"""Run one full episode for task_id and return the final graded score."""
|
| 117 |
+
http = httpx.Client(base_url=env_url, timeout=30.0)
|
| 118 |
+
|
| 119 |
+
# Reset environment
|
| 120 |
+
reset_resp = http.post("/reset", json={"task_id": task_id})
|
| 121 |
+
reset_resp.raise_for_status()
|
| 122 |
+
data = reset_resp.json()
|
| 123 |
+
state = data["state"]
|
| 124 |
+
max_steps = data["max_steps"]
|
| 125 |
+
|
| 126 |
+
print(f" Starting {task_id} (max_steps={max_steps})")
|
| 127 |
+
|
| 128 |
+
total_reward = 0.0
|
| 129 |
+
final_score = 0.0
|
| 130 |
+
step = 0
|
| 131 |
+
|
| 132 |
+
while True:
|
| 133 |
+
action = get_action(state)
|
| 134 |
+
step_resp = http.post("/step", json={"action": action})
|
| 135 |
+
step_resp.raise_for_status()
|
| 136 |
+
result = step_resp.json()
|
| 137 |
+
|
| 138 |
+
state = result["state"]
|
| 139 |
+
reward = result["reward"]
|
| 140 |
+
done = result["done"]
|
| 141 |
+
info = result["info"]
|
| 142 |
+
|
| 143 |
+
total_reward += reward
|
| 144 |
+
step += 1
|
| 145 |
+
|
| 146 |
+
crashed = info.get("crashed", False)
|
| 147 |
+
print(
|
| 148 |
+
f" step={step:3d} action={action:<18s} "
|
| 149 |
+
f"reward={reward:+.3f} latency={state['avg_latency']:6.1f}ms "
|
| 150 |
+
f"queue={state['queue_length']:4d} cpu={state['cpu_usage']:.2f}"
|
| 151 |
+
+ (" [CRASH]" if crashed else "")
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
if done:
|
| 155 |
+
final_score = info.get("final_score", 0.0)
|
| 156 |
+
break
|
| 157 |
+
|
| 158 |
+
print(f" {task_id} done β total_reward={total_reward:.3f}, score={final_score:.3f}")
|
| 159 |
+
http.close()
|
| 160 |
+
return final_score
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# ---------------------------------------------------------------------------
|
| 164 |
+
# Entry point
|
| 165 |
+
# ---------------------------------------------------------------------------
|
| 166 |
+
|
| 167 |
+
def main() -> None:
|
| 168 |
+
env_url = ENV_URL
|
| 169 |
+
print(f"Environment URL : {env_url}")
|
| 170 |
+
print(f"Model : {MODEL_NAME}")
|
| 171 |
+
print(f"API base : {API_BASE_URL}")
|
| 172 |
+
print()
|
| 173 |
+
|
| 174 |
+
# Quick health check
|
| 175 |
+
try:
|
| 176 |
+
resp = httpx.get(f"{env_url}/health", timeout=10.0)
|
| 177 |
+
resp.raise_for_status()
|
| 178 |
+
print("Health check OK\n")
|
| 179 |
+
except Exception as exc:
|
| 180 |
+
print(f"[ERROR] Environment not reachable at {env_url}: {exc}")
|
| 181 |
+
sys.exit(1)
|
| 182 |
+
|
| 183 |
+
results: dict[str, float] = {}
|
| 184 |
+
for task_id in ["task_easy", "task_medium", "task_hard"]:
|
| 185 |
+
print(f"=== {task_id.upper()} ===")
|
| 186 |
+
score = run_task(task_id, env_url)
|
| 187 |
+
results[task_id] = score
|
| 188 |
+
print()
|
| 189 |
+
|
| 190 |
+
print("=== RESULTS ===")
|
| 191 |
+
for task_id, score in results.items():
|
| 192 |
+
print(f" {task_id:<15s}: {score:.3f}")
|
| 193 |
+
overall = sum(results.values()) / len(results)
|
| 194 |
+
print(f" {'Overall':<15s}: {overall:.3f}")
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
if __name__ == "__main__":
|
| 198 |
+
main()
|
models.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from enum import Enum
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
from pydantic import BaseModel, Field
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Action(str, Enum):
|
| 10 |
+
allow_all = "allow_all"
|
| 11 |
+
throttle_70 = "throttle_70"
|
| 12 |
+
throttle_40 = "throttle_40"
|
| 13 |
+
drop_aggressive = "drop_aggressive"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
ACTION_ACCEPT_RATE: dict[Action, float] = {
|
| 17 |
+
Action.allow_all: 1.0,
|
| 18 |
+
Action.throttle_70: 0.7,
|
| 19 |
+
Action.throttle_40: 0.4,
|
| 20 |
+
Action.drop_aggressive: 0.2,
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ServerState(BaseModel):
|
| 25 |
+
cpu_usage: float = Field(..., ge=0.0, le=1.0, description="CPU utilization 0β1")
|
| 26 |
+
memory_usage: float = Field(..., ge=0.0, le=1.0, description="Memory utilization 0β1")
|
| 27 |
+
request_rate: float = Field(..., ge=0.0, description="Incoming requests per second")
|
| 28 |
+
queue_length: int = Field(..., ge=0, description="Pending requests in queue")
|
| 29 |
+
avg_latency: float = Field(..., ge=0.0, description="Average latency in milliseconds")
|
| 30 |
+
step: int = Field(default=0, description="Current episode step")
|
| 31 |
+
crashed: bool = Field(default=False, description="Whether server has crashed")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ResetRequest(BaseModel):
|
| 35 |
+
task_id: str = Field(default="task_easy", description="Task to run")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class ResetResponse(BaseModel):
|
| 39 |
+
state: ServerState
|
| 40 |
+
task_id: str
|
| 41 |
+
max_steps: int
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class StepRequest(BaseModel):
|
| 45 |
+
action: Action
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class StepResponse(BaseModel):
|
| 49 |
+
state: ServerState
|
| 50 |
+
reward: float
|
| 51 |
+
done: bool
|
| 52 |
+
info: dict[str, Any] = Field(default_factory=dict)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class EpisodeStep(BaseModel):
|
| 56 |
+
step: int
|
| 57 |
+
state: ServerState
|
| 58 |
+
action: Action
|
| 59 |
+
reward: float
|
| 60 |
+
incoming_requests: float
|
| 61 |
+
allowed_requests: float
|
| 62 |
+
crashed: bool
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class TaskInfo(BaseModel):
|
| 66 |
+
id: str
|
| 67 |
+
description: str
|
| 68 |
+
episode_length: int
|
| 69 |
+
difficulty: str
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class TaskListResponse(BaseModel):
|
| 73 |
+
tasks: list[TaskInfo]
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class HealthResponse(BaseModel):
|
| 77 |
+
status: str = "ok"
|
openenv.yaml
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: adaptive-traffic-controller
|
| 2 |
+
version: "1.0.0"
|
| 3 |
+
description: >
|
| 4 |
+
LLM agent controls backend traffic throttling to prevent server crashes.
|
| 5 |
+
The agent observes real-time server metrics and chooses a throttling action
|
| 6 |
+
each step to keep CPU, memory, and latency within safe bounds.
|
| 7 |
+
|
| 8 |
+
observation_space:
|
| 9 |
+
cpu_usage:
|
| 10 |
+
type: float
|
| 11 |
+
range: [0.0, 1.0]
|
| 12 |
+
description: CPU utilization as a fraction of total capacity
|
| 13 |
+
memory_usage:
|
| 14 |
+
type: float
|
| 15 |
+
range: [0.0, 1.0]
|
| 16 |
+
description: Memory utilization as a fraction of total capacity
|
| 17 |
+
request_rate:
|
| 18 |
+
type: float
|
| 19 |
+
unit: requests/sec
|
| 20 |
+
description: Current incoming request rate
|
| 21 |
+
queue_length:
|
| 22 |
+
type: int
|
| 23 |
+
range: [0, 500]
|
| 24 |
+
description: Number of pending requests waiting to be processed
|
| 25 |
+
avg_latency:
|
| 26 |
+
type: float
|
| 27 |
+
unit: milliseconds
|
| 28 |
+
description: Average response latency for processed requests
|
| 29 |
+
step:
|
| 30 |
+
type: int
|
| 31 |
+
description: Current step index within the episode
|
| 32 |
+
crashed:
|
| 33 |
+
type: bool
|
| 34 |
+
description: Whether the server has crashed this step
|
| 35 |
+
|
| 36 |
+
action_space:
|
| 37 |
+
type: discrete
|
| 38 |
+
actions:
|
| 39 |
+
- id: allow_all
|
| 40 |
+
accept_rate: 1.0
|
| 41 |
+
description: Accept 100% of incoming requests
|
| 42 |
+
- id: throttle_70
|
| 43 |
+
accept_rate: 0.7
|
| 44 |
+
description: Accept 70%, drop 30% of incoming requests
|
| 45 |
+
- id: throttle_40
|
| 46 |
+
accept_rate: 0.4
|
| 47 |
+
description: Accept 40%, drop 60% of incoming requests
|
| 48 |
+
- id: drop_aggressive
|
| 49 |
+
accept_rate: 0.2
|
| 50 |
+
description: Accept 20%, drop 80% of incoming requests
|
| 51 |
+
|
| 52 |
+
tasks:
|
| 53 |
+
- id: task_easy
|
| 54 |
+
difficulty: easy
|
| 55 |
+
episode_length: 30
|
| 56 |
+
description: >
|
| 57 |
+
Single spike: baseline 40 req/s, spike to 160 req/s at step 10 for
|
| 58 |
+
5 steps, return to 40. Agent must detect spike, throttle, and recover.
|
| 59 |
+
grading:
|
| 60 |
+
full_score: "no crash AND avg_latency < 300ms"
|
| 61 |
+
partial_score: "no crash but avg_latency >= 300ms β 0.5"
|
| 62 |
+
zero_score: "any crash β 0.0"
|
| 63 |
+
|
| 64 |
+
- id: task_medium
|
| 65 |
+
difficulty: medium
|
| 66 |
+
episode_length: 40
|
| 67 |
+
description: >
|
| 68 |
+
Three traffic spikes of 150 req/s at steps 5, 15, 25 (3 steps each),
|
| 69 |
+
baseline 50 req/s. Agent must handle repeated bursts.
|
| 70 |
+
grading:
|
| 71 |
+
formula: "score = (steps_without_crash / total_steps) * latency_factor"
|
| 72 |
+
latency_factor: "1.0 at <=200ms, 0.5 at >=600ms, linear between"
|
| 73 |
+
|
| 74 |
+
- id: task_hard
|
| 75 |
+
difficulty: hard
|
| 76 |
+
episode_length: 50
|
| 77 |
+
description: >
|
| 78 |
+
Sustained overload: traffic ramps 60β200 req/s over 20 steps, stays
|
| 79 |
+
at 200 for 20 steps, then drops to 80. Agent must balance throughput
|
| 80 |
+
vs. stability under prolonged high load.
|
| 81 |
+
grading:
|
| 82 |
+
formula: "score = throughput_ratio * 0.7 + queue_factor * 0.3"
|
| 83 |
+
throughput_ratio: "total_allowed / total_incoming"
|
| 84 |
+
stability_bonus: "crash zeroes out primary score (partial credit * 0.3)"
|
| 85 |
+
queue_factor: "fraction of steps with queue_length < 100"
|
| 86 |
+
|
| 87 |
+
endpoints:
|
| 88 |
+
reset:
|
| 89 |
+
method: POST
|
| 90 |
+
path: /reset
|
| 91 |
+
description: Reset environment, returns initial state
|
| 92 |
+
step:
|
| 93 |
+
method: POST
|
| 94 |
+
path: /step
|
| 95 |
+
description: Execute action, returns next state, reward, done flag, and info
|
| 96 |
+
state:
|
| 97 |
+
method: GET
|
| 98 |
+
path: /state
|
| 99 |
+
description: Get current server state
|
| 100 |
+
tasks:
|
| 101 |
+
method: GET
|
| 102 |
+
path: /tasks
|
| 103 |
+
description: List all available tasks
|
| 104 |
+
spec:
|
| 105 |
+
method: GET
|
| 106 |
+
path: /openenv.yaml
|
| 107 |
+
description: This OpenEnv specification file
|
| 108 |
+
health:
|
| 109 |
+
method: GET
|
| 110 |
+
path: /health
|
| 111 |
+
description: Liveness probe
|
| 112 |
+
|
| 113 |
+
infrastructure:
|
| 114 |
+
port: 7860
|
| 115 |
+
cpu: 2
|
| 116 |
+
memory_gb: 8
|
| 117 |
+
gpu_required: false
|
| 118 |
+
max_inference_minutes: 20
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.111.0
|
| 2 |
+
uvicorn[standard]>=0.29.0
|
| 3 |
+
pydantic>=2.7.0
|
| 4 |
+
openai>=1.30.0
|
| 5 |
+
httpx>=0.27.0
|
| 6 |
+
numpy>=1.26.0
|
| 7 |
+
pyyaml>=6.0.1
|
simulator.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Backend simulation math β models how a real server responds to load."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from models import ServerState
|
| 6 |
+
|
| 7 |
+
MAX_CAPACITY = 100.0 # requests/sec the backend can handle at full health
|
| 8 |
+
BASE_LATENCY = 50.0 # milliseconds at zero load
|
| 9 |
+
MAX_QUEUE = 500
|
| 10 |
+
CRASH_LOAD_RATIO = 1.3 # server crashes when 30% or more over capacity
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def compute_next_state(
|
| 14 |
+
current_state: ServerState,
|
| 15 |
+
allowed_requests: float,
|
| 16 |
+
incoming_requests: float,
|
| 17 |
+
) -> tuple[ServerState, bool]:
|
| 18 |
+
"""
|
| 19 |
+
Compute the next server state after one time step.
|
| 20 |
+
|
| 21 |
+
Returns (next_state, crashed).
|
| 22 |
+
|
| 23 |
+
The environment exposes the *upcoming* request_rate in the observation so
|
| 24 |
+
the agent can react before overload happens (see environment.py).
|
| 25 |
+
Crash fires when allowed traffic exceeds 130% of capacity in a single step.
|
| 26 |
+
"""
|
| 27 |
+
load_ratio = allowed_requests / MAX_CAPACITY
|
| 28 |
+
|
| 29 |
+
# Latency spikes superlinearly under load
|
| 30 |
+
if load_ratio <= 1.0:
|
| 31 |
+
latency = BASE_LATENCY * (1.0 + load_ratio ** 2)
|
| 32 |
+
else:
|
| 33 |
+
latency = BASE_LATENCY * (1.0 + load_ratio ** 3) # exponential degradation
|
| 34 |
+
|
| 35 |
+
# Queue builds when allowed requests exceed capacity
|
| 36 |
+
queue_delta = max(0.0, allowed_requests - MAX_CAPACITY)
|
| 37 |
+
# Queue drains when load is under capacity (servers catch up)
|
| 38 |
+
queue_drain = max(0.0, (MAX_CAPACITY - allowed_requests) * 0.3)
|
| 39 |
+
new_queue = current_state.queue_length + queue_delta - queue_drain
|
| 40 |
+
queue_length = int(min(MAX_QUEUE, max(0.0, new_queue)))
|
| 41 |
+
|
| 42 |
+
# Crash if load exceeds 130% of capacity
|
| 43 |
+
crashed = load_ratio > CRASH_LOAD_RATIO
|
| 44 |
+
|
| 45 |
+
# Latency grows with queue backlog
|
| 46 |
+
latency += queue_length * 0.5
|
| 47 |
+
|
| 48 |
+
# CPU and memory track load
|
| 49 |
+
cpu = min(1.0, 0.3 + load_ratio * 0.6)
|
| 50 |
+
memory = min(1.0, 0.2 + load_ratio * 0.4)
|
| 51 |
+
|
| 52 |
+
next_state = ServerState(
|
| 53 |
+
cpu_usage=round(cpu, 4),
|
| 54 |
+
memory_usage=round(memory, 4),
|
| 55 |
+
request_rate=round(incoming_requests, 2),
|
| 56 |
+
queue_length=queue_length,
|
| 57 |
+
avg_latency=round(latency, 2),
|
| 58 |
+
step=current_state.step + 1,
|
| 59 |
+
crashed=crashed,
|
| 60 |
+
)
|
| 61 |
+
return next_state, crashed
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def initial_state(incoming_requests: float = 40.0) -> ServerState:
|
| 65 |
+
"""Return a clean initial server state."""
|
| 66 |
+
load_ratio = incoming_requests / MAX_CAPACITY
|
| 67 |
+
latency = BASE_LATENCY * (1.0 + load_ratio ** 2)
|
| 68 |
+
cpu = min(1.0, 0.3 + load_ratio * 0.6)
|
| 69 |
+
memory = min(1.0, 0.2 + load_ratio * 0.4)
|
| 70 |
+
return ServerState(
|
| 71 |
+
cpu_usage=round(cpu, 4),
|
| 72 |
+
memory_usage=round(memory, 4),
|
| 73 |
+
request_rate=round(incoming_requests, 2),
|
| 74 |
+
queue_length=0,
|
| 75 |
+
avg_latency=round(latency, 2),
|
| 76 |
+
step=0,
|
| 77 |
+
crashed=False,
|
| 78 |
+
)
|
tasks.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Task definitions β each describes a traffic pattern and episode parameters."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from models import TaskInfo
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# ---------------------------------------------------------------------------
|
| 9 |
+
# Traffic pattern generators
|
| 10 |
+
# Each returns incoming request rate (req/s) for a given step index (0-based).
|
| 11 |
+
# ---------------------------------------------------------------------------
|
| 12 |
+
|
| 13 |
+
def traffic_easy(step: int) -> float:
|
| 14 |
+
"""
|
| 15 |
+
Task Easy β Single Spike
|
| 16 |
+
Baseline 40 req/s, spike to 160 at step 10 for 5 steps, back to 40.
|
| 17 |
+
"""
|
| 18 |
+
if 10 <= step < 15:
|
| 19 |
+
return 160.0
|
| 20 |
+
return 40.0
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def traffic_medium(step: int) -> float:
|
| 24 |
+
"""
|
| 25 |
+
Task Medium β Multiple Spikes
|
| 26 |
+
Baseline 50 req/s, spikes of 150 req/s at steps 5β7, 15β17, 25β27.
|
| 27 |
+
"""
|
| 28 |
+
if 5 <= step < 8:
|
| 29 |
+
return 150.0
|
| 30 |
+
if 15 <= step < 18:
|
| 31 |
+
return 150.0
|
| 32 |
+
if 25 <= step < 28:
|
| 33 |
+
return 150.0
|
| 34 |
+
return 50.0
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def traffic_hard(step: int) -> float:
|
| 38 |
+
"""
|
| 39 |
+
Task Hard β Sustained Overload
|
| 40 |
+
Ramps from 60 β 200 req/s over 20 steps, stays at 200 for 20 more steps,
|
| 41 |
+
then drops back to 80 for the final 10 steps.
|
| 42 |
+
"""
|
| 43 |
+
if step < 20:
|
| 44 |
+
# linear ramp 60 β 200
|
| 45 |
+
return 60.0 + (200.0 - 60.0) * (step / 19.0)
|
| 46 |
+
if step < 40:
|
| 47 |
+
return 200.0
|
| 48 |
+
return 80.0
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
TRAFFIC_PATTERNS: dict[str, callable] = {
|
| 52 |
+
"task_easy": traffic_easy,
|
| 53 |
+
"task_medium": traffic_medium,
|
| 54 |
+
"task_hard": traffic_hard,
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
EPISODE_LENGTHS: dict[str, int] = {
|
| 58 |
+
"task_easy": 30,
|
| 59 |
+
"task_medium": 40,
|
| 60 |
+
"task_hard": 50,
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
TASK_METADATA: list[TaskInfo] = [
|
| 64 |
+
TaskInfo(
|
| 65 |
+
id="task_easy",
|
| 66 |
+
description=(
|
| 67 |
+
"Single traffic spike: baseline 40 req/s rising to 160 req/s at step 10 "
|
| 68 |
+
"for 5 steps, then back to 40. Agent must detect and throttle the spike "
|
| 69 |
+
"without crashing the server."
|
| 70 |
+
),
|
| 71 |
+
episode_length=30,
|
| 72 |
+
difficulty="easy",
|
| 73 |
+
),
|
| 74 |
+
TaskInfo(
|
| 75 |
+
id="task_medium",
|
| 76 |
+
description=(
|
| 77 |
+
"Three traffic spikes of 150 req/s at steps 5, 15, and 25 (3 steps each), "
|
| 78 |
+
"baseline 50 req/s. Agent must handle repeated bursts while maintaining "
|
| 79 |
+
"throughput between spikes."
|
| 80 |
+
),
|
| 81 |
+
episode_length=40,
|
| 82 |
+
difficulty="medium",
|
| 83 |
+
),
|
| 84 |
+
TaskInfo(
|
| 85 |
+
id="task_hard",
|
| 86 |
+
description=(
|
| 87 |
+
"Sustained overload: traffic ramps from 60 β 200 req/s over 20 steps, "
|
| 88 |
+
"stays at 200 for 20 more steps, then drops to 80. Agent must balance "
|
| 89 |
+
"throughput vs. stability under prolonged high load."
|
| 90 |
+
),
|
| 91 |
+
episode_length=50,
|
| 92 |
+
difficulty="hard",
|
| 93 |
+
),
|
| 94 |
+
]
|