File size: 6,573 Bytes
b89c8aa
 
 
 
 
 
 
 
 
 
 
 
 
 
a7effbb
 
 
 
b89c8aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
"""Comprehensive multi-model test with HF PRO token.

Tests 4 models x 3 difficulties x 2 seeds = 24 episodes.
"""
from __future__ import annotations

import os
import sys
import time
from typing import Any

from dotenv import load_dotenv
load_dotenv(override=True)

# Add repo root so `import inference` (root-level module) resolves.
_REPO_ROOT = os.path.join(os.path.dirname(__file__), "..")
if _REPO_ROOT not in sys.path:
    sys.path.insert(0, _REPO_ROOT)

from openai import OpenAI
from inference import parse_action, serialize_observation, action_to_str, SYSTEM_PROMPT
from triagesieve_env.models import ActionType, TriageSieveAction
from triagesieve_env.server.triagesieve_env_environment import TriageSieveEnvironment

HF_TOKEN = os.getenv("HF_TOKEN")
BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")

MODELS = [
    "Qwen/Qwen2.5-72B-Instruct",
    "meta-llama/Llama-3.3-70B-Instruct",
    "Qwen/Qwen2.5-7B-Instruct",
    "meta-llama/Llama-3.1-8B-Instruct",
]

CONFIGS = [
    {"difficulty": "easy",   "seed": 42, "max_steps": 8},
    {"difficulty": "easy",   "seed": 7,  "max_steps": 8},
    {"difficulty": "medium", "seed": 42, "max_steps": 14},
    {"difficulty": "medium", "seed": 2,  "max_steps": 14},
    {"difficulty": "hard",   "seed": 42, "max_steps": 20},
    {"difficulty": "hard",   "seed": 1,  "max_steps": 20},
]


def run_episode(client: OpenAI, model_name: str, seed: int, difficulty: str, max_steps: int) -> dict[str, Any]:
    env = TriageSieveEnvironment()
    obs = env.reset(seed=seed, difficulty=difficulty, mode="eval_strict")

    steps: list[dict[str, Any]] = []
    last_reward = 0.0
    episode_done = False

    for step_num in range(1, max_steps + 1):
        if episode_done or obs.action_budget_remaining <= 0:
            break

        obs_text = serialize_observation(obs)
        user_content = f"Step {step_num} | Last reward: {last_reward:.2f}\n\n{obs_text}"

        try:
            r = client.chat.completions.create(
                model=model_name,
                messages=[
                    {"role": "system", "content": SYSTEM_PROMPT},
                    {"role": "user", "content": user_content},
                ],
                temperature=0.0,
                max_tokens=512,
            )
            raw = (r.choices[0].message.content or "").strip()
        except Exception as exc:
            raw = ""

        action = parse_action(raw)
        parsed = action is not None
        if action is None:
            action = TriageSieveAction(action_type=ActionType.SKIP_TURN, metadata={})

        obs = env.step(action)
        reward = obs.reward if obs.reward is not None else 0.0
        episode_done = obs.done
        error = None if obs.last_action_result == "ok" else obs.last_action_result

        steps.append({
            "step": step_num,
            "raw": raw[:150] if raw else "(empty)",
            "parsed": parsed,
            "action": action_to_str(action),
            "reward": reward,
            "done": episode_done,
            "error": error,
        })
        last_reward = reward

    if not episode_done:
        obs = env.step(TriageSieveAction(action_type=ActionType.FINISH_EPISODE, metadata={}))
        reward = obs.reward if obs.reward is not None else 0.0
        steps.append({
            "step": len(steps) + 1, "raw": "(auto)", "parsed": True,
            "action": "finish_episode", "reward": reward, "done": True, "error": None,
        })

    final_score = steps[-1]["reward"] if steps else 0.0

    return {
        "model": model_name.split("/")[-1],
        "difficulty": difficulty,
        "seed": seed,
        "final_score": final_score,
        "total_steps": len(steps),
        "parse_failures": sum(1 for s in steps if not s["parsed"]),
        "invalid_actions": sum(1 for s in steps if s["error"]),
        "steps": steps,
    }


def print_episode(r: dict[str, Any]) -> None:
    print(f"\n{'='*80}")
    print(f"  {r['model']}  |  {r['difficulty']}  |  seed={r['seed']}")
    print(f"{'='*80}")
    for s in r["steps"]:
        p = "OK" if s["parsed"] else "FAIL"
        err = f"  ERR: {s['error'][:50]}" if s["error"] else ""
        print(f"  Step {s['step']:>2}: [{p:>4}] {s['action']:<45} reward={s['reward']:+.4f}{err}")
        if not s["parsed"] and s["raw"] != "(auto)" and s["raw"] != "(empty)":
            print(f"         LLM: {s['raw'][:100]}")
    print(f"\n  SCORE: {r['final_score']:.4f}  |  Parse fails: {r['parse_failures']}  |  Invalid: {r['invalid_actions']}")


def main() -> None:
    if not HF_TOKEN:
        print("ERROR: HF_TOKEN not set")
        sys.exit(1)

    client = OpenAI(base_url=BASE_URL, api_key=HF_TOKEN)
    all_results: list[dict[str, Any]] = []

    for model_name in MODELS:
        for cfg in CONFIGS:
            model_short = model_name.split("/")[-1]
            print(f"\n>>> {model_short} / {cfg['difficulty']} / seed={cfg['seed']} ...", flush=True)
            t0 = time.time()
            result = run_episode(client, model_name, cfg["seed"], cfg["difficulty"], cfg["max_steps"])
            result["time"] = time.time() - t0
            all_results.append(result)
            print_episode(result)
            print(f"  Time: {result['time']:.1f}s")

    # Summary
    print(f"\n\n{'='*100}")
    print("FULL SUMMARY")
    print(f"{'='*100}")
    print(f"  {'Model':<30} {'Diff':<8} {'Seed':>4} {'Score':>8} {'Steps':>6} {'Parse':>6} {'Invalid':>8} {'Time':>6}")
    print(f"  {'-'*30} {'-'*8} {'-'*4} {'-'*8} {'-'*6} {'-'*6} {'-'*8} {'-'*6}")
    for r in all_results:
        print(
            f"  {r['model']:<30} {r['difficulty']:<8} {r['seed']:>4} {r['final_score']:>8.4f} "
            f"{r['total_steps']:>6} {r['parse_failures']:>6} {r['invalid_actions']:>8} {r['time']:>5.1f}s"
        )

    # Aggregate stats
    print(f"\n  --- Aggregate ---")
    scores = [r["final_score"] for r in all_results]
    parse_fails = sum(r["parse_failures"] for r in all_results)
    invalid = sum(r["invalid_actions"] for r in all_results)
    crashes = sum(1 for r in all_results if r["final_score"] < 0)
    print(f"  Total episodes: {len(all_results)}")
    print(f"  Score range: [{min(scores):.4f}, {max(scores):.4f}]")
    print(f"  Mean score: {sum(scores)/len(scores):.4f}")
    print(f"  Total parse failures: {parse_fails}")
    print(f"  Total invalid actions: {invalid}")
    print(f"  Negative scores (bug indicator): {crashes}")
    print(f"  Episodes with score > 0: {sum(1 for s in scores if s > 0)}/{len(scores)}")


if __name__ == "__main__":
    main()