Spaces:
Running
Running
| """ | |
| Agent Comparison β FSDS Cleaning Environment | |
| ============================================= | |
| Benchmarks four agents on the held-out evaluation set and prints a | |
| side-by-side comparison table. | |
| Agents evaluated | |
| ---------------- | |
| 1. RandomAgent β lower bound (uniform random actions) | |
| 2. HeuristicAgent β upper bound (scripted oracle policy) | |
| 3. SFT model β supervised fine-tuned checkpoint warm-start | |
| 4. GRPO model β RL-trained checkpoint (SFT β GRPO) | |
| Run in Colab after both training files have completed: | |
| - training_sft.py β ./data-cleaning-sft-final | |
| - training_colab.py β ./data-cleaning-grpo-final | |
| """ | |
| # ββ Cell 1 βΈ Install (skip if already installed) ββββββββββββββββββββββ | |
| # %% | |
| # !pip install -q "openenv-core[core]>=0.2.1" | |
| # !pip install -q "git+https://huggingface.co/spaces/israaaML/fsds_cleaning_env" | |
| # !pip uninstall -y vllm | |
| # !pip install -q unsloth | |
| # ββ Cell 2 βΈ Imports & config βββββββββββββββββββββββββββββββββββββββββ | |
| # %% | |
| import json | |
| from pathlib import Path | |
| ENV_URL = "https://israaaML-fsds-cleaning-env.hf.space" | |
| SFT_MODEL_PATH = "./data-cleaning-sft-final" | |
| GRPO_MODEL_PATH = "./data-cleaning-grpo-final" | |
| EPISODES_PER_TASK = 3 # increase for more reliable estimates (slower) | |
| OUTPUT_FILE = "./results_comparison.json" | |
| # ββ Cell 3 βΈ Connect to environment & sanity check ββββββββββββββββββββ | |
| # %% | |
| from fsds_cleaning_env import FSDSCleaningEnv | |
| from fsds_cleaning_env.evaluation_tasks import EVAL_TASKS | |
| from fsds_cleaning_env.evaluate_agent import run_evaluation | |
| from fsds_cleaning_env.agents import RandomAgent, HeuristicAgent, LLMAgent | |
| from fsds_cleaning_env.metrics import aggregate_metrics, compute_episode_metrics | |
| with FSDSCleaningEnv(base_url=ENV_URL).sync() as env: | |
| env.reset(task_id="ecommerce_mobile") | |
| brief = env.call_tool("get_task_brief") | |
| print(f"Connected to env. Task: {brief.get('title')}") | |
| tasks_list = env.call_tool("list_tasks") | |
| print(f"Available tasks: {[t['task_id'] for t in tasks_list.get('tasks', [])]}") | |
| print(f"\nEval tasks ({len(EVAL_TASKS)} scenarios):") | |
| for t in EVAL_TASKS: | |
| print(f" {t.name} (task_id={t.task_id}, seed_index={t.eval_index})") | |
| # ββ Cell 4 βΈ Run all agents βββββββββββββββββββββββββββββββββββββββββββ | |
| # %% | |
| results = {} | |
| # ββ 4a. Random (lower bound) ββββββββββββββββββββββββββββββββββββββββββ | |
| print("\n[1/4] Evaluating RandomAgent β¦") | |
| results["random"] = run_evaluation( | |
| RandomAgent(), | |
| base_url=ENV_URL, | |
| max_episodes_per_task=EPISODES_PER_TASK, | |
| ) | |
| agg = results["random"]["aggregate"] | |
| print(f" success={agg['success_rate']:.0%} return={agg['avg_return']:.4f} steps={agg['avg_steps']:.1f}") | |
| # ββ 4b. Heuristic (upper bound) βββββββββββββββββββββββββββββββββββββββ | |
| print("\n[2/4] Evaluating HeuristicAgent β¦") | |
| results["heuristic"] = run_evaluation( | |
| HeuristicAgent(), | |
| base_url=ENV_URL, | |
| max_episodes_per_task=EPISODES_PER_TASK, | |
| ) | |
| agg = results["heuristic"]["aggregate"] | |
| print(f" success={agg['success_rate']:.0%} return={agg['avg_return']:.4f} steps={agg['avg_steps']:.1f}") | |
| # ββ 4c. SFT model βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print(f"\n[3/4] Evaluating SFT model ({SFT_MODEL_PATH}) β¦") | |
| results["sft"] = run_evaluation( | |
| LLMAgent(model_path=SFT_MODEL_PATH, temperature=0.0), | |
| base_url=ENV_URL, | |
| max_episodes_per_task=EPISODES_PER_TASK, | |
| ) | |
| agg = results["sft"]["aggregate"] | |
| print(f" success={agg['success_rate']:.0%} return={agg['avg_return']:.4f} steps={agg['avg_steps']:.1f}") | |
| # ββ 4d. GRPO model ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print(f"\n[4/4] Evaluating GRPO model ({GRPO_MODEL_PATH}) β¦") | |
| results["grpo"] = run_evaluation( | |
| LLMAgent(model_path=GRPO_MODEL_PATH, temperature=0.0), | |
| base_url=ENV_URL, | |
| max_episodes_per_task=EPISODES_PER_TASK, | |
| ) | |
| agg = results["grpo"]["aggregate"] | |
| print(f" success={agg['success_rate']:.0%} return={agg['avg_return']:.4f} steps={agg['avg_steps']:.1f}") | |
| # ββ Cell 5 βΈ Comparison table βββββββββββββββββββββββββββββββββββββββββ | |
| # %% | |
| AGENTS = [ | |
| ("Random", "random"), | |
| ("Heuristic", "heuristic"), | |
| ("SFT", "sft"), | |
| ("GRPO", "grpo"), | |
| ] | |
| COL_W = 12 | |
| def _col(v, w=COL_W): | |
| return str(v).center(w) | |
| header = ( | |
| f"{'Agent':<14}" | |
| + _col("Success %") | |
| + _col("Avg Return") | |
| + _col("Avg Steps") | |
| + _col("Avg Invalid") | |
| + _col("Episodes") | |
| ) | |
| sep = "-" * len(header) | |
| print("\n" + sep) | |
| print(" FSDS Cleaning Agent Benchmark") | |
| print(sep) | |
| print(header) | |
| print(sep) | |
| for label, key in AGENTS: | |
| if key not in results: | |
| continue | |
| agg = results[key]["aggregate"] | |
| print( | |
| f"{label:<14}" | |
| + _col(f"{agg['success_rate']:.1%}") | |
| + _col(f"{agg['avg_return']:.4f}") | |
| + _col(f"{agg['avg_steps']:.1f}") | |
| + _col(f"{agg['avg_invalid_actions']:.2f}") | |
| + _col(agg["episodes"]) | |
| ) | |
| print(sep) | |
| # Improvement of GRPO over SFT | |
| if "sft" in results and "grpo" in results: | |
| sft_sr = results["sft"]["aggregate"]["success_rate"] | |
| grpo_sr = results["grpo"]["aggregate"]["success_rate"] | |
| sft_ret = results["sft"]["aggregate"]["avg_return"] | |
| grpo_ret = results["grpo"]["aggregate"]["avg_return"] | |
| print(f"\nGRPO vs SFT β success rate delta : {grpo_sr - sft_sr:+.1%}") | |
| print(f"GRPO vs SFT β avg return delta : {grpo_ret - sft_ret:+.4f}") | |
| # ββ Cell 6 βΈ Per-task breakdown βββββββββββββββββββββββββββββββββββββββ | |
| # %% | |
| # Group per-episode results by task_id for a fine-grained breakdown. | |
| from collections import defaultdict | |
| print("\n=== Per-task success rates ===") | |
| task_ids = sorted({ep["task_id"] for ep in results["heuristic"]["episodes"]}) | |
| col_labels = [label for label, _ in AGENTS if _ in results] | |
| keys = [key for _, key in AGENTS if key in results] | |
| # Header | |
| print(f"\n{'Task':<30}" + "".join(f"{lbl:>12}" for lbl in col_labels)) | |
| print("-" * (30 + 12 * len(col_labels))) | |
| for tid in task_ids: | |
| row = f"{tid:<30}" | |
| for key in keys: | |
| eps = [e for e in results[key]["episodes"] if e["task_id"] == tid] | |
| if not eps: | |
| row += f"{'N/A':>12}" | |
| else: | |
| sr = sum(1 for e in eps if e.get("success", False)) / len(eps) | |
| row += f"{sr:>11.0%} " | |
| print(row) | |
| # ββ Cell 7 βΈ Save results βββββββββββββββββββββββββββββββββββββββββββββ | |
| # %% | |
| Path(OUTPUT_FILE).write_text(json.dumps(results, indent=2)) | |
| print(f"\nFull results saved to {OUTPUT_FILE}") | |