insureos-models / distill.py
piyushptiwari's picture
Upload folder using huggingface_hub
2cc32a5 verified
"""
InsureOS β€” Knowledge Distillation Script
Distils InsureLLM-8B (DPO-aligned teacher) β†’ InsureLLM-4B (Qwen3-4B student).
Uses KL-divergence + hard-label distillation for 16 GB VRAM.
"""
import os
import json
import argparse
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType
from tqdm import tqdm
# ── Defaults ──
TEACHER_MODEL = "models/insurellm-8b-dpo-merged"
STUDENT_MODEL = "Qwen/Qwen3-4B"
DATA_PATH = "data/output/insurance_sft_10k.jsonl"
OUTPUT_DIR = "models/insurellm-4b-distilled"
MAX_SEQ_LEN = 1024
LORA_R = 32
LORA_ALPHA = 64
TEMPERATURE = 3.0 # softens teacher logits
ALPHA_KL = 0.7 # weight of KL loss vs hard label loss
EPOCHS = 3
BATCH_SIZE = 2
GRAD_ACCUM = 8
LR = 1e-4
WARMUP_STEPS = 50
SAVE_STEPS = 200
def load_data(path: str, tokenizer, max_len: int) -> Dataset:
"""Load and tokenize SFT data for distillation."""
records = []
with open(path) as f:
for line in f:
obj = json.loads(line)
# Apply chat template to get text
text = tokenizer.apply_chat_template(
obj["messages"],
tokenize=False,
add_generation_prompt=False,
)
records.append({"text": text})
ds = Dataset.from_list(records)
def tokenize_fn(examples):
return tokenizer(
examples["text"],
truncation=True,
max_length=max_len,
padding="max_length",
return_tensors="pt",
)
ds = ds.map(tokenize_fn, batched=True, remove_columns=["text"])
ds.set_format("torch")
return ds
def main():
parser = argparse.ArgumentParser(description="Distil InsureLLM-8B β†’ InsureLLM-4B")
parser.add_argument("--teacher-model", default=TEACHER_MODEL)
parser.add_argument("--student-model", default=STUDENT_MODEL)
parser.add_argument("--data-path", default=DATA_PATH)
parser.add_argument("--output-dir", default=OUTPUT_DIR)
parser.add_argument("--epochs", type=int, default=EPOCHS)
parser.add_argument("--batch-size", type=int, default=BATCH_SIZE)
parser.add_argument("--lr", type=float, default=LR)
parser.add_argument("--temperature", type=float, default=TEMPERATURE)
parser.add_argument("--alpha-kl", type=float, default=ALPHA_KL)
args = parser.parse_args()
print(f"{'='*60}")
print(f" InsureOS β€” Knowledge Distillation")
print(f" Teacher: {args.teacher_model}")
print(f" Student: {args.student_model}")
print(f" Temperature: {args.temperature}, Alpha: {args.alpha_kl}")
print(f"{'='*60}\n")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ── 1. Load tokenizer (student's) ──
print("[1/5] Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
args.student_model,
trust_remote_code=True,
padding_side="right",
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# ── 2. Load teacher (4-bit, frozen) ──
print("[2/5] Loading teacher model (4-bit, frozen)...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
teacher = AutoModelForCausalLM.from_pretrained(
args.teacher_model,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
teacher.eval()
for p in teacher.parameters():
p.requires_grad = False
# ── 3. Load student (4-bit + LoRA for training) ──
print("[3/5] Loading student model with LoRA...")
student = AutoModelForCausalLM.from_pretrained(
args.student_model,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
student = prepare_model_for_kbit_training(student, use_gradient_checkpointing=True)
lora_config = LoraConfig(
r=LORA_R,
lora_alpha=LORA_ALPHA,
lora_dropout=0.05,
target_modules="all-linear",
task_type=TaskType.CAUSAL_LM,
bias="none",
)
student = get_peft_model(student, lora_config)
student.print_trainable_parameters()
# ── 4. Load data ──
print("[4/5] Loading and tokenizing data...")
dataset = load_data(args.data_path, tokenizer, MAX_SEQ_LEN)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
print(f" Examples: {len(dataset)}, Batches/epoch: {len(dataloader)}")
# ── 5. Distillation training loop ──
print("[5/5] Starting distillation...\n")
optimizer = torch.optim.AdamW(student.parameters(), lr=args.lr, weight_decay=0.01)
total_steps = len(dataloader) * args.epochs // GRAD_ACCUM
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps)
global_step = 0
best_loss = float("inf")
for epoch in range(args.epochs):
student.train()
epoch_loss = 0.0
accum_loss = 0.0
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{args.epochs}")
for step, batch in enumerate(pbar):
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
# Teacher forward (no grad)
with torch.no_grad():
teacher_outputs = teacher(
input_ids=input_ids,
attention_mask=attention_mask,
)
teacher_logits = teacher_outputs.logits
# Student forward
student_outputs = student(
input_ids=input_ids,
attention_mask=attention_mask,
labels=input_ids, # for hard label loss
)
student_logits = student_outputs.logits
hard_loss = student_outputs.loss
# KL divergence loss (soft labels)
T = args.temperature
teacher_probs = F.log_softmax(teacher_logits / T, dim=-1)
student_log_probs = F.log_softmax(student_logits / T, dim=-1)
# Only compute KL over non-padding tokens
mask = attention_mask.unsqueeze(-1).float()
kl_loss = F.kl_div(
student_log_probs * mask,
teacher_probs * mask,
log_target=True,
reduction="batchmean",
) * (T ** 2)
# Combined loss
loss = args.alpha_kl * kl_loss + (1 - args.alpha_kl) * hard_loss
loss = loss / GRAD_ACCUM
loss.backward()
accum_loss += loss.item()
if (step + 1) % GRAD_ACCUM == 0:
torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
global_step += 1
pbar.set_postfix({
"loss": f"{accum_loss:.4f}",
"kl": f"{kl_loss.item():.4f}",
"hard": f"{hard_loss.item():.4f}",
"lr": f"{scheduler.get_last_lr()[0]:.2e}",
})
epoch_loss += accum_loss
accum_loss = 0.0
# Save checkpoint
if global_step % SAVE_STEPS == 0:
ckpt_dir = os.path.join(args.output_dir, f"checkpoint-{global_step}")
student.save_pretrained(ckpt_dir)
tokenizer.save_pretrained(ckpt_dir)
print(f"\n Checkpoint saved: {ckpt_dir}")
avg_loss = epoch_loss / max(1, global_step)
print(f"\nEpoch {epoch+1} β€” avg loss: {avg_loss:.4f}")
if avg_loss < best_loss:
best_loss = avg_loss
best_dir = os.path.join(args.output_dir, "best")
student.save_pretrained(best_dir)
tokenizer.save_pretrained(best_dir)
print(f" Best model saved: {best_dir}")
# ── Final save ──
print("\nSaving final distilled model...")
student.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
# Merge LoRA
merged_dir = f"{args.output_dir}-merged"
print(f"Merging LoRA β†’ {merged_dir}")
merged = student.merge_and_unload()
merged.save_pretrained(merged_dir)
tokenizer.save_pretrained(merged_dir)
print(f"\nβœ“ Distillation complete!")
print(f" Student (LoRA): {args.output_dir}")
print(f" Student (merged): {merged_dir}")
print(f" Best loss: {best_loss:.4f}")
if __name__ == "__main__":
main()