File size: 1,114 Bytes
8cf92b3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | #!/usr/bin/env python3
"""Train MapGS.
Examples
--------
python scripts/train.py --config configs/synthetic_smoke.yaml
python scripts/train.py --config configs/waymo_stage1.yaml --override train.batch_size=8
"""
import argparse
from mapgs.config import load_config
from mapgs.data import build_dataset
from mapgs.train import Trainer
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--config", type=str, default=None)
ap.add_argument("--override", nargs="*", default=[], help="dotted overrides a.b=c")
ap.add_argument("--max-iters", type=int, default=None)
ap.add_argument("--resume", type=str, default=None)
args = ap.parse_args()
cfg = load_config(args.config, args.override)
dataset = build_dataset(cfg, "train")
trainer = Trainer(cfg)
if args.resume or cfg.train.resume:
trainer.load(args.resume or cfg.train.resume)
trainer.fit(dataset, max_iters=args.max_iters)
trainer.save(f"{cfg.train.out_dir}/ckpt_final.pt")
print(f"saved final checkpoint to {cfg.train.out_dir}/ckpt_final.pt")
if __name__ == "__main__":
main()
|