fine-tuning-bot / train.py
jefalod's picture
Create train.py
3eb42eb verified
raw
history blame
2.28 kB
# train.py
import torch
from datasets import load_dataset
from transformers import (
AutoTokenizer, AutoModelForCausalLM, TrainingArguments,
Trainer, BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# Load dataset from URL
dataset = load_dataset(
"json",
data_files="https://huggingface.co/datasets/bitext/Bitext-customer-support-llm-chatbot-training-dataset/resolve/main/bitext_customer_support.jsonl",
split="train[:100]" # limit for fast training in Spaces
)
def format_example(example):
return {
"text": f"### Instruction:\n{example['instruction']}\n\n### Response:\n{example['output']}"
}
dataset = dataset.map(format_example)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token
def tokenize(example):
tokens = tokenizer(
example["text"],
padding="max_length",
truncation=True,
max_length=512
)
tokens["labels"] = tokens["input_ids"].copy()
return tokens
tokenized_dataset = dataset.map(tokenize, batched=True)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto"
)
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
r=8,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "v_proj"]
)
model = get_peft_model(model, lora_config)
training_args = TrainingArguments(
output_dir="trained_model",
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
num_train_epochs=1,
learning_rate=2e-4,
logging_dir="./logs",
save_strategy="no",
bf16=True,
optim="paged_adamw_8bit",
report_to="none"
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
tokenizer=tokenizer
)
trainer.train()
model.save_pretrained("trained_model")
tokenizer.save_pretrained("trained_model")