Spaces:
Running
Running
| """Comprehensive multi-model test with HF PRO token. | |
| Tests 4 models x 3 difficulties x 2 seeds = 24 episodes. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import sys | |
| import time | |
| from typing import Any | |
| from dotenv import load_dotenv | |
| load_dotenv(override=True) | |
| # Add repo root so `import inference` (root-level module) resolves. | |
| _REPO_ROOT = os.path.join(os.path.dirname(__file__), "..") | |
| if _REPO_ROOT not in sys.path: | |
| sys.path.insert(0, _REPO_ROOT) | |
| from openai import OpenAI | |
| from inference import parse_action, serialize_observation, action_to_str, SYSTEM_PROMPT | |
| from triagesieve_env.models import ActionType, TriageSieveAction | |
| from triagesieve_env.server.triagesieve_env_environment import TriageSieveEnvironment | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") | |
| MODELS = [ | |
| "Qwen/Qwen2.5-72B-Instruct", | |
| "meta-llama/Llama-3.3-70B-Instruct", | |
| "Qwen/Qwen2.5-7B-Instruct", | |
| "meta-llama/Llama-3.1-8B-Instruct", | |
| ] | |
| CONFIGS = [ | |
| {"difficulty": "easy", "seed": 42, "max_steps": 8}, | |
| {"difficulty": "easy", "seed": 7, "max_steps": 8}, | |
| {"difficulty": "medium", "seed": 42, "max_steps": 14}, | |
| {"difficulty": "medium", "seed": 2, "max_steps": 14}, | |
| {"difficulty": "hard", "seed": 42, "max_steps": 20}, | |
| {"difficulty": "hard", "seed": 1, "max_steps": 20}, | |
| ] | |
| def run_episode(client: OpenAI, model_name: str, seed: int, difficulty: str, max_steps: int) -> dict[str, Any]: | |
| env = TriageSieveEnvironment() | |
| obs = env.reset(seed=seed, difficulty=difficulty, mode="eval_strict") | |
| steps: list[dict[str, Any]] = [] | |
| last_reward = 0.0 | |
| episode_done = False | |
| for step_num in range(1, max_steps + 1): | |
| if episode_done or obs.action_budget_remaining <= 0: | |
| break | |
| obs_text = serialize_observation(obs) | |
| user_content = f"Step {step_num} | Last reward: {last_reward:.2f}\n\n{obs_text}" | |
| try: | |
| r = client.chat.completions.create( | |
| model=model_name, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_content}, | |
| ], | |
| temperature=0.0, | |
| max_tokens=512, | |
| ) | |
| raw = (r.choices[0].message.content or "").strip() | |
| except Exception as exc: | |
| raw = "" | |
| action = parse_action(raw) | |
| parsed = action is not None | |
| if action is None: | |
| action = TriageSieveAction(action_type=ActionType.SKIP_TURN, metadata={}) | |
| obs = env.step(action) | |
| reward = obs.reward if obs.reward is not None else 0.0 | |
| episode_done = obs.done | |
| error = None if obs.last_action_result == "ok" else obs.last_action_result | |
| steps.append({ | |
| "step": step_num, | |
| "raw": raw[:150] if raw else "(empty)", | |
| "parsed": parsed, | |
| "action": action_to_str(action), | |
| "reward": reward, | |
| "done": episode_done, | |
| "error": error, | |
| }) | |
| last_reward = reward | |
| if not episode_done: | |
| obs = env.step(TriageSieveAction(action_type=ActionType.FINISH_EPISODE, metadata={})) | |
| reward = obs.reward if obs.reward is not None else 0.0 | |
| steps.append({ | |
| "step": len(steps) + 1, "raw": "(auto)", "parsed": True, | |
| "action": "finish_episode", "reward": reward, "done": True, "error": None, | |
| }) | |
| final_score = steps[-1]["reward"] if steps else 0.0 | |
| return { | |
| "model": model_name.split("/")[-1], | |
| "difficulty": difficulty, | |
| "seed": seed, | |
| "final_score": final_score, | |
| "total_steps": len(steps), | |
| "parse_failures": sum(1 for s in steps if not s["parsed"]), | |
| "invalid_actions": sum(1 for s in steps if s["error"]), | |
| "steps": steps, | |
| } | |
| def print_episode(r: dict[str, Any]) -> None: | |
| print(f"\n{'='*80}") | |
| print(f" {r['model']} | {r['difficulty']} | seed={r['seed']}") | |
| print(f"{'='*80}") | |
| for s in r["steps"]: | |
| p = "OK" if s["parsed"] else "FAIL" | |
| err = f" ERR: {s['error'][:50]}" if s["error"] else "" | |
| print(f" Step {s['step']:>2}: [{p:>4}] {s['action']:<45} reward={s['reward']:+.4f}{err}") | |
| if not s["parsed"] and s["raw"] != "(auto)" and s["raw"] != "(empty)": | |
| print(f" LLM: {s['raw'][:100]}") | |
| print(f"\n SCORE: {r['final_score']:.4f} | Parse fails: {r['parse_failures']} | Invalid: {r['invalid_actions']}") | |
| def main() -> None: | |
| if not HF_TOKEN: | |
| print("ERROR: HF_TOKEN not set") | |
| sys.exit(1) | |
| client = OpenAI(base_url=BASE_URL, api_key=HF_TOKEN) | |
| all_results: list[dict[str, Any]] = [] | |
| for model_name in MODELS: | |
| for cfg in CONFIGS: | |
| model_short = model_name.split("/")[-1] | |
| print(f"\n>>> {model_short} / {cfg['difficulty']} / seed={cfg['seed']} ...", flush=True) | |
| t0 = time.time() | |
| result = run_episode(client, model_name, cfg["seed"], cfg["difficulty"], cfg["max_steps"]) | |
| result["time"] = time.time() - t0 | |
| all_results.append(result) | |
| print_episode(result) | |
| print(f" Time: {result['time']:.1f}s") | |
| # Summary | |
| print(f"\n\n{'='*100}") | |
| print("FULL SUMMARY") | |
| print(f"{'='*100}") | |
| print(f" {'Model':<30} {'Diff':<8} {'Seed':>4} {'Score':>8} {'Steps':>6} {'Parse':>6} {'Invalid':>8} {'Time':>6}") | |
| print(f" {'-'*30} {'-'*8} {'-'*4} {'-'*8} {'-'*6} {'-'*6} {'-'*8} {'-'*6}") | |
| for r in all_results: | |
| print( | |
| f" {r['model']:<30} {r['difficulty']:<8} {r['seed']:>4} {r['final_score']:>8.4f} " | |
| f"{r['total_steps']:>6} {r['parse_failures']:>6} {r['invalid_actions']:>8} {r['time']:>5.1f}s" | |
| ) | |
| # Aggregate stats | |
| print(f"\n --- Aggregate ---") | |
| scores = [r["final_score"] for r in all_results] | |
| parse_fails = sum(r["parse_failures"] for r in all_results) | |
| invalid = sum(r["invalid_actions"] for r in all_results) | |
| crashes = sum(1 for r in all_results if r["final_score"] < 0) | |
| print(f" Total episodes: {len(all_results)}") | |
| print(f" Score range: [{min(scores):.4f}, {max(scores):.4f}]") | |
| print(f" Mean score: {sum(scores)/len(scores):.4f}") | |
| print(f" Total parse failures: {parse_fails}") | |
| print(f" Total invalid actions: {invalid}") | |
| print(f" Negative scores (bug indicator): {crashes}") | |
| print(f" Episodes with score > 0: {sum(1 for s in scores if s > 0)}/{len(scores)}") | |
| if __name__ == "__main__": | |
| main() | |