triagesieve_env / scripts /test_multi_model_pro.py
Angshuman28's picture
Upload folder using huggingface_hub
a7effbb verified
"""Comprehensive multi-model test with HF PRO token.
Tests 4 models x 3 difficulties x 2 seeds = 24 episodes.
"""
from __future__ import annotations
import os
import sys
import time
from typing import Any
from dotenv import load_dotenv
load_dotenv(override=True)
# Add repo root so `import inference` (root-level module) resolves.
_REPO_ROOT = os.path.join(os.path.dirname(__file__), "..")
if _REPO_ROOT not in sys.path:
sys.path.insert(0, _REPO_ROOT)
from openai import OpenAI
from inference import parse_action, serialize_observation, action_to_str, SYSTEM_PROMPT
from triagesieve_env.models import ActionType, TriageSieveAction
from triagesieve_env.server.triagesieve_env_environment import TriageSieveEnvironment
HF_TOKEN = os.getenv("HF_TOKEN")
BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODELS = [
"Qwen/Qwen2.5-72B-Instruct",
"meta-llama/Llama-3.3-70B-Instruct",
"Qwen/Qwen2.5-7B-Instruct",
"meta-llama/Llama-3.1-8B-Instruct",
]
CONFIGS = [
{"difficulty": "easy", "seed": 42, "max_steps": 8},
{"difficulty": "easy", "seed": 7, "max_steps": 8},
{"difficulty": "medium", "seed": 42, "max_steps": 14},
{"difficulty": "medium", "seed": 2, "max_steps": 14},
{"difficulty": "hard", "seed": 42, "max_steps": 20},
{"difficulty": "hard", "seed": 1, "max_steps": 20},
]
def run_episode(client: OpenAI, model_name: str, seed: int, difficulty: str, max_steps: int) -> dict[str, Any]:
env = TriageSieveEnvironment()
obs = env.reset(seed=seed, difficulty=difficulty, mode="eval_strict")
steps: list[dict[str, Any]] = []
last_reward = 0.0
episode_done = False
for step_num in range(1, max_steps + 1):
if episode_done or obs.action_budget_remaining <= 0:
break
obs_text = serialize_observation(obs)
user_content = f"Step {step_num} | Last reward: {last_reward:.2f}\n\n{obs_text}"
try:
r = client.chat.completions.create(
model=model_name,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_content},
],
temperature=0.0,
max_tokens=512,
)
raw = (r.choices[0].message.content or "").strip()
except Exception as exc:
raw = ""
action = parse_action(raw)
parsed = action is not None
if action is None:
action = TriageSieveAction(action_type=ActionType.SKIP_TURN, metadata={})
obs = env.step(action)
reward = obs.reward if obs.reward is not None else 0.0
episode_done = obs.done
error = None if obs.last_action_result == "ok" else obs.last_action_result
steps.append({
"step": step_num,
"raw": raw[:150] if raw else "(empty)",
"parsed": parsed,
"action": action_to_str(action),
"reward": reward,
"done": episode_done,
"error": error,
})
last_reward = reward
if not episode_done:
obs = env.step(TriageSieveAction(action_type=ActionType.FINISH_EPISODE, metadata={}))
reward = obs.reward if obs.reward is not None else 0.0
steps.append({
"step": len(steps) + 1, "raw": "(auto)", "parsed": True,
"action": "finish_episode", "reward": reward, "done": True, "error": None,
})
final_score = steps[-1]["reward"] if steps else 0.0
return {
"model": model_name.split("/")[-1],
"difficulty": difficulty,
"seed": seed,
"final_score": final_score,
"total_steps": len(steps),
"parse_failures": sum(1 for s in steps if not s["parsed"]),
"invalid_actions": sum(1 for s in steps if s["error"]),
"steps": steps,
}
def print_episode(r: dict[str, Any]) -> None:
print(f"\n{'='*80}")
print(f" {r['model']} | {r['difficulty']} | seed={r['seed']}")
print(f"{'='*80}")
for s in r["steps"]:
p = "OK" if s["parsed"] else "FAIL"
err = f" ERR: {s['error'][:50]}" if s["error"] else ""
print(f" Step {s['step']:>2}: [{p:>4}] {s['action']:<45} reward={s['reward']:+.4f}{err}")
if not s["parsed"] and s["raw"] != "(auto)" and s["raw"] != "(empty)":
print(f" LLM: {s['raw'][:100]}")
print(f"\n SCORE: {r['final_score']:.4f} | Parse fails: {r['parse_failures']} | Invalid: {r['invalid_actions']}")
def main() -> None:
if not HF_TOKEN:
print("ERROR: HF_TOKEN not set")
sys.exit(1)
client = OpenAI(base_url=BASE_URL, api_key=HF_TOKEN)
all_results: list[dict[str, Any]] = []
for model_name in MODELS:
for cfg in CONFIGS:
model_short = model_name.split("/")[-1]
print(f"\n>>> {model_short} / {cfg['difficulty']} / seed={cfg['seed']} ...", flush=True)
t0 = time.time()
result = run_episode(client, model_name, cfg["seed"], cfg["difficulty"], cfg["max_steps"])
result["time"] = time.time() - t0
all_results.append(result)
print_episode(result)
print(f" Time: {result['time']:.1f}s")
# Summary
print(f"\n\n{'='*100}")
print("FULL SUMMARY")
print(f"{'='*100}")
print(f" {'Model':<30} {'Diff':<8} {'Seed':>4} {'Score':>8} {'Steps':>6} {'Parse':>6} {'Invalid':>8} {'Time':>6}")
print(f" {'-'*30} {'-'*8} {'-'*4} {'-'*8} {'-'*6} {'-'*6} {'-'*8} {'-'*6}")
for r in all_results:
print(
f" {r['model']:<30} {r['difficulty']:<8} {r['seed']:>4} {r['final_score']:>8.4f} "
f"{r['total_steps']:>6} {r['parse_failures']:>6} {r['invalid_actions']:>8} {r['time']:>5.1f}s"
)
# Aggregate stats
print(f"\n --- Aggregate ---")
scores = [r["final_score"] for r in all_results]
parse_fails = sum(r["parse_failures"] for r in all_results)
invalid = sum(r["invalid_actions"] for r in all_results)
crashes = sum(1 for r in all_results if r["final_score"] < 0)
print(f" Total episodes: {len(all_results)}")
print(f" Score range: [{min(scores):.4f}, {max(scores):.4f}]")
print(f" Mean score: {sum(scores)/len(scores):.4f}")
print(f" Total parse failures: {parse_fails}")
print(f" Total invalid actions: {invalid}")
print(f" Negative scores (bug indicator): {crashes}")
print(f" Episodes with score > 0: {sum(1 for s in scores if s > 0)}/{len(scores)}")
if __name__ == "__main__":
main()