triagesieve_env / scripts /test_multi_model.py
Angshuman28's picture
Upload folder using huggingface_hub
a7effbb verified
"""Test the environment with multiple LLMs and capture detailed logs.
Usage:
python scripts/test_multi_model.py
"""
from __future__ import annotations
import asyncio
import os
import sys
import time
from dataclasses import dataclass
from typing import Any
from dotenv import load_dotenv
load_dotenv()
# Add repo root so `import inference` (root-level module) resolves.
_REPO_ROOT = os.path.join(os.path.dirname(__file__), "..")
if _REPO_ROOT not in sys.path:
sys.path.insert(0, _REPO_ROOT)
from openai import OpenAI
from inference import (
get_model_action,
parse_action,
serialize_observation,
action_to_str,
SYSTEM_PROMPT,
)
from triagesieve_env.models import (
ActionType,
TriageSieveAction,
TriageSieveObservation,
)
from triagesieve_env.server.triagesieve_env_environment import TriageSieveEnvironment
HF_TOKEN = os.getenv("HF_TOKEN")
BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODELS = [
"Qwen/Qwen2.5-72B-Instruct",
"meta-llama/Llama-3.3-70B-Instruct",
"Qwen/Qwen2.5-7B-Instruct",
]
DIFFICULTIES = ["easy", "medium", "hard"]
SEED = 42
MAX_STEPS = {"easy": 8, "medium": 14, "hard": 20}
def run_episode(
client: OpenAI,
model_name: str,
seed: int,
difficulty: str,
max_steps: int,
) -> dict[str, Any]:
"""Run one episode synchronously, capturing detailed logs."""
env = TriageSieveEnvironment()
obs = env.reset(seed=seed, difficulty=difficulty, mode="eval_strict")
steps = []
last_reward = 0.0
episode_done = False
for step_num in range(1, max_steps + 1):
if episode_done or obs.action_budget_remaining <= 0:
break
obs_text = serialize_observation(obs)
# Call LLM
user_content = f"Step {step_num} | Last reward: {last_reward:.2f}\n\n{obs_text}"
try:
completion = client.chat.completions.create(
model=model_name,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_content},
],
temperature=0.0,
max_tokens=512,
stream=False,
)
raw_response = (completion.choices[0].message.content or "").strip()
except Exception as exc:
raw_response = ""
print(f" [LLM ERROR] step {step_num}: {exc}")
# Parse action
action = parse_action(raw_response)
parse_ok = action is not None
if action is None:
action = TriageSieveAction(action_type=ActionType.SKIP_TURN, metadata={})
# Step environment
obs = env.step(action)
reward = obs.reward if obs.reward is not None else 0.0
episode_done = obs.done
error_str = None if obs.last_action_result == "ok" else obs.last_action_result
steps.append({
"step": step_num,
"raw_llm": raw_response[:120],
"parsed": parse_ok,
"action": action_to_str(action),
"reward": reward,
"done": episode_done,
"error": error_str,
"budget_left": obs.action_budget_remaining,
})
last_reward = reward
# Send finish if not done
if not episode_done:
obs = env.step(TriageSieveAction(action_type=ActionType.FINISH_EPISODE, metadata={}))
reward = obs.reward if obs.reward is not None else 0.0
steps.append({
"step": len(steps) + 1,
"raw_llm": "(auto)",
"parsed": True,
"action": "finish_episode",
"reward": reward,
"done": True,
"error": None,
"budget_left": obs.action_budget_remaining,
})
final_score = steps[-1]["reward"] if steps else 0.0
return {
"model": model_name,
"difficulty": difficulty,
"seed": seed,
"final_score": final_score,
"total_steps": len(steps),
"parse_failures": sum(1 for s in steps if not s["parsed"]),
"invalid_actions": sum(1 for s in steps if s["error"] is not None),
"steps": steps,
}
def print_episode(result: dict[str, Any]) -> None:
"""Print a formatted episode trace."""
model_short = result["model"].split("/")[-1]
print(f"\n{'='*80}")
print(f" Model: {model_short} | Difficulty: {result['difficulty']} | Seed: {result['seed']}")
print(f"{'='*80}")
for s in result["steps"]:
parse_marker = "OK" if s["parsed"] else "PARSE_FAIL"
err = f" ERR: {s['error']}" if s["error"] else ""
print(
f" Step {s['step']:>2}: [{parse_marker:>10}] {s['action']:<40} "
f"reward={s['reward']:+.3f}{err}"
)
if not s["parsed"] and s["raw_llm"] != "(auto)":
# Show what the LLM actually said
print(f" LLM said: {s['raw_llm'][:100]}")
score = result["final_score"]
pf = result["parse_failures"]
ia = result["invalid_actions"]
print(f"\n Final Score: {score:.4f} | Parse Failures: {pf} | Invalid Actions: {ia}")
status = "GOOD" if score >= 0.5 else ("OK" if score > 0 else "BAD")
print(f" Verdict: {status}")
def main() -> None:
if not HF_TOKEN:
print("ERROR: HF_TOKEN not set")
sys.exit(1)
all_results = []
for model_name in MODELS:
client = OpenAI(base_url=BASE_URL, api_key=HF_TOKEN)
model_short = model_name.split("/")[-1]
for diff in DIFFICULTIES:
print(f"\n>>> Running {model_short} / {diff} / seed={SEED} ...", flush=True)
start = time.time()
result = run_episode(
client=client,
model_name=model_name,
seed=SEED,
difficulty=diff,
max_steps=MAX_STEPS[diff],
)
elapsed = time.time() - start
result["elapsed_s"] = elapsed
all_results.append(result)
print_episode(result)
print(f" Time: {elapsed:.1f}s")
# Summary table
print(f"\n\n{'='*80}")
print("SUMMARY")
print(f"{'='*80}")
print(f" {'Model':<35} {'Diff':<8} {'Score':>8} {'Steps':>6} {'Parse':>6} {'Invalid':>8} {'Time':>6}")
print(f" {'-'*35} {'-'*8} {'-'*8} {'-'*6} {'-'*6} {'-'*8} {'-'*6}")
for r in all_results:
model_short = r["model"].split("/")[-1][:35]
print(
f" {model_short:<35} {r['difficulty']:<8} {r['final_score']:>8.4f} "
f"{r['total_steps']:>6} {r['parse_failures']:>6} {r['invalid_actions']:>8} "
f"{r['elapsed_s']:>5.1f}s"
)
if __name__ == "__main__":
main()