3d_model / ylff /models /api_models.py
Azan
Clean deployment build (Squashed)
7a87926
"""
Pydantic models for YLFF API request/response schemas.
All API models are rigorously defined with:
- Comprehensive field validation
- Detailed descriptions and examples
- Type hints and optional field defaults
- JSON schema generation support
"""
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, field_validator
# Enums for type safety
class JobStatus(str, Enum):
"""Job execution status."""
QUEUED = "queued"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class DeviceType(str, Enum):
"""Device type for model inference/training."""
CPU = "cpu"
CUDA = "cuda"
MPS = "mps" # Apple Metal Performance Shaders
class UseCase(str, Enum):
"""Use case for model selection."""
BA_VALIDATION = "ba_validation"
MONO_DEPTH = "mono_depth"
MULTI_VIEW = "multi_view"
POSE_CONDITIONED = "pose_conditioned"
TRAINING = "training"
INFERENCE = "inference"
# Request Models
class ValidateSequenceRequest(BaseModel):
"""Request model for sequence validation."""
sequence_dir: str = Field(
...,
description="Directory containing image sequence",
examples=["data/sequences/sequence_001"],
min_length=1,
)
model_name: Optional[str] = Field(
None,
description="DA3 model name (default: auto-select based on use_case)",
examples=["depth-anything/DA3-LARGE", "depth-anything/DA3-GIANT"],
)
use_case: UseCase = Field(
UseCase.BA_VALIDATION,
description="Use case for model selection",
examples=["ba_validation", "mono_depth"],
)
accept_threshold: float = Field(
2.0,
description=(
"Accept threshold in degrees " "(frames with rotation error below this are accepted)"
),
ge=0.0,
le=180.0,
examples=[2.0, 5.0],
)
reject_threshold: float = Field(
30.0,
description=(
"Reject threshold in degrees "
"(frames with rotation error above this are rejected as outliers)"
),
ge=0.0,
le=180.0,
examples=[30.0, 45.0],
)
output: Optional[str] = Field(
None,
description="Output JSON path for validation results",
examples=["data/results/validation.json"],
)
@field_validator("reject_threshold")
@classmethod
def reject_greater_than_accept(cls, v, info):
"""Ensure reject threshold is greater than accept threshold."""
accept_thresh = info.data.get("accept_threshold")
if accept_thresh is not None and v <= accept_thresh:
raise ValueError(
f"reject_threshold ({v}) must be greater than "
f"accept_threshold ({accept_thresh})"
)
return v
@field_validator("sequence_dir")
@classmethod
def validate_sequence_dir(cls, v):
"""Validate sequence directory path format."""
if not v or not v.strip():
raise ValueError("sequence_dir cannot be empty")
return v.strip()
model_config = {
"json_schema_extra": {
"example": {
"sequence_dir": "data/sequences/sequence_001",
"model_name": "depth-anything/DA3-LARGE",
"use_case": "ba_validation",
"accept_threshold": 2.0,
"reject_threshold": 30.0,
"output": "data/results/validation.json",
}
}
}
class ValidateARKitRequest(BaseModel):
"""Request model for ARKit validation."""
arkit_dir: str = Field(
...,
description="Directory containing ARKit video and JSON metadata",
examples=["assets/examples/ARKit", "data/arkit_recordings/session_001"],
min_length=1,
)
output_dir: str = Field(
"data/arkit_validation",
description="Output directory for validation results",
examples=["data/arkit_validation", "data/results/arkit_001"],
)
model_name: Optional[str] = Field(
None,
description="DA3 model name (default: DA3NESTED-GIANT-LARGE for BA validation)",
examples=["depth-anything/DA3-LARGE", "depth-anything/DA3NESTED-GIANT-LARGE"],
)
max_frames: Optional[int] = Field(
None,
description="Maximum number of frames to process (None = process all)",
ge=1,
examples=[10, 30, 100],
)
frame_interval: int = Field(
1,
description="Extract every Nth frame (1 = all frames, 5 = every 5th frame)",
ge=1,
examples=[1, 5, 10],
)
device: DeviceType = Field(
DeviceType.CPU,
description="Device for DA3 inference",
examples=["cpu", "cuda", "mps"],
)
gui: bool = Field(
False,
description="Show real-time GUI visualization during validation",
examples=[False, True],
)
@field_validator("arkit_dir")
@classmethod
def validate_arkit_dir(cls, v):
"""Validate ARKit directory path format."""
if not v or not v.strip():
raise ValueError("arkit_dir cannot be empty")
return v.strip()
model_config = {
"json_schema_extra": {
"example": {
"arkit_dir": "assets/examples/ARKit",
"output_dir": "data/arkit_validation",
"model_name": "depth-anything/DA3NESTED-GIANT-LARGE",
"max_frames": 30,
"frame_interval": 1,
"device": "cpu",
"gui": False,
}
}
}
class BuildDatasetRequest(BaseModel):
"""Request model for building training dataset."""
sequences_dir: str = Field(
...,
description="Directory containing sequence directories",
examples=["data/raw/sequences", "data/collected/sequences"],
min_length=1,
)
output_dir: str = Field(
"data/training",
description="Output directory for training dataset",
examples=["data/training", "data/training/dataset_v1"],
)
model_name: Optional[str] = Field(
None,
description="DA3 model name for validation",
examples=["depth-anything/DA3-LARGE"],
)
max_samples: Optional[int] = Field(
None,
description="Maximum number of training samples to generate (None = no limit)",
ge=1,
examples=[100, 500, 1000],
)
accept_threshold: float = Field(
2.0,
description="Accept threshold in degrees",
ge=0.0,
le=180.0,
examples=[2.0],
)
reject_threshold: float = Field(
30.0,
description="Reject threshold in degrees",
ge=0.0,
le=180.0,
examples=[30.0],
)
use_wandb: bool = Field(
True,
description="Enable Weights & Biases logging",
examples=[True, False],
)
wandb_project: str = Field(
"ylff",
description="W&B project name",
examples=["ylff", "ylff-datasets"],
min_length=1,
)
wandb_name: Optional[str] = Field(
None,
description="W&B run name (default: auto-generated)",
examples=["dataset-build-2024-12-06", "v1-training-set"],
)
# Optimization parameters
use_batched_inference: bool = Field(
False,
description="Use batched inference for better GPU utilization",
examples=[False, True],
)
inference_batch_size: int = Field(
4,
description="Batch size for inference (when use_batched_inference=True)",
ge=1,
examples=[2, 4, 8],
)
use_inference_cache: bool = Field(
False,
description="Cache inference results to avoid recomputing identical sequences",
examples=[False, True],
)
cache_dir: Optional[str] = Field(
None,
description="Directory for inference cache (None = in-memory only)",
examples=[None, "cache/inference"],
)
compile_model: bool = Field(
True,
description="Compile model with torch.compile for faster inference",
examples=[True, False],
)
@field_validator("reject_threshold")
@classmethod
def reject_greater_than_accept(cls, v, info):
"""Ensure reject threshold is greater than accept threshold."""
accept_thresh = info.data.get("accept_threshold")
if accept_thresh is not None and v <= accept_thresh:
raise ValueError(
f"reject_threshold ({v}) must be greater than "
f"accept_threshold ({accept_thresh})"
)
return v
model_config = {
"json_schema_extra": {
"example": {
"sequences_dir": "data/raw/sequences",
"output_dir": "data/training",
"model_name": "depth-anything/DA3-LARGE",
"max_samples": 1000,
"accept_threshold": 2.0,
"reject_threshold": 30.0,
"use_wandb": True,
"wandb_project": "ylff",
"wandb_name": "dataset-build-2024-12-06",
}
}
}
class TrainRequest(BaseModel):
"""Request model for model fine-tuning."""
training_data_dir: str = Field(
...,
description="Directory containing training samples",
examples=["data/training", "data/training/dataset_v1"],
min_length=1,
)
model_name: Optional[str] = Field(
None,
description="DA3 model name to fine-tune",
examples=["depth-anything/DA3-LARGE"],
)
epochs: int = Field(
10,
description="Number of training epochs",
ge=1,
le=1000,
examples=[10, 20, 50],
)
lr: float = Field(
1e-5,
description="Learning rate",
gt=0.0,
examples=[1e-5, 1e-4, 1e-6],
)
batch_size: int = Field(
1,
description="Training batch size",
ge=1,
examples=[1, 2, 4, 8],
)
checkpoint_dir: str = Field(
"checkpoints",
description="Directory to save model checkpoints",
examples=["checkpoints", "models/checkpoints"],
)
device: DeviceType = Field(
DeviceType.CUDA,
description="Device for training",
examples=["cuda", "cpu", "mps"],
)
use_wandb: bool = Field(
True,
description="Enable Weights & Biases logging",
examples=[True, False],
)
wandb_project: str = Field(
"ylff",
description="W&B project name",
examples=["ylff", "ylff-training"],
)
wandb_name: Optional[str] = Field(
None,
description="W&B run name",
examples=["fine-tune-v1", "training-run-2024-12-06"],
)
# Optimization parameters
gradient_accumulation_steps: int = Field(
1,
description=(
"Number of steps to accumulate gradients " "(effective batch size = batch_size * this)"
),
ge=1,
examples=[1, 4, 8],
)
use_amp: bool = Field(
True,
description="Use automatic mixed precision training (FP16)",
examples=[True, False],
)
warmup_steps: int = Field(
0,
description="Number of warmup steps for learning rate (0 = no warmup)",
ge=0,
examples=[0, 100, 500],
)
num_workers: Optional[int] = Field(
None,
description="Number of data loading workers (None = auto-detect)",
ge=0,
examples=[None, 2, 4, 8],
)
resume_from_checkpoint: Optional[str] = Field(
None,
description="Path to checkpoint to resume from",
examples=[None, "checkpoints/latest.pth", "checkpoints/best.pth"],
)
use_ema: bool = Field(
False,
description="Use Exponential Moving Average for model weights",
examples=[False, True],
)
ema_decay: float = Field(
0.9999,
description="EMA decay factor (higher = slower update, more stable)",
gt=0.0,
lt=1.0,
examples=[0.999, 0.9999, 0.99999],
)
use_onecycle: bool = Field(
False,
description="Use OneCycleLR scheduler instead of CosineAnnealingLR",
examples=[False, True],
)
use_gradient_checkpointing: bool = Field(
False,
description="Enable gradient checkpointing to save memory (slower but uses less memory)",
examples=[False, True],
)
compile_model: bool = Field(
True,
description="Compile model with torch.compile for faster training (PyTorch 2.0+)",
examples=[True, False],
)
# Phase 4 optimizations
use_bf16: bool = Field(
False,
description="Use BF16 instead of FP16 (better training stability, same speed)",
examples=[False, True],
)
gradient_clip_norm: Optional[float] = Field(
1.0,
description="Maximum gradient norm for clipping (None = disabled, 1.0 = default)",
ge=0.0,
examples=[None, 0.5, 1.0, 2.0],
)
find_lr: bool = Field(
False,
description="Automatically find optimal learning rate before training",
examples=[False, True],
)
find_batch_size: bool = Field(
False,
description="Automatically find optimal batch size before training",
examples=[False, True],
)
# FSDP options
use_fsdp: bool = Field(
False,
description=(
"Use FSDP (Fully Sharded Data Parallel) " "for memory-efficient multi-GPU training"
),
examples=[False, True],
)
fsdp_sharding_strategy: str = Field(
"FULL_SHARD",
description="FSDP sharding strategy: FULL_SHARD, SHARD_GRAD_OP, or NO_SHARD",
examples=["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD"],
)
fsdp_mixed_precision: Optional[str] = Field(
None,
description="FSDP mixed precision: bf16, fp16, or None (auto-detects from use_bf16)",
examples=[None, "bf16", "fp16"],
)
# Advanced optimizations
use_qat: bool = Field(
False,
description=("Use Quantization Aware Training (QAT) for better INT8 quantization"),
examples=[False, True],
)
qat_backend: str = Field(
"fbgemm",
description="QAT backend: fbgemm (x86) or qnnpack (ARM)",
examples=["fbgemm", "qnnpack"],
)
use_sequence_parallel: bool = Field(
False,
description="Enable sequence parallelism for very long sequences",
examples=[False, True],
)
sequence_parallel_gpus: int = Field(
1,
description="Number of GPUs for sequence parallelism",
ge=1,
examples=[1, 2, 4],
)
activation_recompute_strategy: Optional[str] = Field(
None,
description="Activation recompute strategy: checkpoint, cpu_offload, hybrid, or None",
examples=[None, "checkpoint", "cpu_offload", "hybrid"],
)
# Checkpoint options
async_checkpoint: bool = Field(
True,
description="Use async checkpoint saving (non-blocking, faster training)",
examples=[True, False],
)
compress_checkpoint: bool = Field(
True,
description="Compress checkpoints with gzip (30-50% smaller files)",
examples=[True, False],
)
model_config = {
"json_schema_extra": {
"example": {
"training_data_dir": "data/training",
"model_name": "depth-anything/DA3-LARGE",
"epochs": 10,
"lr": 1e-5,
"batch_size": 1,
"checkpoint_dir": "checkpoints",
"device": "cuda",
"use_wandb": True,
"wandb_project": "ylff",
"wandb_name": "fine-tune-v1",
}
}
}
class PretrainRequest(BaseModel):
"""Request model for model pre-training on ARKit sequences."""
arkit_sequences_dir: str = Field(
...,
description="Directory containing ARKit sequence directories",
examples=["data/arkit_sequences", "data/collected/arkit"],
min_length=1,
)
model_name: Optional[str] = Field(
None,
description="DA3 model name to pre-train",
examples=["depth-anything/DA3-LARGE"],
)
epochs: int = Field(
10,
description="Number of pre-training epochs",
ge=1,
le=1000,
examples=[5, 10, 20],
)
lr: float = Field(
1e-4,
description="Learning rate for pre-training",
gt=0.0,
examples=[1e-4, 1e-3],
)
batch_size: int = Field(
1,
description="Pre-training batch size",
ge=1,
examples=[1, 2, 4],
)
checkpoint_dir: str = Field(
"checkpoints/pretrain",
description="Directory to save model checkpoints",
examples=["checkpoints/pretrain"],
)
device: DeviceType = Field(
DeviceType.CUDA,
description="Device for pre-training",
examples=["cuda"],
)
max_sequences: Optional[int] = Field(
None,
description="Maximum number of sequences to process (None = all)",
ge=1,
examples=[10, 50, 100],
)
max_frames_per_sequence: Optional[int] = Field(
None,
description="Maximum frames per sequence to process (None = all)",
ge=1,
examples=[30, 100],
)
frame_interval: int = Field(
1,
description="Extract every Nth frame",
ge=1,
examples=[1, 5, 10],
)
use_lidar: bool = Field(
False,
description="Use ARKit LiDAR depth as supervision signal",
examples=[False, True],
)
use_ba_depth: bool = Field(
False,
description="Use BA depth maps as supervision signal",
examples=[False, True],
)
min_ba_quality: float = Field(
0.0,
description="Minimum BA quality threshold (0.0-1.0)",
ge=0.0,
le=1.0,
examples=[0.0, 0.5, 0.8],
)
use_wandb: bool = Field(
True,
description="Enable Weights & Biases logging",
examples=[True, False],
)
wandb_project: str = Field(
"ylff",
description="W&B project name",
examples=["ylff", "ylff-pretraining"],
)
wandb_name: Optional[str] = Field(
None,
description="W&B run name",
examples=["pretrain-v1", "pretrain-arkit-2024-12-06"],
)
# Optimization parameters
gradient_accumulation_steps: int = Field(
1,
description="Number of steps to accumulate gradients",
ge=1,
examples=[1, 4, 8],
)
use_amp: bool = Field(
True,
description="Use automatic mixed precision training (FP16)",
examples=[True, False],
)
warmup_steps: int = Field(
0,
description="Number of warmup steps for learning rate",
ge=0,
examples=[0, 100, 500],
)
num_workers: Optional[int] = Field(
None,
description="Number of data loading workers (None = auto-detect)",
ge=0,
examples=[None, 2, 4, 8],
)
resume_from_checkpoint: Optional[str] = Field(
None,
description="Path to checkpoint to resume from",
examples=[None, "checkpoints/pretrain/latest.pth"],
)
use_ema: bool = Field(
False,
description="Use Exponential Moving Average for model weights",
examples=[False, True],
)
ema_decay: float = Field(
0.9999,
description="EMA decay factor",
gt=0.0,
lt=1.0,
examples=[0.9999],
)
use_onecycle: bool = Field(
False,
description="Use OneCycleLR scheduler instead of CosineAnnealingLR",
examples=[False, True],
)
use_gradient_checkpointing: bool = Field(
False,
description="Enable gradient checkpointing to save memory",
examples=[False, True],
)
compile_model: bool = Field(
True,
description="Compile model with torch.compile for faster training",
examples=[True, False],
)
cache_dir: Optional[str] = Field(
None,
description="Directory for caching BA results (None = disabled)",
examples=[None, "cache/ba_results"],
)
# Phase 4 optimizations
use_bf16: bool = Field(
False,
description="Use BF16 instead of FP16 (better training stability, same speed)",
examples=[False, True],
)
gradient_clip_norm: Optional[float] = Field(
1.0,
description="Maximum gradient norm for clipping (None = disabled, 1.0 = default)",
ge=0.0,
examples=[None, 0.5, 1.0, 2.0],
)
find_lr: bool = Field(
False,
description="Automatically find optimal learning rate before training",
examples=[False, True],
)
find_batch_size: bool = Field(
False,
description="Automatically find optimal batch size before training",
examples=[False, True],
)
# FSDP options
use_fsdp: bool = Field(
False,
description=(
"Use FSDP (Fully Sharded Data Parallel) " "for memory-efficient multi-GPU training"
),
examples=[False, True],
)
fsdp_sharding_strategy: str = Field(
"FULL_SHARD",
description="FSDP sharding strategy: FULL_SHARD, SHARD_GRAD_OP, or NO_SHARD",
examples=["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD"],
)
fsdp_mixed_precision: Optional[str] = Field(
None,
description="FSDP mixed precision: bf16, fp16, or None (auto-detects from use_bf16)",
examples=[None, "bf16", "fp16"],
)
# Advanced optimizations
use_qat: bool = Field(
False,
description=("Use Quantization Aware Training (QAT) for better INT8 quantization"),
examples=[False, True],
)
qat_backend: str = Field(
"fbgemm",
description="QAT backend: fbgemm (x86) or qnnpack (ARM)",
examples=["fbgemm", "qnnpack"],
)
use_sequence_parallel: bool = Field(
False,
description="Enable sequence parallelism for very long sequences",
examples=[False, True],
)
sequence_parallel_gpus: int = Field(
1,
description="Number of GPUs for sequence parallelism",
ge=1,
examples=[1, 2, 4],
)
activation_recompute_strategy: Optional[str] = Field(
None,
description="Activation recompute strategy: checkpoint, cpu_offload, hybrid, or None",
examples=[None, "checkpoint", "cpu_offload", "hybrid"],
)
# Checkpoint options
async_checkpoint: bool = Field(
True,
description="Use async checkpoint saving (non-blocking, faster training)",
examples=[True, False],
)
compress_checkpoint: bool = Field(
True,
description="Compress checkpoints with gzip (30-50% smaller files)",
examples=[True, False],
)
# ARKit pose options
prefer_arkit_poses: bool = Field(
True,
description=(
"Use ARKit poses directly when tracking quality is good "
"(much faster, skips BA for good sequences)"
),
examples=[True, False],
)
min_arkit_quality: float = Field(
0.8,
description=(
"Minimum fraction of frames with good tracking to use ARKit poses directly "
"(0.0-1.0, higher = stricter)"
),
ge=0.0,
le=1.0,
examples=[0.7, 0.8, 0.9],
)
@field_validator("arkit_sequences_dir")
@classmethod
def validate_arkit_sequences_dir(cls, v):
"""Validate ARKit sequences directory path format."""
if not v or not v.strip():
raise ValueError("arkit_sequences_dir cannot be empty")
return v.strip()
model_config = {
"json_schema_extra": {
"example": {
"arkit_sequences_dir": "data/arkit_sequences",
"model_name": "depth-anything/DA3-LARGE",
"epochs": 10,
"lr": 1e-4,
"batch_size": 1,
"checkpoint_dir": "checkpoints/pretrain",
"device": "cuda",
"max_sequences": None,
"max_frames_per_sequence": None,
"frame_interval": 1,
"use_lidar": False,
"use_ba_depth": False,
"min_ba_quality": 0.0,
"use_wandb": True,
"wandb_project": "ylff",
"wandb_name": "pretrain-v1",
}
}
}
class EvaluateBAAgreementRequest(BaseModel):
"""Request model for BA agreement evaluation."""
test_data_dir: str = Field(
...,
description="Directory containing test sequences",
examples=["data/test", "data/validation"],
min_length=1,
)
model_name: str = Field(
"depth-anything/DA3-LARGE",
description="DA3 model name",
examples=["depth-anything/DA3-LARGE", "depth-anything/DA3-GIANT"],
)
checkpoint: Optional[str] = Field(
None,
description="Path to model checkpoint (optional, overrides model_name)",
examples=["checkpoints/best_model.pth", "checkpoints/epoch_10.pth"],
)
threshold: float = Field(
2.0,
description="Agreement threshold in degrees",
ge=0.0,
le=180.0,
examples=[2.0, 5.0],
)
device: DeviceType = Field(
DeviceType.CUDA,
description="Device for inference",
examples=["cuda", "cpu"],
)
use_wandb: bool = Field(
True,
description="Enable Weights & Biases logging",
examples=[True, False],
)
wandb_project: str = Field(
"ylff",
description="W&B project name",
examples=["ylff", "ylff-evaluation"],
)
wandb_name: Optional[str] = Field(
None,
description="W&B run name",
examples=["eval-ba-agreement-v1", "eval-checkpoint-best"],
)
@field_validator("test_data_dir")
@classmethod
def validate_test_data_dir(cls, v):
"""Validate test data directory path format."""
if not v or not v.strip():
raise ValueError("test_data_dir cannot be empty")
return v.strip()
model_config = {
"json_schema_extra": {
"example": {
"test_data_dir": "data/test",
"model_name": "depth-anything/DA3-LARGE",
"checkpoint": None,
"threshold": 2.0,
"device": "cuda",
"use_wandb": True,
"wandb_project": "ylff",
"wandb_name": "eval-ba-agreement-v1",
}
}
}
class VisualizeRequest(BaseModel):
"""Request model for result visualization."""
results_dir: str = Field(
...,
description="Directory containing validation results",
examples=["data/arkit_validation", "data/validation_results"],
min_length=1,
)
output_dir: Optional[str] = Field(
None,
description="Output directory for visualizations (default: results_dir/visualizations)",
examples=["data/arkit_validation/visualizations"],
)
use_plotly: bool = Field(
True,
description="Use Plotly for interactive 3D plots (requires plotly package)",
examples=[False, True],
)
@field_validator("results_dir")
@classmethod
def validate_results_dir(cls, v):
"""Validate results directory path format."""
if not v or not v.strip():
raise ValueError("results_dir cannot be empty")
return v.strip()
model_config = {
"json_schema_extra": {
"example": {
"results_dir": "data/arkit_validation",
"output_dir": None,
"use_plotly": False,
}
}
}
# Response Models
class ValidateDatasetRequest(BaseModel):
"""Request model for dataset validation."""
dataset_path: str = Field(
...,
description="Path to dataset file (pickle, json, or hdf5)",
examples=["data/training/dataset.pkl", "data/training/dataset.json"],
)
strict: bool = Field(
False,
description="If True, raise exception on validation failure",
examples=[False, True],
)
check_images: bool = Field(
True,
description="Validate image data",
examples=[True, False],
)
check_poses: bool = Field(
True,
description="Validate pose data",
examples=[True, False],
)
check_metadata: bool = Field(
True,
description="Validate metadata fields",
examples=[True, False],
)
class CurateDatasetRequest(BaseModel):
"""Request model for dataset curation."""
dataset_path: str = Field(
...,
description="Path to input dataset file",
examples=["data/training/dataset.pkl"],
)
output_path: str = Field(
...,
description="Path to save curated dataset",
examples=["data/training/dataset_curated.pkl"],
)
# Filtering options
min_error: Optional[float] = Field(
None,
description="Minimum error threshold",
examples=[None, 0.5, 1.0],
)
max_error: Optional[float] = Field(
None,
description="Maximum error threshold",
examples=[None, 30.0, 50.0],
)
min_weight: Optional[float] = Field(
None,
description="Minimum weight threshold",
examples=[None, 0.1, 0.5],
)
max_weight: Optional[float] = Field(
None,
description="Maximum weight threshold",
examples=[None, 1.0, 2.0],
)
# Outlier removal
remove_outliers: bool = Field(
False,
description="Remove outlier samples",
examples=[False, True],
)
outlier_percentile: float = Field(
95.0,
description="Percentile threshold for outlier detection",
ge=0.0,
le=100.0,
examples=[95.0, 99.0],
)
# Balancing
balance: bool = Field(
False,
description="Balance dataset by error distribution",
examples=[False, True],
)
balance_strategy: str = Field(
"error_bins",
description="Balancing strategy: error_bins, uniform, or weighted",
examples=["error_bins", "uniform", "weighted"],
)
num_bins: int = Field(
10,
description="Number of error bins for balancing",
ge=2,
examples=[10, 20],
)
class AnalyzeDatasetRequest(BaseModel):
"""Request model for dataset analysis."""
dataset_path: str = Field(
...,
description="Path to dataset file",
examples=["data/training/dataset.pkl"],
)
output_path: Optional[str] = Field(
None,
description="Path to save analysis report",
examples=[None, "data/training/analysis.json"],
)
format: str = Field(
"json",
description="Report format: json, text, or markdown",
examples=["json", "text", "markdown"],
)
compute_distributions: bool = Field(
True,
description="Compute error/weight distributions",
examples=[True, False],
)
compute_correlations: bool = Field(
True,
description="Compute correlations between metrics",
examples=[True, False],
)
class DatasetValidationResponse(BaseModel):
"""Response model for dataset validation."""
validation_passed: bool = Field(
...,
description="Whether validation passed",
examples=[True, False],
)
statistics: Dict[str, Any] = Field(
...,
description="Dataset statistics",
)
issues: List[Dict[str, Any]] = Field(
...,
description="List of validation issues",
)
summary: Dict[str, Any] = Field(
...,
description="Validation summary",
)
class DatasetAnalysisResponse(BaseModel):
"""Response model for dataset analysis."""
statistics: Dict[str, Any] = Field(
...,
description="Dataset statistics",
)
quality_metrics: Dict[str, Any] = Field(
...,
description="Quality metrics",
)
report: Optional[str] = Field(
None,
description="Human-readable report (if format was text/markdown)",
)
class UploadDatasetRequest(BaseModel):
"""Request model for dataset upload."""
output_dir: str = Field(
...,
description="Directory to extract uploaded dataset",
examples=["data/uploaded_datasets", "data/arkit_sequences"],
)
should_validate: bool = Field(
True,
alias="validate",
description="Validate ARKit pairs before extraction",
examples=[True, False],
)
class DownloadDatasetRequest(BaseModel):
"""Request model for dataset download from S3."""
bucket_name: str = Field(
...,
description="S3 bucket name",
examples=["my-datasets-bucket", "ylff-datasets"],
)
s3_key: str = Field(
...,
description="S3 object key (path to dataset file)",
examples=["datasets/arkit_sequences.zip", "datasets/training_set_v1.tar.gz"],
)
output_dir: str = Field(
...,
description="Directory to save downloaded dataset",
examples=["data/downloaded_datasets", "data/arkit_sequences"],
)
extract: bool = Field(
True,
description="Extract downloaded archive",
examples=[True, False],
)
aws_access_key_id: Optional[str] = Field(
None,
description="AWS access key ID (optional, uses credentials chain if None)",
examples=[None, "AKIAIOSFODNN7EXAMPLE"],
)
aws_secret_access_key: Optional[str] = Field(
None,
description="AWS secret access key (optional)",
examples=[None, "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"],
)
region_name: str = Field(
"us-east-1",
description="AWS region name",
examples=["us-east-1", "us-west-2", "eu-west-1"],
)
class UploadDatasetResponse(BaseModel):
"""Response model for dataset upload."""
success: bool = Field(
...,
description="Whether upload was successful",
examples=[True, False],
)
output_dir: str = Field(
...,
description="Directory where dataset was extracted",
)
metadata: Dict[str, Any] = Field(
...,
description="Upload metadata (file counts, pairs, etc.)",
)
errors: List[str] = Field(
default_factory=list,
description="List of validation/processing errors",
)
class DownloadDatasetResponse(BaseModel):
"""Response model for dataset download."""
success: bool = Field(
...,
description="Whether download was successful",
examples=[True, False],
)
output_path: Optional[str] = Field(
None,
description="Path to downloaded file (if not extracted)",
)
output_dir: Optional[str] = Field(
None,
description="Directory where dataset was extracted (if extracted)",
)
file_size: Optional[int] = Field(
None,
description="Size of downloaded file in bytes",
)
error: Optional[str] = Field(
None,
description="Error message if download failed",
)
class JobResponse(BaseModel):
"""Response model for job-based endpoints."""
job_id: str = Field(
...,
description="Unique job identifier",
examples=["550e8400-e29b-41d4-a716-446655440000"],
)
status: JobStatus = Field(
...,
description="Current job status",
examples=["queued", "running", "completed", "failed"],
)
message: Optional[str] = Field(
None,
description="Status message or error description",
examples=["Job queued", "Validation completed successfully"],
)
result: Optional[Dict[str, Any]] = Field(
None,
description="Job result data (only present when status is 'completed' or 'failed')",
examples=[
{
"success": True,
"stdout": "Output text",
"stderr": "",
"validation_stats": {
"total_frames": 10,
"accepted": 5,
"rejected_learnable": 3,
"rejected_outlier": 2,
},
}
],
)
model_config = {
"json_schema_extra": {
"example": {
"job_id": "550e8400-e29b-41d4-a716-446655440000",
"status": "completed",
"message": "ARKit validation completed successfully",
"result": {
"success": True,
"validation_stats": {
"total_frames": 10,
"accepted": 0,
"rejected_learnable": 0,
"rejected_outlier": 10,
},
},
}
}
}
class ValidationStats(BaseModel):
"""Statistics from BA validation."""
total_frames: int = Field(
...,
description="Total number of frames processed",
ge=0,
examples=[10, 100],
)
accepted: int = Field(
...,
description="Number of accepted frames (< accept_threshold)",
ge=0,
examples=[5, 50],
)
rejected_learnable: int = Field(
...,
description="Number of rejected-learnable frames (between thresholds)",
ge=0,
examples=[3, 30],
)
rejected_outlier: int = Field(
...,
description="Number of rejected-outlier frames (> reject_threshold)",
ge=0,
examples=[2, 20],
)
accepted_percentage: float = Field(
...,
description="Percentage of accepted frames",
ge=0.0,
le=100.0,
examples=[50.0, 75.5],
)
rejected_learnable_percentage: float = Field(
...,
description="Percentage of rejected-learnable frames",
ge=0.0,
le=100.0,
examples=[30.0, 20.5],
)
rejected_outlier_percentage: float = Field(
...,
description="Percentage of rejected-outlier frames",
ge=0.0,
le=100.0,
examples=[20.0, 5.0],
)
ba_status: Optional[str] = Field(
None,
description="BA validation status",
examples=["accepted", "rejected_learnable", "rejected_outlier", "ba_failed"],
)
max_error_deg: Optional[float] = Field(
None,
description="Maximum rotation error in degrees",
ge=0.0,
examples=[177.76, 25.5, 1.2],
)
model_config = {
"json_schema_extra": {
"example": {
"total_frames": 10,
"accepted": 5,
"rejected_learnable": 3,
"rejected_outlier": 2,
"accepted_percentage": 50.0,
"rejected_learnable_percentage": 30.0,
"rejected_outlier_percentage": 20.0,
"ba_status": "rejected_outlier",
"max_error_deg": 177.76,
}
}
}
class ErrorResponse(BaseModel):
"""Standard error response model."""
error: str = Field(
...,
description="Error type/name",
examples=["ValidationError", "FileNotFoundError", "InternalServerError"],
)
message: str = Field(
...,
description="Human-readable error message",
examples=["Sequence directory not found", "Invalid request data"],
)
request_id: str = Field(
...,
description="Request ID for log correlation",
examples=["req_1234567890"],
)
details: Optional[Dict[str, Any]] = Field(
None,
description="Additional error details",
examples=[{"field": "sequence_dir", "error": "Path does not exist"}],
)
endpoint: Optional[str] = Field(
None,
description="Endpoint where error occurred",
examples=["/api/v1/validate/sequence"],
)
model_config = {
"json_schema_extra": {
"example": {
"error": "FileNotFoundError",
"message": "Sequence directory not found: /invalid/path",
"request_id": "req_1234567890",
"details": {"path": "/invalid/path"},
"endpoint": "/api/v1/validate/sequence",
}
}
}
class HealthResponse(BaseModel):
"""Health check response model."""
status: str = Field(
...,
description="Health status",
examples=["healthy", "degraded", "unhealthy"],
)
timestamp: float = Field(
...,
description="Unix timestamp of health check",
examples=[1701878400.123],
)
request_id: str = Field(
...,
description="Request ID",
examples=["req_1234567890"],
)
profiling: Optional[Dict[str, Any]] = Field(
None,
description="Profiling status if available",
examples=[{"enabled": True, "total_entries": 42}],
)
model_config = {
"json_schema_extra": {
"example": {
"status": "healthy",
"timestamp": 1701878400.123,
"request_id": "req_1234567890",
"profiling": {"enabled": True, "total_entries": 42},
}
}
}
class ModelsResponse(BaseModel):
"""Response model for models list endpoint."""
models: Dict[str, Any] = Field(
...,
description="Dictionary of available models with metadata",
examples=[
{
"depth-anything/DA3-LARGE": {
"series": "main",
"size": "large",
"capabilities": ["mono_depth", "pose_estimation"],
}
}
],
)
recommended: Optional[str] = Field(
None,
description="Recommended model for the requested use case",
examples=["depth-anything/DA3-LARGE"],
)
model_config = {
"json_schema_extra": {
"example": {
"models": {"depth-anything/DA3-LARGE": {}},
"recommended": "depth-anything/DA3-LARGE",
}
}
}
class TrainUnifiedRequest(BaseModel):
"""Request model for unified YLFF training."""
preprocessed_cache_dir: str = Field(
...,
description="Directory containing pre-processed results",
examples=["cache/preprocessed"],
min_length=1,
)
arkit_sequences_dir: Optional[str] = Field(
None,
description="Directory with original ARKit sequences (for loading images)",
examples=["data/arkit_sequences"],
)
model_name: Optional[str] = Field(
None,
description="DA3 model name (default: auto-select)",
examples=["depth-anything/DA3-LARGE"],
)
epochs: int = Field(
200,
description="Number of training epochs",
ge=1,
examples=[100, 200],
)
lr: float = Field(
2e-4,
description="Learning rate",
gt=0.0,
examples=[2e-4],
)
weight_decay: float = Field(
0.04,
description="Weight decay",
ge=0.0,
examples=[0.04],
)
batch_size: int = Field(
32,
description="Batch size per GPU",
ge=1,
examples=[32, 64],
)
device: DeviceType = Field(
DeviceType.CUDA,
description="Device for training",
examples=["cuda", "cpu"],
)
checkpoint_dir: str = Field(
"checkpoints/ylff_training",
description="Checkpoint directory",
examples=["checkpoints/ylff_training"],
)
log_interval: int = Field(
10,
description="Log metrics every N steps",
ge=1,
)
save_interval: int = Field(
1000,
description="Save checkpoint every N steps",
ge=1,
)
use_fp16: bool = Field(True, description="Use FP16 mixed precision")
use_bf16: bool = Field(False, description="Use BF16 mixed precision")
ema_decay: float = Field(0.999, description="EMA decay rate for teacher")
use_wandb: bool = Field(True, description="Enable Weights & Biases logging")
wandb_project: str = Field("ylff", description="W&B project name")
gradient_accumulation_steps: int = Field(1, description="Gradient accumulation steps", ge=1)
gradient_clip_norm: float = Field(1.0, description="Gradient clipping norm")
num_workers: Optional[int] = Field(None, description="Number of data loading workers")
resume_from_checkpoint: Optional[str] = Field(None, description="Resume from checkpoint path")
use_fsdp: bool = Field(False, description="Enable FSDP (single-GPU stub)")
model_config = {
"json_schema_extra": {
"example": {
"preprocessed_cache_dir": "cache/preprocessed",
"arkit_sequences_dir": "data/arkit_sequences",
"epochs": 200,
"batch_size": 32,
"model_name": "depth-anything/DA3-SMALL",
}
}
}