3d_model / scripts /experiments /test_api_simple.py
Azan
Clean deployment build (Squashed)
7a87926
#!/usr/bin/env python3
"""
Simple API test script with detailed logging.
Tests YLFF API endpoints without complex dependencies.
"""
import argparse
import json
import logging
import sys
import time
from datetime import datetime
from pathlib import Path
from typing import Any, Dict
import requests
# Setup logging to stdout
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%H:%M:%S",
stream=sys.stdout,
force=True,
)
logger = logging.getLogger(__name__)
def test_endpoint(base_url: str, method: str, endpoint: str, **kwargs) -> Dict[str, Any]:
"""Test a single endpoint."""
url = f"{base_url.rstrip('/')}{endpoint}"
# Set default timeout to 300 seconds for long-running operations
timeout = kwargs.pop("timeout", 300)
logger.info(f"→ {method} {url}")
try:
start_time = time.time()
response = requests.request(method, url, timeout=timeout, **kwargs)
duration = time.time() - start_time
logger.info(f"← {response.status_code} ({duration:.3f}s)")
try:
data = response.json() if response.content else None
except json.JSONDecodeError:
data = response.text
return {
"status_code": response.status_code,
"data": data,
"duration": duration,
"success": 200 <= response.status_code < 300,
}
except requests.exceptions.RequestException as e:
logger.error(f"✗ Request failed: {e}")
return {"status_code": None, "error": str(e), "success": False}
def main():
parser = argparse.ArgumentParser(description="Test YLFF API endpoints")
parser.add_argument("--base-url", default="http://localhost:8000", help="Base URL")
parser.add_argument("--arkit-dir", type=str, help="ARKit directory")
parser.add_argument("--sequence-dir", type=str, help="Sequence directory")
args = parser.parse_args()
logger.info("=" * 80)
logger.info("YLFF API Test")
logger.info("=" * 80)
logger.info(f"Base URL: {args.base_url}")
logger.info("")
results = []
# Test 1: Health
logger.info("[1/11] Testing /health")
result = test_endpoint(args.base_url, "GET", "/health")
results.append(("GET /health", result))
logger.info("")
# Test 2: Root
logger.info("[2/11] Testing /")
result = test_endpoint(args.base_url, "GET", "/")
results.append(("GET /", result))
logger.info("")
# Test 3: Models
logger.info("[3/11] Testing /models")
result = test_endpoint(args.base_url, "GET", "/models")
results.append(("GET /models", result))
logger.info("")
# Test 4: List Jobs
logger.info("[4/11] Testing /api/v1/jobs")
result = test_endpoint(args.base_url, "GET", "/api/v1/jobs")
results.append(("GET /api/v1/jobs", result))
logger.info("")
# Test 5: Profiling Metrics
logger.info("[5/11] Testing /api/v1/profiling/metrics")
result = test_endpoint(args.base_url, "GET", "/api/v1/profiling/metrics")
results.append(("GET /api/v1/profiling/metrics", result))
logger.info("")
# Test 6: Hot Paths
logger.info("[6/11] Testing /api/v1/profiling/hot-paths")
result = test_endpoint(args.base_url, "GET", "/api/v1/profiling/hot-paths")
results.append(("GET /api/v1/profiling/hot-paths", result))
logger.info("")
# Test 7: Latency
logger.info("[7/11] Testing /api/v1/profiling/latency")
result = test_endpoint(args.base_url, "GET", "/api/v1/profiling/latency")
results.append(("GET /api/v1/profiling/latency", result))
logger.info("")
# Test 8: System Metrics
logger.info("[8/11] Testing /api/v1/profiling/system")
result = test_endpoint(args.base_url, "GET", "/api/v1/profiling/system")
results.append(("GET /api/v1/profiling/system", result))
logger.info("")
# Test 9: Validate Sequence (if provided)
if args.sequence_dir:
logger.info(f"[9/11] Testing /api/v1/validate/sequence (dir: {args.sequence_dir})")
payload = {
"sequence_dir": args.sequence_dir,
"use_case": "ba_validation",
"accept_threshold": 2.0,
"reject_threshold": 30.0,
}
result = test_endpoint(args.base_url, "POST", "/api/v1/validate/sequence", json=payload)
results.append(("POST /api/v1/validate/sequence", result))
if result.get("success") and result.get("data"):
job_id = result["data"].get("job_id")
logger.info(f" Job ID: {job_id}")
logger.info("")
else:
logger.info("[9/11] Skipping /api/v1/validate/sequence (no sequence_dir)")
logger.info("")
# Test 10: Validate ARKit (if provided)
if args.arkit_dir:
logger.info(f"[10/11] Testing /api/v1/validate/arkit (dir: {args.arkit_dir})")
payload = {
"arkit_dir": args.arkit_dir,
"output_dir": "data/test_arkit_output",
"max_frames": 10,
"frame_interval": 1,
"device": "cuda",
"gui": False,
}
result = test_endpoint(args.base_url, "POST", "/api/v1/validate/arkit", json=payload)
results.append(("POST /api/v1/validate/arkit", result))
if result.get("success") and result.get("data"):
job_id = result["data"].get("job_id")
logger.info(f" Job ID: {job_id}")
logger.info("")
else:
logger.info("[10/11] Skipping /api/v1/validate/arkit (no arkit_dir)")
logger.info("")
# Test 11: Check job status and poll for completion
logger.info("[11/11] Polling job status until completion")
job_ids = []
for endpoint, result in results:
if result.get("success") and result.get("data"):
job_id = result["data"].get("job_id")
if job_id:
job_ids.append(job_id)
if job_ids:
logger.info(f" Found {len(job_ids)} job(s) to monitor")
for job_id in job_ids:
logger.info(f" Monitoring job: {job_id}")
max_polls = 60 # Poll for up to 5 minutes (5s intervals)
poll_interval = 5
for poll_num in range(max_polls):
result = test_endpoint(args.base_url, "GET", f"/api/v1/jobs/{job_id}")
if result.get("success") and result.get("data"):
status = result["data"].get("status", "unknown")
message = result["data"].get("message", "")
logger.info(
f" Poll {poll_num + 1}/{max_polls}: Status={status}, "
f"Message={message[:60]}"
)
if status in ["completed", "failed"]:
logger.info(f" Job {status}!")
if status == "completed":
job_result = result["data"].get("result", {})
if job_result:
logger.info(f" Result keys: {list(job_result.keys())}")
# Try to get validation statistics from result
# or fetch from endpoint
validation_stats = job_result.get("validation_stats", {})
# If not in result, try fetching from validation results endpoint
if not validation_stats:
logger.info(
" Fetching validation statistics from "
"results endpoint..."
)
stats_result = test_endpoint(
args.base_url,
"GET",
f"/api/v1/validation/results/{job_id}",
)
if stats_result.get("success") and stats_result.get("data"):
validation_stats = stats_result["data"].get(
"validation_stats", {}
)
# If endpoint doesn't exist yet (404), try to
# calculate from local results if available
if (
not validation_stats
and stats_result.get("status_code") == 404
):
logger.info(
" Results endpoint not available, "
"checking local validation results..."
)
# Try common output directories
common_dirs = [
"data/test_arkit_output",
"data/arkit_ba_validation",
"data/arkit_validation",
]
for output_dir in common_dirs:
results_file = (
Path(output_dir) / "validation_results.json"
)
if results_file.exists():
try:
with open(results_file) as f:
val_data = json.load(f)
# Calculate stats from rotation errors
if "da3_vs_arkit" in val_data:
rot_errors = val_data["da3_vs_arkit"].get(
"rotation_errors_deg", []
)
if rot_errors:
accepted = sum(
1 for e in rot_errors if e < 2.0
)
learnable = sum(
1
for e in rot_errors
if 2.0 <= e < 30.0
)
outlier = sum(
1 for e in rot_errors if e >= 30.0
)
total = len(rot_errors)
validation_stats = {
"total_frames": total,
"accepted": accepted,
"rejected_learnable": learnable,
"rejected_outlier": outlier,
"accepted_percentage": 100.0
* accepted
/ total,
"rejected_learnable_percentage": (
100.0 * learnable / total
),
"rejected_outlier_percentage": 100.0
* outlier
/ total,
}
if "ba_result" in val_data:
validation_stats["ba_status"] = (
val_data["ba_result"].get(
"status"
)
)
validation_stats[
"max_error_deg"
] = val_data["ba_result"].get(
"error"
)
logger.info(
f" Found validation results at: "
f"{results_file}"
)
break
except Exception as e:
logger.warning(
f" Could not read {results_file}: {e}"
)
if validation_stats:
logger.info("")
logger.info(" === BA Validation Statistics ===")
total = validation_stats.get("total_frames", 0)
accepted = validation_stats.get("accepted", 0)
rejected_learnable = validation_stats.get(
"rejected_learnable", 0
)
rejected_outlier = validation_stats.get("rejected_outlier", 0)
logger.info(f" Total Frames Processed: {total}")
logger.info("")
logger.info(" Frame Categorization:")
accepted_pct = validation_stats.get("accepted_percentage", 0)
learnable_pct = validation_stats.get(
"rejected_learnable_percentage", 0
)
outlier_pct = validation_stats.get(
"rejected_outlier_percentage", 0
)
logger.info(
f" ✓ Accepted (< 2°): "
f"{accepted:3d} frames ({accepted_pct:5.1f}%)"
)
logger.info(
f" ⚠ Rejected-Learnable (2-30°): "
f"{rejected_learnable:3d} frames "
f"({learnable_pct:5.1f}%)"
)
logger.info(
f" ✗ Rejected-Outlier (> 30°): "
f"{rejected_outlier:3d} frames "
f"({outlier_pct:5.1f}%)"
)
logger.info("")
total_rejected = rejected_learnable + rejected_outlier
rejected_pct = (
100.0 * total_rejected / total if total > 0 else 0
)
logger.info(
f" Total Rejected: {total_rejected} frames "
f"({rejected_pct:.1f}%)"
)
logger.info("")
if validation_stats.get("ba_status"):
logger.info(
f" BA Validation Status: "
f"{validation_stats['ba_status']}"
)
if validation_stats.get("max_error_deg"):
max_error = validation_stats["max_error_deg"]
logger.info(f" Max Rotation Error: {max_error:.2f}°")
# Show diagnostics if available
if "diagnostics" in validation_stats:
diag = validation_stats["diagnostics"]
logger.info("")
logger.info(" === Detailed Diagnostics ===")
if "error_distribution" in diag:
err_dist = diag["error_distribution"]
if "rotation_errors_deg" in err_dist:
rot_dist = err_dist["rotation_errors_deg"]
logger.info(" Rotation Error Distribution:")
logger.info(
f" Q1 (25th): {rot_dist.get('q1', 0):.2f}°"
)
median = rot_dist.get("median", 0)
logger.info(f" Median: {median:.2f}°")
logger.info(
f" Q3 (75th): {rot_dist.get('q3', 0):.2f}°"
)
logger.info(
f" 90th: {rot_dist.get('p90', 0):.2f}°"
)
logger.info(
f" 95th: {rot_dist.get('p95', 0):.2f}°"
)
if "alignment_info" in diag.get("da3_vs_arkit", {}):
align = diag["da3_vs_arkit"]["alignment_info"]
logger.info("")
logger.info(" Alignment Diagnostics:")
scale_factor = align.get("scale_factor", 0)
rotation_det = align.get("rotation_det", 0)
logger.info(
f" Scale factor: {scale_factor:.6f} "
f"(should be ~1.0)"
)
logger.info(
f" Rotation det: {rotation_det:.6f} "
f"(should be ~1.0)"
)
if (
"per_frame_errors" in diag
and len(diag["per_frame_errors"]) > 0
):
logger.info("")
logger.info(" Sample Frame Errors (first 5):")
for frame_err in diag["per_frame_errors"][:5]:
frame_idx = frame_err["frame_idx"]
rot_err = frame_err["rotation_error_deg"]
trans_err = frame_err["translation_error_m"]
category = frame_err["category"]
logger.info(
f" Frame {frame_idx}: "
f"{rot_err:.2f}° rot, "
f"{trans_err:.3f}m trans - "
f"{category}"
)
logger.info("")
break
if poll_num < max_polls - 1:
time.sleep(poll_interval)
else:
logger.warning(f" Failed to get job status: {result}")
break
results.append((f"GET /api/v1/jobs/{job_id} (final)", result))
logger.info("")
else:
logger.info(" No job IDs available to check")
logger.info("")
# Test 12: Get updated profiling metrics after jobs run
logger.info("[12/12] Getting profiling metrics after job execution")
result = test_endpoint(args.base_url, "GET", "/api/v1/profiling/metrics")
results.append(("GET /api/v1/profiling/metrics (post-exec)", result))
if result.get("success") and result.get("data"):
metrics = result["data"]
logger.info(f" Total entries: {metrics.get('total_entries', 0)}")
logger.info(f" Stages tracked: {len(metrics.get('stage_stats', {}))}")
if metrics.get("hot_paths"):
logger.info(" Top 5 hot paths:")
for i, path in enumerate(metrics["hot_paths"][:5], 1):
logger.info(
f" {i}. {path.get('function')}: {path.get('total_time', 0):.3f}s "
f"({path.get('call_count', 0)} calls)"
)
logger.info("")
# Test 13: Get latency breakdown
logger.info("[13/13] Getting latency breakdown")
result = test_endpoint(args.base_url, "GET", "/api/v1/profiling/latency")
results.append(("GET /api/v1/profiling/latency (post-exec)", result))
if result.get("success") and result.get("data"):
latency = result["data"]
total = latency.get("total_time", 0)
breakdown = latency.get("breakdown", {})
logger.info(f" Total time: {total:.3f}s")
logger.info(" Breakdown by stage:")
for stage, stats in sorted(
breakdown.items(), key=lambda x: x[1].get("total_time", 0), reverse=True
)[:10]:
pct = stats.get("percentage", 0)
avg = stats.get("avg_time", 0)
calls = stats.get("call_count", 0)
logger.info(
f" {stage:30s} {stats.get('total_time', 0):8.3f}s ({pct:5.1f}%) "
f"avg: {avg:.3f}s, calls: {calls}"
)
logger.info("")
# Summary
logger.info("=" * 80)
logger.info("Test Summary")
logger.info("=" * 80)
success_count = sum(1 for _, r in results if r.get("success"))
total_count = len(results)
logger.info(f"Success: {success_count}/{total_count}")
logger.info("")
logger.info("Endpoint Results:")
for endpoint, result in results:
status = "✓" if result.get("success") else "✗"
status_code = result.get("status_code", "N/A")
duration = result.get("duration", 0)
status_code_str = str(status_code) if status_code is not None else "N/A"
logger.info(f"{status} {endpoint:50s} {status_code_str:>3} ({duration:.3f}s)")
# Save results to JSON
output_file = Path("data/api_test_results.json")
output_file.parent.mkdir(parents=True, exist_ok=True)
output_data = {
"timestamp": datetime.now().isoformat(),
"base_url": args.base_url,
"summary": {
"total_tests": total_count,
"successful": success_count,
"failed": total_count - success_count,
},
"results": [
{
"endpoint": endpoint,
"status_code": r.get("status_code"),
"success": r.get("success"),
"duration": r.get("duration"),
"data": r.get("data") if r.get("success") else None,
"error": r.get("error") if not r.get("success") else None,
}
for endpoint, r in results
],
}
with open(output_file, "w") as f:
json.dump(output_data, f, indent=2, default=str)
logger.info("")
logger.info(f"Results saved to: {output_file}")
return 0 if success_count == total_count else 1
if __name__ == "__main__":
sys.exit(main())