| """ |
| VLN Waypoint Prediction Evaluation — vLLM accelerated version |
| Uses vLLM offline batch inference for much faster evaluation. |
| """ |
|
|
| import argparse |
| import json |
| import os |
| import re |
| import time |
| import logging |
| from typing import Dict, List, Optional, Tuple |
|
|
| import numpy as np |
| from PIL import Image |
|
|
| logging.basicConfig( |
| format="%(asctime)s - %(levelname)s - %(message)s", |
| level=logging.INFO, |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| DIMS = ["dx", "dy", "dz", "dpitch", "dyaw", "droll"] |
| NUM_WAYPOINTS = 5 |
|
|
|
|
| def load_val_data(val_path: str) -> List[Dict]: |
| data = [] |
| with open(val_path) as f: |
| for line in f: |
| item = json.loads(line.strip()) |
| data.append(item) |
| logger.info(f"Loaded {len(data)} validation samples") |
| return data |
|
|
|
|
| def parse_waypoints(text: str) -> Optional[List[Dict]]: |
| try: |
| if "</think>" in text: |
| text = text.split("</think>")[-1] |
|
|
| match = re.search(r'\{.*\}', text, re.DOTALL) |
| if not match: |
| return None |
| obj = json.loads(match.group()) |
| deltas = obj.get("waypoint_deltas", []) |
| if len(deltas) == 0: |
| return None |
| result = [] |
| for d in deltas: |
| wp = {} |
| if isinstance(d, dict): |
| for dim in DIMS: |
| wp[dim] = float(d.get(dim, 0.0)) |
| elif isinstance(d, (list, tuple)) and len(d) >= len(DIMS): |
| for i, dim in enumerate(DIMS): |
| wp[dim] = float(d[i]) |
| else: |
| return None |
| result.append(wp) |
| return result |
| except (json.JSONDecodeError, ValueError, TypeError, AttributeError, IndexError): |
| return None |
|
|
|
|
| def build_vllm_inputs(item: Dict) -> dict: |
| """Build a single vLLM input with multimodal data for Qwen3.5.""" |
| from PIL import Image as _PILImage |
| messages = item["messages"] |
| image_paths = item.get("images", []) |
|
|
| chat_messages = [] |
| for msg in messages: |
| if msg["role"] == "assistant": |
| break |
| if msg["role"] == "user": |
| if image_paths: |
| content_parts = [] |
| for p in image_paths: |
| try: |
| pil_img = _PILImage.open(p).convert("RGB") |
| except Exception as _e: |
| logger.warning(f"failed to open image {p}: {_e}") |
| continue |
| content_parts.append({"type": "image_pil", "image_pil": pil_img}) |
| content_parts.append({"type": "text", "text": msg["content"]}) |
| chat_messages.append({"role": "user", "content": content_parts}) |
| else: |
| chat_messages.append({"role": "user", "content": msg["content"]}) |
| else: |
| chat_messages.append({"role": msg["role"], "content": msg["content"]}) |
|
|
| return chat_messages |
|
|
|
|
| def compute_metrics(all_errors, parse_failures, total): |
| metrics = {} |
|
|
| for dim in DIMS: |
| vals = [e for s in all_errors for e in s[dim]] |
| if vals: |
| metrics[f"mae_{dim}"] = float(np.mean(vals)) |
| metrics[f"rmse_{dim}"] = float(np.sqrt(np.mean(np.array(vals) ** 2))) |
|
|
| all_vals = [] |
| for dim in DIMS: |
| all_vals.extend([e for s in all_errors for e in s[dim]]) |
|
|
| pos_dims = ["dx", "dy", "dz"] |
| pos_vals = [] |
| for dim in pos_dims: |
| pos_vals.extend([e for s in all_errors for e in s[dim]]) |
|
|
| rot_dims = ["dpitch", "dyaw", "droll"] |
| rot_vals = [] |
| for dim in rot_dims: |
| rot_vals.extend([e for s in all_errors for e in s[dim]]) |
|
|
| metrics["mae_overall"] = float(np.mean(all_vals)) if all_vals else 0 |
| metrics["mae_position"] = float(np.mean(pos_vals)) if pos_vals else 0 |
| metrics["mae_rotation"] = float(np.mean(rot_vals)) if rot_vals else 0 |
| metrics["rmse_overall"] = float(np.sqrt(np.mean(np.array(all_vals) ** 2))) if all_vals else 0 |
|
|
| per_wp_euc = {} |
| for s in all_errors: |
| n_wp = len(s["dx"]) |
| for wi in range(n_wp): |
| euc = np.sqrt(s["dx"][wi]**2 + s["dy"][wi]**2 + s["dz"][wi]**2) |
| per_wp_euc.setdefault(wi, []).append(euc) |
|
|
| all_euc = [] |
| for wi in sorted(per_wp_euc.keys()): |
| vals = per_wp_euc[wi] |
| all_euc.extend(vals) |
| metrics[f"wp{wi+1}_euc_mae"] = float(np.mean(vals)) |
| metrics[f"wp{wi+1}_euc_median"] = float(np.median(vals)) |
|
|
| metrics["euclidean_mae"] = float(np.mean(all_euc)) if all_euc else 0 |
|
|
| ade_list, fde_list = [], [] |
| for s in all_errors: |
| n_wp = len(s["dx"]) |
| traj_eucs = [] |
| for wi in range(n_wp): |
| euc = np.sqrt(s["dx"][wi]**2 + s["dy"][wi]**2 + s["dz"][wi]**2) |
| traj_eucs.append(euc) |
| if traj_eucs: |
| ade_list.append(np.mean(traj_eucs)) |
| fde_list.append(traj_eucs[-1]) |
|
|
| metrics["ADE"] = float(np.mean(ade_list)) if ade_list else 0 |
| metrics["FDE"] = float(np.mean(fde_list)) if fde_list else 0 |
| metrics["ADE_median"] = float(np.median(ade_list)) if ade_list else 0 |
| metrics["FDE_median"] = float(np.median(fde_list)) if fde_list else 0 |
|
|
| |
| pos_thresholds = [0.1, 0.2, 0.3, 0.5, 1.0, 2.0, 5.0] |
| for thr in pos_thresholds: |
| hit = sum(1 for e in all_euc if e < thr) |
| metrics[f"SR@{thr}m"] = hit / len(all_euc) if all_euc else 0 |
|
|
| |
| traj_thresholds = [0.3, 0.5, 1.0, 2.0, 5.0] |
| for thr in traj_thresholds: |
| traj_success = 0 |
| for s in all_errors: |
| n_wp = len(s["dx"]) |
| all_under = True |
| for wi in range(n_wp): |
| euc = np.sqrt(s["dx"][wi]**2 + s["dy"][wi]**2 + s["dz"][wi]**2) |
| if euc >= thr: |
| all_under = False |
| break |
| if all_under: |
| traj_success += 1 |
| metrics[f"TrajSR@{thr}m"] = traj_success / len(all_errors) if all_errors else 0 |
|
|
| |
| all_rot_errors = [] |
| per_sample_rot_mags = [] |
| per_sample_pos_mags = [] |
| for s in all_errors: |
| rots = [] |
| poss = [] |
| for wi in range(len(s["dx"])): |
| rot_err = np.sqrt(s["dpitch"][wi]**2 + s["dyaw"][wi]**2 + s["droll"][wi]**2) |
| pos_err = np.sqrt(s["dx"][wi]**2 + s["dy"][wi]**2 + s["dz"][wi]**2) |
| all_rot_errors.append(rot_err) |
| rots.append(rot_err) |
| poss.append(pos_err) |
| per_sample_rot_mags.append(rots) |
| per_sample_pos_mags.append(poss) |
|
|
| |
| rot_thresholds = [0.5, 1.0, 2.0, 5.0, 10.0] |
| for thr in rot_thresholds: |
| hit = sum(1 for e in all_rot_errors if e < thr) |
| metrics[f"RotAcc@{thr}deg"] = hit / len(all_rot_errors) if all_rot_errors else 0 |
|
|
| |
| for thr in [1.0, 2.0, 5.0, 10.0]: |
| traj_rot_success = 0 |
| for rots in per_sample_rot_mags: |
| if all(r < thr for r in rots): |
| traj_rot_success += 1 |
| metrics[f"TrajRotSR@{thr}deg"] = traj_rot_success / len(per_sample_rot_mags) if per_sample_rot_mags else 0 |
|
|
| |
| |
| JOINT_PAIRS = [(0.5, 1.0), (0.5, 5.0), (1.0, 1.0), (1.0, 5.0), (0.3, 1.0), (0.5, 2.0)] |
| for pos_thr, rot_thr in JOINT_PAIRS: |
| hit = 0 |
| for poss, rots in zip(per_sample_pos_mags, per_sample_rot_mags): |
| if any(p < pos_thr and r < rot_thr for p, r in zip(poss, rots)): |
| hit += 1 |
| metrics[f"JointSR@({pos_thr}m,{rot_thr}deg)"] = hit / len(per_sample_pos_mags) if per_sample_pos_mags else 0 |
|
|
| |
| for pos_thr, rot_thr in JOINT_PAIRS: |
| hit = 0 |
| for poss, rots in zip(per_sample_pos_mags, per_sample_rot_mags): |
| if all(p < pos_thr and r < rot_thr for p, r in zip(poss, rots)): |
| hit += 1 |
| metrics[f"TrajJointSR@({pos_thr}m,{rot_thr}deg)"] = hit / len(per_sample_pos_mags) if per_sample_pos_mags else 0 |
|
|
| |
| per_wp_rot = {} |
| for s in all_errors: |
| n_wp = len(s["dx"]) |
| for wi in range(n_wp): |
| rot_err = np.sqrt(s["dpitch"][wi]**2 + s["dyaw"][wi]**2 + s["droll"][wi]**2) |
| per_wp_rot.setdefault(wi, []).append(rot_err) |
|
|
| for wi in sorted(per_wp_rot.keys()): |
| vals = per_wp_rot[wi] |
| metrics[f"wp{wi+1}_rot_mae"] = float(np.mean(vals)) |
|
|
| metrics["rotation_euc_mae"] = float(np.mean(all_rot_errors)) if all_rot_errors else 0 |
|
|
| |
| if ade_list: |
| ade_arr = np.array(ade_list) |
| for p in [50, 75, 90, 95, 99]: |
| metrics[f"ADE_p{p}"] = float(np.percentile(ade_arr, p)) |
| metrics["ADE_max"] = float(ade_arr.max()) |
| if fde_list: |
| fde_arr = np.array(fde_list) |
| for p in [50, 75, 90, 95, 99]: |
| metrics[f"FDE_p{p}"] = float(np.percentile(fde_arr, p)) |
| metrics["FDE_max"] = float(fde_arr.max()) |
| if all_rot_errors: |
| rot_arr = np.array(all_rot_errors) |
| for p in [50, 75, 90, 95, 99]: |
| metrics[f"rot_err_p{p}"] = float(np.percentile(rot_arr, p)) |
| metrics["rot_err_max"] = float(rot_arr.max()) |
|
|
| |
| n_samples = len(all_errors) |
| if n_samples > 0: |
| for thr in [2.0, 5.0, 10.0]: |
| metrics[f"HardFailRate_pos_gt_{thr}m"] = sum(1 for e in fde_list if e > thr) / n_samples |
| for thr in [10.0, 30.0, 60.0]: |
| sample_max_rot = [max(rots) if rots else 0 for rots in per_sample_rot_mags] |
| metrics[f"HardFailRate_rot_gt_{thr}deg"] = sum(1 for r in sample_max_rot if r > thr) / n_samples |
|
|
| metrics["parse_failure_rate"] = parse_failures / total if total > 0 else 0 |
| metrics["parse_success_rate"] = 1 - metrics["parse_failure_rate"] |
| metrics["valid_samples"] = len(all_errors) |
| metrics["total_samples"] = total |
| metrics["parse_failures"] = parse_failures |
|
|
| return metrics |
|
|
|
|
| def print_results(results, model_name): |
| logger.info("=" * 70) |
| logger.info(f" Evaluation Results: {model_name}") |
| logger.info("=" * 70) |
| logger.info(f" Samples: {results['valid_samples']}/{results['total_samples']} " |
| f"(parse failures: {results['parse_failures']}, " |
| f"rate: {results['parse_failure_rate']:.2%}, " |
| f"success: {results['parse_success_rate']:.2%})") |
|
|
| logger.info("-" * 70) |
| logger.info(" [Regression Metrics]") |
| logger.info(f" Overall MAE: {results['mae_overall']:.4f}") |
| logger.info(f" Position MAE: {results['mae_position']:.4f} (dx/dy/dz)") |
| logger.info(f" Rotation MAE: {results['mae_rotation']:.4f} (dpitch/dyaw/droll)") |
| logger.info(f" Overall RMSE: {results['rmse_overall']:.4f}") |
|
|
| logger.info("-" * 70) |
| logger.info(" [Trajectory Metrics]") |
| logger.info(f" ADE (mean): {results['ADE']:.4f} (avg displacement error)") |
| logger.info(f" ADE (median): {results['ADE_median']:.4f}") |
| logger.info(f" FDE (mean): {results['FDE']:.4f} (final displacement error)") |
| logger.info(f" FDE (median): {results['FDE_median']:.4f}") |
| logger.info(f" Euclidean MAE: {results['euclidean_mae']:.4f}") |
|
|
| logger.info("-" * 70) |
| logger.info(" [Position Success Rate]") |
| for thr in [0.5, 1.0, 2.0, 5.0]: |
| key = f"SR@{thr}m" |
| logger.info(f" {key:12s} {results.get(key, 0):.2%}") |
|
|
| logger.info("-" * 70) |
| logger.info(" [Trajectory Success Rate (all waypoints under threshold)]") |
| for thr in [1.0, 2.0, 5.0]: |
| key = f"TrajSR@{thr}m" |
| logger.info(f" {key:14s} {results.get(key, 0):.2%}") |
|
|
| logger.info("-" * 70) |
| logger.info(" [Rotation Accuracy]") |
| for thr in [1.0, 5.0, 10.0]: |
| key = f"RotAcc@{thr}deg" |
| logger.info(f" {key:16s} {results.get(key, 0):.2%}") |
|
|
| logger.info("-" * 70) |
| logger.info(" [Per-waypoint Position Error (Euclidean)]") |
| for wi in range(NUM_WAYPOINTS): |
| euc_key = f"wp{wi+1}_euc_mae" |
| med_key = f"wp{wi+1}_euc_median" |
| if euc_key in results: |
| logger.info(f" Waypoint {wi+1}: MAE={results[euc_key]:.4f} " |
| f"Median={results.get(med_key, 0):.4f}") |
|
|
| logger.info("-" * 70) |
| logger.info(" [Per-waypoint Rotation Error]") |
| for wi in range(NUM_WAYPOINTS): |
| rot_key = f"wp{wi+1}_rot_mae" |
| if rot_key in results: |
| logger.info(f" Waypoint {wi+1}: MAE={results[rot_key]:.4f}") |
|
|
| logger.info("-" * 70) |
| logger.info(" [Per-dimension MAE / RMSE]") |
| for dim in DIMS: |
| logger.info(f" {dim:8s} MAE={results.get(f'mae_{dim}',0):.4f} " |
| f"RMSE={results.get(f'rmse_{dim}',0):.4f}") |
| logger.info("=" * 70) |
|
|
|
|
| def get_max_model_len(model_path: str) -> int: |
| config_path = os.path.join(model_path, "config.json") |
| if os.path.exists(config_path): |
| with open(config_path) as f: |
| cfg = json.load(f) |
| max_pos = cfg.get("max_position_embeddings", 8192) |
| return min(int(max_pos), 8192) |
| return 8192 |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model_path", type=str, required=True) |
| parser.add_argument("--val_path", type=str, |
| default="/mnt/data-a808/R26112/datasets/0318_vln_waypoint_val.jsonl") |
| parser.add_argument("--max_samples", type=int, default=None) |
| parser.add_argument("--output_dir", type=str, default=None) |
| parser.add_argument("--tensor_parallel_size", type=int, default=1) |
| parser.add_argument("--batch_size", type=int, default=64, |
| help="Number of requests per vLLM batch call") |
| parser.add_argument("--gpu_memory_utilization", type=float, default=0.85) |
| parser.add_argument("--max_model_len", type=int, default=None) |
| parser.add_argument("--save_raw", action="store_true", |
| help="If set, also save per-sample raw errors to " |
| "raw_errors_<model_name>.json (enables strict offline analysis).") |
| args = parser.parse_args() |
|
|
| model_name = os.path.basename(args.model_path.rstrip("/")) |
| if args.output_dir is None: |
| args.output_dir = os.path.dirname(args.model_path.rstrip("/")) |
|
|
| val_data = load_val_data(args.val_path) |
| if args.max_samples and args.max_samples < len(val_data): |
| val_data = val_data[:args.max_samples] |
|
|
| from vllm import LLM, SamplingParams |
|
|
| max_model_len = args.max_model_len or get_max_model_len(args.model_path) |
|
|
| logger.info(f"Loading model with vLLM: {args.model_path}") |
| logger.info(f" tensor_parallel_size={args.tensor_parallel_size}") |
| logger.info(f" gpu_memory_utilization={args.gpu_memory_utilization}") |
| logger.info(f" max_model_len={max_model_len}") |
|
|
| llm = LLM( |
| model=args.model_path, |
| trust_remote_code=True, |
| tensor_parallel_size=args.tensor_parallel_size, |
| gpu_memory_utilization=args.gpu_memory_utilization, |
| max_model_len=max_model_len, |
| limit_mm_per_prompt={"image": 5}, |
| allowed_local_media_path="/", |
| ) |
|
|
| sampling_params = SamplingParams( |
| temperature=0, |
| max_tokens=512, |
| ) |
|
|
| logger.info("Validating samples (parsing ground-truth only; images opened lazily per batch)...") |
| valid_items = [] |
|
|
| for idx, item in enumerate(val_data): |
| gt_text = [m for m in item["messages"] if m["role"] == "assistant"][0]["content"] |
| gt_wp = parse_waypoints(gt_text) |
| if gt_wp is None: |
| logger.warning(f"Sample {idx}: cannot parse ground truth, skipping") |
| continue |
| valid_items.append((idx, item, gt_wp)) |
|
|
| logger.info(f"Valid samples: {len(valid_items)}/{len(val_data)}") |
|
|
| total = len(val_data) |
| all_errors = [] |
| parse_failures = 0 |
|
|
| import gc |
| for batch_start in range(0, len(valid_items), args.batch_size): |
| batch_end = min(batch_start + args.batch_size, len(valid_items)) |
| batch_items = valid_items[batch_start:batch_end] |
| |
| batch_msgs = [build_vllm_inputs(it[1]) for it in batch_items] |
| batch_gt = [it[2] for it in batch_items] |
| valid_indices = [it[0] for it in batch_items] |
|
|
| logger.info(f"Running vLLM batch [{batch_start+1}-{batch_end}/{len(valid_items)}]...") |
| t0 = time.time() |
|
|
| outputs = llm.chat( |
| messages=batch_msgs, |
| sampling_params=sampling_params, |
| chat_template_kwargs={"enable_thinking": False}, |
| ) |
|
|
| elapsed = time.time() - t0 |
| logger.info(f" Batch done in {elapsed:.1f}s ({len(batch_msgs)/elapsed:.1f} samples/s)") |
|
|
| for i, output in enumerate(outputs): |
| generated = output.outputs[0].text |
| pred_wp = parse_waypoints(generated) |
|
|
| if pred_wp is None: |
| parse_failures += 1 |
| sample_idx = valid_indices[i] |
| if parse_failures <= 5 or parse_failures % 50 == 0: |
| logger.warning(f"Sample {sample_idx}: parse failure. Output: {generated[:200]}") |
| continue |
|
|
| gt_wp = batch_gt[i] |
| n_wp = min(len(gt_wp), len(pred_wp)) |
| sample_errors = {dim: [] for dim in DIMS} |
| for wi in range(n_wp): |
| for dim in DIMS: |
| err = abs(pred_wp[wi][dim] - gt_wp[wi][dim]) |
| sample_errors[dim].append(err) |
| all_errors.append(sample_errors) |
|
|
| del batch_msgs |
| gc.collect() |
|
|
| cur_total_processed = batch_end |
| if all_errors: |
| cur_mae = {} |
| for dim in DIMS: |
| vals = [e for s in all_errors for e in s[dim]] |
| cur_mae[dim] = np.mean(vals) if vals else 0 |
| avg = np.mean(list(cur_mae.values())) |
| logger.info( |
| f" Progress [{cur_total_processed}/{len(valid_items)}] " |
| f"MAE: {avg:.4f} | parse_fail={parse_failures}" |
| ) |
|
|
| results = compute_metrics(all_errors, parse_failures, total) |
|
|
| elapsed_total = time.time() - t0 |
| results["inference_engine"] = "vllm" |
| results["vllm_version"] = "0.19.0" |
|
|
| print_results(results, model_name) |
|
|
| os.makedirs(args.output_dir, exist_ok=True) |
| out_file = os.path.join(args.output_dir, f"eval_results_{model_name}.json") |
| with open(out_file, "w") as f: |
| json.dump(results, f, indent=2) |
| logger.info(f"Results saved to {out_file}") |
|
|
| if args.save_raw: |
| raw_file = os.path.join(args.output_dir, f"raw_errors_{model_name}.json") |
| |
| raw_payload = { |
| "n_samples": len(all_errors), |
| "parse_failures": parse_failures, |
| "total_samples": total, |
| "dims": DIMS, |
| |
| "errors_per_sample": [ |
| {dim: list(map(float, s[dim])) for dim in DIMS} for s in all_errors |
| ], |
| } |
| with open(raw_file, "w") as f: |
| json.dump(raw_payload, f) |
| logger.info(f"Raw errors saved to {raw_file} ({os.path.getsize(raw_file)/1e6:.2f} MB)") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|