SmartReviewAI / finetune_lora.py
shibbir24's picture
Update finetune_lora.py
6b5074d verified
import os
import torch
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Trainer,
TrainingArguments,
DataCollatorForLanguageModeling,
)
from peft import LoraConfig, get_peft_model
import streamlit as st
def train_lora(base_model: str, epochs: int = 2, lr: float = 1e-4, train_csv: str = "dataset/amazon_product_reviews.csv"):
"""
Fine-tune a base model using LoRA on the provided dataset and visualize progress in Streamlit.
"""
st.write(f"### πŸ”§ Loading base model `{base_model}`...")
tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# Load dataset
st.info("πŸ“‚ Loading dataset for fine-tuning...")
ds = load_dataset("csv", data_files={"train": train_csv})["train"]
def preprocess(example):
prompt = (
f"Product: {example.get('Product','')}\n"
f"Category: {example.get('Category','')}\n"
f"Features: {example.get('Features','')}\n"
f"Rating: {example.get('Rating','')}\n"
f"Tone: {example.get('Tone','')}\n\n"
f"Review: {example.get('Review','')}"
)
return tokenizer(prompt, truncation=True, padding="max_length", max_length=256)
tokenized_ds = ds.map(preprocess, batched=False)
# LoRA config
lora_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["c_attn", "q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
# Apply LoRA to base model
model = AutoModelForCausalLM.from_pretrained(base_model)
model = get_peft_model(model, lora_config)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
output_dir = "lora_adapter"
#output_dir = os.path.join(os.getcwd(), "lora_adapter")
os.makedirs(output_dir, exist_ok=True)
# Streamlit progress UI
progress_bar = st.progress(0)
status_text = st.empty()
loss_chart = st.empty()
loss_list = []
from transformers import TrainerCallback
class StreamlitCallback(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs):
if logs and "loss" in logs:
loss = logs["loss"]
loss_list.append(loss)
progress = int((state.epoch / epochs) * 100)
progress_bar.progress(progress)
status_text.text(f"Epoch {state.epoch:.1f}/{epochs} | Step {state.global_step} | Loss: {loss:.4f}")
loss_chart.line_chart(loss_list)
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=2,
num_train_epochs=epochs,
learning_rate=lr,
logging_steps=5,
save_strategy="epoch",
report_to="none"
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_ds,
data_collator=data_collator,
tokenizer=tokenizer,
callbacks=[StreamlitCallback()]
)
trainer.train()
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
st.success("πŸŽ‰ LoRA adapter trained and saved successfully!")
return {"train_loss": loss_list, "epochs": epochs, "base_model": base_model}