multiplication-lora-demo / scripts /train_multiply.py
nlac
first commit
2253b0d
# LoRA fine-tuning script for teaching multiplication of 6-digit numbers by a constant number (7)
# Uses PEFT + TRL for efficient training on Qwen2.5-0.5B.
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
)
from peft import LoraConfig, prepare_model_for_kbit_training
from trl import SFTTrainer
from datasets import Dataset
import random
import config
def setup_model_and_tokenizer(use_4bit: bool = False):
"""Load model with optional 4-bit quantization."""
tokenizer = AutoTokenizer.from_pretrained(config.BASE_MODEL)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if use_4bit:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
# bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
config.BASE_MODEL,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
)
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
else:
model = AutoModelForCausalLM.from_pretrained(
config.BASE_MODEL,
# dtype=torch.bfloat16,
dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)
return model, tokenizer
def generate_training_data(num_samples: int, val_ratio: float = 0.1, seed: int = 42):
random.seed(seed)
# Generate all unique multiplication pairs to avoid duplicates
examples = []
seen = set()
while len(examples) < num_samples:
a = random.randint(100000, 999999)
b = 7
# Create canonical key to avoid duplicates (order doesn't matter for multiplication)
key = a
if key in seen:
continue
seen.add(key)
result = a * b
# Vary the prompt format for robustness
prompt_templates = [f"{a} * {b}", f"{a}* {b}", f"{a} *{b}"]
prompt = random.choice(prompt_templates) + random.choice(["", "?", " ?"])
examples.append(
{
"item": [
{
"role": "system",
"content": config.SYSTEM_PROMPT,
},
{"role": "user", "content": prompt},
{"role": "assistant", "content": str(result)},
]
}
)
# Shuffle and split into train/validation
ds = Dataset.from_list(examples)
ds.shuffle(seed)
splitted = ds.train_test_split(test_size=val_ratio)
return splitted
def main():
output_dir = config.OUTPUT_DIR / "lora-multiplicator"
print("Multiplication LoRA Fine-tuning")
print(f"\nBase model: {config.BASE_MODEL}")
# Check CUDA
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(
f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB"
)
# Load data
print("\nGenerating training data...")
dataset = generate_training_data(config.NUM_SAMPLES)
print(
f"train samples: {len(dataset['train'])}, validation samples: {len(dataset['test'])}"
)
# Load model
print(f"\nLoading model: {config.BASE_MODEL}")
model, tokenizer = setup_model_and_tokenizer(torch.cuda.is_available())
peft_config = LoraConfig(
r=config.LORA_R,
lora_alpha=config.LORA_ALPHA,
target_modules=config.TARGET_MODULES,
lora_dropout=config.LORA_DROPOUT,
bias="none",
task_type="CAUSAL_LM",
)
# effective_batch_size = per_device_train_batch_size × gradient_accumulation_steps × num_gpus
training_args = TrainingArguments(
output_dir=str(output_dir),
num_train_epochs=3, # Increased from 1 to 3 for better convergence on arithmetic tasks
per_device_train_batch_size=4, # Increased from 2 for more stable gradients
gradient_accumulation_steps=4, # Effective batch size of 16
gradient_checkpointing=True, # Trade compute for memory savings
learning_rate=1e-3, # Increased from 2e-4 - higher LR works better for LoRA fine-tuning
lr_scheduler_type="cosine", # Cosine annealing for better convergence
bf16=torch.cuda.is_available(),
warmup_ratio=0.05,
logging_steps=10,
save_strategy="steps", # Save checkpoints during training
save_steps=200, # Save every 200 steps
save_total_limit=2, # Keep only 2 best checkpoints to save disk space
report_to="none", # No external reporting
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
remove_unused_columns=False,
max_grad_norm=1.0, # Gradient clipping for training stability
# evaluation
eval_strategy="steps", # Changed from "epoch" to track loss during training
eval_steps=100, # Evaluate every 100 steps
do_eval=True,
per_device_eval_batch_size=8,
)
formatter = lambda example: (
tokenizer.apply_chat_template(
example["item"], #
tokenize=False, # return string, not tokens
add_generation_prompt=False, # false for training
)
)
# Create trainer
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
peft_config=peft_config,
formatting_func=formatter,
)
# Train
print("\nStarting training...")
trainer.train()
# Save final model
final_path = output_dir / "final"
print("\nSaving model...")
trainer.save_model(str(final_path))
tokenizer.save_pretrained(str(final_path))
print("\nTraining complete!")
print(f"Model saved to: {final_path}")
if __name__ == "__main__":
main()