project_agora / scripts /eval_planner.py
ilessio-aiflowlab's picture
[AGORA] Full export: pth + safetensors + ONNX + TRT fp16 + TRT fp32
12d70dc verified
#!/usr/bin/env python3
"""Evaluate the fine-tuned AGORA planner against the heuristic baseline.
Compares task allocation accuracy, assignment quality, and response format
compliance between the trained LLM planner and AGORA's built-in heuristic engine.
Usage:
CUDA_VISIBLE_DEVICES=2 python scripts/eval_planner.py
CUDA_VISIBLE_DEVICES=2 python scripts/eval_planner.py --model /mnt/artifacts-datai/models/project_agora/agora-planner-v1/merged
"""
from __future__ import annotations
import json
import os
import sys
import time
from pathlib import Path
import torch
sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "src"))
PROJECT = "project_agora"
ARTIFACTS = "/mnt/artifacts-datai"
MODEL_DIR = f"{ARTIFACTS}/models/{PROJECT}/agora-planner-v1/merged"
EVAL_DATA = f"{ARTIFACTS}/logs/{PROJECT}/planning_eval.jsonl"
REPORT_DIR = f"{ARTIFACTS}/reports/{PROJECT}"
os.makedirs(REPORT_DIR, exist_ok=True)
def load_eval_data(path: str) -> list[dict]:
"""Load evaluation examples from JSONL."""
examples = []
with open(path) as f:
for line in f:
examples.append(json.loads(line))
return examples
def extract_json_from_response(text: str) -> dict | None:
"""Try to extract a JSON object from model response."""
text = text.strip()
# Try direct parse
try:
return json.loads(text)
except json.JSONDecodeError:
pass
# Try finding JSON block
for start_marker in ["{", "```json\n", "```\n"]:
idx = text.find(start_marker)
if idx >= 0:
candidate = text[idx:]
if candidate.startswith("```"):
end = candidate.find("```", 3)
candidate = candidate[candidate.find("{"):end] if end > 0 else candidate[3:]
try:
return json.loads(candidate)
except json.JSONDecodeError:
# Try to find matching brace
depth = 0
for i, c in enumerate(candidate):
if c == "{":
depth += 1
elif c == "}":
depth -= 1
if depth == 0:
try:
return json.loads(candidate[:i + 1])
except json.JSONDecodeError:
break
return None
def score_allocation(predicted: dict, reference: dict) -> dict:
"""Score a predicted allocation against the reference."""
ref_assignments = reference.get("assignments", {})
pred_assignments = predicted.get("assignments", {})
# Flatten to task -> robot mappings
ref_task_map = {}
for robot_id, task_ids in ref_assignments.items():
for tid in task_ids:
ref_task_map[tid] = robot_id
pred_task_map = {}
for robot_id, task_ids in pred_assignments.items():
if isinstance(task_ids, list):
for tid in task_ids:
pred_task_map[str(tid)] = robot_id
all_tasks = set(ref_task_map.keys()) | set(pred_task_map.keys())
if not all_tasks:
return {
"exact_match": 1.0,
"task_coverage": 1.0,
"robot_match_rate": 1.0,
"format_valid": True,
}
# Task coverage: how many reference tasks are assigned in prediction
ref_tasks_covered = sum(1 for t in ref_task_map if t in pred_task_map)
coverage = ref_tasks_covered / max(len(ref_task_map), 1)
# Robot match: among covered tasks, how many assigned to the same robot
robot_matches = sum(
1 for t in ref_task_map
if t in pred_task_map and pred_task_map[t] == ref_task_map[t]
)
robot_match_rate = robot_matches / max(ref_tasks_covered, 1)
# Exact match: perfect allocation
exact = ref_task_map == pred_task_map
return {
"exact_match": 1.0 if exact else 0.0,
"task_coverage": coverage,
"robot_match_rate": robot_match_rate,
"format_valid": True,
"ref_tasks": len(ref_task_map),
"pred_tasks": len(pred_task_map),
}
def evaluate_model(model_path: str, eval_data: list[dict], max_examples: int = 100) -> dict:
"""Run the fine-tuned model on eval data and compute metrics."""
from transformers import AutoModelForCausalLM, AutoTokenizer
print(f"Loading model from: {model_path}")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
model.eval()
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
results = []
total_time = 0
format_failures = 0
for i, example in enumerate(eval_data[:max_examples]):
msgs = example["messages"]
system_msg = msgs[0]["content"]
user_msg = msgs[1]["content"]
ref_response = msgs[2]["content"]
ref_parsed = extract_json_from_response(ref_response)
# Build prompt using chat template
chat = [
{"role": "system", "content": system_msg},
{"role": "user", "content": user_msg},
]
prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
t0 = time.time()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=512,
temperature=0.1,
do_sample=True,
top_p=0.9,
pad_token_id=tokenizer.pad_token_id,
)
t1 = time.time()
total_time += t1 - t0
generated = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
pred_parsed = extract_json_from_response(generated)
if pred_parsed is None:
format_failures += 1
results.append({
"exact_match": 0.0,
"task_coverage": 0.0,
"robot_match_rate": 0.0,
"format_valid": False,
})
elif ref_parsed:
score = score_allocation(pred_parsed, ref_parsed)
results.append(score)
else:
results.append({"format_valid": True, "exact_match": 0.0, "task_coverage": 0.0, "robot_match_rate": 0.0})
if (i + 1) % 10 == 0:
avg_time = total_time / (i + 1)
print(f" [{i + 1}/{min(max_examples, len(eval_data))}] "
f"avg_time={avg_time:.2f}s/example, format_ok={len(results) - format_failures}/{len(results)}")
# Aggregate metrics
n = len(results)
metrics = {
"total_examples": n,
"exact_match": sum(r["exact_match"] for r in results) / max(n, 1),
"task_coverage": sum(r["task_coverage"] for r in results) / max(n, 1),
"robot_match_rate": sum(r["robot_match_rate"] for r in results) / max(n, 1),
"format_valid_rate": sum(1 for r in results if r["format_valid"]) / max(n, 1),
"format_failures": format_failures,
"avg_inference_time_s": total_time / max(n, 1),
"total_inference_time_s": total_time,
}
return metrics
def main():
import argparse
parser = argparse.ArgumentParser(description="Evaluate AGORA planner model")
parser.add_argument("--model", default=MODEL_DIR, help="Model path")
parser.add_argument("--eval-data", default=EVAL_DATA, help="Eval JSONL path")
parser.add_argument("--max-examples", type=int, default=100, help="Max eval examples")
args = parser.parse_args()
if not Path(args.model).exists():
print(f"ERROR: Model not found at {args.model}")
sys.exit(1)
if not Path(args.eval_data).exists():
print(f"ERROR: Eval data not found at {args.eval_data}")
sys.exit(1)
eval_data = load_eval_data(args.eval_data)
print(f"Loaded {len(eval_data)} eval examples")
print(f"\n{'=' * 60}")
print("AGORA Planner Evaluation")
print(f"{'=' * 60}")
print(f"Model: {args.model}")
print(f"Eval data: {args.eval_data}")
print(f"Examples: {min(args.max_examples, len(eval_data))}")
print(f"{'=' * 60}\n")
metrics = evaluate_model(args.model, eval_data, args.max_examples)
print(f"\n{'=' * 60}")
print("EVALUATION RESULTS")
print(f"{'=' * 60}")
print(f"Total examples: {metrics['total_examples']}")
print(f"Exact match rate: {metrics['exact_match']:.1%}")
print(f"Task coverage: {metrics['task_coverage']:.1%}")
print(f"Robot match rate: {metrics['robot_match_rate']:.1%}")
print(f"Format valid rate: {metrics['format_valid_rate']:.1%}")
print(f"Format failures: {metrics['format_failures']}")
print(f"Avg inference time: {metrics['avg_inference_time_s']:.2f}s")
print(f"{'=' * 60}")
# Save report
report_path = f"{REPORT_DIR}/planner_eval.json"
with open(report_path, "w") as f:
json.dump(metrics, f, indent=2)
print(f"\nReport saved to: {report_path}")
if __name__ == "__main__":
main()