|
|
""" |
|
|
Pydantic models for YLFF API request/response schemas. |
|
|
|
|
|
All API models are rigorously defined with: |
|
|
- Comprehensive field validation |
|
|
- Detailed descriptions and examples |
|
|
- Type hints and optional field defaults |
|
|
- JSON schema generation support |
|
|
""" |
|
|
|
|
|
from enum import Enum |
|
|
from typing import Any, Dict, List, Optional |
|
|
from pydantic import BaseModel, Field, field_validator |
|
|
|
|
|
|
|
|
|
|
|
class JobStatus(str, Enum): |
|
|
"""Job execution status.""" |
|
|
|
|
|
QUEUED = "queued" |
|
|
RUNNING = "running" |
|
|
COMPLETED = "completed" |
|
|
FAILED = "failed" |
|
|
CANCELLED = "cancelled" |
|
|
|
|
|
|
|
|
class DeviceType(str, Enum): |
|
|
"""Device type for model inference/training.""" |
|
|
|
|
|
CPU = "cpu" |
|
|
CUDA = "cuda" |
|
|
MPS = "mps" |
|
|
|
|
|
|
|
|
class UseCase(str, Enum): |
|
|
"""Use case for model selection.""" |
|
|
|
|
|
BA_VALIDATION = "ba_validation" |
|
|
MONO_DEPTH = "mono_depth" |
|
|
MULTI_VIEW = "multi_view" |
|
|
POSE_CONDITIONED = "pose_conditioned" |
|
|
TRAINING = "training" |
|
|
INFERENCE = "inference" |
|
|
|
|
|
|
|
|
|
|
|
class ValidateSequenceRequest(BaseModel): |
|
|
"""Request model for sequence validation.""" |
|
|
|
|
|
sequence_dir: str = Field( |
|
|
..., |
|
|
description="Directory containing image sequence", |
|
|
examples=["data/sequences/sequence_001"], |
|
|
min_length=1, |
|
|
) |
|
|
|
|
|
model_name: Optional[str] = Field( |
|
|
None, |
|
|
description="DA3 model name (default: auto-select based on use_case)", |
|
|
examples=["depth-anything/DA3-LARGE", "depth-anything/DA3-GIANT"], |
|
|
) |
|
|
|
|
|
use_case: UseCase = Field( |
|
|
UseCase.BA_VALIDATION, |
|
|
description="Use case for model selection", |
|
|
examples=["ba_validation", "mono_depth"], |
|
|
) |
|
|
|
|
|
accept_threshold: float = Field( |
|
|
2.0, |
|
|
description=( |
|
|
"Accept threshold in degrees " "(frames with rotation error below this are accepted)" |
|
|
), |
|
|
ge=0.0, |
|
|
le=180.0, |
|
|
examples=[2.0, 5.0], |
|
|
) |
|
|
|
|
|
reject_threshold: float = Field( |
|
|
30.0, |
|
|
description=( |
|
|
"Reject threshold in degrees " |
|
|
"(frames with rotation error above this are rejected as outliers)" |
|
|
), |
|
|
ge=0.0, |
|
|
le=180.0, |
|
|
examples=[30.0, 45.0], |
|
|
) |
|
|
|
|
|
output: Optional[str] = Field( |
|
|
None, |
|
|
description="Output JSON path for validation results", |
|
|
examples=["data/results/validation.json"], |
|
|
) |
|
|
|
|
|
@field_validator("reject_threshold") |
|
|
@classmethod |
|
|
def reject_greater_than_accept(cls, v, info): |
|
|
"""Ensure reject threshold is greater than accept threshold.""" |
|
|
accept_thresh = info.data.get("accept_threshold") |
|
|
if accept_thresh is not None and v <= accept_thresh: |
|
|
raise ValueError( |
|
|
f"reject_threshold ({v}) must be greater than " |
|
|
f"accept_threshold ({accept_thresh})" |
|
|
) |
|
|
return v |
|
|
|
|
|
@field_validator("sequence_dir") |
|
|
@classmethod |
|
|
def validate_sequence_dir(cls, v): |
|
|
"""Validate sequence directory path format.""" |
|
|
if not v or not v.strip(): |
|
|
raise ValueError("sequence_dir cannot be empty") |
|
|
return v.strip() |
|
|
|
|
|
model_config = { |
|
|
"json_schema_extra": { |
|
|
"example": { |
|
|
"sequence_dir": "data/sequences/sequence_001", |
|
|
"model_name": "depth-anything/DA3-LARGE", |
|
|
"use_case": "ba_validation", |
|
|
"accept_threshold": 2.0, |
|
|
"reject_threshold": 30.0, |
|
|
"output": "data/results/validation.json", |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
class ValidateARKitRequest(BaseModel): |
|
|
"""Request model for ARKit validation.""" |
|
|
|
|
|
arkit_dir: str = Field( |
|
|
..., |
|
|
description="Directory containing ARKit video and JSON metadata", |
|
|
examples=["assets/examples/ARKit", "data/arkit_recordings/session_001"], |
|
|
min_length=1, |
|
|
) |
|
|
|
|
|
output_dir: str = Field( |
|
|
"data/arkit_validation", |
|
|
description="Output directory for validation results", |
|
|
examples=["data/arkit_validation", "data/results/arkit_001"], |
|
|
) |
|
|
|
|
|
model_name: Optional[str] = Field( |
|
|
None, |
|
|
description="DA3 model name (default: DA3NESTED-GIANT-LARGE for BA validation)", |
|
|
examples=["depth-anything/DA3-LARGE", "depth-anything/DA3NESTED-GIANT-LARGE"], |
|
|
) |
|
|
|
|
|
max_frames: Optional[int] = Field( |
|
|
None, |
|
|
description="Maximum number of frames to process (None = process all)", |
|
|
ge=1, |
|
|
examples=[10, 30, 100], |
|
|
) |
|
|
|
|
|
frame_interval: int = Field( |
|
|
1, |
|
|
description="Extract every Nth frame (1 = all frames, 5 = every 5th frame)", |
|
|
ge=1, |
|
|
examples=[1, 5, 10], |
|
|
) |
|
|
|
|
|
device: DeviceType = Field( |
|
|
DeviceType.CPU, |
|
|
description="Device for DA3 inference", |
|
|
examples=["cpu", "cuda", "mps"], |
|
|
) |
|
|
|
|
|
gui: bool = Field( |
|
|
False, |
|
|
description="Show real-time GUI visualization during validation", |
|
|
examples=[False, True], |
|
|
) |
|
|
|
|
|
@field_validator("arkit_dir") |
|
|
@classmethod |
|
|
def validate_arkit_dir(cls, v): |
|
|
"""Validate ARKit directory path format.""" |
|
|
if not v or not v.strip(): |
|
|
raise ValueError("arkit_dir cannot be empty") |
|
|
return v.strip() |
|
|
|
|
|
model_config = { |
|
|
"json_schema_extra": { |
|
|
"example": { |
|
|
"arkit_dir": "assets/examples/ARKit", |
|
|
"output_dir": "data/arkit_validation", |
|
|
"model_name": "depth-anything/DA3NESTED-GIANT-LARGE", |
|
|
"max_frames": 30, |
|
|
"frame_interval": 1, |
|
|
"device": "cpu", |
|
|
"gui": False, |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
class BuildDatasetRequest(BaseModel): |
|
|
"""Request model for building training dataset.""" |
|
|
|
|
|
sequences_dir: str = Field( |
|
|
..., |
|
|
description="Directory containing sequence directories", |
|
|
examples=["data/raw/sequences", "data/collected/sequences"], |
|
|
min_length=1, |
|
|
) |
|
|
|
|
|
output_dir: str = Field( |
|
|
"data/training", |
|
|
description="Output directory for training dataset", |
|
|
examples=["data/training", "data/training/dataset_v1"], |
|
|
) |
|
|
|
|
|
model_name: Optional[str] = Field( |
|
|
None, |
|
|
description="DA3 model name for validation", |
|
|
examples=["depth-anything/DA3-LARGE"], |
|
|
) |
|
|
|
|
|
max_samples: Optional[int] = Field( |
|
|
None, |
|
|
description="Maximum number of training samples to generate (None = no limit)", |
|
|
ge=1, |
|
|
examples=[100, 500, 1000], |
|
|
) |
|
|
|
|
|
accept_threshold: float = Field( |
|
|
2.0, |
|
|
description="Accept threshold in degrees", |
|
|
ge=0.0, |
|
|
le=180.0, |
|
|
examples=[2.0], |
|
|
) |
|
|
|
|
|
reject_threshold: float = Field( |
|
|
30.0, |
|
|
description="Reject threshold in degrees", |
|
|
ge=0.0, |
|
|
le=180.0, |
|
|
examples=[30.0], |
|
|
) |
|
|
|
|
|
use_wandb: bool = Field( |
|
|
True, |
|
|
description="Enable Weights & Biases logging", |
|
|
examples=[True, False], |
|
|
) |
|
|
|
|
|
wandb_project: str = Field( |
|
|
"ylff", |
|
|
description="W&B project name", |
|
|
examples=["ylff", "ylff-datasets"], |
|
|
min_length=1, |
|
|
) |
|
|
|
|
|
wandb_name: Optional[str] = Field( |
|
|
None, |
|
|
description="W&B run name (default: auto-generated)", |
|
|
examples=["dataset-build-2024-12-06", "v1-training-set"], |
|
|
) |
|
|
|
|
|
|
|
|
use_batched_inference: bool = Field( |
|
|
False, |
|
|
description="Use batched inference for better GPU utilization", |
|
|
examples=[False, True], |
|
|
) |
|
|
|
|
|
inference_batch_size: int = Field( |
|
|
4, |
|
|
description="Batch size for inference (when use_batched_inference=True)", |
|
|
ge=1, |
|
|
examples=[2, 4, 8], |
|
|
) |
|
|
|
|
|
use_inference_cache: bool = Field( |
|
|
False, |
|
|
description="Cache inference results to avoid recomputing identical sequences", |
|
|
examples=[False, True], |
|
|
) |
|
|
|
|
|
cache_dir: Optional[str] = Field( |
|
|
None, |
|
|
description="Directory for inference cache (None = in-memory only)", |
|
|
examples=[None, "cache/inference"], |
|
|
) |
|
|
|
|
|
compile_model: bool = Field( |
|
|
True, |
|
|
description="Compile model with torch.compile for faster inference", |
|
|
examples=[True, False], |
|
|
) |
|
|
|
|
|
@field_validator("reject_threshold") |
|
|
@classmethod |
|
|
def reject_greater_than_accept(cls, v, info): |
|
|
"""Ensure reject threshold is greater than accept threshold.""" |
|
|
accept_thresh = info.data.get("accept_threshold") |
|
|
if accept_thresh is not None and v <= accept_thresh: |
|
|
raise ValueError( |
|
|
f"reject_threshold ({v}) must be greater than " |
|
|
f"accept_threshold ({accept_thresh})" |
|
|
) |
|
|
return v |
|
|
|
|
|
model_config = { |
|
|
"json_schema_extra": { |
|
|
"example": { |
|
|
"sequences_dir": "data/raw/sequences", |
|
|
"output_dir": "data/training", |
|
|
"model_name": "depth-anything/DA3-LARGE", |
|
|
"max_samples": 1000, |
|
|
"accept_threshold": 2.0, |
|
|
"reject_threshold": 30.0, |
|
|
"use_wandb": True, |
|
|
"wandb_project": "ylff", |
|
|
"wandb_name": "dataset-build-2024-12-06", |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
class TrainRequest(BaseModel): |
|
|
"""Request model for model fine-tuning.""" |
|
|
|
|
|
training_data_dir: str = Field( |
|
|
..., |
|
|
description="Directory containing training samples", |
|
|
examples=["data/training", "data/training/dataset_v1"], |
|
|
min_length=1, |
|
|
) |
|
|
|
|
|
model_name: Optional[str] = Field( |
|
|
None, |
|
|
description="DA3 model name to fine-tune", |
|
|
examples=["depth-anything/DA3-LARGE"], |
|
|
) |
|
|
|
|
|
epochs: int = Field( |
|
|
10, |
|
|
description="Number of training epochs", |
|
|
ge=1, |
|
|
le=1000, |
|
|
examples=[10, 20, 50], |
|
|
) |
|
|
|
|
|
lr: float = Field( |
|
|
1e-5, |
|
|
description="Learning rate", |
|
|
gt=0.0, |
|
|
examples=[1e-5, 1e-4, 1e-6], |
|
|
) |
|
|
|
|
|
batch_size: int = Field( |
|
|
1, |
|
|
description="Training batch size", |
|
|
ge=1, |
|
|
examples=[1, 2, 4, 8], |
|
|
) |
|
|
|
|
|
checkpoint_dir: str = Field( |
|
|
"checkpoints", |
|
|
description="Directory to save model checkpoints", |
|
|
examples=["checkpoints", "models/checkpoints"], |
|
|
) |
|
|
|
|
|
device: DeviceType = Field( |
|
|
DeviceType.CUDA, |
|
|
description="Device for training", |
|
|
examples=["cuda", "cpu", "mps"], |
|
|
) |
|
|
|
|
|
use_wandb: bool = Field( |
|
|
True, |
|
|
description="Enable Weights & Biases logging", |
|
|
examples=[True, False], |
|
|
) |
|
|
|
|
|
wandb_project: str = Field( |
|
|
"ylff", |
|
|
description="W&B project name", |
|
|
examples=["ylff", "ylff-training"], |
|
|
) |
|
|
|
|
|
wandb_name: Optional[str] = Field( |
|
|
None, |
|
|
description="W&B run name", |
|
|
examples=["fine-tune-v1", "training-run-2024-12-06"], |
|
|
) |
|
|
|
|
|
|
|
|
gradient_accumulation_steps: int = Field( |
|
|
1, |
|
|
description=( |
|
|
"Number of steps to accumulate gradients " "(effective batch size = batch_size * this)" |
|
|
), |
|
|
ge=1, |
|
|
examples=[1, 4, 8], |
|
|
) |
|
|
|
|
|
use_amp: bool = Field( |
|
|
True, |
|
|
description="Use automatic mixed precision training (FP16)", |
|
|
examples=[True, False], |
|
|
) |
|
|
|
|
|
warmup_steps: int = Field( |
|
|
0, |
|
|
description="Number of warmup steps for learning rate (0 = no warmup)", |
|
|
ge=0, |
|
|
examples=[0, 100, 500], |
|
|
) |
|
|
|
|
|
num_workers: Optional[int] = Field( |
|
|
None, |
|
|
description="Number of data loading workers (None = auto-detect)", |
|
|
ge=0, |
|
|
examples=[None, 2, 4, 8], |
|
|
) |
|
|
|
|
|
resume_from_checkpoint: Optional[str] = Field( |
|
|
None, |
|
|
description="Path to checkpoint to resume from", |
|
|
examples=[None, "checkpoints/latest.pth", "checkpoints/best.pth"], |
|
|
) |
|
|
|
|
|
use_ema: bool = Field( |
|
|
False, |
|
|
description="Use Exponential Moving Average for model weights", |
|
|
examples=[False, True], |
|
|
) |
|
|
|
|
|
ema_decay: float = Field( |
|
|
0.9999, |
|
|
description="EMA decay factor (higher = slower update, more stable)", |
|
|
gt=0.0, |
|
|
lt=1.0, |
|
|
examples=[0.999, 0.9999, 0.99999], |
|
|
) |
|
|
|
|
|
use_onecycle: bool = Field( |
|
|
False, |
|
|
description="Use OneCycleLR scheduler instead of CosineAnnealingLR", |
|
|
examples=[False, True], |
|
|
) |
|
|
|
|
|
use_gradient_checkpointing: bool = Field( |
|
|
False, |
|
|
description="Enable gradient checkpointing to save memory (slower but uses less memory)", |
|
|
examples=[False, True], |
|
|
) |
|
|
|
|
|
compile_model: bool = Field( |
|
|
True, |
|
|
description="Compile model with torch.compile for faster training (PyTorch 2.0+)", |
|
|
examples=[True, False], |
|
|
) |
|
|
|
|
|
|
|
|
use_bf16: bool = Field( |
|
|
False, |
|
|
description="Use BF16 instead of FP16 (better training stability, same speed)", |
|
|
examples=[False, True], |
|
|
) |
|
|
|
|
|
gradient_clip_norm: Optional[float] = Field( |
|
|
1.0, |
|
|
description="Maximum gradient norm for clipping (None = disabled, 1.0 = default)", |
|
|
ge=0.0, |
|
|
examples=[None, 0.5, 1.0, 2.0], |
|
|
) |
|
|
|
|
|
find_lr: bool = Field( |
|
|
False, |
|
|
description="Automatically find optimal learning rate before training", |
|
|
examples=[False, True], |
|
|
) |
|
|
|
|
|
find_batch_size: bool = Field( |
|
|
False, |
|
|
description="Automatically find optimal batch size before training", |
|
|
examples=[False, True], |
|
|
) |
|
|
|
|
|
|
|
|
use_fsdp: bool = Field( |
|
|
False, |
|
|
description=( |
|
|
"Use FSDP (Fully Sharded Data Parallel) " "for memory-efficient multi-GPU training" |
|
|
), |
|
|
examples=[False, True], |
|
|
) |
|
|
|
|
|
fsdp_sharding_strategy: str = Field( |
|
|
"FULL_SHARD", |
|
|
description="FSDP sharding strategy: FULL_SHARD, SHARD_GRAD_OP, or NO_SHARD", |
|
|
examples=["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD"], |
|
|
) |
|
|
|
|
|
fsdp_mixed_precision: Optional[str] = Field( |
|
|
None, |
|
|
description="FSDP mixed precision: bf16, fp16, or None (auto-detects from use_bf16)", |
|
|
examples=[None, "bf16", "fp16"], |
|
|
) |
|
|
|
|
|
|
|
|
use_qat: bool = Field( |
|
|
False, |
|
|
description=("Use Quantization Aware Training (QAT) for better INT8 quantization"), |
|
|
examples=[False, True], |
|
|
) |
|
|
|
|
|
qat_backend: str = Field( |
|
|
"fbgemm", |
|
|
description="QAT backend: fbgemm (x86) or qnnpack (ARM)", |
|
|
examples=["fbgemm", "qnnpack"], |
|
|
) |
|
|
|
|
|
use_sequence_parallel: bool = Field( |
|
|
False, |
|
|
description="Enable sequence parallelism for very long sequences", |
|
|
examples=[False, True], |
|
|
) |
|
|
|
|
|
sequence_parallel_gpus: int = Field( |
|
|
1, |
|
|
description="Number of GPUs for sequence parallelism", |
|
|
ge=1, |
|
|
examples=[1, 2, 4], |
|
|
) |
|
|
|
|
|
activation_recompute_strategy: Optional[str] = Field( |
|
|
None, |
|
|
description="Activation recompute strategy: checkpoint, cpu_offload, hybrid, or None", |
|
|
examples=[None, "checkpoint", "cpu_offload", "hybrid"], |
|
|
) |
|
|
|
|
|
|
|
|
async_checkpoint: bool = Field( |
|
|
True, |
|
|
description="Use async checkpoint saving (non-blocking, faster training)", |
|
|
examples=[True, False], |
|
|
) |
|
|
|
|
|
compress_checkpoint: bool = Field( |
|
|
True, |
|
|
description="Compress checkpoints with gzip (30-50% smaller files)", |
|
|
examples=[True, False], |
|
|
) |
|
|
|
|
|
model_config = { |
|
|
"json_schema_extra": { |
|
|
"example": { |
|
|
"training_data_dir": "data/training", |
|
|
"model_name": "depth-anything/DA3-LARGE", |
|
|
"epochs": 10, |
|
|
"lr": 1e-5, |
|
|
"batch_size": 1, |
|
|
"checkpoint_dir": "checkpoints", |
|
|
"device": "cuda", |
|
|
"use_wandb": True, |
|
|
"wandb_project": "ylff", |
|
|
"wandb_name": "fine-tune-v1", |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
class PretrainRequest(BaseModel): |
|
|
"""Request model for model pre-training on ARKit sequences.""" |
|
|
|
|
|
arkit_sequences_dir: str = Field( |
|
|
..., |
|
|
description="Directory containing ARKit sequence directories", |
|
|
examples=["data/arkit_sequences", "data/collected/arkit"], |
|
|
min_length=1, |
|
|
) |
|
|
|
|
|
model_name: Optional[str] = Field( |
|
|
None, |
|
|
description="DA3 model name to pre-train", |
|
|
examples=["depth-anything/DA3-LARGE"], |
|
|
) |
|
|
|
|
|
epochs: int = Field( |
|
|
10, |
|
|
description="Number of pre-training epochs", |
|
|
ge=1, |
|
|
le=1000, |
|
|
examples=[5, 10, 20], |
|
|
) |
|
|
|
|
|
lr: float = Field( |
|
|
1e-4, |
|
|
description="Learning rate for pre-training", |
|
|
gt=0.0, |
|
|
examples=[1e-4, 1e-3], |
|
|
) |
|
|
|
|
|
batch_size: int = Field( |
|
|
1, |
|
|
description="Pre-training batch size", |
|
|
ge=1, |
|
|
examples=[1, 2, 4], |
|
|
) |
|
|
|
|
|
checkpoint_dir: str = Field( |
|
|
"checkpoints/pretrain", |
|
|
description="Directory to save model checkpoints", |
|
|
examples=["checkpoints/pretrain"], |
|
|
) |
|
|
|
|
|
device: DeviceType = Field( |
|
|
DeviceType.CUDA, |
|
|
description="Device for pre-training", |
|
|
examples=["cuda"], |
|
|
) |
|
|
|
|
|
max_sequences: Optional[int] = Field( |
|
|
None, |
|
|
description="Maximum number of sequences to process (None = all)", |
|
|
ge=1, |
|
|
examples=[10, 50, 100], |
|
|
) |
|
|
|
|
|
max_frames_per_sequence: Optional[int] = Field( |
|
|
None, |
|
|
description="Maximum frames per sequence to process (None = all)", |
|
|
ge=1, |
|
|
examples=[30, 100], |
|
|
) |
|
|
|
|
|
frame_interval: int = Field( |
|
|
1, |
|
|
description="Extract every Nth frame", |
|
|
ge=1, |
|
|
examples=[1, 5, 10], |
|
|
) |
|
|
|
|
|
use_lidar: bool = Field( |
|
|
False, |
|
|
description="Use ARKit LiDAR depth as supervision signal", |
|
|
examples=[False, True], |
|
|
) |
|
|
|
|
|
use_ba_depth: bool = Field( |
|
|
False, |
|
|
description="Use BA depth maps as supervision signal", |
|
|
examples=[False, True], |
|
|
) |
|
|
|
|
|
min_ba_quality: float = Field( |
|
|
0.0, |
|
|
description="Minimum BA quality threshold (0.0-1.0)", |
|
|
ge=0.0, |
|
|
le=1.0, |
|
|
examples=[0.0, 0.5, 0.8], |
|
|
) |
|
|
|
|
|
use_wandb: bool = Field( |
|
|
True, |
|
|
description="Enable Weights & Biases logging", |
|
|
examples=[True, False], |
|
|
) |
|
|
|
|
|
wandb_project: str = Field( |
|
|
"ylff", |
|
|
description="W&B project name", |
|
|
examples=["ylff", "ylff-pretraining"], |
|
|
) |
|
|
|
|
|
wandb_name: Optional[str] = Field( |
|
|
None, |
|
|
description="W&B run name", |
|
|
examples=["pretrain-v1", "pretrain-arkit-2024-12-06"], |
|
|
) |
|
|
|
|
|
|
|
|
gradient_accumulation_steps: int = Field( |
|
|
1, |
|
|
description="Number of steps to accumulate gradients", |
|
|
ge=1, |
|
|
examples=[1, 4, 8], |
|
|
) |
|
|
|
|
|
use_amp: bool = Field( |
|
|
True, |
|
|
description="Use automatic mixed precision training (FP16)", |
|
|
examples=[True, False], |
|
|
) |
|
|
|
|
|
warmup_steps: int = Field( |
|
|
0, |
|
|
description="Number of warmup steps for learning rate", |
|
|
ge=0, |
|
|
examples=[0, 100, 500], |
|
|
) |
|
|
|
|
|
num_workers: Optional[int] = Field( |
|
|
None, |
|
|
description="Number of data loading workers (None = auto-detect)", |
|
|
ge=0, |
|
|
examples=[None, 2, 4, 8], |
|
|
) |
|
|
|
|
|
resume_from_checkpoint: Optional[str] = Field( |
|
|
None, |
|
|
description="Path to checkpoint to resume from", |
|
|
examples=[None, "checkpoints/pretrain/latest.pth"], |
|
|
) |
|
|
|
|
|
use_ema: bool = Field( |
|
|
False, |
|
|
description="Use Exponential Moving Average for model weights", |
|
|
examples=[False, True], |
|
|
) |
|
|
|
|
|
ema_decay: float = Field( |
|
|
0.9999, |
|
|
description="EMA decay factor", |
|
|
gt=0.0, |
|
|
lt=1.0, |
|
|
examples=[0.9999], |
|
|
) |
|
|
|
|
|
use_onecycle: bool = Field( |
|
|
False, |
|
|
description="Use OneCycleLR scheduler instead of CosineAnnealingLR", |
|
|
examples=[False, True], |
|
|
) |
|
|
|
|
|
use_gradient_checkpointing: bool = Field( |
|
|
False, |
|
|
description="Enable gradient checkpointing to save memory", |
|
|
examples=[False, True], |
|
|
) |
|
|
|
|
|
compile_model: bool = Field( |
|
|
True, |
|
|
description="Compile model with torch.compile for faster training", |
|
|
examples=[True, False], |
|
|
) |
|
|
|
|
|
cache_dir: Optional[str] = Field( |
|
|
None, |
|
|
description="Directory for caching BA results (None = disabled)", |
|
|
examples=[None, "cache/ba_results"], |
|
|
) |
|
|
|
|
|
|
|
|
use_bf16: bool = Field( |
|
|
False, |
|
|
description="Use BF16 instead of FP16 (better training stability, same speed)", |
|
|
examples=[False, True], |
|
|
) |
|
|
|
|
|
gradient_clip_norm: Optional[float] = Field( |
|
|
1.0, |
|
|
description="Maximum gradient norm for clipping (None = disabled, 1.0 = default)", |
|
|
ge=0.0, |
|
|
examples=[None, 0.5, 1.0, 2.0], |
|
|
) |
|
|
|
|
|
find_lr: bool = Field( |
|
|
False, |
|
|
description="Automatically find optimal learning rate before training", |
|
|
examples=[False, True], |
|
|
) |
|
|
|
|
|
find_batch_size: bool = Field( |
|
|
False, |
|
|
description="Automatically find optimal batch size before training", |
|
|
examples=[False, True], |
|
|
) |
|
|
|
|
|
|
|
|
use_fsdp: bool = Field( |
|
|
False, |
|
|
description=( |
|
|
"Use FSDP (Fully Sharded Data Parallel) " "for memory-efficient multi-GPU training" |
|
|
), |
|
|
examples=[False, True], |
|
|
) |
|
|
|
|
|
fsdp_sharding_strategy: str = Field( |
|
|
"FULL_SHARD", |
|
|
description="FSDP sharding strategy: FULL_SHARD, SHARD_GRAD_OP, or NO_SHARD", |
|
|
examples=["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD"], |
|
|
) |
|
|
|
|
|
fsdp_mixed_precision: Optional[str] = Field( |
|
|
None, |
|
|
description="FSDP mixed precision: bf16, fp16, or None (auto-detects from use_bf16)", |
|
|
examples=[None, "bf16", "fp16"], |
|
|
) |
|
|
|
|
|
|
|
|
use_qat: bool = Field( |
|
|
False, |
|
|
description=("Use Quantization Aware Training (QAT) for better INT8 quantization"), |
|
|
examples=[False, True], |
|
|
) |
|
|
|
|
|
qat_backend: str = Field( |
|
|
"fbgemm", |
|
|
description="QAT backend: fbgemm (x86) or qnnpack (ARM)", |
|
|
examples=["fbgemm", "qnnpack"], |
|
|
) |
|
|
|
|
|
use_sequence_parallel: bool = Field( |
|
|
False, |
|
|
description="Enable sequence parallelism for very long sequences", |
|
|
examples=[False, True], |
|
|
) |
|
|
|
|
|
sequence_parallel_gpus: int = Field( |
|
|
1, |
|
|
description="Number of GPUs for sequence parallelism", |
|
|
ge=1, |
|
|
examples=[1, 2, 4], |
|
|
) |
|
|
|
|
|
activation_recompute_strategy: Optional[str] = Field( |
|
|
None, |
|
|
description="Activation recompute strategy: checkpoint, cpu_offload, hybrid, or None", |
|
|
examples=[None, "checkpoint", "cpu_offload", "hybrid"], |
|
|
) |
|
|
|
|
|
|
|
|
async_checkpoint: bool = Field( |
|
|
True, |
|
|
description="Use async checkpoint saving (non-blocking, faster training)", |
|
|
examples=[True, False], |
|
|
) |
|
|
|
|
|
compress_checkpoint: bool = Field( |
|
|
True, |
|
|
description="Compress checkpoints with gzip (30-50% smaller files)", |
|
|
examples=[True, False], |
|
|
) |
|
|
|
|
|
|
|
|
prefer_arkit_poses: bool = Field( |
|
|
True, |
|
|
description=( |
|
|
"Use ARKit poses directly when tracking quality is good " |
|
|
"(much faster, skips BA for good sequences)" |
|
|
), |
|
|
examples=[True, False], |
|
|
) |
|
|
|
|
|
min_arkit_quality: float = Field( |
|
|
0.8, |
|
|
description=( |
|
|
"Minimum fraction of frames with good tracking to use ARKit poses directly " |
|
|
"(0.0-1.0, higher = stricter)" |
|
|
), |
|
|
ge=0.0, |
|
|
le=1.0, |
|
|
examples=[0.7, 0.8, 0.9], |
|
|
) |
|
|
|
|
|
@field_validator("arkit_sequences_dir") |
|
|
@classmethod |
|
|
def validate_arkit_sequences_dir(cls, v): |
|
|
"""Validate ARKit sequences directory path format.""" |
|
|
if not v or not v.strip(): |
|
|
raise ValueError("arkit_sequences_dir cannot be empty") |
|
|
return v.strip() |
|
|
|
|
|
model_config = { |
|
|
"json_schema_extra": { |
|
|
"example": { |
|
|
"arkit_sequences_dir": "data/arkit_sequences", |
|
|
"model_name": "depth-anything/DA3-LARGE", |
|
|
"epochs": 10, |
|
|
"lr": 1e-4, |
|
|
"batch_size": 1, |
|
|
"checkpoint_dir": "checkpoints/pretrain", |
|
|
"device": "cuda", |
|
|
"max_sequences": None, |
|
|
"max_frames_per_sequence": None, |
|
|
"frame_interval": 1, |
|
|
"use_lidar": False, |
|
|
"use_ba_depth": False, |
|
|
"min_ba_quality": 0.0, |
|
|
"use_wandb": True, |
|
|
"wandb_project": "ylff", |
|
|
"wandb_name": "pretrain-v1", |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
class EvaluateBAAgreementRequest(BaseModel): |
|
|
"""Request model for BA agreement evaluation.""" |
|
|
|
|
|
test_data_dir: str = Field( |
|
|
..., |
|
|
description="Directory containing test sequences", |
|
|
examples=["data/test", "data/validation"], |
|
|
min_length=1, |
|
|
) |
|
|
|
|
|
model_name: str = Field( |
|
|
"depth-anything/DA3-LARGE", |
|
|
description="DA3 model name", |
|
|
examples=["depth-anything/DA3-LARGE", "depth-anything/DA3-GIANT"], |
|
|
) |
|
|
|
|
|
checkpoint: Optional[str] = Field( |
|
|
None, |
|
|
description="Path to model checkpoint (optional, overrides model_name)", |
|
|
examples=["checkpoints/best_model.pth", "checkpoints/epoch_10.pth"], |
|
|
) |
|
|
|
|
|
threshold: float = Field( |
|
|
2.0, |
|
|
description="Agreement threshold in degrees", |
|
|
ge=0.0, |
|
|
le=180.0, |
|
|
examples=[2.0, 5.0], |
|
|
) |
|
|
|
|
|
device: DeviceType = Field( |
|
|
DeviceType.CUDA, |
|
|
description="Device for inference", |
|
|
examples=["cuda", "cpu"], |
|
|
) |
|
|
|
|
|
use_wandb: bool = Field( |
|
|
True, |
|
|
description="Enable Weights & Biases logging", |
|
|
examples=[True, False], |
|
|
) |
|
|
|
|
|
wandb_project: str = Field( |
|
|
"ylff", |
|
|
description="W&B project name", |
|
|
examples=["ylff", "ylff-evaluation"], |
|
|
) |
|
|
|
|
|
wandb_name: Optional[str] = Field( |
|
|
None, |
|
|
description="W&B run name", |
|
|
examples=["eval-ba-agreement-v1", "eval-checkpoint-best"], |
|
|
) |
|
|
|
|
|
@field_validator("test_data_dir") |
|
|
@classmethod |
|
|
def validate_test_data_dir(cls, v): |
|
|
"""Validate test data directory path format.""" |
|
|
if not v or not v.strip(): |
|
|
raise ValueError("test_data_dir cannot be empty") |
|
|
return v.strip() |
|
|
|
|
|
model_config = { |
|
|
"json_schema_extra": { |
|
|
"example": { |
|
|
"test_data_dir": "data/test", |
|
|
"model_name": "depth-anything/DA3-LARGE", |
|
|
"checkpoint": None, |
|
|
"threshold": 2.0, |
|
|
"device": "cuda", |
|
|
"use_wandb": True, |
|
|
"wandb_project": "ylff", |
|
|
"wandb_name": "eval-ba-agreement-v1", |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
class VisualizeRequest(BaseModel): |
|
|
"""Request model for result visualization.""" |
|
|
|
|
|
results_dir: str = Field( |
|
|
..., |
|
|
description="Directory containing validation results", |
|
|
examples=["data/arkit_validation", "data/validation_results"], |
|
|
min_length=1, |
|
|
) |
|
|
|
|
|
output_dir: Optional[str] = Field( |
|
|
None, |
|
|
description="Output directory for visualizations (default: results_dir/visualizations)", |
|
|
examples=["data/arkit_validation/visualizations"], |
|
|
) |
|
|
|
|
|
use_plotly: bool = Field( |
|
|
True, |
|
|
description="Use Plotly for interactive 3D plots (requires plotly package)", |
|
|
examples=[False, True], |
|
|
) |
|
|
|
|
|
@field_validator("results_dir") |
|
|
@classmethod |
|
|
def validate_results_dir(cls, v): |
|
|
"""Validate results directory path format.""" |
|
|
if not v or not v.strip(): |
|
|
raise ValueError("results_dir cannot be empty") |
|
|
return v.strip() |
|
|
|
|
|
model_config = { |
|
|
"json_schema_extra": { |
|
|
"example": { |
|
|
"results_dir": "data/arkit_validation", |
|
|
"output_dir": None, |
|
|
"use_plotly": False, |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
class ValidateDatasetRequest(BaseModel): |
|
|
"""Request model for dataset validation.""" |
|
|
|
|
|
dataset_path: str = Field( |
|
|
..., |
|
|
description="Path to dataset file (pickle, json, or hdf5)", |
|
|
examples=["data/training/dataset.pkl", "data/training/dataset.json"], |
|
|
) |
|
|
|
|
|
strict: bool = Field( |
|
|
False, |
|
|
description="If True, raise exception on validation failure", |
|
|
examples=[False, True], |
|
|
) |
|
|
|
|
|
check_images: bool = Field( |
|
|
True, |
|
|
description="Validate image data", |
|
|
examples=[True, False], |
|
|
) |
|
|
|
|
|
check_poses: bool = Field( |
|
|
True, |
|
|
description="Validate pose data", |
|
|
examples=[True, False], |
|
|
) |
|
|
|
|
|
check_metadata: bool = Field( |
|
|
True, |
|
|
description="Validate metadata fields", |
|
|
examples=[True, False], |
|
|
) |
|
|
|
|
|
|
|
|
class CurateDatasetRequest(BaseModel): |
|
|
"""Request model for dataset curation.""" |
|
|
|
|
|
dataset_path: str = Field( |
|
|
..., |
|
|
description="Path to input dataset file", |
|
|
examples=["data/training/dataset.pkl"], |
|
|
) |
|
|
|
|
|
output_path: str = Field( |
|
|
..., |
|
|
description="Path to save curated dataset", |
|
|
examples=["data/training/dataset_curated.pkl"], |
|
|
) |
|
|
|
|
|
|
|
|
min_error: Optional[float] = Field( |
|
|
None, |
|
|
description="Minimum error threshold", |
|
|
examples=[None, 0.5, 1.0], |
|
|
) |
|
|
|
|
|
max_error: Optional[float] = Field( |
|
|
None, |
|
|
description="Maximum error threshold", |
|
|
examples=[None, 30.0, 50.0], |
|
|
) |
|
|
|
|
|
min_weight: Optional[float] = Field( |
|
|
None, |
|
|
description="Minimum weight threshold", |
|
|
examples=[None, 0.1, 0.5], |
|
|
) |
|
|
|
|
|
max_weight: Optional[float] = Field( |
|
|
None, |
|
|
description="Maximum weight threshold", |
|
|
examples=[None, 1.0, 2.0], |
|
|
) |
|
|
|
|
|
|
|
|
remove_outliers: bool = Field( |
|
|
False, |
|
|
description="Remove outlier samples", |
|
|
examples=[False, True], |
|
|
) |
|
|
|
|
|
outlier_percentile: float = Field( |
|
|
95.0, |
|
|
description="Percentile threshold for outlier detection", |
|
|
ge=0.0, |
|
|
le=100.0, |
|
|
examples=[95.0, 99.0], |
|
|
) |
|
|
|
|
|
|
|
|
balance: bool = Field( |
|
|
False, |
|
|
description="Balance dataset by error distribution", |
|
|
examples=[False, True], |
|
|
) |
|
|
|
|
|
balance_strategy: str = Field( |
|
|
"error_bins", |
|
|
description="Balancing strategy: error_bins, uniform, or weighted", |
|
|
examples=["error_bins", "uniform", "weighted"], |
|
|
) |
|
|
|
|
|
num_bins: int = Field( |
|
|
10, |
|
|
description="Number of error bins for balancing", |
|
|
ge=2, |
|
|
examples=[10, 20], |
|
|
) |
|
|
|
|
|
|
|
|
class AnalyzeDatasetRequest(BaseModel): |
|
|
"""Request model for dataset analysis.""" |
|
|
|
|
|
dataset_path: str = Field( |
|
|
..., |
|
|
description="Path to dataset file", |
|
|
examples=["data/training/dataset.pkl"], |
|
|
) |
|
|
|
|
|
output_path: Optional[str] = Field( |
|
|
None, |
|
|
description="Path to save analysis report", |
|
|
examples=[None, "data/training/analysis.json"], |
|
|
) |
|
|
|
|
|
format: str = Field( |
|
|
"json", |
|
|
description="Report format: json, text, or markdown", |
|
|
examples=["json", "text", "markdown"], |
|
|
) |
|
|
|
|
|
compute_distributions: bool = Field( |
|
|
True, |
|
|
description="Compute error/weight distributions", |
|
|
examples=[True, False], |
|
|
) |
|
|
|
|
|
compute_correlations: bool = Field( |
|
|
True, |
|
|
description="Compute correlations between metrics", |
|
|
examples=[True, False], |
|
|
) |
|
|
|
|
|
|
|
|
class DatasetValidationResponse(BaseModel): |
|
|
"""Response model for dataset validation.""" |
|
|
|
|
|
validation_passed: bool = Field( |
|
|
..., |
|
|
description="Whether validation passed", |
|
|
examples=[True, False], |
|
|
) |
|
|
|
|
|
statistics: Dict[str, Any] = Field( |
|
|
..., |
|
|
description="Dataset statistics", |
|
|
) |
|
|
|
|
|
issues: List[Dict[str, Any]] = Field( |
|
|
..., |
|
|
description="List of validation issues", |
|
|
) |
|
|
|
|
|
summary: Dict[str, Any] = Field( |
|
|
..., |
|
|
description="Validation summary", |
|
|
) |
|
|
|
|
|
|
|
|
class DatasetAnalysisResponse(BaseModel): |
|
|
"""Response model for dataset analysis.""" |
|
|
|
|
|
statistics: Dict[str, Any] = Field( |
|
|
..., |
|
|
description="Dataset statistics", |
|
|
) |
|
|
|
|
|
quality_metrics: Dict[str, Any] = Field( |
|
|
..., |
|
|
description="Quality metrics", |
|
|
) |
|
|
|
|
|
report: Optional[str] = Field( |
|
|
None, |
|
|
description="Human-readable report (if format was text/markdown)", |
|
|
) |
|
|
|
|
|
|
|
|
class UploadDatasetRequest(BaseModel): |
|
|
"""Request model for dataset upload.""" |
|
|
|
|
|
output_dir: str = Field( |
|
|
..., |
|
|
description="Directory to extract uploaded dataset", |
|
|
examples=["data/uploaded_datasets", "data/arkit_sequences"], |
|
|
) |
|
|
|
|
|
should_validate: bool = Field( |
|
|
True, |
|
|
alias="validate", |
|
|
description="Validate ARKit pairs before extraction", |
|
|
examples=[True, False], |
|
|
) |
|
|
|
|
|
|
|
|
class DownloadDatasetRequest(BaseModel): |
|
|
"""Request model for dataset download from S3.""" |
|
|
|
|
|
bucket_name: str = Field( |
|
|
..., |
|
|
description="S3 bucket name", |
|
|
examples=["my-datasets-bucket", "ylff-datasets"], |
|
|
) |
|
|
|
|
|
s3_key: str = Field( |
|
|
..., |
|
|
description="S3 object key (path to dataset file)", |
|
|
examples=["datasets/arkit_sequences.zip", "datasets/training_set_v1.tar.gz"], |
|
|
) |
|
|
|
|
|
output_dir: str = Field( |
|
|
..., |
|
|
description="Directory to save downloaded dataset", |
|
|
examples=["data/downloaded_datasets", "data/arkit_sequences"], |
|
|
) |
|
|
|
|
|
extract: bool = Field( |
|
|
True, |
|
|
description="Extract downloaded archive", |
|
|
examples=[True, False], |
|
|
) |
|
|
|
|
|
aws_access_key_id: Optional[str] = Field( |
|
|
None, |
|
|
description="AWS access key ID (optional, uses credentials chain if None)", |
|
|
examples=[None, "AKIAIOSFODNN7EXAMPLE"], |
|
|
) |
|
|
|
|
|
aws_secret_access_key: Optional[str] = Field( |
|
|
None, |
|
|
description="AWS secret access key (optional)", |
|
|
examples=[None, "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"], |
|
|
) |
|
|
|
|
|
region_name: str = Field( |
|
|
"us-east-1", |
|
|
description="AWS region name", |
|
|
examples=["us-east-1", "us-west-2", "eu-west-1"], |
|
|
) |
|
|
|
|
|
|
|
|
class UploadDatasetResponse(BaseModel): |
|
|
"""Response model for dataset upload.""" |
|
|
|
|
|
success: bool = Field( |
|
|
..., |
|
|
description="Whether upload was successful", |
|
|
examples=[True, False], |
|
|
) |
|
|
|
|
|
output_dir: str = Field( |
|
|
..., |
|
|
description="Directory where dataset was extracted", |
|
|
) |
|
|
|
|
|
metadata: Dict[str, Any] = Field( |
|
|
..., |
|
|
description="Upload metadata (file counts, pairs, etc.)", |
|
|
) |
|
|
|
|
|
errors: List[str] = Field( |
|
|
default_factory=list, |
|
|
description="List of validation/processing errors", |
|
|
) |
|
|
|
|
|
|
|
|
class DownloadDatasetResponse(BaseModel): |
|
|
"""Response model for dataset download.""" |
|
|
|
|
|
success: bool = Field( |
|
|
..., |
|
|
description="Whether download was successful", |
|
|
examples=[True, False], |
|
|
) |
|
|
|
|
|
output_path: Optional[str] = Field( |
|
|
None, |
|
|
description="Path to downloaded file (if not extracted)", |
|
|
) |
|
|
|
|
|
output_dir: Optional[str] = Field( |
|
|
None, |
|
|
description="Directory where dataset was extracted (if extracted)", |
|
|
) |
|
|
|
|
|
file_size: Optional[int] = Field( |
|
|
None, |
|
|
description="Size of downloaded file in bytes", |
|
|
) |
|
|
|
|
|
error: Optional[str] = Field( |
|
|
None, |
|
|
description="Error message if download failed", |
|
|
) |
|
|
|
|
|
|
|
|
class JobResponse(BaseModel): |
|
|
"""Response model for job-based endpoints.""" |
|
|
|
|
|
job_id: str = Field( |
|
|
..., |
|
|
description="Unique job identifier", |
|
|
examples=["550e8400-e29b-41d4-a716-446655440000"], |
|
|
) |
|
|
|
|
|
status: JobStatus = Field( |
|
|
..., |
|
|
description="Current job status", |
|
|
examples=["queued", "running", "completed", "failed"], |
|
|
) |
|
|
|
|
|
message: Optional[str] = Field( |
|
|
None, |
|
|
description="Status message or error description", |
|
|
examples=["Job queued", "Validation completed successfully"], |
|
|
) |
|
|
|
|
|
result: Optional[Dict[str, Any]] = Field( |
|
|
None, |
|
|
description="Job result data (only present when status is 'completed' or 'failed')", |
|
|
examples=[ |
|
|
{ |
|
|
"success": True, |
|
|
"stdout": "Output text", |
|
|
"stderr": "", |
|
|
"validation_stats": { |
|
|
"total_frames": 10, |
|
|
"accepted": 5, |
|
|
"rejected_learnable": 3, |
|
|
"rejected_outlier": 2, |
|
|
}, |
|
|
} |
|
|
], |
|
|
) |
|
|
|
|
|
model_config = { |
|
|
"json_schema_extra": { |
|
|
"example": { |
|
|
"job_id": "550e8400-e29b-41d4-a716-446655440000", |
|
|
"status": "completed", |
|
|
"message": "ARKit validation completed successfully", |
|
|
"result": { |
|
|
"success": True, |
|
|
"validation_stats": { |
|
|
"total_frames": 10, |
|
|
"accepted": 0, |
|
|
"rejected_learnable": 0, |
|
|
"rejected_outlier": 10, |
|
|
}, |
|
|
}, |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
class ValidationStats(BaseModel): |
|
|
"""Statistics from BA validation.""" |
|
|
|
|
|
total_frames: int = Field( |
|
|
..., |
|
|
description="Total number of frames processed", |
|
|
ge=0, |
|
|
examples=[10, 100], |
|
|
) |
|
|
|
|
|
accepted: int = Field( |
|
|
..., |
|
|
description="Number of accepted frames (< accept_threshold)", |
|
|
ge=0, |
|
|
examples=[5, 50], |
|
|
) |
|
|
|
|
|
rejected_learnable: int = Field( |
|
|
..., |
|
|
description="Number of rejected-learnable frames (between thresholds)", |
|
|
ge=0, |
|
|
examples=[3, 30], |
|
|
) |
|
|
|
|
|
rejected_outlier: int = Field( |
|
|
..., |
|
|
description="Number of rejected-outlier frames (> reject_threshold)", |
|
|
ge=0, |
|
|
examples=[2, 20], |
|
|
) |
|
|
|
|
|
accepted_percentage: float = Field( |
|
|
..., |
|
|
description="Percentage of accepted frames", |
|
|
ge=0.0, |
|
|
le=100.0, |
|
|
examples=[50.0, 75.5], |
|
|
) |
|
|
|
|
|
rejected_learnable_percentage: float = Field( |
|
|
..., |
|
|
description="Percentage of rejected-learnable frames", |
|
|
ge=0.0, |
|
|
le=100.0, |
|
|
examples=[30.0, 20.5], |
|
|
) |
|
|
|
|
|
rejected_outlier_percentage: float = Field( |
|
|
..., |
|
|
description="Percentage of rejected-outlier frames", |
|
|
ge=0.0, |
|
|
le=100.0, |
|
|
examples=[20.0, 5.0], |
|
|
) |
|
|
|
|
|
ba_status: Optional[str] = Field( |
|
|
None, |
|
|
description="BA validation status", |
|
|
examples=["accepted", "rejected_learnable", "rejected_outlier", "ba_failed"], |
|
|
) |
|
|
|
|
|
max_error_deg: Optional[float] = Field( |
|
|
None, |
|
|
description="Maximum rotation error in degrees", |
|
|
ge=0.0, |
|
|
examples=[177.76, 25.5, 1.2], |
|
|
) |
|
|
|
|
|
model_config = { |
|
|
"json_schema_extra": { |
|
|
"example": { |
|
|
"total_frames": 10, |
|
|
"accepted": 5, |
|
|
"rejected_learnable": 3, |
|
|
"rejected_outlier": 2, |
|
|
"accepted_percentage": 50.0, |
|
|
"rejected_learnable_percentage": 30.0, |
|
|
"rejected_outlier_percentage": 20.0, |
|
|
"ba_status": "rejected_outlier", |
|
|
"max_error_deg": 177.76, |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
class ErrorResponse(BaseModel): |
|
|
"""Standard error response model.""" |
|
|
|
|
|
error: str = Field( |
|
|
..., |
|
|
description="Error type/name", |
|
|
examples=["ValidationError", "FileNotFoundError", "InternalServerError"], |
|
|
) |
|
|
|
|
|
message: str = Field( |
|
|
..., |
|
|
description="Human-readable error message", |
|
|
examples=["Sequence directory not found", "Invalid request data"], |
|
|
) |
|
|
|
|
|
request_id: str = Field( |
|
|
..., |
|
|
description="Request ID for log correlation", |
|
|
examples=["req_1234567890"], |
|
|
) |
|
|
|
|
|
details: Optional[Dict[str, Any]] = Field( |
|
|
None, |
|
|
description="Additional error details", |
|
|
examples=[{"field": "sequence_dir", "error": "Path does not exist"}], |
|
|
) |
|
|
|
|
|
endpoint: Optional[str] = Field( |
|
|
None, |
|
|
description="Endpoint where error occurred", |
|
|
examples=["/api/v1/validate/sequence"], |
|
|
) |
|
|
|
|
|
model_config = { |
|
|
"json_schema_extra": { |
|
|
"example": { |
|
|
"error": "FileNotFoundError", |
|
|
"message": "Sequence directory not found: /invalid/path", |
|
|
"request_id": "req_1234567890", |
|
|
"details": {"path": "/invalid/path"}, |
|
|
"endpoint": "/api/v1/validate/sequence", |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
class HealthResponse(BaseModel): |
|
|
"""Health check response model.""" |
|
|
|
|
|
status: str = Field( |
|
|
..., |
|
|
description="Health status", |
|
|
examples=["healthy", "degraded", "unhealthy"], |
|
|
) |
|
|
|
|
|
timestamp: float = Field( |
|
|
..., |
|
|
description="Unix timestamp of health check", |
|
|
examples=[1701878400.123], |
|
|
) |
|
|
|
|
|
request_id: str = Field( |
|
|
..., |
|
|
description="Request ID", |
|
|
examples=["req_1234567890"], |
|
|
) |
|
|
|
|
|
profiling: Optional[Dict[str, Any]] = Field( |
|
|
None, |
|
|
description="Profiling status if available", |
|
|
examples=[{"enabled": True, "total_entries": 42}], |
|
|
) |
|
|
|
|
|
model_config = { |
|
|
"json_schema_extra": { |
|
|
"example": { |
|
|
"status": "healthy", |
|
|
"timestamp": 1701878400.123, |
|
|
"request_id": "req_1234567890", |
|
|
"profiling": {"enabled": True, "total_entries": 42}, |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
class ModelsResponse(BaseModel): |
|
|
"""Response model for models list endpoint.""" |
|
|
|
|
|
models: Dict[str, Any] = Field( |
|
|
..., |
|
|
description="Dictionary of available models with metadata", |
|
|
examples=[ |
|
|
{ |
|
|
"depth-anything/DA3-LARGE": { |
|
|
"series": "main", |
|
|
"size": "large", |
|
|
"capabilities": ["mono_depth", "pose_estimation"], |
|
|
} |
|
|
} |
|
|
], |
|
|
) |
|
|
|
|
|
recommended: Optional[str] = Field( |
|
|
None, |
|
|
description="Recommended model for the requested use case", |
|
|
examples=["depth-anything/DA3-LARGE"], |
|
|
) |
|
|
|
|
|
model_config = { |
|
|
"json_schema_extra": { |
|
|
"example": { |
|
|
"models": {"depth-anything/DA3-LARGE": {}}, |
|
|
"recommended": "depth-anything/DA3-LARGE", |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
class TrainUnifiedRequest(BaseModel): |
|
|
"""Request model for unified YLFF training.""" |
|
|
|
|
|
preprocessed_cache_dir: str = Field( |
|
|
..., |
|
|
description="Directory containing pre-processed results", |
|
|
examples=["cache/preprocessed"], |
|
|
min_length=1, |
|
|
) |
|
|
|
|
|
arkit_sequences_dir: Optional[str] = Field( |
|
|
None, |
|
|
description="Directory with original ARKit sequences (for loading images)", |
|
|
examples=["data/arkit_sequences"], |
|
|
) |
|
|
|
|
|
model_name: Optional[str] = Field( |
|
|
None, |
|
|
description="DA3 model name (default: auto-select)", |
|
|
examples=["depth-anything/DA3-LARGE"], |
|
|
) |
|
|
|
|
|
epochs: int = Field( |
|
|
200, |
|
|
description="Number of training epochs", |
|
|
ge=1, |
|
|
examples=[100, 200], |
|
|
) |
|
|
|
|
|
lr: float = Field( |
|
|
2e-4, |
|
|
description="Learning rate", |
|
|
gt=0.0, |
|
|
examples=[2e-4], |
|
|
) |
|
|
|
|
|
weight_decay: float = Field( |
|
|
0.04, |
|
|
description="Weight decay", |
|
|
ge=0.0, |
|
|
examples=[0.04], |
|
|
) |
|
|
|
|
|
batch_size: int = Field( |
|
|
32, |
|
|
description="Batch size per GPU", |
|
|
ge=1, |
|
|
examples=[32, 64], |
|
|
) |
|
|
|
|
|
device: DeviceType = Field( |
|
|
DeviceType.CUDA, |
|
|
description="Device for training", |
|
|
examples=["cuda", "cpu"], |
|
|
) |
|
|
|
|
|
checkpoint_dir: str = Field( |
|
|
"checkpoints/ylff_training", |
|
|
description="Checkpoint directory", |
|
|
examples=["checkpoints/ylff_training"], |
|
|
) |
|
|
|
|
|
log_interval: int = Field( |
|
|
10, |
|
|
description="Log metrics every N steps", |
|
|
ge=1, |
|
|
) |
|
|
|
|
|
save_interval: int = Field( |
|
|
1000, |
|
|
description="Save checkpoint every N steps", |
|
|
ge=1, |
|
|
) |
|
|
|
|
|
use_fp16: bool = Field(True, description="Use FP16 mixed precision") |
|
|
use_bf16: bool = Field(False, description="Use BF16 mixed precision") |
|
|
|
|
|
ema_decay: float = Field(0.999, description="EMA decay rate for teacher") |
|
|
|
|
|
use_wandb: bool = Field(True, description="Enable Weights & Biases logging") |
|
|
wandb_project: str = Field("ylff", description="W&B project name") |
|
|
|
|
|
gradient_accumulation_steps: int = Field(1, description="Gradient accumulation steps", ge=1) |
|
|
gradient_clip_norm: float = Field(1.0, description="Gradient clipping norm") |
|
|
|
|
|
num_workers: Optional[int] = Field(None, description="Number of data loading workers") |
|
|
resume_from_checkpoint: Optional[str] = Field(None, description="Resume from checkpoint path") |
|
|
|
|
|
use_fsdp: bool = Field(False, description="Enable FSDP (single-GPU stub)") |
|
|
|
|
|
model_config = { |
|
|
"json_schema_extra": { |
|
|
"example": { |
|
|
"preprocessed_cache_dir": "cache/preprocessed", |
|
|
"arkit_sequences_dir": "data/arkit_sequences", |
|
|
"epochs": 200, |
|
|
"batch_size": 32, |
|
|
"model_name": "depth-anything/DA3-SMALL", |
|
|
} |
|
|
} |
|
|
} |
|
|
|