| |
| """Quick sanity check for VLAC trajectory values on the toy dataset. |
| |
| The dataset is produced by ``testing/prepare_vlac_test_data.py`` in |
| ``task_progress`` mode. Each entry already includes the image paths (relative |
| ``images/``) together with ground-truth progress numbers in ``[0, 1]``. |
| |
| This script keeps things intentionally small and prints a short report for a set |
| of frame/reference configurations (e.g., 4×4, 4×8, 8×4, 8×8) so we can inspect |
| MAE, final-frame accuracy, and latency versus sequence length. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import base64 |
| import itertools |
| import json |
| import sys |
| import time |
| from tqdm import tqdm |
| from io import BytesIO |
| from pathlib import Path |
| from typing import Dict, Iterable, List, Optional, Sequence |
|
|
| import numpy as np |
| import requests |
| from PIL import Image |
|
|
| |
| |
| |
|
|
|
|
| def read_manifest(dataset_dir: Path, json_name: str) -> List[Dict]: |
| manifest_path = dataset_dir / json_name |
| images_dir = dataset_dir / "images" |
| if not manifest_path.is_file(): |
| raise FileNotFoundError(f"Metadata JSON not found: {manifest_path}") |
| if not images_dir.is_dir(): |
| raise FileNotFoundError(f"Images directory not found: {images_dir}") |
|
|
| with manifest_path.open("r", encoding="utf-8") as f: |
| raw_entries = json.load(f) |
|
|
| entries: List[Dict] = [] |
| for entry in raw_entries: |
| frames = entry.get("frames") or [] |
| if not frames: |
| continue |
| for frame in frames: |
| frame["abs_path"] = str(images_dir / frame["path"]) |
| entry["reference"] = [str(images_dir / rel) for rel in entry.get("reference", [])] |
| entries.append(entry) |
| return entries |
|
|
|
|
| def image_to_base64(path: Path) -> str: |
| with Image.open(path) as img: |
| img = img.convert("RGB") |
| buffer = BytesIO() |
| img.save(buffer, format="JPEG", quality=95) |
| return base64.b64encode(buffer.getvalue()).decode("utf-8") |
|
|
|
|
| def encode_images(paths: Iterable[str]) -> List[str]: |
| return [image_to_base64(Path(p)) for p in paths] |
|
|
|
|
| def sample_fixed_interval_frames(image_list, num_frames): |
| |
| |
| if len(image_list) == 0: |
| raise ValueError("image_list is empty") |
| elif len(image_list) == 1: |
| return [image_list[0]] * num_frames |
| elif num_frames == 2: |
| return [image_list[0]] * (num_frames//2) + [image_list[-1]] * (num_frames//2) |
| elif num_frames == 3: |
| return [image_list[0]] + [image_list[1]] * (num_frames-2) + [image_list[-1]] |
| else: |
| total_frames = len(image_list) |
| indices = np.linspace(start=0, stop=total_frames - 1, num=num_frames, dtype=int) |
| sampled_frames = [image_list[i] for i in indices] |
| return sampled_frames |
|
|
|
|
| def call_trajectory_critic( |
| session: requests.Session, |
| base_url: str, |
| task: str, |
| frames_b64: List[str], |
| reference_b64: Optional[List[str]], |
| timeout: float, |
| ) -> Dict: |
| payload = { |
| "task": task, |
| "frames": frames_b64, |
| "reference": reference_b64, |
| "ref_num": len(reference_b64 or []), |
| "skip": 1, |
| "batch_size": min(len(frames_b64), 8), |
| "think": False, |
| "return_video": False, |
| } |
| start = time.time() |
| resp = session.post(f"{base_url.rstrip('/')}/trajectory-critic", json=payload, timeout=timeout) |
| resp.raise_for_status() |
| result = resp.json() |
| result["latency_sec"] = time.time() - start |
| return result |
|
|
|
|
| |
| |
| |
|
|
|
|
| def evaluate_combo( |
| manifest: Sequence[Dict], |
| base_url: str, |
| timeout: float, |
| frame_limit: int, |
| ref_limit: int, |
| done_threshold_list: list, |
| ) -> Dict[str, float]: |
| session = requests.Session() |
| mae_values: List[float] = [] |
| latencies: List[float] = [] |
| total_frames = 0 |
| pred_last_value_list = [] |
| pred_mid_value_list = [] |
|
|
| for entry in tqdm(manifest): |
| frames = entry["frames"] |
| if len(frames) <= frame_limit: |
| selected_frames = frames |
| else: |
| selected_frames = sample_fixed_interval_frames(frames, frame_limit) |
| selected_frames_paths = [frame["abs_path"] for frame in selected_frames] |
| frames_b64 = encode_images(selected_frames_paths) |
|
|
| reference_paths = entry["reference"] |
| if len(reference_paths) <= ref_limit: |
| selected_reference_paths = reference_paths |
| else: |
| selected_reference_paths = sample_fixed_interval_frames(reference_paths, ref_limit) |
| reference_b64 = encode_images(selected_reference_paths) |
|
|
| gt = np.array([frame["progress"] for frame in selected_frames], dtype=np.float32) |
|
|
| try: |
| result = call_trajectory_critic( |
| session=session, |
| base_url=base_url, |
| task=entry.get("task", ""), |
| frames_b64=frames_b64, |
| reference_b64=reference_b64, |
| timeout=timeout, |
| ) |
| except requests.RequestException as exc: |
| print(f"[warn] request failed for demo {entry.get('demo_id')}: {exc}") |
| continue |
|
|
| preds = np.array(result.get("value_list", []), dtype=np.float32) |
| if preds.size == 0: |
| continue |
|
|
| |
| mid_idx = -2 |
|
|
| pred_last_value_list.append(preds[-1]) |
| pred_mid_value_list.append(preds[mid_idx]) |
|
|
| mae_values.append(float(np.mean(np.abs(preds[-1] - gt[-1])))) |
| latencies.append(result.get("latency_sec", 0.0)) |
| total_frames += len(preds) |
|
|
| accuracy_with_different_thresholds = {} |
| for done_threshold in done_threshold_list: |
| tp = fp = tn = fn = 0 |
| for pred_last, pred_mid in zip(pred_last_value_list, pred_mid_value_list): |
| |
| if pred_last >= done_threshold: |
| tp += 1 |
| else: |
| fn += 1 |
|
|
| if pred_mid >= done_threshold: |
| fp += 1 |
| else: |
| tn += 1 |
|
|
| total = tp + fp + fn + tn |
| precision = tp / (tp + fp) if (tp + fp) else float("nan") |
| recall = tp / (tp + fn) if (tp + fn) else float("nan") |
| if any(np.isnan(value) for value in (precision, recall)) or (precision + recall) == 0: |
| f1 = float("nan") |
| else: |
| f1 = 2 * precision * recall / (precision + recall) |
| accuracy = (tp + tn) / total if total else float("nan") |
|
|
| accuracy_with_different_thresholds[done_threshold] = { |
| "accuracy": accuracy, |
| "precision": precision, |
| "recall": recall, |
| "f1": f1, |
| } |
|
|
| if not mae_values: |
| return { |
| "mae": float("nan"), |
| "frames": 0, |
| "latency": float("nan"), |
| "final_accuracy": {}, |
| } |
|
|
| return { |
| "mae": float(np.mean(mae_values)), |
| "frames": total_frames, |
| "latency": float(np.mean(latencies)) if latencies else float("nan"), |
| "final_accuracy": accuracy_with_different_thresholds, |
| } |
|
|
|
|
| |
| |
| |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="VLAC trajectory sanity check") |
| parser.add_argument("--dataset-dir", required=True, help="Directory containing images/ and dataset JSON") |
| parser.add_argument("--json-name", default="dataset_frame_progress.json", help="Manifest filename") |
| parser.add_argument("--base-url", default="http://localhost:8111", help="VLAC service base URL") |
| parser.add_argument("--timeout", type=float, default=30.0, help="HTTP timeout in seconds") |
| parser.add_argument("--max-demos", type=int, default=None, help="Evaluate only the first N demos") |
| parser.add_argument( |
| "--frame-counts", |
| type=int, |
| nargs="+", |
| default=[4, 8], |
| help="Number of trajectory frames to feed per call (default: 4 8)", |
| ) |
| parser.add_argument( |
| "--ref-counts", |
| type=int, |
| nargs="+", |
| default=[4, 8], |
| help="Number of reference frames to feed per call (default: 4 8)", |
| ) |
| parser.add_argument( |
| "--done-threshold-list", |
| type=list, |
| default=[50, 55, 60, 65, 70, 75, 80, 85, 90, 95], |
| help="Threshold on progress for final-frame accuracy (default: 0.9)", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def main() -> int: |
| args = parse_args() |
| dataset_dir = Path(args.dataset_dir) |
| try: |
| manifest = read_manifest(dataset_dir, args.json_name) |
| except FileNotFoundError as exc: |
| print(exc) |
| return 1 |
|
|
| if args.max_demos is not None: |
| manifest = manifest[: args.max_demos] |
|
|
| if not manifest: |
| print("No demos found in the manifest. Regenerate the dataset with testing/prepare_vlac_test_data.py") |
| return 1 |
|
|
| frame_counts = sorted(set(fc for fc in args.frame_counts if fc > 0)) |
| ref_counts = sorted(set(rc for rc in args.ref_counts if rc > 0)) |
| if not frame_counts or not ref_counts: |
| print("Provide positive frame/reference counts.") |
| return 1 |
|
|
| print(f"Loaded {len(manifest)} demos from {dataset_dir}") |
|
|
| results: Dict[tuple, Dict[str, float]] = {} |
| print("Threshold: ", args.done_threshold_list) |
| for frame_limit, ref_limit in itertools.product(frame_counts, ref_counts): |
| metrics = evaluate_combo( |
| manifest=manifest, |
| base_url=args.base_url, |
| timeout=args.timeout, |
| frame_limit=frame_limit, |
| ref_limit=ref_limit, |
| done_threshold_list=args.done_threshold_list, |
| ) |
| results[(frame_limit, ref_limit)] = metrics |
|
|
| print("\n=== Results by (frames, reference) ===") |
| for (frame_limit, ref_limit), metrics in sorted(results.items()): |
| mae = metrics["mae"] |
| latency = metrics["latency"] |
| final_acc = metrics["final_accuracy"] |
| print(f"{frame_limit}x{ref_limit}") |
| for threshold, stats in final_acc.items(): |
| acc = stats["accuracy"] |
| precision = stats["precision"] |
| recall = stats["recall"] |
| f1 = stats["f1"] |
| print( |
| f"threshold {threshold}: " |
| f"accuracy {acc:.3f}, precision {precision:.3f}, " |
| f"recall {recall:.3f}, f1 {f1:.3f}" |
| ) |
| print() |
| print( |
| f"frames={frame_limit:>2}, ref={ref_limit:>2} -> " |
| f"MAE {mae:.4f}, avg latency {latency:.2f}s, frames used {metrics['frames']}" |
| ) |
| print() |
|
|
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |