3d_model / ylff /services /training /train_student.py
Azan
Clean deployment build (Squashed)
7a87926
"""
Training loop for MetricDepthWithUncertainty student model.
This is a minimal reference implementation; it is not optimized for scale yet.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Optional, Sequence
try:
import torch # type: ignore[import-not-found]
except Exception: # pragma: no cover
torch = None # type: ignore
from ...models.metric_depth_with_uncertainty import MetricDepthWithUncertainty
from .dataset import TeacherSupervisedTemporalDataset
from .losses import compute_losses
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class TrainConfig:
temporal_window: int = 5
epochs: int = 1
batch_size: int = 1
lr: float = 2e-4
device: str = "cuda"
num_workers: int = 0
checkpoint_dir: Path = Path("checkpoints/student_depth_unc")
def train_student(
bundle_dirs: Sequence[Path],
*,
config: Optional[TrainConfig] = None,
) -> Dict[str, float]:
config = config or TrainConfig()
if torch is None: # pragma: no cover
raise ImportError("train_student requires torch to be installed")
from torch.utils.data import DataLoader # type: ignore[import-not-found]
ds = TeacherSupervisedTemporalDataset(bundle_dirs, temporal_window=config.temporal_window)
if len(ds) == 0:
raise ValueError("No training samples found (missing teacher outputs?)")
dl = DataLoader(
ds,
batch_size=int(config.batch_size),
shuffle=True,
num_workers=int(config.num_workers),
pin_memory=(config.device == "cuda"),
)
model = MetricDepthWithUncertainty(temporal_window=config.temporal_window).to(config.device)
opt = torch.optim.AdamW(model.parameters(), lr=float(config.lr), weight_decay=0.01)
config.checkpoint_dir.mkdir(parents=True, exist_ok=True)
step = 0
last = {}
for epoch in range(int(config.epochs)):
model.train()
for batch in dl:
frames = batch["frames"].to(config.device) # (B,T,3,H,W)
depth_gt = batch["depth"].to(config.device) # (B,H,W)
sigma_gt = batch["sigma"].to(config.device) # (B,H,W)
sigma_w = batch.get("sigma_weight")
if sigma_w is not None:
sigma_w = sigma_w.to(config.device)
lidar = batch.get("lidar_depth")
if lidar is not None:
lidar = lidar.to(config.device)
out = model(frames)
losses = compute_losses(
depth_pred=out.depth,
log_sigma_pred=out.log_sigma,
depth_gt=depth_gt,
sigma_teacher=sigma_gt,
lidar_depth=lidar,
sigma_teacher_weight=sigma_w,
)
opt.zero_grad(set_to_none=True)
losses.total.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
last = {
"loss_total": float(losses.total.detach().cpu().item()),
"loss_nll": float(losses.nll.detach().cpu().item()),
"loss_depth": float(losses.depth.detach().cpu().item()),
"loss_lidar": float(losses.lidar.detach().cpu().item()),
"loss_sigma_supervision": float(losses.sigma_supervision.detach().cpu().item()),
}
step += 1
ckpt_path = config.checkpoint_dir / f"epoch_{epoch:04d}.pt"
torch.save(
{"model": model.state_dict(), "config": config.__dict__, "last_metrics": last},
ckpt_path,
)
logger.info(f"Saved checkpoint: {ckpt_path}")
return last