| """ |
| Training loop for MetricDepthWithUncertainty student model. |
| |
| This is a minimal reference implementation; it is not optimized for scale yet. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Dict, Optional, Sequence |
|
|
| try: |
| import torch |
| except Exception: |
| torch = None |
|
|
| from ...models.metric_depth_with_uncertainty import MetricDepthWithUncertainty |
| from .dataset import TeacherSupervisedTemporalDataset |
| from .losses import compute_losses |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass(frozen=True) |
| class TrainConfig: |
| temporal_window: int = 5 |
| epochs: int = 1 |
| batch_size: int = 1 |
| lr: float = 2e-4 |
| device: str = "cuda" |
| num_workers: int = 0 |
| checkpoint_dir: Path = Path("checkpoints/student_depth_unc") |
|
|
|
|
| def train_student( |
| bundle_dirs: Sequence[Path], |
| *, |
| config: Optional[TrainConfig] = None, |
| ) -> Dict[str, float]: |
| config = config or TrainConfig() |
| if torch is None: |
| raise ImportError("train_student requires torch to be installed") |
|
|
| from torch.utils.data import DataLoader |
|
|
| ds = TeacherSupervisedTemporalDataset(bundle_dirs, temporal_window=config.temporal_window) |
| if len(ds) == 0: |
| raise ValueError("No training samples found (missing teacher outputs?)") |
|
|
| dl = DataLoader( |
| ds, |
| batch_size=int(config.batch_size), |
| shuffle=True, |
| num_workers=int(config.num_workers), |
| pin_memory=(config.device == "cuda"), |
| ) |
|
|
| model = MetricDepthWithUncertainty(temporal_window=config.temporal_window).to(config.device) |
| opt = torch.optim.AdamW(model.parameters(), lr=float(config.lr), weight_decay=0.01) |
|
|
| config.checkpoint_dir.mkdir(parents=True, exist_ok=True) |
|
|
| step = 0 |
| last = {} |
| for epoch in range(int(config.epochs)): |
| model.train() |
| for batch in dl: |
| frames = batch["frames"].to(config.device) |
| depth_gt = batch["depth"].to(config.device) |
| sigma_gt = batch["sigma"].to(config.device) |
| sigma_w = batch.get("sigma_weight") |
| if sigma_w is not None: |
| sigma_w = sigma_w.to(config.device) |
| lidar = batch.get("lidar_depth") |
| if lidar is not None: |
| lidar = lidar.to(config.device) |
|
|
| out = model(frames) |
| losses = compute_losses( |
| depth_pred=out.depth, |
| log_sigma_pred=out.log_sigma, |
| depth_gt=depth_gt, |
| sigma_teacher=sigma_gt, |
| lidar_depth=lidar, |
| sigma_teacher_weight=sigma_w, |
| ) |
|
|
| opt.zero_grad(set_to_none=True) |
| losses.total.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| opt.step() |
|
|
| last = { |
| "loss_total": float(losses.total.detach().cpu().item()), |
| "loss_nll": float(losses.nll.detach().cpu().item()), |
| "loss_depth": float(losses.depth.detach().cpu().item()), |
| "loss_lidar": float(losses.lidar.detach().cpu().item()), |
| "loss_sigma_supervision": float(losses.sigma_supervision.detach().cpu().item()), |
| } |
| step += 1 |
|
|
| ckpt_path = config.checkpoint_dir / f"epoch_{epoch:04d}.pt" |
| torch.save( |
| {"model": model.state_dict(), "config": config.__dict__, "last_metrics": last}, |
| ckpt_path, |
| ) |
| logger.info(f"Saved checkpoint: {ckpt_path}") |
|
|
| return last |
|
|