| """Train TactileVAE with PyTorch Lightning. |
| |
| Run: |
| python tactile_vae/script/train_vae_pl.py --config tactile_vae/config/train_vae.yaml |
| |
| Same YAML config format as train_vae.py. |
| |
| Checkpoints written to <output_dir>/: |
| ckpt_best.pt / ckpt_last.pt / ckpt_step_*.pt — original format (TactileVAEWrapper compat) |
| checkpoints/last.ckpt — Lightning format (full resume with trainer state) |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import datetime as dt |
| import math |
| import os |
| import random |
| import sys |
| from pathlib import Path |
| from typing import Any |
|
|
| import numpy as np |
| import pytorch_lightning as pl |
| import torch |
| import yaml |
| from PIL import Image |
| from pytorch_lightning.callbacks import ModelCheckpoint |
| from pytorch_lightning.loggers import CSVLogger |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.optim.lr_scheduler import LambdaLR |
| from torch.utils.data import DataLoader |
|
|
| _REPO_ROOT = Path(__file__).resolve().parents[2] |
| if str(_REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(_REPO_ROOT)) |
|
|
| from tactile_vae.dataset import ColorJitterConfig, ParquetFileShuffleSampler, TactileParquetDataset |
| from tactile_vae.model import TactileVAE, VAELoss |
|
|
|
|
| |
| |
| |
|
|
| def _resolve_path(p: str | None) -> Path | None: |
| if p is None: |
| return None |
| path = Path(p) |
| return path if path.is_absolute() else (_REPO_ROOT / path).resolve() |
|
|
|
|
| def _autogenerate_run_id() -> str: |
| return "run_" + dt.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
|
|
|
|
| def load_config(path: Path) -> dict: |
| with path.open() as f: |
| cfg = yaml.safe_load(f) |
| if not cfg.get("run_id"): |
| cfg["run_id"] = _autogenerate_run_id() |
| if cfg.get("output_dir"): |
| cfg["output_dir"] = str(_resolve_path(cfg["output_dir"])) |
| else: |
| runs_root = _resolve_path(cfg.get("runs_root", "runs")) |
| cfg["output_dir"] = str(runs_root / cfg["run_id"]) |
| cfg["data"]["root"] = str(_resolve_path(cfg["data"]["root"])) |
| if cfg["data"].get("splits_path"): |
| cfg["data"]["splits_path"] = str(_resolve_path(cfg["data"]["splits_path"])) |
| if cfg["train"].get("resume_from"): |
| cfg["train"]["resume_from"] = str(_resolve_path(cfg["train"]["resume_from"])) |
| return cfg |
|
|
|
|
| def _maybe_autoresume(cfg: dict, *, allow_autoresume: bool) -> None: |
| if cfg["train"].get("resume_from") or not allow_autoresume: |
| return |
| |
| last_ckpt = Path(cfg["output_dir"]) / "checkpoints" / "last.ckpt" |
| if last_ckpt.exists(): |
| cfg["train"]["resume_from"] = str(last_ckpt) |
| return |
| last_pt = Path(cfg["output_dir"]) / "ckpt_last.pt" |
| if last_pt.exists(): |
| cfg["train"]["resume_from"] = str(last_pt) |
|
|
|
|
| def set_seed(seed: int) -> None: |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def build_datasets(data_cfg: dict) -> tuple[TactileParquetDataset, TactileParquetDataset]: |
| common = dict( |
| root=data_cfg["root"], |
| image_size=data_cfg["image_size"], |
| cache_files=data_cfg.get("cache_files", 1), |
| splits_path=data_cfg.get("splits_path"), |
| return_meta=data_cfg.get("return_meta", False), |
| ) |
| if data_cfg.get("meta_columns"): |
| common["meta_columns"] = data_cfg["meta_columns"] |
| jitter_cfg = data_cfg.get("color_jitter") |
| color_jitter = ColorJitterConfig(**jitter_cfg) if jitter_cfg else None |
| train_ds = TactileParquetDataset(split="train", color_jitter=color_jitter, **common) |
| val_ds = TactileParquetDataset(split="val", color_jitter=None, **common) |
| return train_ds, val_ds |
|
|
|
|
| def lr_at_step(step: int, base_lr: float, total_steps: int, sched_cfg: dict) -> float: |
| warmup = int(sched_cfg.get("warmup_steps", 0)) |
| sched = sched_cfg.get("type", "constant") |
| if step < warmup: |
| return base_lr * (step + 1) / max(1, warmup) |
| if sched == "constant": |
| return base_lr |
| if sched == "cosine": |
| min_ratio = float(sched_cfg.get("min_lr_ratio", 0.1)) |
| progress = (step - warmup) / max(1, total_steps - warmup) |
| progress = min(max(progress, 0.0), 1.0) |
| return base_lr * (min_ratio + (1 - min_ratio) * 0.5 * (1.0 + math.cos(math.pi * progress))) |
| raise ValueError(f"unknown scheduler type: {sched!r}") |
|
|
|
|
| |
| |
| |
|
|
|
|
| class ConfigurablePerceptualVAELoss(nn.Module): |
| """VAE loss with configurable perceptual term: SSIM or LPIPS.""" |
|
|
| def __init__(self, loss_cfg: dict): |
| super().__init__() |
| self.perceptual_type = str(loss_cfg.get("perceptual_type", "ssim")).lower() |
| if self.perceptual_type not in {"ssim", "lpips"}: |
| raise ValueError( |
| f"loss.perceptual_type must be one of [ssim, lpips], got: {self.perceptual_type!r}" |
| ) |
| self.aux_key = self.perceptual_type |
| self.ssim_impl: VAELoss | None = None |
| self.lpips_impl: nn.Module | None = None |
|
|
| if self.perceptual_type == "ssim": |
| self.ssim_impl = VAELoss(**loss_cfg) |
| else: |
| self.beta = float(loss_cfg.get("beta", 1e-3)) |
| self.recon_type = str(loss_cfg.get("recon_type", "l1")).lower() |
| self.lpips_weight = float(loss_cfg.get("lpips_weight", loss_cfg.get("ssim_weight", 0.1))) |
| try: |
| import lpips |
| except ImportError as exc: |
| raise ImportError( |
| "LPIPS loss requested but `lpips` is not installed. " |
| "Install with: pip install lpips" |
| ) from exc |
| self.lpips_impl = lpips.LPIPS(net="alex") |
| self.lpips_impl.eval() |
| for p in self.lpips_impl.parameters(): |
| p.requires_grad = False |
|
|
| def forward(self, x_hat: torch.Tensor, x: torch.Tensor, mu: torch.Tensor, logvar: torch.Tensor) -> dict[str, torch.Tensor]: |
| if self.perceptual_type == "ssim": |
| assert self.ssim_impl is not None |
| return self.ssim_impl(x_hat, x, mu, logvar) |
|
|
| if self.recon_type == "l1": |
| recon = F.l1_loss(x_hat, x) |
| elif self.recon_type == "mse": |
| recon = F.mse_loss(x_hat, x) |
| else: |
| raise ValueError(f"loss.recon_type must be one of [l1, mse], got: {self.recon_type!r}") |
|
|
| |
| with torch.amp.autocast(device_type=x_hat.device.type, enabled=False): |
| x_hat_lp = (2.0 * x_hat.float()) - 1.0 |
| x_lp = (2.0 * x.float()) - 1.0 |
| assert self.lpips_impl is not None |
| lpips_val = self.lpips_impl(x_hat_lp, x_lp).mean() |
| recon_total = recon + self.lpips_weight * lpips_val |
| kl = (-0.5 * (1 + logvar - mu.pow(2) - logvar.exp())).mean() |
| total = recon_total + self.beta * kl |
| return { |
| "total": total, |
| "recon": recon, |
| "recon_total": recon_total, |
| "lpips": lpips_val, |
| "kl": kl, |
| } |
|
|
|
|
| class TactileVAEModule(pl.LightningModule): |
| def __init__(self, config: dict, *, step_offset: int = 0, total_steps: int = 0): |
| super().__init__() |
| self.config = config |
| self.step_offset = int(step_offset) |
| self.total_steps = int(total_steps) |
| self.model = TactileVAE(**config["model"]) |
| self.criterion = ConfigurablePerceptualVAELoss(config["loss"]) |
|
|
| def forward(self, x, **kw): |
| return self.model(x, **kw) |
|
|
| def training_step(self, batch, batch_idx): |
| x = batch |
| out = self.model(x) |
| losses = self.criterion(out["x_hat"], x, out["mu"], out["logvar"]) |
| if not torch.isfinite(losses["total"]).item(): |
| print( |
| f"[warn] non-finite loss at step={self.trainer.global_step + self.step_offset + 1}, " |
| f"epoch={self.trainer.current_epoch}; skipping optimizer step" |
| ) |
| return None |
| self.log("train/total", losses["total"], prog_bar=True, on_step=True, on_epoch=False, batch_size=x.shape[0]) |
| self.log_dict( |
| {f"train/{k}": v for k, v in losses.items() if k != "total"}, |
| on_step=True, on_epoch=False, batch_size=x.shape[0], |
| ) |
| return losses["total"] |
|
|
| @torch.no_grad() |
| def validation_step(self, batch, batch_idx): |
| x = batch |
| out = self.model(x, sample=False) |
| losses = self.criterion(out["x_hat"], x, out["mu"], out["logvar"]) |
| self.log_dict( |
| {f"val/{k}": v for k, v in losses.items()}, |
| on_step=False, on_epoch=True, batch_size=x.shape[0], |
| ) |
|
|
| def configure_optimizers(self): |
| optim_cfg = self.config["optim"] |
| optimizer = torch.optim.AdamW( |
| self.model.parameters(), |
| lr=optim_cfg["lr"], |
| weight_decay=optim_cfg.get("weight_decay", 0.0), |
| betas=tuple(optim_cfg.get("betas", (0.9, 0.95))), |
| eps=optim_cfg.get("eps", 1e-8), |
| ) |
| base_lr = float(optim_cfg["lr"]) |
| sched_cfg = self.config["scheduler"] |
| scheduler = LambdaLR( |
| optimizer, |
| lr_lambda=lambda step: lr_at_step( |
| step + self.step_offset, base_lr, self.total_steps, sched_cfg |
| ) / base_lr, |
| ) |
| return { |
| "optimizer": optimizer, |
| "lr_scheduler": {"scheduler": scheduler, "interval": "step", "frequency": 1}, |
| } |
|
|
|
|
| |
| |
| |
|
|
| class TactileVAEDataModule(pl.LightningDataModule): |
| def __init__(self, config: dict): |
| super().__init__() |
| self.config = config |
| self.train_ds: TactileParquetDataset | None = None |
| self.val_ds: TactileParquetDataset | None = None |
| self.train_sampler: ParquetFileShuffleSampler | None = None |
|
|
| def setup(self, stage: str | None = None): |
| if self.train_ds is not None: |
| return |
| self.train_ds, self.val_ds = build_datasets(self.config["data"]) |
| self.train_sampler = ParquetFileShuffleSampler(self.train_ds, seed=self.config["seed"]) |
|
|
| def train_dataloader(self): |
| tcfg = self.config["train"] |
| return DataLoader( |
| self.train_ds, |
| batch_size=tcfg["batch_size"], |
| sampler=self.train_sampler, |
| num_workers=tcfg["num_workers"], |
| pin_memory=True, |
| drop_last=True, |
| persistent_workers=tcfg["num_workers"] > 0, |
| prefetch_factor=2 if tcfg["num_workers"] > 0 else None, |
| ) |
|
|
| def val_dataloader(self): |
| tcfg = self.config["train"] |
| return DataLoader( |
| self.val_ds, |
| batch_size=tcfg["batch_size"], |
| shuffle=False, |
| num_workers=max(2, tcfg["num_workers"] // 2), |
| pin_memory=True, |
| drop_last=False, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| class SetEpochCallback(pl.Callback): |
| """Keeps ParquetFileShuffleSampler epoch-aware for proper per-epoch shuffling.""" |
|
|
| def __init__(self, *, epoch_offset: int = 0): |
| self.epoch_offset = int(epoch_offset) |
|
|
| def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: |
| dm = trainer.datamodule |
| if hasattr(dm, "train_sampler") and hasattr(dm.train_sampler, "set_epoch"): |
| dm.train_sampler.set_epoch(trainer.current_epoch + self.epoch_offset) |
|
|
|
|
| class SampleGridCallback(pl.Callback): |
| """Saves a top=original / bottom=reconstruction image grid every N steps.""" |
|
|
| def __init__(self, config: dict, *, step_offset: int = 0): |
| self.sample_every = config["train"]["sample_every_steps"] |
| self.n = config["train"]["num_sample_images"] |
| self.out_dir = Path(config["output_dir"]) / "samples" |
| self.rng = np.random.default_rng(config["seed"] + 1) |
| self.step_offset = int(step_offset) |
|
|
| def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): |
| effective_step = trainer.global_step + self.step_offset |
| if effective_step > 0 and effective_step % self.sample_every == 0: |
| self._save_grid(trainer, pl_module, effective_step) |
|
|
| @torch.no_grad() |
| def _save_grid(self, trainer, pl_module, step): |
| val_ds = trainer.datamodule.val_ds |
| device = pl_module.device |
| self.out_dir.mkdir(parents=True, exist_ok=True) |
| idx = self.rng.choice(len(val_ds), size=self.n, replace=False).tolist() |
| imgs = torch.stack([val_ds[i] for i in idx]).to(device) |
| pl_module.eval() |
| recon = pl_module.model(imgs, sample=False)["x_hat"] |
| pl_module.train() |
| h = w = val_ds.image_size |
| canvas = np.zeros((2 * h, self.n * w, 3), dtype=np.uint8) |
| for i in range(self.n): |
| orig = (imgs[i].cpu().clamp(0, 1).permute(1, 2, 0).numpy() * 255).astype(np.uint8) |
| rec = (recon[i].cpu().clamp(0, 1).permute(1, 2, 0).numpy() * 255).astype(np.uint8) |
| canvas[:h, i * w:(i + 1) * w] = orig |
| canvas[h:, i * w:(i + 1) * w] = rec |
| Image.fromarray(canvas).save(self.out_dir / f"step_{step:07d}.png") |
|
|
|
|
| class CompatCheckpointCallback(pl.Callback): |
| """Saves ckpt_last.pt / ckpt_step_*.pt / ckpt_best.pt in the original format |
| so that TactileVAEWrapper.load_pretrained keeps working unchanged.""" |
|
|
| def __init__( |
| self, |
| config: dict, |
| *, |
| step_offset: int = 0, |
| epoch_offset: int = 0, |
| initial_best_val_metric: float = float("inf"), |
| ): |
| self.config = config |
| self.out_dir = Path(config["output_dir"]) |
| self.ckpt_every = config["train"]["ckpt_every_steps"] |
| self.keep_last = config["train"]["keep_last_ckpts"] |
| self.best_metric = config["train"].get("best_metric", "val/total") |
| self.best_val_metric = float(initial_best_val_metric) |
| self.step_offset = int(step_offset) |
| self.epoch_offset = int(epoch_offset) |
|
|
| def _build_payload(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> dict: |
| |
| sd = {k[len("model."):]: v for k, v in pl_module.state_dict().items() if k.startswith("model.")} |
| return { |
| "state_dict": sd, |
| "optimizer": trainer.optimizers[0].state_dict(), |
| "step": trainer.global_step + self.step_offset, |
| "epoch": trainer.current_epoch + self.epoch_offset, |
| "config": self.config, |
| "best_val_metric": self.best_val_metric, |
| "best_metric_name": self.best_metric, |
| "best_val_recon": self.best_val_metric, |
| } |
|
|
| def _save(self, path: Path, trainer, pl_module) -> None: |
| path.parent.mkdir(parents=True, exist_ok=True) |
| tmp = path.with_suffix(path.suffix + ".tmp") |
| torch.save(self._build_payload(trainer, pl_module), tmp) |
| os.replace(tmp, path) |
|
|
| def _rotate(self) -> None: |
| ckpts = sorted(self.out_dir.glob("ckpt_step_*.pt")) |
| while len(ckpts) > self.keep_last: |
| ckpts.pop(0).unlink(missing_ok=True) |
|
|
| def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): |
| effective_step = trainer.global_step + self.step_offset |
| if effective_step > 0 and effective_step % self.ckpt_every == 0: |
| self._save(self.out_dir / f"ckpt_step_{effective_step:07d}.pt", trainer, pl_module) |
| self._save(self.out_dir / "ckpt_last.pt", trainer, pl_module) |
| self._rotate() |
| print(f" saved ckpt_step_{effective_step:07d}.pt") |
|
|
| def on_validation_epoch_end(self, trainer, pl_module): |
| val = float(trainer.callback_metrics.get(self.best_metric, float("inf"))) |
| if val < self.best_val_metric: |
| self.best_val_metric = val |
| self._save(self.out_dir / "ckpt_best.pt", trainer, pl_module) |
| print(f" -> new best {self.best_metric}={val:.4f}, saved ckpt_best.pt") |
|
|
| def on_train_end(self, trainer, pl_module): |
| self._save(self.out_dir / "ckpt_last.pt", trainer, pl_module) |
|
|
|
|
| class CompatResumeStateCallback(pl.Callback): |
| """Loads optimizer state from compat .pt resume checkpoints.""" |
|
|
| def __init__(self, optim_state: dict[str, Any] | None): |
| self.optim_state = optim_state |
|
|
| def on_fit_start(self, trainer, pl_module): |
| if self.optim_state is None: |
| return |
| if not trainer.optimizers: |
| return |
| trainer.optimizers[0].load_state_dict(self.optim_state) |
| print("loaded optimizer state from compat checkpoint") |
|
|
|
|
| |
| |
| |
|
|
| def parse_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser() |
| p.add_argument("--config", type=Path, |
| default=Path(__file__).resolve().parents[1] / "config" / "train_vae.yaml") |
| p.add_argument("--run-id", type=str, default=None, |
| help="override run_id (output_dir = runs_root/<run-id>)") |
| p.add_argument("--output-dir", type=str, default=None, |
| help="override output_dir directly") |
| p.add_argument("--resume-from", type=str, default=None, |
| help="path to .ckpt (Lightning) or .pt (compat) checkpoint") |
| p.add_argument("--no-resume", action="store_true", |
| help="start fresh even if ckpt_last.pt / last.ckpt exists") |
| return p.parse_args() |
|
|
|
|
| def _init_loggers(cfg: dict, out_dir: Path) -> list[Any]: |
| loggers: list[Any] = [CSVLogger(str(out_dir), name="", version="")] |
| if os.environ.get("WANDB_PROJECT"): |
| try: |
| from pytorch_lightning.loggers import WandbLogger |
| loggers.append(WandbLogger( |
| project=os.environ["WANDB_PROJECT"], |
| entity=os.environ.get("WANDB_ENTITY"), |
| id=os.environ.get("WANDB_RUN_ID") or cfg["run_id"], |
| name=os.environ.get("WANDB_NAME") or cfg["run_id"], |
| save_dir=str(out_dir), |
| config=cfg, |
| )) |
| except ImportError: |
| print("wandb not available — logging disabled") |
| return loggers |
|
|
|
|
| def _build_trainer( |
| cfg: dict, |
| *, |
| callbacks: list[pl.Callback], |
| loggers: list[Any], |
| precision: str, |
| resume_from: str | None, |
| resume_step_offset: int, |
| total_steps: int, |
| ) -> pl.Trainer: |
| tcfg = cfg["train"] |
| trainer_kwargs: dict[str, Any] = { |
| "accelerator": "auto", |
| "devices": 1, |
| "precision": precision, |
| "callbacks": callbacks, |
| "logger": loggers, |
| "limit_val_batches": tcfg["num_val_batches"], |
| "val_check_interval": tcfg["val_every_steps"], |
| "check_val_every_n_epoch": None, |
| "log_every_n_steps": tcfg["log_every"], |
| "gradient_clip_val": tcfg.get("gradient_clip_norm") or None, |
| "num_sanity_val_steps": 0, |
| "default_root_dir": str(cfg["output_dir"]), |
| } |
| if resume_from and Path(resume_from).suffix != ".ckpt": |
| remaining_steps = max(0, total_steps - resume_step_offset) |
| trainer_kwargs["max_steps"] = remaining_steps |
| print(f"compat resume remaining_steps={remaining_steps}") |
| elif tcfg.get("max_steps"): |
| trainer_kwargs["max_steps"] = tcfg["max_steps"] |
| else: |
| trainer_kwargs["max_epochs"] = tcfg["epochs"] |
| return pl.Trainer(**trainer_kwargs) |
|
|
|
|
| def main(cfg: dict) -> None: |
| set_seed(cfg["seed"]) |
| out_dir = Path(cfg["output_dir"]) |
| out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| resume_step_offset = 0 |
| resume_epoch_offset = 0 |
| resume_optimizer_state: dict[str, Any] | None = None |
| resume_best_val_metric = float("inf") |
| resume_from = cfg["train"].get("resume_from") |
| if resume_from and Path(resume_from).suffix != ".ckpt": |
| compat = torch.load(str(resume_from), map_location="cpu", weights_only=False) |
| resume_step_offset = int(compat.get("step", 0)) |
| resume_epoch_offset = int(compat.get("epoch", 0)) |
| resume_optimizer_state = compat.get("optimizer") |
| resume_best_val_metric = float( |
| compat.get("best_val_metric", compat.get("best_val_recon", float("inf"))) |
| ) |
|
|
| snap = out_dir / "config.snapshot.yaml" |
| if not snap.exists(): |
| with snap.open("w") as f: |
| yaml.safe_dump(cfg, f, sort_keys=False) |
|
|
| |
| datamodule = TactileVAEDataModule(cfg) |
| datamodule.setup() |
| print(f"datasets: train={len(datamodule.train_ds):,} val={len(datamodule.val_ds):,}") |
|
|
| tcfg = cfg["train"] |
| steps_per_epoch = len(datamodule.train_dataloader()) |
| total_steps = tcfg["max_steps"] if tcfg.get("max_steps") else steps_per_epoch * tcfg["epochs"] |
| print(f"steps/epoch={steps_per_epoch:,} total_steps={total_steps:,}") |
| module = TactileVAEModule(cfg, step_offset=resume_step_offset, total_steps=total_steps) |
| n_params = sum(p.numel() for p in module.model.parameters()) |
| print(f"model: {module.model.__class__.__name__} params={n_params:,}") |
|
|
| |
| use_amp = bool(tcfg.get("amp", False)) |
| if use_amp: |
| amp_dtype = str(tcfg.get("amp_dtype", "bf16")).lower() |
| if amp_dtype not in {"bf16", "bfloat16"}: |
| print(f"[info] overriding train.amp_dtype={amp_dtype!r} to 'bf16' (enforced)") |
| precision = "bf16-mixed" |
| else: |
| precision = "32" |
|
|
| loggers = _init_loggers(cfg, out_dir) |
|
|
| callbacks = [ |
| SetEpochCallback(epoch_offset=resume_epoch_offset), |
| SampleGridCallback(cfg, step_offset=resume_step_offset), |
| CompatCheckpointCallback( |
| cfg, |
| step_offset=resume_step_offset, |
| epoch_offset=resume_epoch_offset, |
| initial_best_val_metric=resume_best_val_metric, |
| ), |
| CompatResumeStateCallback(resume_optimizer_state), |
| ModelCheckpoint( |
| dirpath=str(out_dir / "checkpoints"), |
| filename="last", |
| save_last=True, |
| save_top_k=0, |
| every_n_train_steps=tcfg["ckpt_every_steps"], |
| ), |
| ] |
|
|
| trainer = _build_trainer( |
| cfg, |
| callbacks=callbacks, |
| loggers=loggers, |
| precision=precision, |
| resume_from=resume_from, |
| resume_step_offset=resume_step_offset, |
| total_steps=total_steps, |
| ) |
|
|
| |
| ckpt_path: str | None = None |
| if resume_from: |
| rf = Path(resume_from) |
| if rf.suffix == ".ckpt": |
| ckpt_path = str(rf) |
| print(f"resuming (Lightning): {rf}") |
| else: |
| ckpt = torch.load(str(rf), map_location="cpu", weights_only=False) |
| module.model.load_state_dict(ckpt["state_dict"]) |
| print( |
| f"resuming (compat): {rf} " |
| f"step={resume_step_offset} epoch={resume_epoch_offset}" |
| ) |
|
|
| trainer.fit(module, datamodule=datamodule, ckpt_path=ckpt_path) |
| print(f"done. global_step={trainer.global_step}") |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| with args.config.open() as f: |
| raw_cfg = yaml.safe_load(f) |
| if args.run_id: |
| raw_cfg["run_id"] = args.run_id |
| if args.output_dir: |
| raw_cfg["output_dir"] = args.output_dir |
|
|
| tmp = args.config.parent / f".__cli_override_{os.getpid()}.yaml" |
| try: |
| with tmp.open("w") as f: |
| yaml.safe_dump(raw_cfg, f, sort_keys=False) |
| cfg = load_config(tmp) |
| finally: |
| tmp.unlink(missing_ok=True) |
|
|
| if args.resume_from: |
| cfg["train"]["resume_from"] = str(_resolve_path(args.resume_from)) |
| _maybe_autoresume(cfg, allow_autoresume=not args.no_resume) |
|
|
| main(cfg) |
|
|