book-rec-with-LLMs / src /zero_shot /lora_trainer.py
ymlin105's picture
feat: add real-time book cover fetching and client-server architecture
ad8974a
"""
P2 & P3: LoRA Fine-tuning for Zero-shot Recommendation.
Optimized for RTX 3090/4090 (24GB VRAM).
"""
import os
import json
import torch
from datasets import Dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer
from modelscope import snapshot_download
# ========== Configuration ==========
MODEL_NAME = snapshot_download("qwen/Qwen2-1.5B-Instruct") # Load from ModelScope
OUTPUT_DIR = "./lora_output"
DATA_FILE = "training_data.json"
def load_model_and_tokenizer(model_name: str):
"""Load model with 4-bit quantization for memory efficiency."""
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True
)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)
model = prepare_model_for_kbit_training(model)
return model, tokenizer
def apply_lora(model):
"""Apply LoRA adapters to the model."""
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], # Common for Qwen/Llama
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
return model
def load_dataset(data_file: str):
"""Load and format dataset for SFT."""
with open(data_file, 'r') as f:
data = json.load(f)
# Format as chat/instruction format
formatted = []
for item in data:
text = f"### Instruction:\n{item['instruction']}\n\n### Input:\n{item['input']}\n\n### Response:\n{item['output']}"
formatted.append({"text": text})
return Dataset.from_list(formatted)
def train(model, tokenizer, dataset):
"""Run SFT training with LoRA."""
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
num_train_epochs=1, # Quick iteration; increase for production
per_device_train_batch_size=16,
gradient_accumulation_steps=2,
learning_rate=2e-4,
warmup_steps=10,
logging_steps=10,
save_steps=100,
bf16=True,
optim="paged_adamw_8bit",
report_to="none"
)
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
args=training_args,
processing_class=tokenizer
)
trainer.train()
trainer.save_model(OUTPUT_DIR)
print(f"Model saved to {OUTPUT_DIR}")
def main():
print("=== Zero-shot Recommender LoRA Training ===")
# Step 1: Generate data if not exists
if not os.path.exists(DATA_FILE):
print("Generating training data...")
from semantic_converter import generate_synthetic_interactions, create_training_data
items_df, interactions_df = generate_synthetic_interactions(num_interactions=1000)
training_data = create_training_data(items_df, interactions_df)
with open(DATA_FILE, 'w') as f:
json.dump(training_data, f)
print(f"Generated {len(training_data)} samples.")
# Step 2: Load model
print(f"Loading model: {MODEL_NAME}")
model, tokenizer = load_model_and_tokenizer(MODEL_NAME)
# Step 3: Apply LoRA
print("Applying LoRA adapters...")
model = apply_lora(model)
# Step 4: Load dataset
print("Loading dataset...")
dataset = load_dataset(DATA_FILE)
# Step 5: Train
print("Starting training...")
train(model, tokenizer, dataset)
print("=== Training Complete ===")
if __name__ == "__main__":
main()