Yifei Wang
Deploy HF Space demo (clean)
707a2d1
from __future__ import annotations
import argparse
from numen_scriptorium.config import apply_overrides, load_yaml_config
from numen_scriptorium.training.qlora import train_from_config
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="configs/train_qwen_7b.yaml")
parser.add_argument("--base_model", type=str, default=None)
parser.add_argument("--output_dir", type=str, default=None)
parser.add_argument("--preset", type=str, choices=["t4", "a100"], default=None)
parser.add_argument("--max_seq_len", type=int, default=None)
parser.add_argument("--report_to", type=str, choices=["wandb", "none"], default=None)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument(
"--deterministic",
action="store_true",
help="Enable deterministic torch behavior (may reduce performance).",
)
parser.add_argument("--resume", nargs="?", const="latest", default=None)
args = parser.parse_args()
cfg = load_yaml_config(args.config)
cfg = apply_overrides(
cfg,
{
"base_model": args.base_model,
"output_dir": args.output_dir,
"preset": args.preset,
"report_to": args.report_to,
"seed": args.seed,
"deterministic": args.deterministic if args.deterministic else None,
},
)
train_from_config(cfg, resume=args.resume, max_seq_len_override=args.max_seq_len)
if __name__ == "__main__":
main()