#!/usr/bin/env python3 """ QLoRA fine-tuning entry point for GraiLLM. Designed for use on Google Colab, Kaggle, or Hugging Face free GPUs. The script expects the dataset generated by `prepare_dataset.py` where each record contains a chat-style `messages` list. """ from __future__ import annotations import argparse from pathlib import Path from typing import Dict, List import torch from datasets import load_dataset from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from transformers import ( AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, TrainingArguments, Trainer, ) DEFAULT_BASE_MODEL = "openai/gpt-oss-mini-7b" def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Fine-tune GraiLLM with QLoRA.") parser.add_argument( "--train-file", type=Path, required=True, help="Path to the JSONL training file produced by prepare_dataset.py.", ) parser.add_argument( "--eval-file", type=Path, required=True, help="Path to the JSONL evaluation file produced by prepare_dataset.py.", ) parser.add_argument( "--base-model", type=str, default=DEFAULT_BASE_MODEL, help="Base Hugging Face model ID to fine-tune (QLoRA friendly).", ) parser.add_argument( "--output-dir", type=Path, default=Path("outputs/graillm-lora"), help="Directory where checkpoints and final adapters will be saved.", ) parser.add_argument( "--batch-size", type=int, default=16, help="Micro batch size per device after gradient accumulation.", ) parser.add_argument( "--grad-accum", type=int, default=4, help="Gradient accumulation steps.", ) parser.add_argument( "--epochs", type=int, default=3, help="Number of training epochs.", ) parser.add_argument( "--lr", type=float, default=2e-4, help="Learning rate.", ) parser.add_argument("--max-steps", type=int, default=-1, help="Max training steps.") parser.add_argument("--bf16", action="store_true", help="Enable bfloat16 training.") parser.add_argument( "--push-to-hub", action="store_true", help="Push the adapter weights to the active Hugging Face repo after training.", ) parser.add_argument( "--hub-model-id", type=str, default="dakotarainlock/GraiLLM-7B-Lora", help="Target repository when --push-to-hub is supplied.", ) return parser.parse_args() def format_messages(messages: List[Dict[str, str]]) -> str: """Convert a message list into a single training string.""" turns = [] for message in messages: role = message["role"] content = message["content"].strip() if not content: continue if role == "system": turns.append(f"<>\n{content}\n<>") elif role == "user": turns.append(f"[USER]\n{content}") elif role == "assistant": turns.append(f"[ASSISTANT]\n{content}") return "\n\n".join(turns) + "\n" def tokenize_batch(example: Dict[str, List[Dict[str, str]]], tokenizer: AutoTokenizer): text = format_messages(example["messages"]) tokenized = tokenizer( text, truncation=True, max_length=min(tokenizer.model_max_length, 2048), padding=False, ) tokenized["labels"] = tokenized["input_ids"].copy() return tokenized def main() -> None: args = parse_args() torch_dtype = torch.bfloat16 if args.bf16 else torch.float16 tokenizer = AutoTokenizer.from_pretrained( args.base_model, use_fast=True, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( args.base_model, device_map="auto", torch_dtype=torch_dtype, load_in_4bit=True, ) model = prepare_model_for_kbit_training(model) peft_config = LoraConfig( r=64, lora_alpha=16, target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", ) model = get_peft_model(model, peft_config) dataset = load_dataset( "json", data_files={ "train": str(args.train_file), "eval": str(args.eval_file), }, ) tokenized_ds = dataset.map( lambda ex: tokenize_batch(ex, tokenizer), remove_columns=dataset["train"].column_names, ) collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) training_args = TrainingArguments( output_dir=str(args.output_dir), num_train_epochs=args.epochs, per_device_train_batch_size=max(1, args.batch_size // args.grad_accum), per_device_eval_batch_size=max(1, args.batch_size // args.grad_accum), gradient_accumulation_steps=args.grad_accum, learning_rate=args.lr, fp16=not args.bf16, bf16=args.bf16, logging_steps=10, evaluation_strategy="steps", eval_steps=50, save_strategy="steps", save_steps=100, save_total_limit=3, warmup_ratio=0.03, lr_scheduler_type="cosine", report_to="tensorboard", max_steps=args.max_steps, push_to_hub=args.push_to_hub, hub_model_id=args.hub_model_id if args.push_to_hub else None, ) trainer = Trainer( model=model, tokenizer=tokenizer, args=training_args, train_dataset=tokenized_ds["train"], eval_dataset=tokenized_ds["eval"], data_collator=collator, ) trainer.train() trainer.save_model() tokenizer.save_pretrained(args.output_dir) if args.push_to_hub: trainer.push_to_hub() if __name__ == "__main__": main()