| |
| """Evaluate the fine-tuned AGORA planner against the heuristic baseline. |
| |
| Compares task allocation accuracy, assignment quality, and response format |
| compliance between the trained LLM planner and AGORA's built-in heuristic engine. |
| |
| Usage: |
| CUDA_VISIBLE_DEVICES=2 python scripts/eval_planner.py |
| CUDA_VISIBLE_DEVICES=2 python scripts/eval_planner.py --model /mnt/artifacts-datai/models/project_agora/agora-planner-v1/merged |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import os |
| import sys |
| import time |
| from pathlib import Path |
|
|
| import torch |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "src")) |
|
|
| PROJECT = "project_agora" |
| ARTIFACTS = "/mnt/artifacts-datai" |
| MODEL_DIR = f"{ARTIFACTS}/models/{PROJECT}/agora-planner-v1/merged" |
| EVAL_DATA = f"{ARTIFACTS}/logs/{PROJECT}/planning_eval.jsonl" |
| REPORT_DIR = f"{ARTIFACTS}/reports/{PROJECT}" |
| os.makedirs(REPORT_DIR, exist_ok=True) |
|
|
|
|
| def load_eval_data(path: str) -> list[dict]: |
| """Load evaluation examples from JSONL.""" |
| examples = [] |
| with open(path) as f: |
| for line in f: |
| examples.append(json.loads(line)) |
| return examples |
|
|
|
|
| def extract_json_from_response(text: str) -> dict | None: |
| """Try to extract a JSON object from model response.""" |
| text = text.strip() |
| |
| try: |
| return json.loads(text) |
| except json.JSONDecodeError: |
| pass |
| |
| for start_marker in ["{", "```json\n", "```\n"]: |
| idx = text.find(start_marker) |
| if idx >= 0: |
| candidate = text[idx:] |
| if candidate.startswith("```"): |
| end = candidate.find("```", 3) |
| candidate = candidate[candidate.find("{"):end] if end > 0 else candidate[3:] |
| try: |
| return json.loads(candidate) |
| except json.JSONDecodeError: |
| |
| depth = 0 |
| for i, c in enumerate(candidate): |
| if c == "{": |
| depth += 1 |
| elif c == "}": |
| depth -= 1 |
| if depth == 0: |
| try: |
| return json.loads(candidate[:i + 1]) |
| except json.JSONDecodeError: |
| break |
| return None |
|
|
|
|
| def score_allocation(predicted: dict, reference: dict) -> dict: |
| """Score a predicted allocation against the reference.""" |
| ref_assignments = reference.get("assignments", {}) |
| pred_assignments = predicted.get("assignments", {}) |
|
|
| |
| ref_task_map = {} |
| for robot_id, task_ids in ref_assignments.items(): |
| for tid in task_ids: |
| ref_task_map[tid] = robot_id |
|
|
| pred_task_map = {} |
| for robot_id, task_ids in pred_assignments.items(): |
| if isinstance(task_ids, list): |
| for tid in task_ids: |
| pred_task_map[str(tid)] = robot_id |
|
|
| all_tasks = set(ref_task_map.keys()) | set(pred_task_map.keys()) |
| if not all_tasks: |
| return { |
| "exact_match": 1.0, |
| "task_coverage": 1.0, |
| "robot_match_rate": 1.0, |
| "format_valid": True, |
| } |
|
|
| |
| ref_tasks_covered = sum(1 for t in ref_task_map if t in pred_task_map) |
| coverage = ref_tasks_covered / max(len(ref_task_map), 1) |
|
|
| |
| robot_matches = sum( |
| 1 for t in ref_task_map |
| if t in pred_task_map and pred_task_map[t] == ref_task_map[t] |
| ) |
| robot_match_rate = robot_matches / max(ref_tasks_covered, 1) |
|
|
| |
| exact = ref_task_map == pred_task_map |
|
|
| return { |
| "exact_match": 1.0 if exact else 0.0, |
| "task_coverage": coverage, |
| "robot_match_rate": robot_match_rate, |
| "format_valid": True, |
| "ref_tasks": len(ref_task_map), |
| "pred_tasks": len(pred_task_map), |
| } |
|
|
|
|
| def evaluate_model(model_path: str, eval_data: list[dict], max_examples: int = 100) -> dict: |
| """Run the fine-tuned model on eval data and compute metrics.""" |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| print(f"Loading model from: {model_path}") |
| tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
| model = AutoModelForCausalLM.from_pretrained( |
| model_path, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| trust_remote_code=True, |
| ) |
| model.eval() |
|
|
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| results = [] |
| total_time = 0 |
| format_failures = 0 |
|
|
| for i, example in enumerate(eval_data[:max_examples]): |
| msgs = example["messages"] |
| system_msg = msgs[0]["content"] |
| user_msg = msgs[1]["content"] |
| ref_response = msgs[2]["content"] |
| ref_parsed = extract_json_from_response(ref_response) |
|
|
| |
| chat = [ |
| {"role": "system", "content": system_msg}, |
| {"role": "user", "content": user_msg}, |
| ] |
| prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
| t0 = time.time() |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=512, |
| temperature=0.1, |
| do_sample=True, |
| top_p=0.9, |
| pad_token_id=tokenizer.pad_token_id, |
| ) |
| t1 = time.time() |
| total_time += t1 - t0 |
|
|
| generated = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) |
| pred_parsed = extract_json_from_response(generated) |
|
|
| if pred_parsed is None: |
| format_failures += 1 |
| results.append({ |
| "exact_match": 0.0, |
| "task_coverage": 0.0, |
| "robot_match_rate": 0.0, |
| "format_valid": False, |
| }) |
| elif ref_parsed: |
| score = score_allocation(pred_parsed, ref_parsed) |
| results.append(score) |
| else: |
| results.append({"format_valid": True, "exact_match": 0.0, "task_coverage": 0.0, "robot_match_rate": 0.0}) |
|
|
| if (i + 1) % 10 == 0: |
| avg_time = total_time / (i + 1) |
| print(f" [{i + 1}/{min(max_examples, len(eval_data))}] " |
| f"avg_time={avg_time:.2f}s/example, format_ok={len(results) - format_failures}/{len(results)}") |
|
|
| |
| n = len(results) |
| metrics = { |
| "total_examples": n, |
| "exact_match": sum(r["exact_match"] for r in results) / max(n, 1), |
| "task_coverage": sum(r["task_coverage"] for r in results) / max(n, 1), |
| "robot_match_rate": sum(r["robot_match_rate"] for r in results) / max(n, 1), |
| "format_valid_rate": sum(1 for r in results if r["format_valid"]) / max(n, 1), |
| "format_failures": format_failures, |
| "avg_inference_time_s": total_time / max(n, 1), |
| "total_inference_time_s": total_time, |
| } |
| return metrics |
|
|
|
|
| def main(): |
| import argparse |
| parser = argparse.ArgumentParser(description="Evaluate AGORA planner model") |
| parser.add_argument("--model", default=MODEL_DIR, help="Model path") |
| parser.add_argument("--eval-data", default=EVAL_DATA, help="Eval JSONL path") |
| parser.add_argument("--max-examples", type=int, default=100, help="Max eval examples") |
| args = parser.parse_args() |
|
|
| if not Path(args.model).exists(): |
| print(f"ERROR: Model not found at {args.model}") |
| sys.exit(1) |
| if not Path(args.eval_data).exists(): |
| print(f"ERROR: Eval data not found at {args.eval_data}") |
| sys.exit(1) |
|
|
| eval_data = load_eval_data(args.eval_data) |
| print(f"Loaded {len(eval_data)} eval examples") |
|
|
| print(f"\n{'=' * 60}") |
| print("AGORA Planner Evaluation") |
| print(f"{'=' * 60}") |
| print(f"Model: {args.model}") |
| print(f"Eval data: {args.eval_data}") |
| print(f"Examples: {min(args.max_examples, len(eval_data))}") |
| print(f"{'=' * 60}\n") |
|
|
| metrics = evaluate_model(args.model, eval_data, args.max_examples) |
|
|
| print(f"\n{'=' * 60}") |
| print("EVALUATION RESULTS") |
| print(f"{'=' * 60}") |
| print(f"Total examples: {metrics['total_examples']}") |
| print(f"Exact match rate: {metrics['exact_match']:.1%}") |
| print(f"Task coverage: {metrics['task_coverage']:.1%}") |
| print(f"Robot match rate: {metrics['robot_match_rate']:.1%}") |
| print(f"Format valid rate: {metrics['format_valid_rate']:.1%}") |
| print(f"Format failures: {metrics['format_failures']}") |
| print(f"Avg inference time: {metrics['avg_inference_time_s']:.2f}s") |
| print(f"{'=' * 60}") |
|
|
| |
| report_path = f"{REPORT_DIR}/planner_eval.json" |
| with open(report_path, "w") as f: |
| json.dump(metrics, f, indent=2) |
| print(f"\nReport saved to: {report_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|