AI-Coach / scripts /run_pushup_eval_tests.py
Hoang Duc Hung
feat: stabilize push-up feedback and VLM handling
350d731
#!/usr/bin/env python3
from __future__ import annotations
import argparse
from collections import Counter
from datetime import datetime
import json
import os
from pathlib import Path
import sys
import time
from typing import Any
ROOT = Path(__file__).resolve().parent.parent
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from push_up.analysis_service import TEMPLATE_SOURCE, analyze_pushup, prepare_template_cache
from push_up.processor import VideoProcessor
VIDEO_EXTENSIONS = {".mp4", ".mov", ".m4v", ".avi", ".webm"}
DEFAULT_TESTS_DIR = ROOT / "data" / "tests"
DEFAULT_OUTPUT_DIR = ROOT / "analysis_artifacts" / "video_test_runs"
EXPECTED_REJECTION_LABELS = {"vo_teakwondo"}
def parse_args() -> argparse.Namespace:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
parser = argparse.ArgumentParser(
description="Run push-up analysis against every video in data/tests and write one JSON result file."
)
parser.add_argument(
"--tests-dir",
type=Path,
default=DEFAULT_TESTS_DIR,
help="Directory containing test videos. Default: data/tests",
)
parser.add_argument(
"--output",
type=Path,
default=DEFAULT_OUTPUT_DIR / f"pushup_eval_{timestamp}.json",
help="Single JSON file to write. Default: analysis_artifacts/video_test_runs/pushup_eval_<timestamp>.json",
)
parser.add_argument(
"--artifact-root",
type=Path,
default=DEFAULT_OUTPUT_DIR / f"artifacts_{timestamp}",
help="Root directory for annotated images when --save-artifacts is used.",
)
parser.add_argument(
"--save-artifacts",
action="store_true",
help="Save per-rep annotated student/expert frames. Also enables deterministic rule-based arrows.",
)
parser.add_argument(
"--enable-vlm",
action="store_true",
help="Allow NVIDIA VLM calls for per-rep text feedback. By default VLM is disabled for repeatable tests.",
)
parser.add_argument(
"--include-template",
action="store_true",
help="Also test template video against itself.",
)
return parser.parse_args()
def main() -> int:
args = parse_args()
tests_dir = (ROOT / args.tests_dir).resolve() if not args.tests_dir.is_absolute() else args.tests_dir
output_path = (ROOT / args.output).resolve() if not args.output.is_absolute() else args.output
artifact_root = (
(ROOT / args.artifact_root).resolve() if not args.artifact_root.is_absolute() else args.artifact_root
)
if not tests_dir.exists():
print(f"[error] Test directory does not exist: {tests_dir}")
return 1
videos = discover_videos(tests_dir)
if args.include_template:
videos = [TEMPLATE_SOURCE] + videos
if not videos:
print(f"[error] No test videos found in: {tests_dir}")
return 1
if not args.enable_vlm:
os.environ["NVIDIA_API_KEY"] = "nvapi-..."
print(f"[setup] template={TEMPLATE_SOURCE}")
print(f"[setup] tests_dir={tests_dir}")
print(f"[setup] output={output_path}")
print(f"[setup] save_artifacts={args.save_artifacts}")
print(f"[setup] vlm={'enabled' if args.enable_vlm else 'disabled'}")
prepare_template_cache()
batch_started = time.perf_counter()
results = []
for index, video_path in enumerate(videos, start=1):
label = video_path.stem if video_path != TEMPLATE_SOURCE else "template_vs_template"
print("=" * 88)
print(f"[case {index}/{len(videos)}] {label}")
print(f"[video] {video_path}")
started = time.perf_counter()
orientation = orientation_label(video_path)
try:
result = analyze_pushup(
video_path,
artifact_root if args.save_artifacts else None,
save_artifacts=args.save_artifacts,
)
elapsed = time.perf_counter() - started
entry = compact_result(
label=label,
video_path=video_path,
orientation=orientation,
elapsed_seconds=elapsed,
result=result,
)
except Exception as exc:
elapsed = time.perf_counter() - started
entry = {
"label": label,
"video_path": project_relative_path(video_path),
"orientation": orientation,
"elapsed_seconds": round(elapsed, 2),
"error": f"{type(exc).__name__}: {exc}",
"ok": False,
}
results.append(entry)
print_case_summary(entry)
payload = {
"created_at": datetime.now().isoformat(timespec="seconds"),
"tests_dir": project_relative_path(tests_dir),
"template_video_path": project_relative_path(TEMPLATE_SOURCE),
"save_artifacts": args.save_artifacts,
"artifact_root": project_relative_path(artifact_root) if args.save_artifacts else "",
"vlm_enabled": args.enable_vlm,
"total_videos": len(results),
"passed_videos": sum(1 for item in results if item.get("ok")),
"failed_videos": sum(1 for item in results if not item.get("ok")),
"elapsed_seconds": round(time.perf_counter() - batch_started, 2),
"error_distribution": error_distribution(results),
"results": results,
}
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("w", encoding="utf-8") as file:
json.dump(payload, file, ensure_ascii=False, indent=2)
print("=" * 88)
print(f"[done] wrote {output_path}")
print(f"[done] passed={payload['passed_videos']} failed={payload['failed_videos']}")
return 0 if payload["failed_videos"] == 0 else 1
def discover_videos(tests_dir: Path) -> list[Path]:
return sorted(
path
for path in tests_dir.rglob("*")
if path.is_file() and path.suffix.lower() in VIDEO_EXTENSIONS
)
def orientation_label(video_path: Path) -> str:
try:
processor = VideoProcessor()
needs_flip = processor._detect_orientation(str(video_path))
except Exception as exc:
return f"unknown: {type(exc).__name__}: {exc}"
return "head-left -> flipped" if needs_flip else "head-right/no-flip"
def compact_result(
*,
label: str,
video_path: Path,
orientation: str,
elapsed_seconds: float,
result: dict[str, Any],
) -> dict[str, Any]:
if result.get("error"):
expected_rejection = label in EXPECTED_REJECTION_LABELS
return {
"label": label,
"video_path": project_relative_path(video_path),
"orientation": orientation,
"elapsed_seconds": round(elapsed_seconds, 2),
"ok": expected_rejection,
"expected_rejection": expected_rejection,
"error": result["error"],
}
return {
"label": label,
"video_path": project_relative_path(video_path),
"orientation": orientation,
"elapsed_seconds": round(elapsed_seconds, 2),
"ok": True,
"expected_rejection": False,
"error": None,
"overall_score_pct": result.get("overall_score_pct"),
"student_reps": result.get("student_reps"),
"expert_reps": result.get("expert_reps"),
"good_reps": result.get("good_reps"),
"serious_reps": result.get("serious_reps"),
"summary": result.get("summary", ""),
"main_errors": result.get("main_errors", []),
"student_video_path": result.get("student_video_path", ""),
"rep_results": [compact_rep(rep) for rep in result.get("rep_results", [])],
}
def compact_rep(rep: dict[str, Any]) -> dict[str, Any]:
return {
"rep_num": rep.get("rep_num"),
"score_pct": rep.get("score_pct"),
"rule_score_pct": rep.get("rule_score_pct"),
"dtw_score_pct": rep.get("dtw_score_pct"),
"status": rep.get("status"),
"primary_error": rep.get("primary_error"),
"error_labels": rep.get("error_labels", []),
"rule_feedback": rep.get("rule_feedback") or rep.get("feedback", ""),
"llm_feedback": rep.get("llm_feedback", ""),
"llm_feedback_source": rep.get("llm_feedback_source", ""),
"llm_feedback_error": rep.get("llm_feedback_error", ""),
"llm_visual_error_label": rep.get("llm_visual_error_label", ""),
"llm_arrow": rep.get("llm_arrow"),
"student_frame_path": rep.get("student_frame_path", ""),
"expert_frame_path": rep.get("expert_frame_path", ""),
}
def print_case_summary(entry: dict[str, Any]) -> None:
if not entry.get("ok"):
print(f"[result] ERROR: {entry.get('error')}")
print(f"[time] {entry.get('elapsed_seconds')}s")
return
if entry.get("expected_rejection"):
print(f"[result] EXPECTED_REJECTION: {entry.get('error')}")
print(f"[time] {entry.get('elapsed_seconds')}s")
return
print(
"[result] "
f"overall={entry.get('overall_score_pct')}%, "
f"student_reps={entry.get('student_reps')}, "
f"expert_reps={entry.get('expert_reps')}, "
f"good_reps={entry.get('good_reps')}, "
f"serious_reps={entry.get('serious_reps')}"
)
print(f"[summary] {entry.get('summary', '')}")
for rep in entry.get("rep_results", []):
errors = ", ".join(rep.get("error_labels") or []) or "none"
print(
" "
f"rep={int(rep['rep_num']):02d} "
f"score={rep.get('score_pct')}% "
f"status={rep.get('status')} "
f"errors={errors}"
)
print(f"[time] {entry.get('elapsed_seconds')}s")
def error_distribution(results: list[dict[str, Any]]) -> dict[str, int]:
counter: Counter[str] = Counter()
for result in results:
for error in result.get("main_errors", []):
label = error.get("label") or error.get("type") or "unknown"
counter[label] += int(error.get("count") or 0)
return dict(counter.most_common())
def project_relative_path(path: Path) -> str:
try:
return path.resolve().relative_to(ROOT).as_posix()
except ValueError:
return path.resolve().as_posix()
if __name__ == "__main__":
raise SystemExit(main())