Spaces:
Running
Running
| """Test the environment with multiple LLMs and capture detailed logs. | |
| Usage: | |
| python scripts/test_multi_model.py | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import os | |
| import sys | |
| import time | |
| from dataclasses import dataclass | |
| from typing import Any | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # Add repo root so `import inference` (root-level module) resolves. | |
| _REPO_ROOT = os.path.join(os.path.dirname(__file__), "..") | |
| if _REPO_ROOT not in sys.path: | |
| sys.path.insert(0, _REPO_ROOT) | |
| from openai import OpenAI | |
| from inference import ( | |
| get_model_action, | |
| parse_action, | |
| serialize_observation, | |
| action_to_str, | |
| SYSTEM_PROMPT, | |
| ) | |
| from triagesieve_env.models import ( | |
| ActionType, | |
| TriageSieveAction, | |
| TriageSieveObservation, | |
| ) | |
| from triagesieve_env.server.triagesieve_env_environment import TriageSieveEnvironment | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") | |
| MODELS = [ | |
| "Qwen/Qwen2.5-72B-Instruct", | |
| "meta-llama/Llama-3.3-70B-Instruct", | |
| "Qwen/Qwen2.5-7B-Instruct", | |
| ] | |
| DIFFICULTIES = ["easy", "medium", "hard"] | |
| SEED = 42 | |
| MAX_STEPS = {"easy": 8, "medium": 14, "hard": 20} | |
| def run_episode( | |
| client: OpenAI, | |
| model_name: str, | |
| seed: int, | |
| difficulty: str, | |
| max_steps: int, | |
| ) -> dict[str, Any]: | |
| """Run one episode synchronously, capturing detailed logs.""" | |
| env = TriageSieveEnvironment() | |
| obs = env.reset(seed=seed, difficulty=difficulty, mode="eval_strict") | |
| steps = [] | |
| last_reward = 0.0 | |
| episode_done = False | |
| for step_num in range(1, max_steps + 1): | |
| if episode_done or obs.action_budget_remaining <= 0: | |
| break | |
| obs_text = serialize_observation(obs) | |
| # Call LLM | |
| user_content = f"Step {step_num} | Last reward: {last_reward:.2f}\n\n{obs_text}" | |
| try: | |
| completion = client.chat.completions.create( | |
| model=model_name, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_content}, | |
| ], | |
| temperature=0.0, | |
| max_tokens=512, | |
| stream=False, | |
| ) | |
| raw_response = (completion.choices[0].message.content or "").strip() | |
| except Exception as exc: | |
| raw_response = "" | |
| print(f" [LLM ERROR] step {step_num}: {exc}") | |
| # Parse action | |
| action = parse_action(raw_response) | |
| parse_ok = action is not None | |
| if action is None: | |
| action = TriageSieveAction(action_type=ActionType.SKIP_TURN, metadata={}) | |
| # Step environment | |
| obs = env.step(action) | |
| reward = obs.reward if obs.reward is not None else 0.0 | |
| episode_done = obs.done | |
| error_str = None if obs.last_action_result == "ok" else obs.last_action_result | |
| steps.append({ | |
| "step": step_num, | |
| "raw_llm": raw_response[:120], | |
| "parsed": parse_ok, | |
| "action": action_to_str(action), | |
| "reward": reward, | |
| "done": episode_done, | |
| "error": error_str, | |
| "budget_left": obs.action_budget_remaining, | |
| }) | |
| last_reward = reward | |
| # Send finish if not done | |
| if not episode_done: | |
| obs = env.step(TriageSieveAction(action_type=ActionType.FINISH_EPISODE, metadata={})) | |
| reward = obs.reward if obs.reward is not None else 0.0 | |
| steps.append({ | |
| "step": len(steps) + 1, | |
| "raw_llm": "(auto)", | |
| "parsed": True, | |
| "action": "finish_episode", | |
| "reward": reward, | |
| "done": True, | |
| "error": None, | |
| "budget_left": obs.action_budget_remaining, | |
| }) | |
| final_score = steps[-1]["reward"] if steps else 0.0 | |
| return { | |
| "model": model_name, | |
| "difficulty": difficulty, | |
| "seed": seed, | |
| "final_score": final_score, | |
| "total_steps": len(steps), | |
| "parse_failures": sum(1 for s in steps if not s["parsed"]), | |
| "invalid_actions": sum(1 for s in steps if s["error"] is not None), | |
| "steps": steps, | |
| } | |
| def print_episode(result: dict[str, Any]) -> None: | |
| """Print a formatted episode trace.""" | |
| model_short = result["model"].split("/")[-1] | |
| print(f"\n{'='*80}") | |
| print(f" Model: {model_short} | Difficulty: {result['difficulty']} | Seed: {result['seed']}") | |
| print(f"{'='*80}") | |
| for s in result["steps"]: | |
| parse_marker = "OK" if s["parsed"] else "PARSE_FAIL" | |
| err = f" ERR: {s['error']}" if s["error"] else "" | |
| print( | |
| f" Step {s['step']:>2}: [{parse_marker:>10}] {s['action']:<40} " | |
| f"reward={s['reward']:+.3f}{err}" | |
| ) | |
| if not s["parsed"] and s["raw_llm"] != "(auto)": | |
| # Show what the LLM actually said | |
| print(f" LLM said: {s['raw_llm'][:100]}") | |
| score = result["final_score"] | |
| pf = result["parse_failures"] | |
| ia = result["invalid_actions"] | |
| print(f"\n Final Score: {score:.4f} | Parse Failures: {pf} | Invalid Actions: {ia}") | |
| status = "GOOD" if score >= 0.5 else ("OK" if score > 0 else "BAD") | |
| print(f" Verdict: {status}") | |
| def main() -> None: | |
| if not HF_TOKEN: | |
| print("ERROR: HF_TOKEN not set") | |
| sys.exit(1) | |
| all_results = [] | |
| for model_name in MODELS: | |
| client = OpenAI(base_url=BASE_URL, api_key=HF_TOKEN) | |
| model_short = model_name.split("/")[-1] | |
| for diff in DIFFICULTIES: | |
| print(f"\n>>> Running {model_short} / {diff} / seed={SEED} ...", flush=True) | |
| start = time.time() | |
| result = run_episode( | |
| client=client, | |
| model_name=model_name, | |
| seed=SEED, | |
| difficulty=diff, | |
| max_steps=MAX_STEPS[diff], | |
| ) | |
| elapsed = time.time() - start | |
| result["elapsed_s"] = elapsed | |
| all_results.append(result) | |
| print_episode(result) | |
| print(f" Time: {elapsed:.1f}s") | |
| # Summary table | |
| print(f"\n\n{'='*80}") | |
| print("SUMMARY") | |
| print(f"{'='*80}") | |
| print(f" {'Model':<35} {'Diff':<8} {'Score':>8} {'Steps':>6} {'Parse':>6} {'Invalid':>8} {'Time':>6}") | |
| print(f" {'-'*35} {'-'*8} {'-'*8} {'-'*6} {'-'*6} {'-'*8} {'-'*6}") | |
| for r in all_results: | |
| model_short = r["model"].split("/")[-1][:35] | |
| print( | |
| f" {model_short:<35} {r['difficulty']:<8} {r['final_score']:>8.4f} " | |
| f"{r['total_steps']:>6} {r['parse_failures']:>6} {r['invalid_actions']:>8} " | |
| f"{r['elapsed_s']:>5.1f}s" | |
| ) | |
| if __name__ == "__main__": | |
| main() | |