from __future__ import annotations import argparse from numen_scriptorium.config import apply_overrides, load_yaml_config from numen_scriptorium.training.qlora import train_from_config def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default="configs/train_qwen_7b.yaml") parser.add_argument("--base_model", type=str, default=None) parser.add_argument("--output_dir", type=str, default=None) parser.add_argument("--preset", type=str, choices=["t4", "a100"], default=None) parser.add_argument("--max_seq_len", type=int, default=None) parser.add_argument("--report_to", type=str, choices=["wandb", "none"], default=None) parser.add_argument("--seed", type=int, default=None) parser.add_argument( "--deterministic", action="store_true", help="Enable deterministic torch behavior (may reduce performance).", ) parser.add_argument("--resume", nargs="?", const="latest", default=None) args = parser.parse_args() cfg = load_yaml_config(args.config) cfg = apply_overrides( cfg, { "base_model": args.base_model, "output_dir": args.output_dir, "preset": args.preset, "report_to": args.report_to, "seed": args.seed, "deterministic": args.deterministic if args.deterministic else None, }, ) train_from_config(cfg, resume=args.resume, max_seq_len_override=args.max_seq_len) if __name__ == "__main__": main()