from datasets import load_dataset from train import train from reward import formatting_reward_func, correctness_reward_func from future_work.inference import inference from future_work.model import setup_model from future_work.dataset import dataset_setup from future_work.adapters import save_model, save_gguf from consts import BASE_MODEL, TRAIN_DATASET import argparse if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--idx", type=int, help='what index to test inference on') parser.add_argument("--save_model", type=bool, default = True) parser.add_argument("--save_gguf", type = bool, default = True) parser.add_argument("--local", type=bool, default = True) parser.add_argument("--model_name", type=str, default = "lora_model") args = parser.parse_args() model, tokenizer = setup_model(BASE_MODEL) dataset = load_dataset(TRAIN_DATASET, split="testmini") train_ds, ds = dataset_setup(dataset, tokenizer) reward_fns = [formatting_reward_func, correctness_reward_func] trainer = train(tokenizer, model, reward_fns, train_ds) eval = inference(args.idx, model, dataset, tokenizer) if args.save_model and args.local: save_model(model, tokenizer, args.local) if args.save_gguf and args.local: if not args.model_name: save_gguf("math_finetune", args.local, tokenizer) else: save_gguf(args.model_name, args.local, tokenizer)