aegis-env / server /benchmark.py
NishithP2004's picture
Upload folder using huggingface_hub
fa01cfa verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Benchmark helpers: list OpenAI-compatible `/v1/models`, run episodes per model."""
from __future__ import annotations
import json
import urllib.error
import urllib.request
from typing import Any, Dict, List, Optional
try:
from inference import (
_action_log_str,
_get_action_with_retry,
_one_line,
build_user_prompt,
)
from models import AegisObservation
except ImportError: # pragma: no cover — allow `python -m server.app` from package subdir
import sys
from pathlib import Path
_root = Path(__file__).resolve().parents[1]
if str(_root) not in sys.path:
sys.path.insert(0, str(_root))
from inference import (
_action_log_str,
_get_action_with_retry,
_one_line,
build_user_prompt,
)
from models import AegisObservation
TEMPERATURE = 0.2
MAX_TOKENS = 4096
def fetch_model_ids(api_root: str, timeout_s: float = 45.0) -> List[str]:
"""
GET {api_root}/models — OpenAI-compatible listing (Ollama exposes this at /v1/models).
"""
root = api_root.strip().rstrip("/")
url = root if root.endswith("/models") else f"{root}/models"
req = urllib.request.Request(url, headers={"Accept": "application/json", "User-Agent": "aegis-env-benchmark/1.0"})
try:
with urllib.request.urlopen(req, timeout=timeout_s) as resp:
payload = json.loads(resp.read().decode("utf-8"))
except urllib.error.HTTPError as e:
body = e.read().decode("utf-8", errors="replace") if e.fp else ""
raise RuntimeError(f"HTTP {e.code} listing models from {url}: {body or e.reason}") from e
except urllib.error.URLError as e:
raise RuntimeError(f"Failed to reach {url}: {e!s}") from e
ids: List[str] = []
for item in payload.get("data") or []:
if isinstance(item, dict):
mid = item.get("id") or item.get("name")
if isinstance(mid, str) and mid.strip():
ids.append(mid.strip())
# Native Ollama `/api/tags` shape (optional fallback)
if not ids:
for item in payload.get("models") or []:
if isinstance(item, dict):
mid = item.get("name") or item.get("model")
if isinstance(mid, str) and mid.strip():
ids.append(mid.strip())
return sorted(set(ids))
def run_single_model_episode(
env: Any,
llm: OpenAI,
model: str,
task_name: str,
max_steps: int,
episode_seed: Optional[int],
) -> Dict[str, Any]:
"""
One grading episode: only `model` changes vs other runs (same env instance, reset between models).
"""
rewards: List[float] = []
history: List[str] = []
last_action: Optional[str] = None
last_reward = 0.0
obs: AegisObservation
try:
obs = env.reset(seed=episode_seed, task_name=task_name)
for step in range(1, max_steps + 1):
prompt = build_user_prompt(step, last_action, last_reward, history, obs)
try:
action, _text = _get_action_with_retry(
llm,
model,
prompt,
TEMPERATURE,
MAX_TOKENS,
float(obs.max_score) if obs.max_score else 1.0,
llm_enabled=True,
)
except Exception as e:
rewards.append(0.0)
history.append(f"step={step} parse_error={_one_line(str(e))}")
last_action = None
last_reward = 0.0
continue
out = env.step(action)
r = float(getattr(out, "reward", None) or 0.0)
rewards.append(r)
last_action = _action_log_str(action)
last_reward = r
history.append(f"step={step} action={last_action} reward={r:.2f}")
obs = out
if bool(getattr(out, "done", False)):
break
return {
"model": model,
"rewards": rewards,
"total_reward": float(sum(rewards)),
"steps": len(rewards),
"final_done": bool(getattr(obs, "done", False)),
"error": None,
}
except Exception as e:
return {
"model": model,
"rewards": rewards,
"total_reward": float(sum(rewards)),
"steps": len(rewards),
"final_done": False,
"error": f"{type(e).__name__}: {e}",
}