|
|
""" |
|
|
Command-line interface for YLFF. |
|
|
""" |
|
|
|
|
|
import json |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import Optional |
|
|
import typer |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
try: |
|
|
import torch |
|
|
except Exception: |
|
|
torch = None |
|
|
|
|
|
app = typer.Typer(help="You Learn From Failure: BA-Supervised Fine-Tuning") |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
validate_app = typer.Typer(help="Validate sequences using BA") |
|
|
app.add_typer(validate_app, name="validate") |
|
|
|
|
|
dataset_app = typer.Typer(help="Build training datasets") |
|
|
app.add_typer(dataset_app, name="dataset") |
|
|
|
|
|
train_app = typer.Typer(help="Fine-tune models") |
|
|
app.add_typer(train_app, name="train") |
|
|
|
|
|
preprocess_app = typer.Typer(help="Pre-process ARKit sequences (BA + oracle uncertainty)") |
|
|
app.add_typer(preprocess_app, name="preprocess") |
|
|
|
|
|
eval_app = typer.Typer(help="Evaluate models") |
|
|
app.add_typer(eval_app, name="eval") |
|
|
|
|
|
|
|
|
ingest_app = typer.Typer(help="Ingest raw exports into canonical capture bundles") |
|
|
app.add_typer(ingest_app, name="ingest") |
|
|
|
|
|
|
|
|
teacher_app = typer.Typer(help="Run offline teacher pipeline (metrology)") |
|
|
app.add_typer(teacher_app, name="teacher") |
|
|
|
|
|
infer_app = typer.Typer(help="Run inference + optional reconstruction (metrology)") |
|
|
app.add_typer(infer_app, name="infer") |
|
|
|
|
|
audit_app = typer.Typer(help="Run audit + calibration on external references (metrology)") |
|
|
app.add_typer(audit_app, name="audit") |
|
|
|
|
|
|
|
|
catalog_app = typer.Typer(help="Build/inspect scene catalogs (S3 or local)") |
|
|
app.add_typer(catalog_app, name="catalog") |
|
|
|
|
|
orchestrate_app = typer.Typer(help="Run backfill orchestration (single-node)") |
|
|
app.add_typer(orchestrate_app, name="orchestrate") |
|
|
|
|
|
|
|
|
@app.command("serve") |
|
|
def serve( |
|
|
host: str = typer.Option("0.0.0.0", help="Host to bind to"), |
|
|
port: int = typer.Option(8000, help="Port to bind to"), |
|
|
): |
|
|
"""Start the YLFF API server.""" |
|
|
from .server import start_server |
|
|
start_server(host=host, port=port) |
|
|
|
|
|
|
|
|
@ingest_app.command("bundle") |
|
|
def ingest_bundle( |
|
|
raw_dir: Path = typer.Argument(..., help="Raw export directory (single or multi-device)"), |
|
|
output_root: Path = typer.Option( |
|
|
Path("data/captures"), |
|
|
help="Root directory under which `capture_<id>/` bundles are created", |
|
|
), |
|
|
capture_id: Optional[str] = typer.Option(None, help="Optional capture id override"), |
|
|
overwrite: bool = typer.Option(False, help="Overwrite destination if it exists"), |
|
|
run_quality_gates: bool = typer.Option(True, help="Run quality gates during ingest"), |
|
|
enable_sync_validation: bool = typer.Option( |
|
|
True, help="Validate sync_offsets.json if present" |
|
|
), |
|
|
copy_mode: str = typer.Option( |
|
|
"copy", |
|
|
help="Materialization mode: copy | hardlink | symlink | auto", |
|
|
), |
|
|
): |
|
|
"""Convert a raw phone export directory into a canonical capture bundle.""" |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
from .services.ingest_pipeline import IngestConfig, ingest_capture_bundle |
|
|
|
|
|
meta = ingest_capture_bundle( |
|
|
raw_dir, |
|
|
output_root=output_root, |
|
|
config=IngestConfig( |
|
|
capture_id=capture_id, |
|
|
overwrite=overwrite, |
|
|
run_quality_gates=run_quality_gates, |
|
|
enable_sync_validation=enable_sync_validation, |
|
|
copy_mode=copy_mode, |
|
|
), |
|
|
) |
|
|
typer.echo(json.dumps(meta, indent=2)) |
|
|
|
|
|
|
|
|
@ingest_app.command("materialize") |
|
|
def ingest_materialize( |
|
|
bundle_dir: Path = typer.Argument(..., help="Existing capture bundle directory"), |
|
|
output_dir: Path = typer.Argument(..., help="Destination directory (portable copy)"), |
|
|
overwrite: bool = typer.Option(False, help="Overwrite destination if it exists"), |
|
|
keep_symlinks: bool = typer.Option( |
|
|
False, help="If set, preserve symlinks instead of copying their targets" |
|
|
), |
|
|
): |
|
|
"""Materialize a link-based bundle into a portable copy.""" |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
from .services.ingest_pipeline import materialize_capture_bundle |
|
|
|
|
|
meta = materialize_capture_bundle( |
|
|
bundle_dir=bundle_dir, |
|
|
output_dir=output_dir, |
|
|
overwrite=overwrite, |
|
|
dereference_symlinks=not bool(keep_symlinks), |
|
|
) |
|
|
typer.echo(json.dumps(meta, indent=2)) |
|
|
|
|
|
|
|
|
@validate_app.command("sequence") |
|
|
def validate_sequence( |
|
|
sequence_dir: Path = typer.Argument(..., help="Directory containing image sequence"), |
|
|
model_name: str = typer.Option( |
|
|
None, help="DA3 model name (default: auto-select for BA validation)" |
|
|
), |
|
|
use_case: str = typer.Option( |
|
|
"ba_validation", help="Use case for model selection (ba_validation, pose_estimation, etc.)" |
|
|
), |
|
|
accept_threshold: float = typer.Option(2.0, help="Accept threshold (degrees)"), |
|
|
reject_threshold: float = typer.Option(30.0, help="Reject threshold (degrees)"), |
|
|
output: Optional[Path] = typer.Option(None, help="Output JSON path for results"), |
|
|
): |
|
|
"""Validate a single sequence using BA.""" |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
import json |
|
|
import cv2 |
|
|
|
|
|
from .services.ba_validator import BAValidator |
|
|
from .utils.model_loader import get_recommended_model, load_da3_model |
|
|
|
|
|
|
|
|
if model_name is None: |
|
|
model_name = get_recommended_model(use_case) |
|
|
logger.info(f"Auto-selected model for '{use_case}': {model_name}") |
|
|
|
|
|
|
|
|
logger.info(f"Loading model: {model_name}") |
|
|
model = load_da3_model(model_name, use_case=use_case) |
|
|
|
|
|
|
|
|
validator = BAValidator( |
|
|
accept_threshold=accept_threshold, |
|
|
reject_threshold=reject_threshold, |
|
|
) |
|
|
|
|
|
|
|
|
image_paths = sorted(list(sequence_dir.glob("*.jpg")) + list(sequence_dir.glob("*.png"))) |
|
|
if not image_paths: |
|
|
typer.echo(f"Error: No images found in {sequence_dir}", err=True) |
|
|
raise typer.Exit(1) |
|
|
|
|
|
images = [] |
|
|
for img_path in image_paths: |
|
|
img = cv2.imread(str(img_path)) |
|
|
if img is not None: |
|
|
images.append(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) |
|
|
|
|
|
logger.info(f"Loaded {len(images)} images") |
|
|
|
|
|
|
|
|
logger.info("Running DA3 inference...") |
|
|
if torch is None: |
|
|
typer.echo("Error: torch is required for inference. Install torch.", err=True) |
|
|
raise typer.Exit(1) |
|
|
with torch.no_grad(): |
|
|
model_output = model.inference(images) |
|
|
|
|
|
|
|
|
logger.info("Running BA validation...") |
|
|
result = validator.validate( |
|
|
images=images, |
|
|
poses_model=model_output.extrinsics, |
|
|
intrinsics=model_output.intrinsics if hasattr(model_output, "intrinsics") else None, |
|
|
) |
|
|
|
|
|
|
|
|
typer.echo(f"\nStatus: {result['status']}") |
|
|
if isinstance(result.get("error"), (int, float)): |
|
|
typer.echo(f"Error: {result['error']:.2f} degrees") |
|
|
if result.get("reprojection_error"): |
|
|
typer.echo(f"Reprojection Error: {result['reprojection_error']:.4f}") |
|
|
|
|
|
|
|
|
if output: |
|
|
with open(output, "w") as f: |
|
|
json.dump( |
|
|
{ |
|
|
"status": result["status"], |
|
|
"error": result.get("error"), |
|
|
"reprojection_error": result.get("reprojection_error"), |
|
|
}, |
|
|
f, |
|
|
indent=2, |
|
|
) |
|
|
typer.echo(f"\nResults saved to {output}") |
|
|
|
|
|
|
|
|
@validate_app.command("arkit") |
|
|
def validate_arkit( |
|
|
arkit_dir: Path = typer.Argument(..., help="Directory containing ARKit video and metadata"), |
|
|
output_dir: Path = typer.Option(Path("data/arkit_validation"), help="Output directory"), |
|
|
model_name: str = typer.Option( |
|
|
None, help="DA3 model name (default: DA3NESTED-GIANT-LARGE for BA validation)" |
|
|
), |
|
|
max_frames: Optional[int] = typer.Option(None, help="Maximum frames to process"), |
|
|
frame_interval: int = typer.Option(1, help="Extract every Nth frame"), |
|
|
device: str = typer.Option("cpu", help="Device for DA3 inference"), |
|
|
gui: bool = typer.Option(False, help="Show real-time GUI visualization"), |
|
|
): |
|
|
"""Validate ARKit data with BA.""" |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
|
|
|
import importlib.util |
|
|
import sys |
|
|
|
|
|
project_root = Path(__file__).parent.parent |
|
|
|
|
|
if gui: |
|
|
script_path = project_root / "scripts" / "experiments" / "run_arkit_ba_validation_gui.py" |
|
|
spec = importlib.util.spec_from_file_location("run_arkit_ba_validation_gui", script_path) |
|
|
if spec is None or spec.loader is None: |
|
|
typer.echo(f"Error: Could not load script {script_path}", err=True) |
|
|
raise typer.Exit(1) |
|
|
module = importlib.util.module_from_spec(spec) |
|
|
spec.loader.exec_module(module) |
|
|
|
|
|
old_argv = sys.argv |
|
|
try: |
|
|
sys.argv = [ |
|
|
"run_arkit_ba_validation_gui", |
|
|
"--arkit-dir", |
|
|
str(arkit_dir), |
|
|
"--output-dir", |
|
|
str(output_dir), |
|
|
] |
|
|
if max_frames: |
|
|
sys.argv.extend(["--max-frames", str(max_frames)]) |
|
|
sys.argv.extend(["--frame-interval", str(frame_interval)]) |
|
|
sys.argv.extend(["--device", device]) |
|
|
module.main() |
|
|
finally: |
|
|
sys.argv = old_argv |
|
|
else: |
|
|
script_path = project_root / "scripts" / "experiments" / "run_arkit_ba_validation.py" |
|
|
spec = importlib.util.spec_from_file_location("run_arkit_ba_validation", script_path) |
|
|
if spec is None or spec.loader is None: |
|
|
typer.echo(f"Error: Could not load script {script_path}", err=True) |
|
|
raise typer.Exit(1) |
|
|
module = importlib.util.module_from_spec(spec) |
|
|
spec.loader.exec_module(module) |
|
|
|
|
|
old_argv = sys.argv |
|
|
try: |
|
|
sys.argv = [ |
|
|
"run_arkit_ba_validation", |
|
|
"--arkit-dir", |
|
|
str(arkit_dir), |
|
|
"--output-dir", |
|
|
str(output_dir), |
|
|
] |
|
|
if max_frames: |
|
|
sys.argv.extend(["--max-frames", str(max_frames)]) |
|
|
sys.argv.extend(["--frame-interval", str(frame_interval)]) |
|
|
sys.argv.extend(["--device", device]) |
|
|
module.main() |
|
|
finally: |
|
|
sys.argv = old_argv |
|
|
|
|
|
|
|
|
@dataset_app.command("build") |
|
|
def build_dataset( |
|
|
sequences_dir: Path = typer.Argument(..., help="Directory containing sequence directories"), |
|
|
output_dir: Path = typer.Option(Path("data/training"), help="Output directory"), |
|
|
model_name: str = typer.Option( |
|
|
None, help="DA3 model name (default: DA3NESTED-GIANT-LARGE for fine-tuning)" |
|
|
), |
|
|
max_samples: Optional[int] = typer.Option(None, help="Maximum number of samples"), |
|
|
accept_threshold: float = typer.Option(2.0, help="Accept threshold (degrees)"), |
|
|
reject_threshold: float = typer.Option(30.0, help="Reject threshold (degrees)"), |
|
|
use_wandb: bool = typer.Option(True, help="Enable Weights & Biases logging"), |
|
|
wandb_project: str = typer.Option("ylff", help="W&B project name"), |
|
|
wandb_name: Optional[str] = typer.Option(None, help="W&B run name"), |
|
|
|
|
|
use_batched_inference: bool = typer.Option( |
|
|
False, help="Use batched inference for better GPU utilization" |
|
|
), |
|
|
inference_batch_size: int = typer.Option(4, help="Batch size for inference"), |
|
|
use_inference_cache: bool = typer.Option(False, help="Cache inference results"), |
|
|
cache_dir: Optional[Path] = typer.Option(None, help="Directory for inference cache"), |
|
|
compile_model: bool = typer.Option(True, help="Compile model with torch.compile"), |
|
|
): |
|
|
"""Build training dataset from sequences.""" |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
from .services.ba_validator import BAValidator |
|
|
from .services.data_pipeline import BADataPipeline |
|
|
from .utils.model_loader import get_recommended_model, load_da3_model |
|
|
|
|
|
|
|
|
if model_name is None: |
|
|
model_name = get_recommended_model("fine_tuning") |
|
|
logger.info(f"Auto-selected model for fine-tuning: {model_name}") |
|
|
|
|
|
|
|
|
logger.info(f"Loading model: {model_name}") |
|
|
model = load_da3_model( |
|
|
model_name, |
|
|
use_case="fine_tuning", |
|
|
compile_model=compile_model, |
|
|
compile_mode="reduce-overhead", |
|
|
) |
|
|
|
|
|
|
|
|
validator = BAValidator( |
|
|
accept_threshold=accept_threshold, |
|
|
reject_threshold=reject_threshold, |
|
|
work_dir=output_dir / "ba_work", |
|
|
) |
|
|
pipeline = BADataPipeline(model, validator, data_dir=output_dir) |
|
|
|
|
|
|
|
|
sequence_paths = [p for p in sequences_dir.iterdir() if p.is_dir()] |
|
|
logger.info(f"Found {len(sequence_paths)} sequences") |
|
|
|
|
|
if not sequence_paths: |
|
|
typer.echo(f"Error: No sequences found in {sequences_dir}", err=True) |
|
|
raise typer.Exit(1) |
|
|
|
|
|
|
|
|
if use_wandb: |
|
|
from .utils.wandb_utils import finish_wandb, init_wandb |
|
|
|
|
|
wandb_run = init_wandb( |
|
|
project=wandb_project, |
|
|
name=wandb_name or f"dataset-build-{len(sequence_paths)}-seqs", |
|
|
config={ |
|
|
"task": "dataset_build", |
|
|
"model_name": model_name, |
|
|
"accept_threshold": accept_threshold, |
|
|
"reject_threshold": reject_threshold, |
|
|
"max_samples": max_samples, |
|
|
"num_sequences": len(sequence_paths), |
|
|
"use_batched_inference": use_batched_inference, |
|
|
"inference_batch_size": inference_batch_size, |
|
|
"use_inference_cache": use_inference_cache, |
|
|
"compile_model": compile_model, |
|
|
}, |
|
|
tags=["dataset", "ba-validation"], |
|
|
) |
|
|
|
|
|
|
|
|
pipeline.build_training_set( |
|
|
raw_sequence_paths=sequence_paths, |
|
|
max_samples=max_samples, |
|
|
use_batched_inference=use_batched_inference, |
|
|
inference_batch_size=inference_batch_size, |
|
|
use_inference_cache=use_inference_cache, |
|
|
cache_dir=cache_dir, |
|
|
) |
|
|
|
|
|
|
|
|
if use_wandb and wandb_run: |
|
|
finish_wandb() |
|
|
|
|
|
logger.info("\nDataset Statistics:") |
|
|
logger.info(f" Total sequences: {pipeline.stats['total']}") |
|
|
logger.info(f" Accepted: {pipeline.stats['accepted']}") |
|
|
logger.info(f" Learnable: {pipeline.stats['learnable']}") |
|
|
logger.info(f" Outliers: {pipeline.stats['outlier']}") |
|
|
logger.info(f" BA Failed: {pipeline.stats['ba_failed']}") |
|
|
logger.info(f"\nTraining samples saved to: {output_dir}") |
|
|
|
|
|
|
|
|
@dataset_app.command("validate") |
|
|
def validate_dataset( |
|
|
dataset_path: Path = typer.Argument(..., help="Path to dataset file"), |
|
|
strict: bool = typer.Option(False, help="Fail on validation errors"), |
|
|
check_images: bool = typer.Option(True, help="Validate image data"), |
|
|
check_poses: bool = typer.Option(True, help="Validate pose data"), |
|
|
check_metadata: bool = typer.Option(True, help="Validate metadata"), |
|
|
output: Optional[Path] = typer.Option(None, help="Path to save validation report"), |
|
|
): |
|
|
"""Validate dataset file for quality and integrity.""" |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
from .utils.dataset_validation import validate_dataset_file |
|
|
|
|
|
try: |
|
|
report = validate_dataset_file( |
|
|
dataset_path=dataset_path, |
|
|
strict=strict, |
|
|
) |
|
|
|
|
|
logger.info("\nDataset Validation Report:") |
|
|
logger.info(f" Validation passed: {report['validation_passed']}") |
|
|
logger.info(f" Total samples: {report['statistics']['total_samples']}") |
|
|
logger.info(f" Valid samples: {report['statistics']['valid_samples']}") |
|
|
logger.info(f" Invalid samples: {report['statistics']['invalid_samples']}") |
|
|
logger.info(f" Errors: {report['statistics']['errors']}") |
|
|
logger.info(f" Warnings: {report['statistics']['warnings']}") |
|
|
|
|
|
if output: |
|
|
import json |
|
|
|
|
|
with open(output, "w") as f: |
|
|
json.dump(report, f, indent=2, default=str) |
|
|
logger.info(f"\nValidation report saved to: {output}") |
|
|
|
|
|
if not report["validation_passed"] and strict: |
|
|
raise typer.Exit(1) |
|
|
|
|
|
except FileNotFoundError as e: |
|
|
typer.echo(f"Error: {e}", err=True) |
|
|
raise typer.Exit(1) |
|
|
except Exception as e: |
|
|
typer.echo(f"Error: {e}", err=True) |
|
|
raise typer.Exit(1) |
|
|
|
|
|
|
|
|
@dataset_app.command("curate") |
|
|
def curate_dataset( |
|
|
dataset_path: Path = typer.Argument(..., help="Path to input dataset file"), |
|
|
output_path: Path = typer.Argument(..., help="Path to save curated dataset"), |
|
|
|
|
|
min_error: Optional[float] = typer.Option(None, help="Minimum error threshold"), |
|
|
max_error: Optional[float] = typer.Option(None, help="Maximum error threshold"), |
|
|
min_weight: Optional[float] = typer.Option(None, help="Minimum weight threshold"), |
|
|
max_weight: Optional[float] = typer.Option(None, help="Maximum weight threshold"), |
|
|
|
|
|
remove_outliers: bool = typer.Option(False, help="Remove outlier samples"), |
|
|
outlier_percentile: float = typer.Option(95.0, help="Percentile for outlier detection"), |
|
|
|
|
|
balance: bool = typer.Option(False, help="Balance dataset by error distribution"), |
|
|
balance_strategy: str = typer.Option("error_bins", help="Balancing strategy"), |
|
|
num_bins: int = typer.Option(10, help="Number of error bins"), |
|
|
): |
|
|
"""Curate dataset (filter, balance, remove outliers).""" |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
from .utils.dataset_curation import DatasetCurator |
|
|
|
|
|
|
|
|
if dataset_path.suffix == ".pkl" or dataset_path.suffix == ".pickle": |
|
|
import pickle |
|
|
|
|
|
with open(dataset_path, "rb") as f: |
|
|
samples = pickle.load(f) |
|
|
elif dataset_path.suffix == ".json": |
|
|
import json |
|
|
|
|
|
with open(dataset_path) as f: |
|
|
data = json.load(f) |
|
|
samples = data.get("samples", data) |
|
|
else: |
|
|
typer.echo(f"Error: Unsupported format: {dataset_path.suffix}", err=True) |
|
|
raise typer.Exit(1) |
|
|
|
|
|
logger.info(f"Loaded {len(samples)} samples from {dataset_path}") |
|
|
|
|
|
|
|
|
curator = DatasetCurator() |
|
|
curated_samples = samples |
|
|
|
|
|
|
|
|
curated_samples, filter_stats = curator.filter_by_quality( |
|
|
curated_samples, |
|
|
min_error=min_error, |
|
|
max_error=max_error, |
|
|
min_weight=min_weight, |
|
|
max_weight=max_weight, |
|
|
) |
|
|
|
|
|
|
|
|
if remove_outliers: |
|
|
curated_samples, outlier_stats = curator.remove_outliers( |
|
|
curated_samples, error_percentile=outlier_percentile |
|
|
) |
|
|
else: |
|
|
outlier_stats = {"removed": 0} |
|
|
|
|
|
|
|
|
if balance: |
|
|
curated_samples, _ = curator.balance_dataset( |
|
|
curated_samples, |
|
|
strategy=balance_strategy, |
|
|
num_bins=num_bins, |
|
|
) |
|
|
|
|
|
|
|
|
output_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
if output_path.suffix == ".pkl" or output_path.suffix == ".pickle": |
|
|
import pickle |
|
|
|
|
|
with open(output_path, "wb") as f: |
|
|
pickle.dump(curated_samples, f) |
|
|
elif output_path.suffix == ".json": |
|
|
import json |
|
|
|
|
|
with open(output_path, "w") as f: |
|
|
json.dump({"samples": curated_samples}, f, indent=2, default=str) |
|
|
|
|
|
logger.info("\nCuration Results:") |
|
|
logger.info(f" Original samples: {len(samples)}") |
|
|
logger.info(f" Curated samples: {len(curated_samples)}") |
|
|
removed_by_error = filter_stats.get("removed_by_error", 0) |
|
|
removed_by_weight = filter_stats.get("removed_by_weight", 0) |
|
|
removed_by_filter = removed_by_error + removed_by_weight |
|
|
logger.info(f" Removed by filter: {removed_by_filter}") |
|
|
logger.info(f" Removed outliers: {outlier_stats.get('removed', 0)}") |
|
|
logger.info(f"\nCurated dataset saved to: {output_path}") |
|
|
|
|
|
|
|
|
@dataset_app.command("index") |
|
|
def index_captures( |
|
|
captures_root: Path = typer.Argument( |
|
|
Path("data/captures"), help="Root directory containing capture bundles" |
|
|
), |
|
|
output_path: Path = typer.Option(Path("data/captures_index.jsonl"), help="Output JSONL path"), |
|
|
workers: int = typer.Option(8, help="Number of indexing worker threads"), |
|
|
include_depth_stream_summary: bool = typer.Option( |
|
|
True, help="Parse packed depth index.json for format/coverage summary" |
|
|
), |
|
|
discover: str = typer.Option("children", help="Bundle discovery: children | recursive"), |
|
|
): |
|
|
"""Build a fast JSONL curation index over capture bundles.""" |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
from .services.curation.indexer import CurationIndexConfig, build_curation_index_jsonl |
|
|
|
|
|
meta = build_curation_index_jsonl( |
|
|
captures_root=captures_root, |
|
|
output_path=output_path, |
|
|
config=CurationIndexConfig( |
|
|
workers=int(workers), |
|
|
include_depth_stream_summary=bool(include_depth_stream_summary), |
|
|
discover=str(discover), |
|
|
), |
|
|
) |
|
|
typer.echo(json.dumps(meta, indent=2)) |
|
|
|
|
|
|
|
|
@dataset_app.command("index_sqlite") |
|
|
def index_captures_sqlite( |
|
|
captures_root: Path = typer.Argument( |
|
|
Path("data/captures"), help="Root directory containing capture bundles" |
|
|
), |
|
|
db_path: Path = typer.Option(Path("data/captures_index.db"), help="Output SQLite DB path"), |
|
|
workers: int = typer.Option(8, help="Number of indexing worker threads"), |
|
|
incremental: bool = typer.Option(True, help="Skip bundles whose manifest.json is unchanged"), |
|
|
include_depth_stream_summary: bool = typer.Option( |
|
|
True, help="Parse packed depth index.json for format/coverage summary" |
|
|
), |
|
|
discover: str = typer.Option("children", help="Bundle discovery: children | recursive"), |
|
|
): |
|
|
"""Build an incremental SQLite curation index over capture bundles.""" |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
from .services.curation.sqlite_index import SQLiteIndexConfig, build_curation_index_sqlite |
|
|
|
|
|
meta = build_curation_index_sqlite( |
|
|
captures_root=captures_root, |
|
|
db_path=db_path, |
|
|
config=SQLiteIndexConfig( |
|
|
workers=int(workers), |
|
|
incremental=bool(incremental), |
|
|
include_depth_stream_summary=bool(include_depth_stream_summary), |
|
|
discover=str(discover), |
|
|
), |
|
|
) |
|
|
typer.echo(json.dumps(meta, indent=2)) |
|
|
|
|
|
|
|
|
@dataset_app.command("query_sqlite") |
|
|
def query_captures_sqlite( |
|
|
db_path: Path = typer.Argument(Path("data/captures_index.db"), help="SQLite index DB path"), |
|
|
|
|
|
source_format: Optional[str] = typer.Option(None, help="Filter by ingest source_format"), |
|
|
has_packed_depth: Optional[bool] = typer.Option(None, help="Filter by packed depth presence"), |
|
|
scene_type: Optional[str] = typer.Option(None, help="Filter by scene_type"), |
|
|
operating_regime: Optional[str] = typer.Option(None, help="Filter by operating_regime"), |
|
|
min_devices: Optional[int] = typer.Option(None, help="Minimum number of devices in bundle"), |
|
|
packed_depth_min_frames: Optional[int] = typer.Option( |
|
|
None, help="Require packed depth summary frames >= N (device-level)" |
|
|
), |
|
|
packed_depth_max_gaps: Optional[int] = typer.Option( |
|
|
None, help="Require packed depth summary gaps <= N (device-level)" |
|
|
), |
|
|
|
|
|
limit: Optional[int] = typer.Option(None, help="Limit number of bundle dirs returned"), |
|
|
order_by: str = typer.Option( |
|
|
"bundle_dir", help="Order: bundle_dir|capture_id|created_at|scene_type" |
|
|
), |
|
|
output_txt: Optional[Path] = typer.Option(None, help="Write bundle dirs to a .txt file"), |
|
|
output_jsonl: Optional[Path] = typer.Option( |
|
|
None, help="Write full stored JSON rows to a .jsonl file" |
|
|
), |
|
|
): |
|
|
"""Query the SQLite curation index and optionally export results.""" |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
from .services.curation.sqlite_query import ( |
|
|
QueryFilters, |
|
|
export_bundle_dirs_txt, |
|
|
export_rows_jsonl, |
|
|
query_bundle_dirs, |
|
|
) |
|
|
|
|
|
bundle_dirs = query_bundle_dirs( |
|
|
db_path=db_path, |
|
|
filters=QueryFilters( |
|
|
source_format=source_format, |
|
|
has_packed_depth=has_packed_depth, |
|
|
scene_type=scene_type, |
|
|
operating_regime=operating_regime, |
|
|
min_devices=min_devices, |
|
|
packed_depth_min_frames=packed_depth_min_frames, |
|
|
packed_depth_max_gaps=packed_depth_max_gaps, |
|
|
), |
|
|
limit=limit, |
|
|
order_by=order_by, |
|
|
) |
|
|
|
|
|
if output_txt is not None: |
|
|
export_bundle_dirs_txt(bundle_dirs, output_txt) |
|
|
if output_jsonl is not None: |
|
|
export_rows_jsonl(db_path=db_path, bundle_dirs=bundle_dirs, output_path=output_jsonl) |
|
|
|
|
|
typer.echo( |
|
|
json.dumps( |
|
|
{ |
|
|
"db_path": str(db_path), |
|
|
"count": int(len(bundle_dirs)), |
|
|
"output_txt": str(output_txt) if output_txt else None, |
|
|
"output_jsonl": str(output_jsonl) if output_jsonl else None, |
|
|
"bundle_dirs": bundle_dirs[:50], |
|
|
"bundle_dirs_truncated": bool(len(bundle_dirs) > 50), |
|
|
}, |
|
|
indent=2, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
@dataset_app.command("shard_from_sqlite") |
|
|
def shard_from_sqlite( |
|
|
db_path: Path = typer.Argument(Path("data/captures_index.db"), help="SQLite index DB path"), |
|
|
output_dir: Path = typer.Argument( |
|
|
Path("data/_shards"), help="Output directory for sample_index.part_*.jsonl" |
|
|
), |
|
|
|
|
|
source_format: Optional[str] = typer.Option(None, help="Filter by ingest source_format"), |
|
|
has_packed_depth: Optional[bool] = typer.Option(None, help="Filter by packed depth presence"), |
|
|
scene_type: Optional[str] = typer.Option(None, help="Filter by scene_type"), |
|
|
operating_regime: Optional[str] = typer.Option(None, help="Filter by operating_regime"), |
|
|
min_devices: Optional[int] = typer.Option(None, help="Minimum number of devices in bundle"), |
|
|
packed_depth_min_frames: Optional[int] = typer.Option( |
|
|
None, help="Require packed depth summary frames >= N (device-level)" |
|
|
), |
|
|
packed_depth_max_gaps: Optional[int] = typer.Option( |
|
|
None, help="Require packed depth summary gaps <= N (device-level)" |
|
|
), |
|
|
limit_bundles: Optional[int] = typer.Option(None, help="Limit bundles before sharding"), |
|
|
order_by: str = typer.Option( |
|
|
"bundle_dir", help="Order: bundle_dir|capture_id|created_at|scene_type" |
|
|
), |
|
|
|
|
|
temporal_window: int = typer.Option(5, help="Temporal window (odd)"), |
|
|
device_id: Optional[str] = typer.Option( |
|
|
None, help="Device id override (required for multi-device bundles unless allowed)" |
|
|
), |
|
|
allow_multi_device_default_first: bool = typer.Option( |
|
|
False, help="If set, multi-device bundles default to devices[0] when device_id is unset" |
|
|
), |
|
|
max_samples_per_bundle: Optional[int] = typer.Option( |
|
|
None, help="Cap sample centers per bundle (for quick smoke runs)" |
|
|
), |
|
|
shard_size: int = typer.Option(200000, help="Max rows per shard file"), |
|
|
): |
|
|
""" |
|
|
Build sharded jsonl sample indices for training directly from the SQLite index. |
|
|
|
|
|
Output rows match `TeacherSupervisedTemporalDataset.from_sample_index_jsonl`. |
|
|
""" |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
from .services.curation.shard_from_sqlite import ( |
|
|
ShardFromSQLiteConfig, |
|
|
write_sample_index_from_sqlite, |
|
|
) |
|
|
from .services.curation.sqlite_query import QueryFilters |
|
|
|
|
|
meta = write_sample_index_from_sqlite( |
|
|
db_path=db_path, |
|
|
output_dir=output_dir, |
|
|
filters=QueryFilters( |
|
|
source_format=source_format, |
|
|
has_packed_depth=has_packed_depth, |
|
|
scene_type=scene_type, |
|
|
operating_regime=operating_regime, |
|
|
min_devices=min_devices, |
|
|
packed_depth_min_frames=packed_depth_min_frames, |
|
|
packed_depth_max_gaps=packed_depth_max_gaps, |
|
|
), |
|
|
cfg=ShardFromSQLiteConfig( |
|
|
temporal_window=int(temporal_window), |
|
|
device_id=device_id, |
|
|
allow_multi_device_default_first=bool(allow_multi_device_default_first), |
|
|
max_samples_per_bundle=max_samples_per_bundle, |
|
|
shard_size=int(shard_size), |
|
|
), |
|
|
limit_bundles=limit_bundles, |
|
|
order_by=order_by, |
|
|
) |
|
|
typer.echo(json.dumps(meta, indent=2)) |
|
|
|
|
|
|
|
|
@dataset_app.command("analyze") |
|
|
def analyze_dataset( |
|
|
dataset_path: Path = typer.Argument(..., help="Path to dataset file"), |
|
|
output: Optional[Path] = typer.Option(None, help="Path to save analysis report"), |
|
|
format: str = typer.Option("json", help="Report format: json, text, or markdown"), |
|
|
compute_distributions: bool = typer.Option(True, help="Compute distributions"), |
|
|
compute_correlations: bool = typer.Option(True, help="Compute correlations"), |
|
|
): |
|
|
"""Analyze dataset and generate statistics report.""" |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
from .utils.dataset_analysis import analyze_dataset_file |
|
|
|
|
|
try: |
|
|
results = analyze_dataset_file( |
|
|
dataset_path=dataset_path, |
|
|
output_path=output, |
|
|
format=format, |
|
|
) |
|
|
|
|
|
logger.info("\nDataset Analysis:") |
|
|
total_samples = results.get("statistics", {}).get("total_samples", 0) |
|
|
logger.info(f" Total samples: {total_samples}") |
|
|
|
|
|
if "error_statistics" in results.get("statistics", {}): |
|
|
err_stats = results["statistics"]["error_statistics"] |
|
|
logger.info( |
|
|
f" Error - Mean: {err_stats['mean']:.4f}, Median: {err_stats['median']:.4f}" |
|
|
) |
|
|
|
|
|
if "quality_metrics" in results: |
|
|
qm = results["quality_metrics"] |
|
|
if "low_error_ratio" in qm: |
|
|
low_ratio = qm["low_error_ratio"] * 100 |
|
|
medium_ratio = qm["medium_error_ratio"] * 100 |
|
|
high_ratio = qm["high_error_ratio"] * 100 |
|
|
logger.info(f" Low error ratio: {low_ratio:.1f}%") |
|
|
logger.info(f" Medium error ratio: {medium_ratio:.1f}%") |
|
|
logger.info(f" High error ratio: {high_ratio:.1f}%") |
|
|
|
|
|
if output: |
|
|
logger.info(f"\nAnalysis report saved to: {output}") |
|
|
|
|
|
except FileNotFoundError as e: |
|
|
typer.echo(f"Error: {e}", err=True) |
|
|
raise typer.Exit(1) |
|
|
except Exception as e: |
|
|
typer.echo(f"Error: {e}", err=True) |
|
|
raise typer.Exit(1) |
|
|
|
|
|
|
|
|
@dataset_app.command("upload") |
|
|
def upload_dataset( |
|
|
zip_path: Path = typer.Argument(..., help="Path to zip file containing ARKit pairs"), |
|
|
output_dir: Path = typer.Option( |
|
|
Path("data/uploaded_datasets"), |
|
|
help="Directory to extract uploaded dataset", |
|
|
), |
|
|
validate: bool = typer.Option(True, help="Validate ARKit pairs before extraction"), |
|
|
): |
|
|
"""Upload and extract dataset zip file containing ARKit video and metadata pairs.""" |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
from .utils.dataset_upload import process_uploaded_dataset |
|
|
|
|
|
if not zip_path.exists(): |
|
|
typer.echo(f"Error: Zip file not found: {zip_path}", err=True) |
|
|
raise typer.Exit(1) |
|
|
|
|
|
try: |
|
|
result = process_uploaded_dataset( |
|
|
zip_path=zip_path, |
|
|
output_dir=output_dir, |
|
|
validate=validate, |
|
|
) |
|
|
|
|
|
if result["success"]: |
|
|
metadata = result["metadata"] |
|
|
typer.echo("\n✅ Dataset uploaded successfully!") |
|
|
typer.echo(f" Output directory: {result['output_dir']}") |
|
|
typer.echo(f" Video files: {metadata.get('video_files', 0)}") |
|
|
typer.echo(f" Metadata files: {metadata.get('metadata_files', 0)}") |
|
|
typer.echo(f" Valid pairs: {metadata.get('valid_pairs', 0)}") |
|
|
if metadata.get("organized_sequences"): |
|
|
typer.echo(f" Organized sequences: {metadata['organized_sequences']}") |
|
|
else: |
|
|
typer.echo("\n❌ Dataset upload failed:", err=True) |
|
|
for error in result["errors"]: |
|
|
typer.echo(f" - {error}", err=True) |
|
|
raise typer.Exit(1) |
|
|
|
|
|
except Exception as e: |
|
|
typer.echo(f"Error: {e}", err=True) |
|
|
raise typer.Exit(1) |
|
|
|
|
|
|
|
|
@dataset_app.command("download") |
|
|
def download_dataset( |
|
|
bucket_name: str = typer.Argument(..., help="S3 bucket name"), |
|
|
s3_key: str = typer.Argument(..., help="S3 object key (path to dataset)"), |
|
|
output_dir: Path = typer.Option( |
|
|
Path("data/downloaded_datasets"), |
|
|
help="Directory to save downloaded dataset", |
|
|
), |
|
|
extract: bool = typer.Option(True, help="Extract downloaded archive"), |
|
|
aws_access_key_id: Optional[str] = typer.Option( |
|
|
None, help="AWS access key ID (optional, uses credentials chain if None)" |
|
|
), |
|
|
aws_secret_access_key: Optional[str] = typer.Option( |
|
|
None, help="AWS secret access key (optional)" |
|
|
), |
|
|
region_name: str = typer.Option("us-east-1", help="AWS region name"), |
|
|
): |
|
|
"""Download dataset from AWS S3.""" |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
from .utils.dataset_download import S3DatasetDownloader |
|
|
|
|
|
try: |
|
|
downloader = S3DatasetDownloader( |
|
|
aws_access_key_id=aws_access_key_id, |
|
|
aws_secret_access_key=aws_secret_access_key, |
|
|
region_name=region_name, |
|
|
) |
|
|
|
|
|
result = downloader.download_and_extract( |
|
|
bucket_name=bucket_name, |
|
|
s3_key=s3_key, |
|
|
output_dir=output_dir, |
|
|
extract=extract, |
|
|
show_progress=True, |
|
|
) |
|
|
|
|
|
if result["success"]: |
|
|
typer.echo("\n✅ Dataset downloaded successfully!") |
|
|
if result.get("output_path"): |
|
|
typer.echo(f" Downloaded to: {result['output_path']}") |
|
|
if result.get("output_dir"): |
|
|
typer.echo(f" Extracted to: {result['output_dir']}") |
|
|
if result.get("file_size"): |
|
|
size_mb = result["file_size"] / (1024 * 1024) |
|
|
typer.echo(f" File size: {size_mb:.2f} MB") |
|
|
else: |
|
|
typer.echo(f"\n❌ Download failed: {result.get('error', 'Unknown error')}", err=True) |
|
|
raise typer.Exit(1) |
|
|
|
|
|
except ImportError: |
|
|
typer.echo( |
|
|
"Error: boto3 is required for S3 downloads. Install with: pip install boto3", |
|
|
err=True, |
|
|
) |
|
|
raise typer.Exit(1) |
|
|
except Exception as e: |
|
|
typer.echo(f"Error: {e}", err=True) |
|
|
raise typer.Exit(1) |
|
|
|
|
|
|
|
|
@train_app.command("start") |
|
|
def train( |
|
|
training_data_dir: Path = typer.Argument( |
|
|
..., help="[DEPRECATED] Use 'ylff train unified' instead" |
|
|
), |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
[DEPRECATED] Fine-tune DA3 model on BA-supervised training samples. |
|
|
|
|
|
⚠️ This command is deprecated. Use 'ylff train unified' instead. |
|
|
|
|
|
The unified training service provides better geometric accuracy and incorporates |
|
|
DINOv2 teacher-student learning with DA3 techniques. |
|
|
|
|
|
Migration: |
|
|
# OLD |
|
|
ylff train start data/training --epochs 10 |
|
|
|
|
|
# NEW |
|
|
ylff preprocess arkit data/arkit_sequences --output-cache cache/preprocessed |
|
|
ylff train unified cache/preprocessed --epochs 200 |
|
|
""" |
|
|
typer.echo("⚠️ This command is deprecated. Use 'ylff train unified' instead.") |
|
|
typer.echo("\nThe unified training service provides:") |
|
|
typer.echo(" - DINOv2 teacher-student paradigm") |
|
|
typer.echo(" - Geometric consistency as first-order goal") |
|
|
typer.echo(" - DA3 techniques (depth-ray, multi-resolution)") |
|
|
typer.echo("\nTo migrate:") |
|
|
typer.echo(" 1. Pre-process your data: ylff preprocess arkit <dir>") |
|
|
typer.echo(" 2. Train with unified service: ylff train unified <cache_dir>") |
|
|
raise typer.Exit(1) |
|
|
|
|
|
|
|
|
@train_app.command("unified") |
|
|
def train_unified( |
|
|
preprocessed_cache_dir: Path = typer.Argument( |
|
|
..., help="Directory containing pre-processed results (from 'ylff preprocess arkit')" |
|
|
), |
|
|
arkit_sequences_dir: Optional[Path] = typer.Option( |
|
|
None, help="Directory with original ARKit sequences (for loading images)" |
|
|
), |
|
|
model_name: str = typer.Option(None, help="DA3 model name (default: auto-select)"), |
|
|
epochs: int = typer.Option(200, help="Number of training epochs"), |
|
|
lr: float = typer.Option(2e-4, help="Learning rate (base, scales with batch size)"), |
|
|
weight_decay: float = typer.Option(0.04, help="Weight decay"), |
|
|
batch_size: int = typer.Option(32, help="Batch size per GPU"), |
|
|
device: str = typer.Option("cuda", help="Device for training"), |
|
|
checkpoint_dir: Path = typer.Option( |
|
|
Path("checkpoints/ylff_training"), help="Checkpoint directory" |
|
|
), |
|
|
log_interval: int = typer.Option(10, help="Log metrics every N steps"), |
|
|
save_interval: int = typer.Option(1000, help="Save checkpoint every N steps"), |
|
|
use_fp16: bool = typer.Option(True, help="Use FP16 mixed precision"), |
|
|
use_bf16: bool = typer.Option(False, help="Use BF16 mixed precision (overrides FP16)"), |
|
|
ema_decay: float = typer.Option(0.999, help="EMA decay rate for teacher"), |
|
|
use_wandb: bool = typer.Option(True, help="Enable Weights & Biases logging (required)"), |
|
|
wandb_project: str = typer.Option("ylff", help="W&B project name"), |
|
|
gradient_accumulation_steps: int = typer.Option(1, help="Gradient accumulation steps"), |
|
|
gradient_clip_norm: float = typer.Option(1.0, help="Gradient clipping norm"), |
|
|
num_workers: Optional[int] = typer.Option(None, help="Number of data loading workers"), |
|
|
resume_from_checkpoint: Optional[Path] = typer.Option(None, help="Resume from checkpoint"), |
|
|
use_fsdp: bool = typer.Option( |
|
|
False, |
|
|
help=( |
|
|
"Stub: enable FSDP adapter scaffold for multi-GPU. " |
|
|
"Single-GPU works; multi-GPU raises NotImplementedError for now." |
|
|
), |
|
|
), |
|
|
): |
|
|
""" |
|
|
Train using unified YLFF training service with geometric consistency as first-order goal. |
|
|
|
|
|
This is the PRIMARY training command that uses the unified training service. |
|
|
It combines DINOv2's teacher-student paradigm with DA3 techniques and treats |
|
|
geometric consistency as the primary objective. |
|
|
|
|
|
Requires pre-processed data from 'ylff preprocess arkit' command. |
|
|
""" |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
from .services.preprocessed_dataset import PreprocessedARKitDataset |
|
|
from .services.ylff_training import train_ylff |
|
|
from .utils.model_loader import get_recommended_model, load_da3_model |
|
|
|
|
|
|
|
|
if model_name is None: |
|
|
model_name = get_recommended_model("fine_tuning") |
|
|
logger.info(f"Auto-selected model: {model_name}") |
|
|
|
|
|
|
|
|
logger.info(f"Loading model: {model_name}") |
|
|
model = load_da3_model( |
|
|
model_name, |
|
|
device=device, |
|
|
use_case="fine_tuning", |
|
|
compile_model=False, |
|
|
) |
|
|
|
|
|
|
|
|
logger.info(f"Loading preprocessed dataset from {preprocessed_cache_dir}") |
|
|
dataset = PreprocessedARKitDataset( |
|
|
cache_dir=preprocessed_cache_dir, |
|
|
arkit_sequences_dir=arkit_sequences_dir, |
|
|
load_images=True, |
|
|
) |
|
|
|
|
|
if len(dataset) == 0: |
|
|
typer.echo( |
|
|
f"❌ No pre-processed sequences found in {preprocessed_cache_dir}", |
|
|
err=True, |
|
|
) |
|
|
typer.echo("Run 'ylff preprocess arkit' first to pre-process sequences.", err=True) |
|
|
raise typer.Exit(1) |
|
|
|
|
|
logger.info(f"Loaded {len(dataset)} pre-processed sequences") |
|
|
|
|
|
|
|
|
loss_weights = { |
|
|
"geometric_consistency": 3.0, |
|
|
"absolute_scale": 2.5, |
|
|
"pose_geometric": 2.0, |
|
|
"gradient_loss": 1.0, |
|
|
"teacher_consistency": 0.5, |
|
|
} |
|
|
|
|
|
|
|
|
logger.info("Starting unified YLFF training...") |
|
|
logger.info(f" Epochs: {epochs}") |
|
|
logger.info(f" Learning rate: {lr}") |
|
|
logger.info(f" Batch size: {batch_size}") |
|
|
logger.info(f" Loss weights: {loss_weights}") |
|
|
|
|
|
metrics = train_ylff( |
|
|
model=model, |
|
|
dataset=dataset, |
|
|
epochs=epochs, |
|
|
lr=lr, |
|
|
weight_decay=weight_decay, |
|
|
batch_size=batch_size, |
|
|
device=device, |
|
|
checkpoint_dir=checkpoint_dir, |
|
|
log_interval=log_interval, |
|
|
save_interval=save_interval, |
|
|
use_fp16=use_fp16, |
|
|
use_bf16=use_bf16, |
|
|
ema_decay=ema_decay, |
|
|
loss_weights=loss_weights, |
|
|
use_wandb=use_wandb, |
|
|
wandb_project=wandb_project, |
|
|
gradient_accumulation_steps=gradient_accumulation_steps, |
|
|
gradient_clip_norm=gradient_clip_norm, |
|
|
num_workers=num_workers, |
|
|
use_fsdp=use_fsdp, |
|
|
resume_from_checkpoint=resume_from_checkpoint, |
|
|
) |
|
|
|
|
|
logger.info(f"\n{'=' * 60}") |
|
|
logger.info("Training complete!") |
|
|
logger.info(f" Final loss: {metrics.get('total_loss', 0):.4f}") |
|
|
logger.info(f" Geometric consistency: {metrics.get('geometric_consistency', 0):.4f}") |
|
|
logger.info(f" Absolute scale: {metrics.get('absolute_scale', 0):.4f}") |
|
|
logger.info(f" Checkpoints: {checkpoint_dir}") |
|
|
logger.info(f"{'=' * 60}") |
|
|
|
|
|
typer.echo(f"\n✅ Training complete! Model saved to {checkpoint_dir}") |
|
|
|
|
|
|
|
|
@train_app.command("pretrain") |
|
|
def pretrain( |
|
|
arkit_sequences_dir: Path = typer.Argument( |
|
|
..., help="[DEPRECATED] Use 'ylff train unified' instead" |
|
|
), |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
[DEPRECATED] Pre-train DA3 model on ARKit data using BA as oracle teacher. |
|
|
|
|
|
⚠️ This command is deprecated. Use 'ylff train unified' instead. |
|
|
|
|
|
The unified training service provides better geometric accuracy and incorporates |
|
|
DINOv2 teacher-student learning with DA3 techniques. |
|
|
|
|
|
Migration: |
|
|
# OLD |
|
|
ylff train pretrain data/arkit_sequences --epochs 10 |
|
|
|
|
|
# NEW |
|
|
ylff preprocess arkit data/arkit_sequences --output-cache cache/preprocessed |
|
|
ylff train unified cache/preprocessed --epochs 200 |
|
|
""" |
|
|
typer.echo("⚠️ This command is deprecated. Use 'ylff train unified' instead.") |
|
|
typer.echo("\nThe unified training service provides:") |
|
|
typer.echo(" - DINOv2 teacher-student paradigm") |
|
|
typer.echo(" - Geometric consistency as first-order goal") |
|
|
typer.echo(" - DA3 techniques (depth-ray, multi-resolution)") |
|
|
typer.echo("\nTo migrate:") |
|
|
typer.echo(" 1. Pre-process your data: ylff preprocess arkit <dir>") |
|
|
typer.echo(" 2. Train with unified service: ylff train unified <cache_dir>") |
|
|
raise typer.Exit(1) |
|
|
|
|
|
|
|
|
@eval_app.command("ba-agreement") |
|
|
def evaluate_ba_agreement( |
|
|
test_data_dir: Path = typer.Argument(..., help="Directory containing test sequences"), |
|
|
model_name: str = typer.Option("depth-anything/DA3-LARGE", help="DA3 model name"), |
|
|
checkpoint: Optional[Path] = typer.Option(None, help="Checkpoint path (optional)"), |
|
|
threshold: float = typer.Option(2.0, help="Agreement threshold (degrees)"), |
|
|
device: str = typer.Option("cuda", help="Device for inference"), |
|
|
use_wandb: bool = typer.Option(True, help="Enable Weights & Biases logging"), |
|
|
wandb_project: str = typer.Option("ylff", help="W&B project name"), |
|
|
wandb_name: Optional[str] = typer.Option(None, help="W&B run name"), |
|
|
): |
|
|
"""Evaluate model agreement with BA.""" |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
from .services.ba_validator import BAValidator |
|
|
from .services.evaluate import evaluate_ba_agreement |
|
|
from .utils.model_loader import ( |
|
|
get_recommended_model, |
|
|
load_da3_model, |
|
|
load_model_from_checkpoint, |
|
|
) |
|
|
|
|
|
|
|
|
if model_name is None: |
|
|
model_name = get_recommended_model("ba_validation") |
|
|
logger.info(f"Auto-selected model for evaluation: {model_name}") |
|
|
|
|
|
|
|
|
logger.info(f"Loading model: {model_name}") |
|
|
model = load_da3_model(model_name, device=device, use_case="ba_validation") |
|
|
|
|
|
if checkpoint: |
|
|
logger.info(f"Loading checkpoint: {checkpoint}") |
|
|
model = load_model_from_checkpoint(model, checkpoint, device=device) |
|
|
|
|
|
|
|
|
validator = BAValidator( |
|
|
accept_threshold=threshold, |
|
|
reject_threshold=30.0, |
|
|
) |
|
|
|
|
|
|
|
|
sequence_paths = [p for p in test_data_dir.iterdir() if p.is_dir()] |
|
|
if not sequence_paths: |
|
|
typer.echo(f"Error: No sequences found in {test_data_dir}", err=True) |
|
|
raise typer.Exit(1) |
|
|
|
|
|
logger.info(f"Found {len(sequence_paths)} test sequences") |
|
|
|
|
|
|
|
|
metrics = evaluate_ba_agreement( |
|
|
model=model, |
|
|
sequences=sequence_paths, |
|
|
ba_validator=validator, |
|
|
threshold=threshold, |
|
|
use_wandb=use_wandb, |
|
|
wandb_project=wandb_project, |
|
|
wandb_name=wandb_name, |
|
|
) |
|
|
|
|
|
|
|
|
typer.echo("\n" + "=" * 60) |
|
|
typer.echo("Evaluation Results") |
|
|
typer.echo("=" * 60) |
|
|
typer.echo(f"BA Agreement Rate: {metrics['agreement_rate']:.2%}") |
|
|
typer.echo(f"Mean Rotation Error: {metrics['mean_rotation_error_deg']:.2f}°") |
|
|
typer.echo(f"Mean Translation Error: {metrics['mean_translation_error']:.4f} m") |
|
|
typer.echo(f"Total Sequences: {metrics['total_sequences']}") |
|
|
typer.echo(f"Agreed Sequences: {metrics['agreed_sequences']}") |
|
|
|
|
|
|
|
|
@app.command() |
|
|
def list_models( |
|
|
use_case: Optional[str] = typer.Option(None, help="Filter by use case"), |
|
|
): |
|
|
"""List available DA3 models and their characteristics.""" |
|
|
from .utils.model_loader import get_recommended_model, list_available_models |
|
|
|
|
|
models = list_available_models() |
|
|
|
|
|
if use_case: |
|
|
recommended = get_recommended_model(use_case) |
|
|
typer.echo(f"\nRecommended for '{use_case}': {recommended}\n") |
|
|
|
|
|
typer.echo("Available DA3 Models:\n") |
|
|
for name, info in models.items(): |
|
|
typer.echo(f" {name}") |
|
|
typer.echo(f" Series: {info['series']}") |
|
|
typer.echo(f" Description: {info['description']}") |
|
|
typer.echo(f" Metric: {info['metric']}") |
|
|
typer.echo(f" Capabilities: {', '.join(info['capabilities'])}") |
|
|
if info.get("recommended_for"): |
|
|
typer.echo(f" Recommended for: {', '.join(info['recommended_for'])}") |
|
|
typer.echo() |
|
|
|
|
|
|
|
|
@app.command() |
|
|
def visualize( |
|
|
results_dir: Path = typer.Argument(..., help="Directory containing validation results"), |
|
|
output_dir: Optional[Path] = typer.Option(None, help="Output directory for visualizations"), |
|
|
use_plotly: bool = typer.Option(True, help="Use plotly for interactive plots"), |
|
|
): |
|
|
"""Visualize BA validation results.""" |
|
|
import importlib.util |
|
|
import sys |
|
|
|
|
|
project_root = Path(__file__).parent.parent |
|
|
script_path = project_root / "scripts" / "tools" / "visualize_ba_results.py" |
|
|
|
|
|
spec = importlib.util.spec_from_file_location("visualize_ba_results", script_path) |
|
|
if spec is None or spec.loader is None: |
|
|
typer.echo(f"Error: Could not load script {script_path}", err=True) |
|
|
raise typer.Exit(1) |
|
|
module = importlib.util.module_from_spec(spec) |
|
|
spec.loader.exec_module(module) |
|
|
|
|
|
|
|
|
old_argv = sys.argv |
|
|
try: |
|
|
sys.argv = ["visualize_ba_results", "--results-dir", str(results_dir)] |
|
|
if output_dir: |
|
|
sys.argv.extend(["--output-dir", str(output_dir)]) |
|
|
if use_plotly: |
|
|
sys.argv.append("--use-plotly") |
|
|
module.main() |
|
|
finally: |
|
|
sys.argv = old_argv |
|
|
|
|
|
|
|
|
@preprocess_app.command("arkit") |
|
|
def preprocess_arkit( |
|
|
arkit_sequences_dir: Path = typer.Argument( |
|
|
..., help="Directory containing ARKit sequence directories" |
|
|
), |
|
|
output_cache_dir: Path = typer.Option( |
|
|
Path("cache/preprocessed"), |
|
|
help="Directory to save pre-processed results", |
|
|
), |
|
|
model_name: str = typer.Option( |
|
|
None, help="DA3 model name for initial inference (default: auto-select)" |
|
|
), |
|
|
device: str = typer.Option("cuda", help="Device for DA3 inference"), |
|
|
prefer_arkit_poses: bool = typer.Option( |
|
|
True, |
|
|
help="Use ARKit poses when tracking quality is good (skips BA, much faster)", |
|
|
), |
|
|
min_arkit_quality: float = typer.Option( |
|
|
0.8, |
|
|
help="Minimum fraction of frames with good tracking to use ARKit poses (0.0-1.0)", |
|
|
), |
|
|
use_lidar: bool = typer.Option(True, help="Include LiDAR depth in oracle uncertainty"), |
|
|
use_ba_depth: bool = typer.Option(False, help="Include BA depth in oracle uncertainty"), |
|
|
num_workers: int = typer.Option(4, help="Number of parallel workers for processing"), |
|
|
): |
|
|
""" |
|
|
Pre-process ARKit sequences: compute BA and oracle uncertainty offline. |
|
|
|
|
|
This runs OUTSIDE the training loop and can be parallelized. Results are |
|
|
cached to disk and loaded during training for fast iteration. |
|
|
|
|
|
Steps: |
|
|
1. Extract ARKit data (poses, LiDAR) - FREE |
|
|
2. Run DA3 inference (GPU, batchable) |
|
|
3. Run BA validation (CPU, expensive) - only if ARKit quality is poor |
|
|
4. Compute oracle uncertainty propagation |
|
|
5. Save to cache for training |
|
|
""" |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
|
|
|
from .services.ba_validator import BAValidator |
|
|
from .services.preprocessing import preprocess_arkit_sequence |
|
|
from .utils.model_loader import get_recommended_model, load_da3_model |
|
|
from .utils.oracle_uncertainty import OracleUncertaintyPropagator |
|
|
|
|
|
|
|
|
if model_name is None: |
|
|
model_name = get_recommended_model("ba_validation") |
|
|
logger.info(f"Auto-selected model: {model_name}") |
|
|
|
|
|
|
|
|
logger.info(f"Loading model: {model_name}") |
|
|
model = load_da3_model( |
|
|
model_name, |
|
|
device=device, |
|
|
use_case="ba_validation", |
|
|
compile_model=False, |
|
|
) |
|
|
|
|
|
|
|
|
ba_validator = BAValidator() |
|
|
oracle_propagator = OracleUncertaintyPropagator() |
|
|
|
|
|
|
|
|
|
|
|
arkit_dirs = sorted(list(set([ |
|
|
d.parent for d in arkit_sequences_dir.rglob("videos") if d.is_dir() |
|
|
]))) |
|
|
|
|
|
if not arkit_dirs: |
|
|
typer.echo(f"❌ No ARKit sequences found in {arkit_sequences_dir}") |
|
|
raise typer.Exit(1) |
|
|
|
|
|
logger.info(f"Found {len(arkit_dirs)} ARKit sequences") |
|
|
logger.info(f"Output cache: {output_cache_dir}") |
|
|
|
|
|
|
|
|
output_cache_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
results = [] |
|
|
if num_workers > 1: |
|
|
logger.info(f"Processing {len(arkit_dirs)} sequences with {num_workers} workers...") |
|
|
with ThreadPoolExecutor(max_workers=num_workers) as executor: |
|
|
futures = { |
|
|
executor.submit( |
|
|
preprocess_arkit_sequence, |
|
|
arkit_dir=arkit_dir, |
|
|
output_cache_dir=output_cache_dir, |
|
|
model=model, |
|
|
ba_validator=ba_validator, |
|
|
oracle_propagator=oracle_propagator, |
|
|
device=device, |
|
|
prefer_arkit_poses=prefer_arkit_poses, |
|
|
min_arkit_quality=min_arkit_quality, |
|
|
use_lidar=use_lidar, |
|
|
use_ba_depth=use_ba_depth, |
|
|
): arkit_dir |
|
|
for arkit_dir in arkit_dirs |
|
|
} |
|
|
|
|
|
for future in as_completed(futures): |
|
|
arkit_dir = futures[future] |
|
|
try: |
|
|
result = future.result() |
|
|
results.append(result) |
|
|
if result["status"] == "success": |
|
|
logger.info( |
|
|
f"✅ {arkit_dir.name}: {result['num_frames']} frames, " |
|
|
f"confidence={result['mean_confidence']:.2f}" |
|
|
) |
|
|
else: |
|
|
logger.warning(f"⚠️ {arkit_dir.name}: {result.get('reason', 'failed')}") |
|
|
except Exception as e: |
|
|
logger.error(f"❌ {arkit_dir.name}: {e}", exc_info=True) |
|
|
results.append( |
|
|
{"status": "failed", "sequence_id": arkit_dir.name, "error": str(e)} |
|
|
) |
|
|
else: |
|
|
logger.info(f"Processing {len(arkit_dirs)} sequences sequentially...") |
|
|
for arkit_dir in arkit_dirs: |
|
|
result = preprocess_arkit_sequence( |
|
|
arkit_dir=arkit_dir, |
|
|
output_cache_dir=output_cache_dir, |
|
|
model=model, |
|
|
ba_validator=ba_validator, |
|
|
oracle_propagator=oracle_propagator, |
|
|
device=device, |
|
|
prefer_arkit_poses=prefer_arkit_poses, |
|
|
min_arkit_quality=min_arkit_quality, |
|
|
use_lidar=use_lidar, |
|
|
use_ba_depth=use_ba_depth, |
|
|
) |
|
|
results.append(result) |
|
|
|
|
|
|
|
|
successful = sum(1 for r in results if r["status"] == "success") |
|
|
failed = len(results) - successful |
|
|
|
|
|
logger.info(f"\n{'=' * 60}") |
|
|
logger.info("Pre-processing complete!") |
|
|
logger.info(f" ✅ Successful: {successful}/{len(results)}") |
|
|
logger.info(f" ❌ Failed: {failed}/{len(results)}") |
|
|
logger.info(f" 📁 Cache directory: {output_cache_dir}") |
|
|
logger.info(f"{'=' * 60}") |
|
|
|
|
|
typer.echo(f"\n✅ Pre-processing complete! {successful}/{len(results)} sequences processed") |
|
|
typer.echo(f"📁 Results saved to: {output_cache_dir}") |
|
|
|
|
|
|
|
|
@teacher_app.command("run") |
|
|
def teacher_run( |
|
|
bundle_dir: Path = typer.Argument(..., help="Capture bundle directory"), |
|
|
output_dir: Optional[Path] = typer.Option(None, help="Override output directory"), |
|
|
device_id: Optional[str] = typer.Option( |
|
|
None, help="Device id (required for multi-device bundles)" |
|
|
), |
|
|
model_name: Optional[str] = typer.Option(None, help="Model name (defaults to metric model)"), |
|
|
device: str = typer.Option("cuda", help="Device for inference"), |
|
|
max_frames: Optional[int] = typer.Option(None, help="Max frames"), |
|
|
frame_interval: int = typer.Option(1, help="Extract every Nth frame"), |
|
|
): |
|
|
"""Run the offline teacher pipeline and write teacher_outputs/*.""" |
|
|
from .services.teacher_pipeline import TeacherConfig, run_teacher |
|
|
|
|
|
cfg = TeacherConfig( |
|
|
device_id=device_id, |
|
|
model_name=model_name, |
|
|
device=device, |
|
|
max_frames=max_frames, |
|
|
frame_interval=frame_interval, |
|
|
) |
|
|
result = run_teacher(bundle_dir=bundle_dir, output_dir=output_dir, config=cfg) |
|
|
typer.echo(json.dumps(result, indent=2)) |
|
|
|
|
|
|
|
|
@infer_app.command("run") |
|
|
def infer_run( |
|
|
input_path: Path = typer.Argument(..., help="Video file or capture bundle directory"), |
|
|
output_dir: Path = typer.Argument(..., help="Output directory"), |
|
|
device_id: Optional[str] = typer.Option( |
|
|
None, help="Device id (bundle-only; required for multi-device)" |
|
|
), |
|
|
model_name: Optional[str] = typer.Option(None, help="Model name (defaults to metric model)"), |
|
|
device: str = typer.Option("cuda", help="Device for inference"), |
|
|
max_frames: Optional[int] = typer.Option(60, help="Max frames"), |
|
|
frame_interval: int = typer.Option(2, help="Extract every Nth frame"), |
|
|
enable_gtsam_ba: bool = typer.Option( |
|
|
True, help="Run GTSAM BA with ray-depth priors if available" |
|
|
), |
|
|
): |
|
|
"""Run metrology inference pipeline.""" |
|
|
from .services.inference_pipeline import InferenceConfig, run_inference |
|
|
|
|
|
cfg = InferenceConfig( |
|
|
device_id=device_id, |
|
|
model_name=model_name, |
|
|
device=device, |
|
|
max_frames=max_frames, |
|
|
frame_interval=frame_interval, |
|
|
enable_gtsam_ba=enable_gtsam_ba, |
|
|
) |
|
|
meta = run_inference(input_path=input_path, output_dir=output_dir, config=cfg) |
|
|
typer.echo(json.dumps(meta, indent=2)) |
|
|
|
|
|
|
|
|
@audit_app.command("run") |
|
|
def audit_run( |
|
|
measurements_json: Path = typer.Argument(..., help="External reference measurements JSON"), |
|
|
calibrate: bool = typer.Option(True, help="Fit affine σ calibration before auditing"), |
|
|
calibration_split_fraction: float = typer.Option( |
|
|
0.5, help="Fraction used for calibration fit" |
|
|
), |
|
|
): |
|
|
"""Run audit gates and (optional) σ calibration.""" |
|
|
from .services.audit.audit_runner import load_measurements_json, run_audit |
|
|
|
|
|
ms = load_measurements_json(measurements_json) |
|
|
result = run_audit( |
|
|
ms, calibrate=calibrate, calibration_split_fraction=calibration_split_fraction |
|
|
) |
|
|
typer.echo(result.model_dump_json(indent=2)) |
|
|
|
|
|
|
|
|
@catalog_app.command("build_s3") |
|
|
def catalog_build_s3( |
|
|
bucket: str = typer.Argument(..., help="S3 bucket containing capture bundles"), |
|
|
prefix: str = typer.Argument(..., help="S3 prefix under which manifests live"), |
|
|
output_json: Path = typer.Option( |
|
|
Path("data/orchestrator/outputs/scene_catalog.json"), |
|
|
help="Where to write the catalog JSON", |
|
|
), |
|
|
output_jsonl: Optional[Path] = typer.Option( |
|
|
None, help="Optional path to also write catalog.jsonl (one scene per line)" |
|
|
), |
|
|
output_report_json: Optional[Path] = typer.Option( |
|
|
None, help="Optional path to also write a validation report JSON" |
|
|
), |
|
|
region: Optional[str] = typer.Option(None, help="AWS region (optional)"), |
|
|
endpoint_url: Optional[str] = typer.Option(None, help="S3 endpoint URL (optional)"), |
|
|
): |
|
|
"""Build a scene catalog by listing manifest.json objects under an S3 prefix.""" |
|
|
from .services.scene_catalog import ( |
|
|
build_scene_catalog, |
|
|
list_manifest_uris_s3, |
|
|
validate_scene_catalog, |
|
|
write_scene_catalog, |
|
|
write_scene_catalog_jsonl, |
|
|
) |
|
|
|
|
|
uris = list_manifest_uris_s3( |
|
|
bucket=bucket, prefix=prefix, s3_region=region, s3_endpoint_url=endpoint_url |
|
|
) |
|
|
cat = build_scene_catalog(uris, s3_region=region, s3_endpoint_url=endpoint_url) |
|
|
write_scene_catalog(cat, output_json) |
|
|
if output_jsonl is not None: |
|
|
write_scene_catalog_jsonl(cat, output_jsonl) |
|
|
if output_report_json is not None: |
|
|
report = validate_scene_catalog(cat) |
|
|
output_report_json.parent.mkdir(parents=True, exist_ok=True) |
|
|
output_report_json.write_text(json.dumps(report, indent=2, sort_keys=True)) |
|
|
typer.echo(cat.model_dump_json(indent=2)) |
|
|
|
|
|
|
|
|
@orchestrate_app.command("backfill") |
|
|
def orchestrate_backfill( |
|
|
catalog_json: Optional[Path] = typer.Option( |
|
|
None, help="Optional catalog JSON path (if omitted, list from S3)" |
|
|
), |
|
|
s3_bucket: Optional[str] = typer.Option(None, help="S3 bucket (if no catalog_json)"), |
|
|
s3_prefix: Optional[str] = typer.Option(None, help="S3 prefix (if no catalog_json)"), |
|
|
stage: str = typer.Option("teacher", help="Stage to run: teacher (default)"), |
|
|
device: str = typer.Option("cuda", help="Device string passed to pipelines"), |
|
|
model_name: Optional[str] = typer.Option(None, help="Optional model name override"), |
|
|
work_dir: Path = typer.Option(Path("data/orchestrator/work"), help="Local work dir"), |
|
|
output_root: Path = typer.Option(Path("data/orchestrator/outputs"), help="Output root dir"), |
|
|
max_scenes: Optional[int] = typer.Option(None, help="Limit number of scenes (debug)"), |
|
|
region: Optional[str] = typer.Option(None, help="AWS region (optional)"), |
|
|
endpoint_url: Optional[str] = typer.Option(None, help="S3 endpoint URL (optional)"), |
|
|
upload_bucket: Optional[str] = typer.Option( |
|
|
None, help="Optional S3 bucket for derived outputs" |
|
|
), |
|
|
upload_base_prefix: str = typer.Option("ylff", help="Base prefix for derived outputs"), |
|
|
pipeline_version: str = typer.Option("v1", help="Pipeline version stamp for derived outputs"), |
|
|
): |
|
|
"""Run a single-node backfill loop over a catalog or S3 prefix.""" |
|
|
from .services.orchestration.runner import BackfillConfig, run_backfill |
|
|
|
|
|
res = run_backfill( |
|
|
BackfillConfig( |
|
|
catalog_json=catalog_json, |
|
|
s3_bucket=s3_bucket, |
|
|
s3_prefix=s3_prefix, |
|
|
s3_region=region, |
|
|
s3_endpoint_url=endpoint_url, |
|
|
stage=stage, |
|
|
device=device, |
|
|
model_name=model_name, |
|
|
work_dir=work_dir, |
|
|
output_root=output_root, |
|
|
max_scenes=max_scenes, |
|
|
upload_bucket=upload_bucket, |
|
|
upload_base_prefix=upload_base_prefix, |
|
|
pipeline_version=pipeline_version, |
|
|
) |
|
|
) |
|
|
typer.echo(json.dumps(res, indent=2)) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
app() |
|
|
|