import os import sys import math import torch import numpy as np from datasets import load_dataset from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, DataCollatorForLanguageModeling, TrainingArguments from trl import SFTConfig, SFTTrainer """# **Initial Configs**""" # ----------------------- # Safety checks # ----------------------- if not torch.cuda.is_available(): print("ERROR: CUDA not available. This script requires a GPU.") sys.exit(1) device = "cuda" print("CUDA device:", torch.cuda.get_device_name(0)) print(f"Total GPU memory (GB): {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f}") # ----------------------- # Config # ----------------------- MODEL_NAME = "microsoft/Phi-3.5-mini-instruct" DATASET_NAME = "dassarthak18/FStarDataset-V2-Conversation" OUTPUT_DIR = "./phi3.5-mini-lora" MAX_SEQ_LENGTH = 2048 # safe starting point PER_DEVICE_BATCH_SIZE = 4 # per GPU GRAD_ACCUM_STEPS = 4 # effective batch size = 16 NUM_EPOCHS = 6 LEARNING_RATE = 2e-4 # LoRA params LORA_R = 32 LORA_ALPHA = 64 LORA_DROPOUT = 0.05 # PEFT target modules for Phi-3.5 TARGET_MODULES = [ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj" ] # ---------- Robust FlashAttention detection & fallback ---------- def detect_flash_attn(): try: # try the compiled extension import which fails on GLIBC mismatch import flash_attn_2_cuda # noqa: F401 print("flash-attn compiled extension import: OK") return "flash_attention_2" except Exception as e: print("flash-attn import failed (will fallback). Reason:", repr(e)) return "eager" attn_impl = detect_flash_attn() print("Using attn_implementation =", attn_impl) # ---------------------------------------------------------------- # ----------------------- # Tokenizer # ----------------------- print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token """# **Load Model and Dataset**""" """ # ----------------------- # Load model in 4-bit (bitsandbytes) # ----------------------- print("Loading model in 4-bit (QLoRA) — this uses bitsandbytes.") # Important: load_in_4bit requires bitsandbytes installed and a compatible transformers version. # bnb_4bit_compute_dtype uses fp16 (bf16 not supported) bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.float16, ) """ print("Loading model...") model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, device_map="auto", trust_remote_code=True, dtype=torch.float16, attn_implementation=attn_impl, #quantization_config=bnb_config, ) # Required tweaks for long training model.config.use_cache = False #model.gradient_checkpointing_enable() """ # Prepare model for k-bit training (adjusts layer norms, enables gradients for some params) print("Preparing model for k-bit (4-bit) training...") model = prepare_model_for_kbit_training(model) """ # ----------------------- # Apply LoRA (PEFT) on top of 4-bit model # ----------------------- print("Applying LoRA adapters (PEFT)...") lora_config = LoraConfig( r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT, bias="none", target_modules=TARGET_MODULES, task_type="CAUSAL_LM" ) model = get_peft_model(model, lora_config) # Ensure LoRA params are trainable for n, p in model.named_parameters(): # By design, most parameters remain frozen; LoRA adapters and some qkv adapters are trainable if "lora" in n or "adapter" in n: p.requires_grad = True # Print summary of trainable params print("\nTrainable parameters (should be LoRA params, small fraction):") model.print_trainable_parameters() # ----------------------- # Load and preprocess conversation dataset # ----------------------- print(f"Loading dataset: {DATASET_NAME} ...") raw_ds = load_dataset(DATASET_NAME) print("Available splits:", raw_ds.keys()) def conversation_to_text(example): conv = example.get("messages") parts = [] for turn in conv: role = str(turn.get("role", "")).strip().lower() content = str(turn.get("content", "")).strip() if not content: continue if role == "user": parts.append(f"\n{content}\n") elif role == "assistant": parts.append(f"\n{content}\n") else: parts.append(content) text = "\n".join(parts).strip() return {"text": text} print("Converting conversations -> single text field ...") processed = raw_ds.map(conversation_to_text, remove_columns=raw_ds["train"].column_names, num_proc=32) # Filter empty examples processed = processed.filter(lambda e: e["text"].strip() != "") train_ds = processed["train"] eval_ds = processed["validation"] test_ds = processed["test"] print(f"Train examples: {len(train_ds)}") print(f"Validation examples: {len(eval_ds)}") print(f"Test examples: {len(test_ds)}") ''' MAX_EVAL = 1000 if len(eval_ds) > MAX_EVAL: eval_ds = eval_ds.select(range(MAX_EVAL)) print(f"Validation truncated to {MAX_EVAL} for memory.") if len(test_ds) > MAX_EVAL: test_ds = test_ds.select(range(MAX_EVAL)) print(f"Test truncated to {MAX_EVAL} for memory.") ''' # ----------------------- # Tokenization (truncation to MAX_SEQ_LENGTH) # ----------------------- def tokenize_fn(batch): texts = batch["text"] # ensure strings if isinstance(texts, list): texts = [str(t) for t in texts] else: texts = [str(texts)] return tokenizer( texts, truncation=True, max_length=MAX_SEQ_LENGTH, padding=False, # dynamic padding handled by collator ) print("Tokenizing datasets...") tokenized_train = train_ds.map(tokenize_fn, batched=True, remove_columns=train_ds.column_names, num_proc=32) tokenized_eval = eval_ds.map(tokenize_fn, batched=True, remove_columns=eval_ds.column_names, num_proc=32) tokenized_test = test_ds.map(tokenize_fn, batched=True, remove_columns=test_ds.column_names, num_proc=32) print(f"Tokenized train size: {len(tokenized_train)}") print(f"Tokenized eval size: {len(tokenized_eval)}") print(f"Tokenized test size: {len(tokenized_test)}") # ----------------------- # Data collator # ----------------------- data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False, pad_to_multiple_of=8 ) """# **Training**""" # ----------------------- # Training args (single GPU optimized) # ----------------------- training_args = SFTConfig( output_dir=OUTPUT_DIR, max_length=MAX_SEQ_LENGTH, num_train_epochs=NUM_EPOCHS, per_device_train_batch_size=PER_DEVICE_BATCH_SIZE, per_device_eval_batch_size=PER_DEVICE_BATCH_SIZE, gradient_accumulation_steps=GRAD_ACCUM_STEPS, learning_rate=LEARNING_RATE, lr_scheduler_type="cosine", warmup_ratio=0.05, logging_steps=25, save_steps=250, eval_steps=250, save_total_limit=3, eval_strategy="steps", load_best_model_at_end=True, metric_for_best_model="eval_loss", greater_is_better=False, fp16=True, bf16=False, packing=False, dataloader_num_workers=2, dataloader_pin_memory=True, report_to="tensorboard", #gradient_checkpointing=True, seed=42, ) # ----------------------- # SFTTrainer setup # ----------------------- trainer = SFTTrainer( model=model, processing_class=tokenizer, args=training_args, train_dataset=tokenized_train, eval_dataset=tokenized_eval, data_collator=data_collator, ) # ----------------------- # Train # ----------------------- print("\n" + "="*40) print("Starting training (FP16 LoRA, no quantization)...") print("="*40 + "\n") trainer.train() """# **Evaluation**""" # ----------------------- # Evaluate on test split (final) # ----------------------- print("\nEvaluating on test split...") test_metrics = trainer.evaluate(eval_dataset=tokenized_test) test_loss = test_metrics.get("eval_loss", None) if test_loss is not None: try: test_ppl = float(np.exp(test_loss)) except OverflowError: test_ppl = float("inf") print(f"\nTest Loss: {test_loss:.4f}") print(f"Test Perplexity: {test_ppl:.2f}") else: print("Test eval_loss not present in metrics:", test_metrics) # ----------------------- # Save adapters and tokenizer # ----------------------- print("\nSaving LoRA adapters and tokenizer...") os.makedirs(OUTPUT_DIR, exist_ok=True) # Save only PEFT adapters (keeps model quantized + small) model.save_pretrained(OUTPUT_DIR) tokenizer.save_pretrained(OUTPUT_DIR) print("\nDone. Model (quantized + LoRA adapters) saved to:", OUTPUT_DIR) print("You can load the adapter with from_pretrained and the same bnb/4bit settings.")