File size: 6,722 Bytes
b89c8aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7effbb
 
 
 
b89c8aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
"""Test the environment with multiple LLMs and capture detailed logs.

Usage:
    python scripts/test_multi_model.py
"""
from __future__ import annotations

import asyncio
import os
import sys
import time
from dataclasses import dataclass
from typing import Any

from dotenv import load_dotenv
load_dotenv()

# Add repo root so `import inference` (root-level module) resolves.
_REPO_ROOT = os.path.join(os.path.dirname(__file__), "..")
if _REPO_ROOT not in sys.path:
    sys.path.insert(0, _REPO_ROOT)

from openai import OpenAI
from inference import (
    get_model_action,
    parse_action,
    serialize_observation,
    action_to_str,
    SYSTEM_PROMPT,
)
from triagesieve_env.models import (
    ActionType,
    TriageSieveAction,
    TriageSieveObservation,
)
from triagesieve_env.server.triagesieve_env_environment import TriageSieveEnvironment


HF_TOKEN = os.getenv("HF_TOKEN")
BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")

MODELS = [
    "Qwen/Qwen2.5-72B-Instruct",
    "meta-llama/Llama-3.3-70B-Instruct",
    "Qwen/Qwen2.5-7B-Instruct",
]

DIFFICULTIES = ["easy", "medium", "hard"]
SEED = 42
MAX_STEPS = {"easy": 8, "medium": 14, "hard": 20}


def run_episode(
    client: OpenAI,
    model_name: str,
    seed: int,
    difficulty: str,
    max_steps: int,
) -> dict[str, Any]:
    """Run one episode synchronously, capturing detailed logs."""
    env = TriageSieveEnvironment()
    obs = env.reset(seed=seed, difficulty=difficulty, mode="eval_strict")

    steps = []
    last_reward = 0.0
    episode_done = False

    for step_num in range(1, max_steps + 1):
        if episode_done or obs.action_budget_remaining <= 0:
            break

        obs_text = serialize_observation(obs)

        # Call LLM
        user_content = f"Step {step_num} | Last reward: {last_reward:.2f}\n\n{obs_text}"
        try:
            completion = client.chat.completions.create(
                model=model_name,
                messages=[
                    {"role": "system", "content": SYSTEM_PROMPT},
                    {"role": "user", "content": user_content},
                ],
                temperature=0.0,
                max_tokens=512,
                stream=False,
            )
            raw_response = (completion.choices[0].message.content or "").strip()
        except Exception as exc:
            raw_response = ""
            print(f"    [LLM ERROR] step {step_num}: {exc}")

        # Parse action
        action = parse_action(raw_response)
        parse_ok = action is not None
        if action is None:
            action = TriageSieveAction(action_type=ActionType.SKIP_TURN, metadata={})

        # Step environment
        obs = env.step(action)
        reward = obs.reward if obs.reward is not None else 0.0
        episode_done = obs.done

        error_str = None if obs.last_action_result == "ok" else obs.last_action_result

        steps.append({
            "step": step_num,
            "raw_llm": raw_response[:120],
            "parsed": parse_ok,
            "action": action_to_str(action),
            "reward": reward,
            "done": episode_done,
            "error": error_str,
            "budget_left": obs.action_budget_remaining,
        })

        last_reward = reward

    # Send finish if not done
    if not episode_done:
        obs = env.step(TriageSieveAction(action_type=ActionType.FINISH_EPISODE, metadata={}))
        reward = obs.reward if obs.reward is not None else 0.0
        steps.append({
            "step": len(steps) + 1,
            "raw_llm": "(auto)",
            "parsed": True,
            "action": "finish_episode",
            "reward": reward,
            "done": True,
            "error": None,
            "budget_left": obs.action_budget_remaining,
        })

    final_score = steps[-1]["reward"] if steps else 0.0

    return {
        "model": model_name,
        "difficulty": difficulty,
        "seed": seed,
        "final_score": final_score,
        "total_steps": len(steps),
        "parse_failures": sum(1 for s in steps if not s["parsed"]),
        "invalid_actions": sum(1 for s in steps if s["error"] is not None),
        "steps": steps,
    }


def print_episode(result: dict[str, Any]) -> None:
    """Print a formatted episode trace."""
    model_short = result["model"].split("/")[-1]
    print(f"\n{'='*80}")
    print(f"  Model: {model_short}  |  Difficulty: {result['difficulty']}  |  Seed: {result['seed']}")
    print(f"{'='*80}")

    for s in result["steps"]:
        parse_marker = "OK" if s["parsed"] else "PARSE_FAIL"
        err = f" ERR: {s['error']}" if s["error"] else ""
        print(
            f"  Step {s['step']:>2}: [{parse_marker:>10}] {s['action']:<40} "
            f"reward={s['reward']:+.3f}{err}"
        )
        if not s["parsed"] and s["raw_llm"] != "(auto)":
            # Show what the LLM actually said
            print(f"           LLM said: {s['raw_llm'][:100]}")

    score = result["final_score"]
    pf = result["parse_failures"]
    ia = result["invalid_actions"]
    print(f"\n  Final Score: {score:.4f}  |  Parse Failures: {pf}  |  Invalid Actions: {ia}")
    status = "GOOD" if score >= 0.5 else ("OK" if score > 0 else "BAD")
    print(f"  Verdict: {status}")


def main() -> None:
    if not HF_TOKEN:
        print("ERROR: HF_TOKEN not set")
        sys.exit(1)

    all_results = []

    for model_name in MODELS:
        client = OpenAI(base_url=BASE_URL, api_key=HF_TOKEN)
        model_short = model_name.split("/")[-1]

        for diff in DIFFICULTIES:
            print(f"\n>>> Running {model_short} / {diff} / seed={SEED} ...", flush=True)
            start = time.time()
            result = run_episode(
                client=client,
                model_name=model_name,
                seed=SEED,
                difficulty=diff,
                max_steps=MAX_STEPS[diff],
            )
            elapsed = time.time() - start
            result["elapsed_s"] = elapsed
            all_results.append(result)
            print_episode(result)
            print(f"  Time: {elapsed:.1f}s")

    # Summary table
    print(f"\n\n{'='*80}")
    print("SUMMARY")
    print(f"{'='*80}")
    print(f"  {'Model':<35} {'Diff':<8} {'Score':>8} {'Steps':>6} {'Parse':>6} {'Invalid':>8} {'Time':>6}")
    print(f"  {'-'*35} {'-'*8} {'-'*8} {'-'*6} {'-'*6} {'-'*8} {'-'*6}")
    for r in all_results:
        model_short = r["model"].split("/")[-1][:35]
        print(
            f"  {model_short:<35} {r['difficulty']:<8} {r['final_score']:>8.4f} "
            f"{r['total_steps']:>6} {r['parse_failures']:>6} {r['invalid_actions']:>8} "
            f"{r['elapsed_s']:>5.1f}s"
        )


if __name__ == "__main__":
    main()