Spaces:
Running
Running
File size: 4,696 Bytes
fa01cfa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Benchmark helpers: list OpenAI-compatible `/v1/models`, run episodes per model."""
from __future__ import annotations
import json
import urllib.error
import urllib.request
from typing import Any, Dict, List, Optional
try:
from inference import (
_action_log_str,
_get_action_with_retry,
_one_line,
build_user_prompt,
)
from models import AegisObservation
except ImportError: # pragma: no cover — allow `python -m server.app` from package subdir
import sys
from pathlib import Path
_root = Path(__file__).resolve().parents[1]
if str(_root) not in sys.path:
sys.path.insert(0, str(_root))
from inference import (
_action_log_str,
_get_action_with_retry,
_one_line,
build_user_prompt,
)
from models import AegisObservation
TEMPERATURE = 0.2
MAX_TOKENS = 4096
def fetch_model_ids(api_root: str, timeout_s: float = 45.0) -> List[str]:
"""
GET {api_root}/models — OpenAI-compatible listing (Ollama exposes this at /v1/models).
"""
root = api_root.strip().rstrip("/")
url = root if root.endswith("/models") else f"{root}/models"
req = urllib.request.Request(url, headers={"Accept": "application/json", "User-Agent": "aegis-env-benchmark/1.0"})
try:
with urllib.request.urlopen(req, timeout=timeout_s) as resp:
payload = json.loads(resp.read().decode("utf-8"))
except urllib.error.HTTPError as e:
body = e.read().decode("utf-8", errors="replace") if e.fp else ""
raise RuntimeError(f"HTTP {e.code} listing models from {url}: {body or e.reason}") from e
except urllib.error.URLError as e:
raise RuntimeError(f"Failed to reach {url}: {e!s}") from e
ids: List[str] = []
for item in payload.get("data") or []:
if isinstance(item, dict):
mid = item.get("id") or item.get("name")
if isinstance(mid, str) and mid.strip():
ids.append(mid.strip())
# Native Ollama `/api/tags` shape (optional fallback)
if not ids:
for item in payload.get("models") or []:
if isinstance(item, dict):
mid = item.get("name") or item.get("model")
if isinstance(mid, str) and mid.strip():
ids.append(mid.strip())
return sorted(set(ids))
def run_single_model_episode(
env: Any,
llm: OpenAI,
model: str,
task_name: str,
max_steps: int,
episode_seed: Optional[int],
) -> Dict[str, Any]:
"""
One grading episode: only `model` changes vs other runs (same env instance, reset between models).
"""
rewards: List[float] = []
history: List[str] = []
last_action: Optional[str] = None
last_reward = 0.0
obs: AegisObservation
try:
obs = env.reset(seed=episode_seed, task_name=task_name)
for step in range(1, max_steps + 1):
prompt = build_user_prompt(step, last_action, last_reward, history, obs)
try:
action, _text = _get_action_with_retry(
llm,
model,
prompt,
TEMPERATURE,
MAX_TOKENS,
float(obs.max_score) if obs.max_score else 1.0,
llm_enabled=True,
)
except Exception as e:
rewards.append(0.0)
history.append(f"step={step} parse_error={_one_line(str(e))}")
last_action = None
last_reward = 0.0
continue
out = env.step(action)
r = float(getattr(out, "reward", None) or 0.0)
rewards.append(r)
last_action = _action_log_str(action)
last_reward = r
history.append(f"step={step} action={last_action} reward={r:.2f}")
obs = out
if bool(getattr(out, "done", False)):
break
return {
"model": model,
"rewards": rewards,
"total_reward": float(sum(rewards)),
"steps": len(rewards),
"final_done": bool(getattr(obs, "done", False)),
"error": None,
}
except Exception as e:
return {
"model": model,
"rewards": rewards,
"total_reward": float(sum(rewards)),
"steps": len(rewards),
"final_done": False,
"error": f"{type(e).__name__}: {e}",
}
|