File size: 4,696 Bytes
fa01cfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Benchmark helpers: list OpenAI-compatible `/v1/models`, run episodes per model."""

from __future__ import annotations

import json
import urllib.error
import urllib.request
from typing import Any, Dict, List, Optional

try:
    from inference import (
        _action_log_str,
        _get_action_with_retry,
        _one_line,
        build_user_prompt,
    )
    from models import AegisObservation
except ImportError:  # pragma: no cover — allow `python -m server.app` from package subdir
    import sys
    from pathlib import Path

    _root = Path(__file__).resolve().parents[1]
    if str(_root) not in sys.path:
        sys.path.insert(0, str(_root))
    from inference import (
        _action_log_str,
        _get_action_with_retry,
        _one_line,
        build_user_prompt,
    )
    from models import AegisObservation

TEMPERATURE = 0.2
MAX_TOKENS = 4096


def fetch_model_ids(api_root: str, timeout_s: float = 45.0) -> List[str]:
    """
    GET {api_root}/models — OpenAI-compatible listing (Ollama exposes this at /v1/models).
    """
    root = api_root.strip().rstrip("/")
    url = root if root.endswith("/models") else f"{root}/models"
    req = urllib.request.Request(url, headers={"Accept": "application/json", "User-Agent": "aegis-env-benchmark/1.0"})
    try:
        with urllib.request.urlopen(req, timeout=timeout_s) as resp:
            payload = json.loads(resp.read().decode("utf-8"))
    except urllib.error.HTTPError as e:
        body = e.read().decode("utf-8", errors="replace") if e.fp else ""
        raise RuntimeError(f"HTTP {e.code} listing models from {url}: {body or e.reason}") from e
    except urllib.error.URLError as e:
        raise RuntimeError(f"Failed to reach {url}: {e!s}") from e

    ids: List[str] = []
    for item in payload.get("data") or []:
        if isinstance(item, dict):
            mid = item.get("id") or item.get("name")
            if isinstance(mid, str) and mid.strip():
                ids.append(mid.strip())
    # Native Ollama `/api/tags` shape (optional fallback)
    if not ids:
        for item in payload.get("models") or []:
            if isinstance(item, dict):
                mid = item.get("name") or item.get("model")
                if isinstance(mid, str) and mid.strip():
                    ids.append(mid.strip())
    return sorted(set(ids))


def run_single_model_episode(
    env: Any,
    llm: OpenAI,
    model: str,
    task_name: str,
    max_steps: int,
    episode_seed: Optional[int],
) -> Dict[str, Any]:
    """
    One grading episode: only `model` changes vs other runs (same env instance, reset between models).
    """
    rewards: List[float] = []
    history: List[str] = []
    last_action: Optional[str] = None
    last_reward = 0.0
    obs: AegisObservation
    try:
        obs = env.reset(seed=episode_seed, task_name=task_name)
        for step in range(1, max_steps + 1):
            prompt = build_user_prompt(step, last_action, last_reward, history, obs)
            try:
                action, _text = _get_action_with_retry(
                    llm,
                    model,
                    prompt,
                    TEMPERATURE,
                    MAX_TOKENS,
                    float(obs.max_score) if obs.max_score else 1.0,
                    llm_enabled=True,
                )
            except Exception as e:
                rewards.append(0.0)
                history.append(f"step={step} parse_error={_one_line(str(e))}")
                last_action = None
                last_reward = 0.0
                continue

            out = env.step(action)
            r = float(getattr(out, "reward", None) or 0.0)
            rewards.append(r)
            last_action = _action_log_str(action)
            last_reward = r
            history.append(f"step={step} action={last_action} reward={r:.2f}")
            obs = out
            if bool(getattr(out, "done", False)):
                break

        return {
            "model": model,
            "rewards": rewards,
            "total_reward": float(sum(rewards)),
            "steps": len(rewards),
            "final_done": bool(getattr(obs, "done", False)),
            "error": None,
        }
    except Exception as e:
        return {
            "model": model,
            "rewards": rewards,
            "total_reward": float(sum(rewards)),
            "steps": len(rewards),
            "final_done": False,
            "error": f"{type(e).__name__}: {e}",
        }