import os
import sys
import math
import torch
import numpy as np
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, DataCollatorForLanguageModeling, TrainingArguments
from trl import SFTConfig, SFTTrainer
"""# **Initial Configs**"""
# -----------------------
# Safety checks
# -----------------------
if not torch.cuda.is_available():
print("ERROR: CUDA not available. This script requires a GPU.")
sys.exit(1)
device = "cuda"
print("CUDA device:", torch.cuda.get_device_name(0))
print(f"Total GPU memory (GB): {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f}")
# -----------------------
# Config
# -----------------------
MODEL_NAME = "microsoft/Phi-3.5-mini-instruct"
DATASET_NAME = "dassarthak18/FStarDataset-V2-Conversation"
OUTPUT_DIR = "./phi3.5-mini-lora"
MAX_SEQ_LENGTH = 2048 # safe starting point
PER_DEVICE_BATCH_SIZE = 4 # per GPU
GRAD_ACCUM_STEPS = 4 # effective batch size = 16
NUM_EPOCHS = 6
LEARNING_RATE = 2e-4
# LoRA params
LORA_R = 32
LORA_ALPHA = 64
LORA_DROPOUT = 0.05
# PEFT target modules for Phi-3.5
TARGET_MODULES = [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"
]
# ---------- Robust FlashAttention detection & fallback ----------
def detect_flash_attn():
try:
# try the compiled extension import which fails on GLIBC mismatch
import flash_attn_2_cuda # noqa: F401
print("flash-attn compiled extension import: OK")
return "flash_attention_2"
except Exception as e:
print("flash-attn import failed (will fallback). Reason:", repr(e))
return "eager"
attn_impl = detect_flash_attn()
print("Using attn_implementation =", attn_impl)
# ----------------------------------------------------------------
# -----------------------
# Tokenizer
# -----------------------
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
"""# **Load Model and Dataset**"""
"""
# -----------------------
# Load model in 4-bit (bitsandbytes)
# -----------------------
print("Loading model in 4-bit (QLoRA) — this uses bitsandbytes.")
# Important: load_in_4bit requires bitsandbytes installed and a compatible transformers version.
# bnb_4bit_compute_dtype uses fp16 (bf16 not supported)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.float16,
)
"""
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto",
trust_remote_code=True,
dtype=torch.float16,
attn_implementation=attn_impl,
#quantization_config=bnb_config,
)
# Required tweaks for long training
model.config.use_cache = False
#model.gradient_checkpointing_enable()
"""
# Prepare model for k-bit training (adjusts layer norms, enables gradients for some params)
print("Preparing model for k-bit (4-bit) training...")
model = prepare_model_for_kbit_training(model)
"""
# -----------------------
# Apply LoRA (PEFT) on top of 4-bit model
# -----------------------
print("Applying LoRA adapters (PEFT)...")
lora_config = LoraConfig(
r=LORA_R,
lora_alpha=LORA_ALPHA,
lora_dropout=LORA_DROPOUT,
bias="none",
target_modules=TARGET_MODULES,
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
# Ensure LoRA params are trainable
for n, p in model.named_parameters():
# By design, most parameters remain frozen; LoRA adapters and some qkv adapters are trainable
if "lora" in n or "adapter" in n:
p.requires_grad = True
# Print summary of trainable params
print("\nTrainable parameters (should be LoRA params, small fraction):")
model.print_trainable_parameters()
# -----------------------
# Load and preprocess conversation dataset
# -----------------------
print(f"Loading dataset: {DATASET_NAME} ...")
raw_ds = load_dataset(DATASET_NAME)
print("Available splits:", raw_ds.keys())
def conversation_to_text(example):
conv = example.get("messages")
parts = []
for turn in conv:
role = str(turn.get("role", "")).strip().lower()
content = str(turn.get("content", "")).strip()
if not content:
continue
if role == "user":
parts.append(f"\n{content}\n")
elif role == "assistant":
parts.append(f"\n{content}\n")
else:
parts.append(content)
text = "\n".join(parts).strip()
return {"text": text}
print("Converting conversations -> single text field ...")
processed = raw_ds.map(conversation_to_text, remove_columns=raw_ds["train"].column_names, num_proc=32)
# Filter empty examples
processed = processed.filter(lambda e: e["text"].strip() != "")
train_ds = processed["train"]
eval_ds = processed["validation"]
test_ds = processed["test"]
print(f"Train examples: {len(train_ds)}")
print(f"Validation examples: {len(eval_ds)}")
print(f"Test examples: {len(test_ds)}")
'''
MAX_EVAL = 1000
if len(eval_ds) > MAX_EVAL:
eval_ds = eval_ds.select(range(MAX_EVAL))
print(f"Validation truncated to {MAX_EVAL} for memory.")
if len(test_ds) > MAX_EVAL:
test_ds = test_ds.select(range(MAX_EVAL))
print(f"Test truncated to {MAX_EVAL} for memory.")
'''
# -----------------------
# Tokenization (truncation to MAX_SEQ_LENGTH)
# -----------------------
def tokenize_fn(batch):
texts = batch["text"]
# ensure strings
if isinstance(texts, list):
texts = [str(t) for t in texts]
else:
texts = [str(texts)]
return tokenizer(
texts,
truncation=True,
max_length=MAX_SEQ_LENGTH,
padding=False, # dynamic padding handled by collator
)
print("Tokenizing datasets...")
tokenized_train = train_ds.map(tokenize_fn, batched=True, remove_columns=train_ds.column_names, num_proc=32)
tokenized_eval = eval_ds.map(tokenize_fn, batched=True, remove_columns=eval_ds.column_names, num_proc=32)
tokenized_test = test_ds.map(tokenize_fn, batched=True, remove_columns=test_ds.column_names, num_proc=32)
print(f"Tokenized train size: {len(tokenized_train)}")
print(f"Tokenized eval size: {len(tokenized_eval)}")
print(f"Tokenized test size: {len(tokenized_test)}")
# -----------------------
# Data collator
# -----------------------
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False,
pad_to_multiple_of=8
)
"""# **Training**"""
# -----------------------
# Training args (single GPU optimized)
# -----------------------
training_args = SFTConfig(
output_dir=OUTPUT_DIR,
max_length=MAX_SEQ_LENGTH,
num_train_epochs=NUM_EPOCHS,
per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
per_device_eval_batch_size=PER_DEVICE_BATCH_SIZE,
gradient_accumulation_steps=GRAD_ACCUM_STEPS,
learning_rate=LEARNING_RATE,
lr_scheduler_type="cosine",
warmup_ratio=0.05,
logging_steps=25,
save_steps=250,
eval_steps=250,
save_total_limit=3,
eval_strategy="steps",
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
fp16=True,
bf16=False,
packing=False,
dataloader_num_workers=2,
dataloader_pin_memory=True,
report_to="tensorboard",
#gradient_checkpointing=True,
seed=42,
)
# -----------------------
# SFTTrainer setup
# -----------------------
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
args=training_args,
train_dataset=tokenized_train,
eval_dataset=tokenized_eval,
data_collator=data_collator,
)
# -----------------------
# Train
# -----------------------
print("\n" + "="*40)
print("Starting training (FP16 LoRA, no quantization)...")
print("="*40 + "\n")
trainer.train()
"""# **Evaluation**"""
# -----------------------
# Evaluate on test split (final)
# -----------------------
print("\nEvaluating on test split...")
test_metrics = trainer.evaluate(eval_dataset=tokenized_test)
test_loss = test_metrics.get("eval_loss", None)
if test_loss is not None:
try:
test_ppl = float(np.exp(test_loss))
except OverflowError:
test_ppl = float("inf")
print(f"\nTest Loss: {test_loss:.4f}")
print(f"Test Perplexity: {test_ppl:.2f}")
else:
print("Test eval_loss not present in metrics:", test_metrics)
# -----------------------
# Save adapters and tokenizer
# -----------------------
print("\nSaving LoRA adapters and tokenizer...")
os.makedirs(OUTPUT_DIR, exist_ok=True)
# Save only PEFT adapters (keeps model quantized + small)
model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print("\nDone. Model (quantized + LoRA adapters) saved to:", OUTPUT_DIR)
print("You can load the adapter with from_pretrained and the same bnb/4bit settings.")