3d_model / scripts /test_api.py
Azan
Clean deployment build (Squashed)
7a87926
#!/usr/bin/env python3
"""
Comprehensive API test script for YLFF endpoints.
Tests all API endpoints including:
- Health and system endpoints
- Validation endpoints (sequence, ARKit)
- Training endpoints (fine-tuning, pre-training) with optimization parameters
- Dataset building with optimization parameters
- Job management
- Profiling endpoints
"""
import argparse
import json
import logging
import sys
import time
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
import requests
# Setup logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%H:%M:%S",
stream=sys.stdout,
force=True,
)
logger = logging.getLogger(__name__)
class APITester:
"""API testing utility class."""
def __init__(self, base_url: str, timeout: int = 300):
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self.results: List[tuple[str, Dict[str, Any]]] = []
self.job_ids: List[str] = []
def test_endpoint(
self,
method: str,
endpoint: str,
description: str = "",
**kwargs,
) -> Dict[str, Any]:
"""Test a single endpoint."""
url = f"{self.base_url}{endpoint}"
desc = f" ({description})" if description else ""
logger.info(f"→ {method} {endpoint}{desc}")
try:
start_time = time.time()
response = requests.request(method, url, timeout=self.timeout, **kwargs)
duration = time.time() - start_time
logger.info(f"← {response.status_code} ({duration:.3f}s)")
try:
data = response.json() if response.content else None
except json.JSONDecodeError:
data = response.text
result = {
"status_code": response.status_code,
"data": data,
"duration": duration,
"success": 200 <= response.status_code < 300,
}
# Extract job_id if present
if result.get("success") and data and isinstance(data, dict):
job_id = data.get("job_id")
if job_id:
self.job_ids.append(job_id)
logger.info(f" Job ID: {job_id}")
return result
except requests.exceptions.RequestException as e:
logger.error(f"✗ Request failed: {e}")
return {"status_code": None, "error": str(e), "success": False}
def test_health_endpoints(self):
"""Test health and system endpoints."""
logger.info("\n" + "=" * 80)
logger.info("HEALTH & SYSTEM ENDPOINTS")
logger.info("=" * 80)
# Health check
result = self.test_endpoint("GET", "/health", "Health check")
self.results.append(("GET /health", result))
# Root
result = self.test_endpoint("GET", "/", "Root endpoint")
self.results.append(("GET /", result))
# Models
result = self.test_endpoint("GET", "/api/v1/models", "List models")
self.results.append(("GET /api/v1/models", result))
# Jobs list
result = self.test_endpoint("GET", "/api/v1/jobs", "List jobs")
self.results.append(("GET /api/v1/jobs", result))
def test_profiling_endpoints(self):
"""Test profiling endpoints."""
logger.info("\n" + "=" * 80)
logger.info("PROFILING ENDPOINTS")
logger.info("=" * 80)
endpoints = [
("/api/v1/profiling/metrics", "Profiling metrics"),
("/api/v1/profiling/hot-paths", "Hot paths"),
("/api/v1/profiling/latency", "Latency breakdown"),
("/api/v1/profiling/system", "System metrics"),
]
for endpoint, desc in endpoints:
result = self.test_endpoint("GET", endpoint, desc)
self.results.append((f"GET {endpoint}", result))
def test_validation_endpoints(
self, sequence_dir: Optional[str] = None, arkit_dir: Optional[str] = None
):
"""Test validation endpoints."""
logger.info("\n" + "=" * 80)
logger.info("VALIDATION ENDPOINTS")
logger.info("=" * 80)
# Validate sequence
if sequence_dir:
payload = {
"sequence_dir": sequence_dir,
"use_case": "ba_validation",
"accept_threshold": 2.0,
"reject_threshold": 30.0,
}
result = self.test_endpoint(
"POST",
"/api/v1/validate/sequence",
f"Validate sequence: {sequence_dir}",
json=payload,
)
self.results.append(("POST /api/v1/validate/sequence", result))
else:
logger.info("Skipping /api/v1/validate/sequence (no sequence_dir)")
# Validate ARKit
if arkit_dir:
payload = {
"arkit_dir": arkit_dir,
"output_dir": "data/test_arkit_output",
"max_frames": 10,
"frame_interval": 1,
"device": "cuda",
"gui": False,
}
result = self.test_endpoint(
"POST",
"/api/v1/validate/arkit",
f"Validate ARKit: {arkit_dir}",
json=payload,
)
self.results.append(("POST /api/v1/validate/arkit", result))
else:
logger.info("Skipping /api/v1/validate/arkit (no arkit_dir)")
def test_dataset_building_endpoints(
self,
sequences_dir: Optional[str] = None,
test_optimizations: bool = True,
):
"""Test dataset building endpoint with optimizations."""
logger.info("\n" + "=" * 80)
logger.info("DATASET BUILDING ENDPOINTS")
logger.info("=" * 80)
if not sequences_dir:
logger.info("Skipping /api/v1/dataset/build (no sequences_dir)")
return
# Test with optimizations
if test_optimizations:
payload = {
"sequences_dir": sequences_dir,
"output_dir": "data/test_training",
"max_samples": 10, # Small for testing
"accept_threshold": 2.0,
"reject_threshold": 30.0,
"use_batched_inference": True,
"inference_batch_size": 4,
"use_inference_cache": True,
"cache_dir": "cache/test_inference",
"compile_model": True,
}
result = self.test_endpoint(
"POST",
"/api/v1/dataset/build",
"Build dataset with optimizations",
json=payload,
)
self.results.append(("POST /api/v1/dataset/build (optimized)", result))
# Test without optimizations (baseline)
payload = {
"sequences_dir": sequences_dir,
"output_dir": "data/test_training_baseline",
"max_samples": 10,
"accept_threshold": 2.0,
"reject_threshold": 30.0,
"use_batched_inference": False,
"use_inference_cache": False,
"compile_model": False,
}
result = self.test_endpoint(
"POST",
"/api/v1/dataset/build",
"Build dataset (baseline)",
json=payload,
)
self.results.append(("POST /api/v1/dataset/build (baseline)", result))
def test_training_endpoints(
self,
training_data_dir: Optional[str] = None,
test_optimizations: bool = True,
):
"""Test training endpoints with optimization parameters."""
logger.info("\n" + "=" * 80)
logger.info("TRAINING ENDPOINTS")
logger.info("=" * 80)
if not training_data_dir:
logger.info("Skipping /api/v1/train/start (no training_data_dir)")
return
# Test fine-tuning with optimizations
if test_optimizations:
payload = {
"training_data_dir": training_data_dir,
"epochs": 1, # Single epoch for testing
"lr": 1e-5,
"batch_size": 1,
"checkpoint_dir": "checkpoints/test",
"device": "cuda",
"use_wandb": False,
# Optimization parameters
"gradient_accumulation_steps": 4,
"use_amp": True,
"warmup_steps": 10,
"num_workers": 2,
"use_ema": True,
"ema_decay": 0.9999,
"use_onecycle": False,
"use_gradient_checkpointing": False,
"compile_model": True,
# Phase 4 optimizations
"use_bf16": False, # Use FP16 for compatibility
"gradient_clip_norm": 1.0,
"find_lr": False, # Skip for quick test
"find_batch_size": False, # Skip for quick test
# FSDP options
"use_fsdp": False, # Skip for quick test
"fsdp_sharding_strategy": "FULL_SHARD",
"fsdp_mixed_precision": None,
# Advanced optimizations
"use_qat": False, # Skip for quick test
"qat_backend": "fbgemm",
"use_sequence_parallel": False, # Skip for quick test
"sequence_parallel_gpus": 1,
"activation_recompute_strategy": None,
# Checkpoint options
"async_checkpoint": True,
"compress_checkpoint": True,
}
result = self.test_endpoint(
"POST",
"/api/v1/train/start",
"Fine-tune with optimizations",
json=payload,
)
self.results.append(("POST /api/v1/train/start (optimized)", result))
# Test baseline (no optimizations)
payload = {
"training_data_dir": training_data_dir,
"epochs": 1,
"lr": 1e-5,
"batch_size": 1,
"checkpoint_dir": "checkpoints/test_baseline",
"device": "cuda",
"use_wandb": False,
"gradient_accumulation_steps": 1,
"use_amp": False,
"compile_model": False,
}
result = self.test_endpoint(
"POST",
"/api/v1/train/start",
"Fine-tune (baseline)",
json=payload,
)
self.results.append(("POST /api/v1/train/start (baseline)", result))
def test_pretraining_endpoints(
self,
arkit_sequences_dir: Optional[str] = None,
test_optimizations: bool = True,
):
"""Test pre-training endpoints with optimization parameters."""
logger.info("\n" + "=" * 80)
logger.info("PRE-TRAINING ENDPOINTS")
logger.info("=" * 80)
if not arkit_sequences_dir:
logger.info("Skipping /api/v1/train/pretrain (no arkit_sequences_dir)")
return
# Test with optimizations
if test_optimizations:
payload = {
"arkit_sequences_dir": arkit_sequences_dir,
"epochs": 1, # Single epoch for testing
"lr": 1e-4,
"batch_size": 1,
"checkpoint_dir": "checkpoints/test_pretrain",
"device": "cuda",
"max_sequences": 1, # Small for testing
"max_frames_per_sequence": 10,
"frame_interval": 1,
"use_lidar": False,
"use_ba_depth": False,
"min_ba_quality": 0.0,
"use_wandb": False,
# Optimization parameters
"gradient_accumulation_steps": 4,
"use_amp": True,
"warmup_steps": 10,
"num_workers": 2,
"use_ema": True,
"ema_decay": 0.9999,
"use_onecycle": False,
"use_gradient_checkpointing": False,
"compile_model": True,
"cache_dir": "cache/test_ba",
# Phase 4 optimizations
"use_bf16": False, # Use FP16 for compatibility
"gradient_clip_norm": 1.0,
"find_lr": False, # Skip for quick test
"find_batch_size": False, # Skip for quick test
# FSDP options
"use_fsdp": False, # Skip for quick test
"fsdp_sharding_strategy": "FULL_SHARD",
"fsdp_mixed_precision": None,
# Advanced optimizations
"use_qat": False, # Skip for quick test
"qat_backend": "fbgemm",
"use_sequence_parallel": False, # Skip for quick test
"sequence_parallel_gpus": 1,
"activation_recompute_strategy": None,
# Checkpoint options
"async_checkpoint": True,
"compress_checkpoint": True,
}
result = self.test_endpoint(
"POST",
"/api/v1/train/pretrain",
"Pre-train with optimizations",
json=payload,
)
self.results.append(("POST /api/v1/train/pretrain (optimized)", result))
def poll_jobs(self, max_polls: int = 60, poll_interval: int = 5):
"""Poll job status until completion."""
logger.info("\n" + "=" * 80)
logger.info("POLLING JOBS")
logger.info("=" * 80)
if not self.job_ids:
logger.info("No jobs to monitor")
return
logger.info(f"Monitoring {len(self.job_ids)} job(s)")
for job_id in self.job_ids:
logger.info(f"\nMonitoring job: {job_id}")
for poll_num in range(max_polls):
result = self.test_endpoint(
"GET",
f"/api/v1/jobs/{job_id}",
f"Job status (poll {poll_num + 1}/{max_polls})",
)
if result.get("success") and result.get("data"):
data = result["data"]
status = data.get("status", "unknown")
message = data.get("message", "")
logger.info(f" Status: {status}, Message: {message[:60]}")
if status in ["completed", "failed"]:
logger.info(f" Job {status}!")
if status == "completed":
job_result = data.get("result", {})
if job_result:
logger.info(f" Result keys: {list(job_result.keys())}")
break
if poll_num < max_polls - 1:
time.sleep(poll_interval)
else:
logger.warning(" Failed to get job status")
break
self.results.append((f"GET /api/v1/jobs/{job_id} (final)", result))
def print_summary(self):
"""Print test summary."""
logger.info("\n" + "=" * 80)
logger.info("TEST SUMMARY")
logger.info("=" * 80)
success_count = sum(1 for _, r in self.results if r.get("success"))
total_count = len(self.results)
logger.info(f"Success: {success_count}/{total_count}")
logger.info("")
logger.info("Endpoint Results:")
for endpoint, result in self.results:
status = "✓" if result.get("success") else "✗"
status_code = result.get("status_code", "N/A")
duration = result.get("duration", 0)
status_code_str = str(status_code) if status_code is not None else "N/A"
logger.info(f"{status} {endpoint:60s} {status_code_str:>3} ({duration:.3f}s)")
def save_results(self, output_file: Path):
"""Save test results to JSON file."""
output_file.parent.mkdir(parents=True, exist_ok=True)
output_data = {
"timestamp": datetime.now().isoformat(),
"base_url": self.base_url,
"summary": {
"total_tests": len(self.results),
"successful": sum(1 for _, r in self.results if r.get("success")),
"failed": sum(1 for _, r in self.results if not r.get("success")),
},
"results": [
{
"endpoint": endpoint,
"status_code": r.get("status_code"),
"success": r.get("success"),
"duration": r.get("duration"),
"data": r.get("data") if r.get("success") else None,
"error": r.get("error") if not r.get("success") else None,
}
for endpoint, r in self.results
],
}
with open(output_file, "w") as f:
json.dump(output_data, f, indent=2, default=str)
logger.info(f"\nResults saved to: {output_file}")
def main():
"""Main test function."""
parser = argparse.ArgumentParser(description="Comprehensive YLFF API endpoint testing")
parser.add_argument(
"--base-url",
default="http://localhost:8000",
help="Base URL for API",
)
parser.add_argument("--sequence-dir", type=str, help="Sequence directory for validation")
parser.add_argument("--arkit-dir", type=str, help="ARKit directory for validation")
parser.add_argument(
"--sequences-dir",
type=str,
help="Sequences directory for dataset building",
)
parser.add_argument(
"--training-data-dir",
type=str,
help="Training data directory for fine-tuning",
)
parser.add_argument(
"--arkit-sequences-dir",
type=str,
help="ARKit sequences directory for pre-training",
)
parser.add_argument(
"--skip-optimizations",
action="store_true",
help="Skip optimization parameter tests",
)
parser.add_argument(
"--skip-polling",
action="store_true",
help="Skip job polling",
)
parser.add_argument(
"--output",
type=Path,
default=Path("data/api_test_results.json"),
help="Output file for results",
)
parser.add_argument(
"--timeout",
type=int,
default=300,
help="Request timeout in seconds",
)
args = parser.parse_args()
logger.info("=" * 80)
logger.info("YLFF API COMPREHENSIVE TEST")
logger.info("=" * 80)
logger.info(f"Base URL: {args.base_url}")
logger.info(f"Timeout: {args.timeout}s")
logger.info("")
tester = APITester(args.base_url, timeout=args.timeout)
# Run tests
tester.test_health_endpoints()
tester.test_profiling_endpoints()
tester.test_validation_endpoints(sequence_dir=args.sequence_dir, arkit_dir=args.arkit_dir)
tester.test_dataset_building_endpoints(
sequences_dir=args.sequences_dir,
test_optimizations=not args.skip_optimizations,
)
tester.test_training_endpoints(
training_data_dir=args.training_data_dir,
test_optimizations=not args.skip_optimizations,
)
tester.test_pretraining_endpoints(
arkit_sequences_dir=args.arkit_sequences_dir,
test_optimizations=not args.skip_optimizations,
)
# Poll jobs if requested
if not args.skip_polling:
tester.poll_jobs()
# Print summary
tester.print_summary()
# Save results
tester.save_results(args.output)
# Return exit code
success_count = sum(1 for _, r in tester.results if r.get("success"))
total_count = len(tester.results)
return 0 if success_count == total_count else 1
if __name__ == "__main__":
sys.exit(main())