| """ |
| LLM agent runner for the Adaptive Traffic Controller environment. |
| |
| Usage: |
| API_BASE_URL=<url> MODEL_NAME=<model> HF_TOKEN=<token> python inference.py |
| |
| Environment variables: |
| API_BASE_URL β OpenAI-compatible base URL (e.g. HuggingFace TGI endpoint) |
| MODEL_NAME β Model identifier |
| HF_TOKEN β API key / HuggingFace token |
| ENV_URL β (optional) Traffic controller environment URL, default http://localhost:7860 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| import sys |
| import time |
|
|
| import httpx |
| from openai import OpenAI |
|
|
| |
| |
| |
|
|
| API_BASE_URL: str = os.getenv("API_BASE_URL", "https://api-inference.huggingface.co/v1") |
| MODEL_NAME: str = os.getenv("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct") |
| HF_TOKEN: str = os.environ["HF_TOKEN"] |
| LOCAL_IMAGE_NAME: str | None = os.getenv("LOCAL_IMAGE_NAME") |
| ENV_URL: str = os.environ.get("ENV_URL", "http://localhost:7860") |
|
|
| VALID_ACTIONS = {"allow_all", "throttle_70", "throttle_40", "drop_aggressive"} |
| DEFAULT_ACTION = "throttle_70" |
| MAX_RETRIES = 3 |
|
|
| client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN) |
|
|
| |
| |
| |
|
|
| SYSTEM_PROMPT_TEMPLATE = """You are a backend traffic controller agent. |
| Your goal: prevent server crashes while maximizing throughput. |
| Server capacity: {capacity} req/s. |
| |
| Server state fields: |
| cpu_usage β fraction 0.0β1.0 (danger above 0.8) |
| memory_usage β fraction 0.0β1.0 (danger above 0.8) |
| request_rate β incoming requests per second (compare against capacity) |
| queue_length β pending requests (danger above 200) |
| avg_latency β milliseconds (danger above 400ms) |
| |
| Available actions (choose exactly one): |
| allow_all β accept 100% of requests (use when load is safe) |
| throttle_70 β accept 70%, drop 30% (use when load is moderate) |
| throttle_40 β accept 40%, drop 60% (use when load is high) |
| drop_aggressive β accept 20%, drop 80% (use when crash is imminent) |
| |
| Decision heuristics (relative to server capacity): |
| - request_rate < 70% capacity AND cpu < 0.6 AND latency < 200ms β allow_all |
| - request_rate < 100% capacity β throttle_70 |
| - request_rate < 130% capacity β throttle_40 |
| - otherwise β drop_aggressive |
| |
| Respond with ONLY the action name, nothing else. No punctuation, no explanation.""" |
|
|
|
|
| def _format_state(state: dict) -> str: |
| return ( |
| f"cpu_usage={state['cpu_usage']:.3f} " |
| f"memory_usage={state['memory_usage']:.3f} " |
| f"request_rate={state['request_rate']:.1f} req/s " |
| f"queue_length={state['queue_length']} " |
| f"avg_latency={state['avg_latency']:.1f}ms " |
| f"step={state.get('step', '?')}" |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def get_action(state: dict, system_prompt: str) -> str: |
| """Query the LLM for a throttling action given the current server state.""" |
| user_msg = f"Current server state: {_format_state(state)}\nChoose action:" |
|
|
| for attempt in range(1, MAX_RETRIES + 1): |
| try: |
| response = client.chat.completions.create( |
| model=MODEL_NAME, |
| messages=[ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": user_msg}, |
| ], |
| max_tokens=20, |
| temperature=0.0, |
| ) |
| raw = response.choices[0].message.content.strip().lower() |
| |
| action = raw.split()[0].rstrip(".,;:!") if raw.split() else "" |
| if action in VALID_ACTIONS: |
| return action |
| print(f" [warn] LLM returned invalid action {raw!r}, attempt {attempt}/{MAX_RETRIES}") |
| except Exception as exc: |
| print(f" [warn] LLM call failed ({exc}), attempt {attempt}/{MAX_RETRIES}") |
| time.sleep(1) |
|
|
| print(f" [warn] falling back to default action: {DEFAULT_ACTION}") |
| return DEFAULT_ACTION |
|
|
|
|
| |
| |
| |
|
|
| def run_task(task_id: str, env_url: str) -> float: |
| """Run one full episode for task_id and return the final graded score.""" |
| http = httpx.Client(base_url=env_url, timeout=30.0) |
|
|
| |
| reset_resp = http.post("/reset", json={"task_id": task_id}) |
| reset_resp.raise_for_status() |
| data = reset_resp.json() |
| state = data["state"] |
| max_steps = data["max_steps"] |
|
|
| |
| capacity = data.get("config", {}).get("server_capacity", 100.0) |
| system_prompt = SYSTEM_PROMPT_TEMPLATE.format(capacity=capacity) |
|
|
| print(f"[START] task={task_id} max_steps={max_steps} model={MODEL_NAME}") |
|
|
| total_reward = 0.0 |
| final_score = 0.0 |
| step = 0 |
|
|
| while True: |
| action = get_action(state, system_prompt) |
| step_resp = http.post("/step", json={"action": action}) |
| step_resp.raise_for_status() |
| result = step_resp.json() |
|
|
| state = result["state"] |
| reward = result["reward"] |
| done = result["done"] |
| info = result["info"] |
|
|
| total_reward += reward |
| step += 1 |
|
|
| crashed = info.get("crashed", False) |
| print( |
| f"[STEP] task={task_id} step={step:3d} action={action:<18s} " |
| f"reward={reward:+.3f} latency={state['avg_latency']:6.1f}ms " |
| f"queue={state['queue_length']:4d} cpu={state['cpu_usage']:.2f}" |
| + (" crashed=true" if crashed else " crashed=false") |
| ) |
|
|
| if done: |
| final_score = info.get("final_score", 0.0) |
| break |
|
|
| print(f"[END] task={task_id} total_reward={total_reward:.3f} score={final_score:.3f}") |
| http.close() |
| return final_score |
|
|
|
|
| |
| |
| |
|
|
| def main() -> None: |
| env_url = ENV_URL |
| print(f"Environment URL : {env_url}") |
| print(f"Model : {MODEL_NAME}") |
| print(f"API base : {API_BASE_URL}") |
| print() |
|
|
| |
| try: |
| resp = httpx.get(f"{env_url}/health", timeout=10.0) |
| resp.raise_for_status() |
| print("Health check OK\n") |
| except Exception as exc: |
| print(f"[ERROR] Environment not reachable at {env_url}: {exc}") |
| sys.exit(1) |
|
|
| results: dict[str, float] = {} |
| for task_id in ["task_easy", "task_medium", "task_hard"]: |
| print(f"=== {task_id.upper()} ===") |
| score = run_task(task_id, env_url) |
| results[task_id] = score |
| print() |
|
|
| print("=== RESULTS ===") |
| for task_id, score in results.items(): |
| print(f" {task_id:<15s}: {score:.3f}") |
| overall = sum(results.values()) / len(results) |
| print(f" {'Overall':<15s}: {overall:.3f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|