|
|
import os |
|
|
import sys |
|
|
import math |
|
|
import torch |
|
|
import numpy as np |
|
|
from datasets import load_dataset |
|
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, DataCollatorForLanguageModeling, TrainingArguments |
|
|
from trl import SFTConfig, SFTTrainer |
|
|
|
|
|
"""# **Initial Configs**""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not torch.cuda.is_available(): |
|
|
print("ERROR: CUDA not available. This script requires a GPU.") |
|
|
sys.exit(1) |
|
|
|
|
|
device = "cuda" |
|
|
print("CUDA device:", torch.cuda.get_device_name(0)) |
|
|
print(f"Total GPU memory (GB): {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_NAME = "microsoft/Phi-3.5-mini-instruct" |
|
|
DATASET_NAME = "dassarthak18/FStarDataset-V2-Conversation" |
|
|
OUTPUT_DIR = "./phi3.5-mini-lora" |
|
|
MAX_SEQ_LENGTH = 2048 |
|
|
PER_DEVICE_BATCH_SIZE = 4 |
|
|
GRAD_ACCUM_STEPS = 4 |
|
|
NUM_EPOCHS = 6 |
|
|
LEARNING_RATE = 2e-4 |
|
|
|
|
|
|
|
|
LORA_R = 32 |
|
|
LORA_ALPHA = 64 |
|
|
LORA_DROPOUT = 0.05 |
|
|
|
|
|
|
|
|
TARGET_MODULES = [ |
|
|
"q_proj", "k_proj", "v_proj", "o_proj", |
|
|
"gate_proj", "up_proj", "down_proj" |
|
|
] |
|
|
|
|
|
|
|
|
def detect_flash_attn(): |
|
|
try: |
|
|
|
|
|
import flash_attn_2_cuda |
|
|
print("flash-attn compiled extension import: OK") |
|
|
return "flash_attention_2" |
|
|
except Exception as e: |
|
|
print("flash-attn import failed (will fallback). Reason:", repr(e)) |
|
|
return "eager" |
|
|
|
|
|
attn_impl = detect_flash_attn() |
|
|
print("Using attn_implementation =", attn_impl) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Loading tokenizer...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
"""# **Load Model and Dataset**""" |
|
|
""" |
|
|
# ----------------------- |
|
|
# Load model in 4-bit (bitsandbytes) |
|
|
# ----------------------- |
|
|
print("Loading model in 4-bit (QLoRA) — this uses bitsandbytes.") |
|
|
# Important: load_in_4bit requires bitsandbytes installed and a compatible transformers version. |
|
|
# bnb_4bit_compute_dtype uses fp16 (bf16 not supported) |
|
|
bnb_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
bnb_4bit_use_double_quant=True, |
|
|
bnb_4bit_compute_dtype=torch.float16, |
|
|
) |
|
|
""" |
|
|
print("Loading model...") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_NAME, |
|
|
device_map="auto", |
|
|
trust_remote_code=True, |
|
|
dtype=torch.float16, |
|
|
attn_implementation=attn_impl, |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
model.config.use_cache = False |
|
|
|
|
|
|
|
|
""" |
|
|
# Prepare model for k-bit training (adjusts layer norms, enables gradients for some params) |
|
|
print("Preparing model for k-bit (4-bit) training...") |
|
|
model = prepare_model_for_kbit_training(model) |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Applying LoRA adapters (PEFT)...") |
|
|
lora_config = LoraConfig( |
|
|
r=LORA_R, |
|
|
lora_alpha=LORA_ALPHA, |
|
|
lora_dropout=LORA_DROPOUT, |
|
|
bias="none", |
|
|
target_modules=TARGET_MODULES, |
|
|
task_type="CAUSAL_LM" |
|
|
) |
|
|
|
|
|
model = get_peft_model(model, lora_config) |
|
|
|
|
|
|
|
|
for n, p in model.named_parameters(): |
|
|
|
|
|
if "lora" in n or "adapter" in n: |
|
|
p.requires_grad = True |
|
|
|
|
|
|
|
|
print("\nTrainable parameters (should be LoRA params, small fraction):") |
|
|
model.print_trainable_parameters() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"Loading dataset: {DATASET_NAME} ...") |
|
|
raw_ds = load_dataset(DATASET_NAME) |
|
|
print("Available splits:", raw_ds.keys()) |
|
|
|
|
|
def conversation_to_text(example): |
|
|
conv = example.get("messages") |
|
|
parts = [] |
|
|
for turn in conv: |
|
|
role = str(turn.get("role", "")).strip().lower() |
|
|
content = str(turn.get("content", "")).strip() |
|
|
if not content: |
|
|
continue |
|
|
if role == "user": |
|
|
parts.append(f"<user>\n{content}\n</user>") |
|
|
elif role == "assistant": |
|
|
parts.append(f"<assistant>\n{content}\n</assistant>") |
|
|
else: |
|
|
parts.append(content) |
|
|
text = "\n".join(parts).strip() |
|
|
return {"text": text} |
|
|
|
|
|
print("Converting conversations -> single text field ...") |
|
|
processed = raw_ds.map(conversation_to_text, remove_columns=raw_ds["train"].column_names, num_proc=32) |
|
|
|
|
|
|
|
|
processed = processed.filter(lambda e: e["text"].strip() != "") |
|
|
|
|
|
train_ds = processed["train"] |
|
|
eval_ds = processed["validation"] |
|
|
test_ds = processed["test"] |
|
|
|
|
|
print(f"Train examples: {len(train_ds)}") |
|
|
print(f"Validation examples: {len(eval_ds)}") |
|
|
print(f"Test examples: {len(test_ds)}") |
|
|
|
|
|
''' |
|
|
MAX_EVAL = 1000 |
|
|
if len(eval_ds) > MAX_EVAL: |
|
|
eval_ds = eval_ds.select(range(MAX_EVAL)) |
|
|
print(f"Validation truncated to {MAX_EVAL} for memory.") |
|
|
if len(test_ds) > MAX_EVAL: |
|
|
test_ds = test_ds.select(range(MAX_EVAL)) |
|
|
print(f"Test truncated to {MAX_EVAL} for memory.") |
|
|
''' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tokenize_fn(batch): |
|
|
texts = batch["text"] |
|
|
|
|
|
if isinstance(texts, list): |
|
|
texts = [str(t) for t in texts] |
|
|
else: |
|
|
texts = [str(texts)] |
|
|
return tokenizer( |
|
|
texts, |
|
|
truncation=True, |
|
|
max_length=MAX_SEQ_LENGTH, |
|
|
padding=False, |
|
|
) |
|
|
|
|
|
print("Tokenizing datasets...") |
|
|
tokenized_train = train_ds.map(tokenize_fn, batched=True, remove_columns=train_ds.column_names, num_proc=32) |
|
|
tokenized_eval = eval_ds.map(tokenize_fn, batched=True, remove_columns=eval_ds.column_names, num_proc=32) |
|
|
tokenized_test = test_ds.map(tokenize_fn, batched=True, remove_columns=test_ds.column_names, num_proc=32) |
|
|
|
|
|
print(f"Tokenized train size: {len(tokenized_train)}") |
|
|
print(f"Tokenized eval size: {len(tokenized_eval)}") |
|
|
print(f"Tokenized test size: {len(tokenized_test)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_collator = DataCollatorForLanguageModeling( |
|
|
tokenizer=tokenizer, |
|
|
mlm=False, |
|
|
pad_to_multiple_of=8 |
|
|
) |
|
|
|
|
|
"""# **Training**""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training_args = SFTConfig( |
|
|
output_dir=OUTPUT_DIR, |
|
|
max_length=MAX_SEQ_LENGTH, |
|
|
num_train_epochs=NUM_EPOCHS, |
|
|
per_device_train_batch_size=PER_DEVICE_BATCH_SIZE, |
|
|
per_device_eval_batch_size=PER_DEVICE_BATCH_SIZE, |
|
|
gradient_accumulation_steps=GRAD_ACCUM_STEPS, |
|
|
learning_rate=LEARNING_RATE, |
|
|
lr_scheduler_type="cosine", |
|
|
warmup_ratio=0.05, |
|
|
logging_steps=25, |
|
|
save_steps=250, |
|
|
eval_steps=250, |
|
|
save_total_limit=3, |
|
|
eval_strategy="steps", |
|
|
load_best_model_at_end=True, |
|
|
metric_for_best_model="eval_loss", |
|
|
greater_is_better=False, |
|
|
fp16=True, |
|
|
bf16=False, |
|
|
packing=False, |
|
|
dataloader_num_workers=2, |
|
|
dataloader_pin_memory=True, |
|
|
report_to="tensorboard", |
|
|
|
|
|
seed=42, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trainer = SFTTrainer( |
|
|
model=model, |
|
|
processing_class=tokenizer, |
|
|
args=training_args, |
|
|
train_dataset=tokenized_train, |
|
|
eval_dataset=tokenized_eval, |
|
|
data_collator=data_collator, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "="*40) |
|
|
print("Starting training (FP16 LoRA, no quantization)...") |
|
|
print("="*40 + "\n") |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
"""# **Evaluation**""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\nEvaluating on test split...") |
|
|
test_metrics = trainer.evaluate(eval_dataset=tokenized_test) |
|
|
test_loss = test_metrics.get("eval_loss", None) |
|
|
|
|
|
if test_loss is not None: |
|
|
try: |
|
|
test_ppl = float(np.exp(test_loss)) |
|
|
except OverflowError: |
|
|
test_ppl = float("inf") |
|
|
print(f"\nTest Loss: {test_loss:.4f}") |
|
|
print(f"Test Perplexity: {test_ppl:.2f}") |
|
|
else: |
|
|
print("Test eval_loss not present in metrics:", test_metrics) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\nSaving LoRA adapters and tokenizer...") |
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
|
|
|
model.save_pretrained(OUTPUT_DIR) |
|
|
tokenizer.save_pretrained(OUTPUT_DIR) |
|
|
|
|
|
print("\nDone. Model (quantized + LoRA adapters) saved to:", OUTPUT_DIR) |
|
|
print("You can load the adapter with from_pretrained and the same bnb/4bit settings.") |