Bhuvandesai's picture
initial deployment
55159b1
Raw
History Blame Contribute Delete
5.98 kB
import os
import sys
import sqlite3
import pandas as pd
import torch
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
TrainerCallback
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer, SFTConfig
class VRAMLoggerCallback(TrainerCallback):
"""Callback to print VRAM usage during training to monitor resource consumption."""
def on_step_end(self, args, state, control, **kwargs):
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated(0) / (1024 ** 3)
reserved = torch.cuda.memory_reserved(0) / (1024 ** 3)
print(f" - [Step {state.global_step}] VRAM Allocated: {allocated:.2f} GB | Reserved: {reserved:.2f} GB")
def train():
print("=" * 60)
print("Starting Phi-3 Text-to-SQL QLoRA Fine-Tuning")
print("=" * 60)
# 1. Device Configuration and Hardware Checks
if not torch.cuda.is_available():
print("WARNING: CUDA is NOT available. Running on CPU is extremely slow and NOT recommended.")
device = "cpu"
use_bf16 = False
use_fp16 = False
else:
device = "cuda"
gpu_name = torch.cuda.get_device_name(0)
print(f"CUDA Device Detected: {gpu_name}")
# RTX 40-series (Ada Lovelace) natively supports bfloat16
use_bf16 = torch.cuda.is_bf16_supported()
use_fp16 = not use_bf16
print(f"bfloat16 Supported: {use_bf16} | float16 Fallback: {use_fp16}")
# 2. Paths
dataset_dir = "data"
train_file = os.path.join(dataset_dir, "train_dataset.jsonl")
val_file = os.path.join(dataset_dir, "test_dataset.jsonl")
output_dir = "models/phi3-text-to-sql"
adapter_dir = "models/phi3-text-to-sql-adapter"
if not os.path.exists(train_file) or not os.path.exists(val_file):
print("ERROR: Dataset files not found. Run dataset.py first.")
sys.exit(1)
# 3. Quantization Config (Crucial for 6GB VRAM)
print("Configuring 4-bit Quantization (QLoRA)...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16 if use_bf16 else torch.float16
)
# 4. Load Model and Tokenizer
model_id = "microsoft/Phi-3-mini-4k-instruct"
print(f"Loading base model: {model_id}...")
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=False,
attn_implementation="eager"
)
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=False)
# Configure padding tokens for Phi-3 (which uses EOS for padding natively in standard chat SFT)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# 5. Prepare Model for Peft Training
print("Preparing model for k-bit training and enabling gradient checkpointing...")
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
# 6. Configure LoRA
# We target all linear layers to maximize standard model adaptation
print("Configuring LoRA parameters...")
peft_config = LoraConfig(
r=8, # Rank of adapter (8 is memory-efficient and very effective for code syntax)
lora_alpha=16,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"
],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
# 7. Load Datasets
print("Loading formatted training datasets...")
dataset = load_dataset("json", data_files={
"train": train_file,
"validation": val_file
})
# 8. Training Arguments
# We optimize heavily for memory: batch_size=1, gradient_accumulation=4 (eff. batch size 4), paged_adamw optimizer
print("Setting up memory-optimized SFT training arguments...")
training_args = SFTConfig(
output_dir=output_dir,
num_train_epochs=3,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
learning_rate=2e-4,
weight_decay=0.01,
lr_scheduler_type="cosine",
warmup_ratio=0.05,
logging_steps=1,
eval_strategy="steps",
eval_steps=5,
save_strategy="steps",
save_steps=10,
save_total_limit=2,
optim="paged_adamw_8bit" if device == "cuda" else "adamw_torch",
bf16=use_bf16,
fp16=use_fp16,
gradient_checkpointing=True,
max_grad_norm=0.3,
report_to="none",
ddp_find_unused_parameters=False,
remove_unused_columns=False, # Crucial for TRL chat template processing
max_length=512 # SFTConfig uses max_length instead of max_seq_length inside SFTTrainer constructor
)
# 9. Initialize SFTTrainer
# SFTTrainer natively formats 'messages' using the model's chat template
print("Initializing SFTTrainer...")
trainer = SFTTrainer(
model=model,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
peft_config=None,
processing_class=tokenizer,
args=training_args,
callbacks=[VRAMLoggerCallback()] if device == "cuda" else []
)
# 10. Execute Fine-Tuning
print("Launching Fine-Tuning loop...")
trainer.train()
# 11. Save the Fine-Tuned Adapter
print(f"Fine-tuning complete! Saving adapter to {adapter_dir}...")
trainer.model.save_pretrained(adapter_dir)
tokenizer.save_pretrained(adapter_dir)
print("Adapter saved successfully. Ready for evaluation and deployment!")
if __name__ == "__main__":
train()