File size: 1,572 Bytes
707a2d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from __future__ import annotations

import argparse

from numen_scriptorium.config import apply_overrides, load_yaml_config
from numen_scriptorium.training.qlora import train_from_config


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, default="configs/train_qwen_7b.yaml")
    parser.add_argument("--base_model", type=str, default=None)
    parser.add_argument("--output_dir", type=str, default=None)
    parser.add_argument("--preset", type=str, choices=["t4", "a100"], default=None)
    parser.add_argument("--max_seq_len", type=int, default=None)
    parser.add_argument("--report_to", type=str, choices=["wandb", "none"], default=None)
    parser.add_argument("--seed", type=int, default=None)
    parser.add_argument(
        "--deterministic",
        action="store_true",
        help="Enable deterministic torch behavior (may reduce performance).",
    )
    parser.add_argument("--resume", nargs="?", const="latest", default=None)
    args = parser.parse_args()

    cfg = load_yaml_config(args.config)
    cfg = apply_overrides(
        cfg,
        {
            "base_model": args.base_model,
            "output_dir": args.output_dir,
            "preset": args.preset,
            "report_to": args.report_to,
            "seed": args.seed,
            "deterministic": args.deterministic if args.deterministic else None,
        },
    )
    train_from_config(cfg, resume=args.resume, max_seq_len_override=args.max_seq_len)


if __name__ == "__main__":
    main()