|
|
|
|
|
""" |
|
|
QLoRA fine-tuning entry point for GraiLLM. |
|
|
|
|
|
Designed for use on Google Colab, Kaggle, or Hugging Face free GPUs. |
|
|
The script expects the dataset generated by `prepare_dataset.py` where each |
|
|
record contains a chat-style `messages` list. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import argparse |
|
|
from pathlib import Path |
|
|
from typing import Dict, List |
|
|
|
|
|
import torch |
|
|
from datasets import load_dataset |
|
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training |
|
|
from transformers import ( |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer, |
|
|
DataCollatorForLanguageModeling, |
|
|
TrainingArguments, |
|
|
Trainer, |
|
|
) |
|
|
|
|
|
|
|
|
DEFAULT_BASE_MODEL = "openai/gpt-oss-mini-7b" |
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
|
parser = argparse.ArgumentParser(description="Fine-tune GraiLLM with QLoRA.") |
|
|
parser.add_argument( |
|
|
"--train-file", |
|
|
type=Path, |
|
|
required=True, |
|
|
help="Path to the JSONL training file produced by prepare_dataset.py.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--eval-file", |
|
|
type=Path, |
|
|
required=True, |
|
|
help="Path to the JSONL evaluation file produced by prepare_dataset.py.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--base-model", |
|
|
type=str, |
|
|
default=DEFAULT_BASE_MODEL, |
|
|
help="Base Hugging Face model ID to fine-tune (QLoRA friendly).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output-dir", |
|
|
type=Path, |
|
|
default=Path("outputs/graillm-lora"), |
|
|
help="Directory where checkpoints and final adapters will be saved.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--batch-size", |
|
|
type=int, |
|
|
default=16, |
|
|
help="Micro batch size per device after gradient accumulation.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--grad-accum", |
|
|
type=int, |
|
|
default=4, |
|
|
help="Gradient accumulation steps.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--epochs", |
|
|
type=int, |
|
|
default=3, |
|
|
help="Number of training epochs.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--lr", |
|
|
type=float, |
|
|
default=2e-4, |
|
|
help="Learning rate.", |
|
|
) |
|
|
parser.add_argument("--max-steps", type=int, default=-1, help="Max training steps.") |
|
|
parser.add_argument("--bf16", action="store_true", help="Enable bfloat16 training.") |
|
|
parser.add_argument( |
|
|
"--push-to-hub", |
|
|
action="store_true", |
|
|
help="Push the adapter weights to the active Hugging Face repo after training.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--hub-model-id", |
|
|
type=str, |
|
|
default="dakotarainlock/GraiLLM-7B-Lora", |
|
|
help="Target repository when --push-to-hub is supplied.", |
|
|
) |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def format_messages(messages: List[Dict[str, str]]) -> str: |
|
|
"""Convert a message list into a single training string.""" |
|
|
turns = [] |
|
|
for message in messages: |
|
|
role = message["role"] |
|
|
content = message["content"].strip() |
|
|
if not content: |
|
|
continue |
|
|
if role == "system": |
|
|
turns.append(f"<<SYS>>\n{content}\n<</SYS>>") |
|
|
elif role == "user": |
|
|
turns.append(f"[USER]\n{content}") |
|
|
elif role == "assistant": |
|
|
turns.append(f"[ASSISTANT]\n{content}") |
|
|
return "\n\n".join(turns) + "\n" |
|
|
|
|
|
|
|
|
def tokenize_batch(example: Dict[str, List[Dict[str, str]]], tokenizer: AutoTokenizer): |
|
|
text = format_messages(example["messages"]) |
|
|
tokenized = tokenizer( |
|
|
text, |
|
|
truncation=True, |
|
|
max_length=min(tokenizer.model_max_length, 2048), |
|
|
padding=False, |
|
|
) |
|
|
tokenized["labels"] = tokenized["input_ids"].copy() |
|
|
return tokenized |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
args = parse_args() |
|
|
torch_dtype = torch.bfloat16 if args.bf16 else torch.float16 |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
args.base_model, |
|
|
use_fast=True, |
|
|
) |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
args.base_model, |
|
|
device_map="auto", |
|
|
torch_dtype=torch_dtype, |
|
|
load_in_4bit=True, |
|
|
) |
|
|
|
|
|
model = prepare_model_for_kbit_training(model) |
|
|
peft_config = LoraConfig( |
|
|
r=64, |
|
|
lora_alpha=16, |
|
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], |
|
|
lora_dropout=0.05, |
|
|
bias="none", |
|
|
task_type="CAUSAL_LM", |
|
|
) |
|
|
model = get_peft_model(model, peft_config) |
|
|
|
|
|
dataset = load_dataset( |
|
|
"json", |
|
|
data_files={ |
|
|
"train": str(args.train_file), |
|
|
"eval": str(args.eval_file), |
|
|
}, |
|
|
) |
|
|
|
|
|
tokenized_ds = dataset.map( |
|
|
lambda ex: tokenize_batch(ex, tokenizer), |
|
|
remove_columns=dataset["train"].column_names, |
|
|
) |
|
|
|
|
|
collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) |
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir=str(args.output_dir), |
|
|
num_train_epochs=args.epochs, |
|
|
per_device_train_batch_size=max(1, args.batch_size // args.grad_accum), |
|
|
per_device_eval_batch_size=max(1, args.batch_size // args.grad_accum), |
|
|
gradient_accumulation_steps=args.grad_accum, |
|
|
learning_rate=args.lr, |
|
|
fp16=not args.bf16, |
|
|
bf16=args.bf16, |
|
|
logging_steps=10, |
|
|
evaluation_strategy="steps", |
|
|
eval_steps=50, |
|
|
save_strategy="steps", |
|
|
save_steps=100, |
|
|
save_total_limit=3, |
|
|
warmup_ratio=0.03, |
|
|
lr_scheduler_type="cosine", |
|
|
report_to="tensorboard", |
|
|
max_steps=args.max_steps, |
|
|
push_to_hub=args.push_to_hub, |
|
|
hub_model_id=args.hub_model_id if args.push_to_hub else None, |
|
|
) |
|
|
|
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
args=training_args, |
|
|
train_dataset=tokenized_ds["train"], |
|
|
eval_dataset=tokenized_ds["eval"], |
|
|
data_collator=collator, |
|
|
) |
|
|
|
|
|
trainer.train() |
|
|
trainer.save_model() |
|
|
tokenizer.save_pretrained(args.output_dir) |
|
|
|
|
|
if args.push_to_hub: |
|
|
trainer.push_to_hub() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|