| |
| """Train MapGS. |
| |
| Examples |
| -------- |
| python scripts/train.py --config configs/synthetic_smoke.yaml |
| python scripts/train.py --config configs/waymo_stage1.yaml --override train.batch_size=8 |
| """ |
|
|
| import argparse |
|
|
| from mapgs.config import load_config |
| from mapgs.data import build_dataset |
| from mapgs.train import Trainer |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--config", type=str, default=None) |
| ap.add_argument("--override", nargs="*", default=[], help="dotted overrides a.b=c") |
| ap.add_argument("--max-iters", type=int, default=None) |
| ap.add_argument("--resume", type=str, default=None) |
| args = ap.parse_args() |
|
|
| cfg = load_config(args.config, args.override) |
| dataset = build_dataset(cfg, "train") |
| trainer = Trainer(cfg) |
| if args.resume or cfg.train.resume: |
| trainer.load(args.resume or cfg.train.resume) |
| trainer.fit(dataset, max_iters=args.max_iters) |
| trainer.save(f"{cfg.train.out_dir}/ckpt_final.pt") |
| print(f"saved final checkpoint to {cfg.train.out_dir}/ckpt_final.pt") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|