new_fine_tune / app.py
Parth211's picture
Create app.py
f2229a7 verified
import os
import torch
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
HfArgumentParser,
TrainingArguments,
pipeline,
logging,
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer
################################################################################
# bitsandbytes parameters
################################################################################
# Activate 4-bit precision base model loading
use_4bit = True
# Compute dtype for 4-bit base models
bnb_4bit_compute_dtype = "float16"
# Quantization type (fp4 or nf4)
bnb_4bit_quant_type = "nf4"
# Activate nested quantization for 4-bit base models (double quantization)
use_nested_quant = False
device_map = {"": 0}
model_name = "NousResearch/Llama-2-7b-chat-hf"
# The instruction dataset to use
dataset_name = "Parth211/mental-health-dataset"
# Fine-tuned model name
new_model = "Llama-2-7b-chat-finetune"
# Load dataset (you can process it here)
dataset = load_dataset(dataset_name, split=f"train[:100]")
# Load tokenizer and model with QLoRA configuration
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
# Check GPU compatibility with bfloat16
if compute_dtype == torch.float16 and use_4bit:
major, _ = torch.cuda.get_device_capability()
if major >= 8:
print("=" * 80)
print("Your GPU supports bfloat16: accelerate training with bf16=True")
print("=" * 80)
# Load base model
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map=device_map,
load_in_4bit=True,
use_cache=False
)
model.config.use_cache = False
model.config.pretraining_tp = 1
# Load LLaMA tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training
# Load LoRA configuration
peft_config = LoraConfig(
r=16,
lora_alpha=64,
# target_modules=["query_key_value"],
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], #specific to Llama models.
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM"
)
# Set training parameters
training_arguments = TrainingArguments(
output_dir='./results',
num_train_epochs=1,
per_device_train_batch_size=4,
gradient_accumulation_steps=1,
gradient_checkpointing=True,
optim='paged_adamw_32bit',
save_steps=0,
logging_steps=25,
learning_rate=2e-4,
weight_decay=0.001,
fp16=True,
bf16=False,
max_grad_norm=0.3,
max_steps=-1,
warmup_ratio=0.03,
group_by_length=True,
lr_scheduler_type='cosine',
report_to="tensorboard"
)
# Set supervised fine-tuning parameters
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
peft_config=peft_config,
dataset_text_field="Text",
max_seq_length=None,
tokenizer=tokenizer,
args=training_arguments,
packing=False,
)
# Train model
trainer.train()