| """ |
| InsureOS β Knowledge Distillation Script |
| Distils InsureLLM-8B (DPO-aligned teacher) β InsureLLM-4B (Qwen3-4B student). |
| Uses KL-divergence + hard-label distillation for 16 GB VRAM. |
| """ |
|
|
| import os |
| import json |
| import argparse |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
| from datasets import Dataset |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
| from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType |
| from tqdm import tqdm |
|
|
|
|
| |
|
|
| TEACHER_MODEL = "models/insurellm-8b-dpo-merged" |
| STUDENT_MODEL = "Qwen/Qwen3-4B" |
| DATA_PATH = "data/output/insurance_sft_10k.jsonl" |
| OUTPUT_DIR = "models/insurellm-4b-distilled" |
| MAX_SEQ_LEN = 1024 |
| LORA_R = 32 |
| LORA_ALPHA = 64 |
| TEMPERATURE = 3.0 |
| ALPHA_KL = 0.7 |
| EPOCHS = 3 |
| BATCH_SIZE = 2 |
| GRAD_ACCUM = 8 |
| LR = 1e-4 |
| WARMUP_STEPS = 50 |
| SAVE_STEPS = 200 |
|
|
|
|
| def load_data(path: str, tokenizer, max_len: int) -> Dataset: |
| """Load and tokenize SFT data for distillation.""" |
| records = [] |
| with open(path) as f: |
| for line in f: |
| obj = json.loads(line) |
| |
| text = tokenizer.apply_chat_template( |
| obj["messages"], |
| tokenize=False, |
| add_generation_prompt=False, |
| ) |
| records.append({"text": text}) |
|
|
| ds = Dataset.from_list(records) |
|
|
| def tokenize_fn(examples): |
| return tokenizer( |
| examples["text"], |
| truncation=True, |
| max_length=max_len, |
| padding="max_length", |
| return_tensors="pt", |
| ) |
|
|
| ds = ds.map(tokenize_fn, batched=True, remove_columns=["text"]) |
| ds.set_format("torch") |
| return ds |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Distil InsureLLM-8B β InsureLLM-4B") |
| parser.add_argument("--teacher-model", default=TEACHER_MODEL) |
| parser.add_argument("--student-model", default=STUDENT_MODEL) |
| parser.add_argument("--data-path", default=DATA_PATH) |
| parser.add_argument("--output-dir", default=OUTPUT_DIR) |
| parser.add_argument("--epochs", type=int, default=EPOCHS) |
| parser.add_argument("--batch-size", type=int, default=BATCH_SIZE) |
| parser.add_argument("--lr", type=float, default=LR) |
| parser.add_argument("--temperature", type=float, default=TEMPERATURE) |
| parser.add_argument("--alpha-kl", type=float, default=ALPHA_KL) |
| args = parser.parse_args() |
|
|
| print(f"{'='*60}") |
| print(f" InsureOS β Knowledge Distillation") |
| print(f" Teacher: {args.teacher_model}") |
| print(f" Student: {args.student_model}") |
| print(f" Temperature: {args.temperature}, Alpha: {args.alpha_kl}") |
| print(f"{'='*60}\n") |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| print("[1/5] Loading tokenizer...") |
| tokenizer = AutoTokenizer.from_pretrained( |
| args.student_model, |
| trust_remote_code=True, |
| padding_side="right", |
| ) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| |
| print("[2/5] Loading teacher model (4-bit, frozen)...") |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_use_double_quant=True, |
| ) |
|
|
| teacher = AutoModelForCausalLM.from_pretrained( |
| args.teacher_model, |
| quantization_config=bnb_config, |
| device_map="auto", |
| trust_remote_code=True, |
| torch_dtype=torch.bfloat16, |
| ) |
| teacher.eval() |
| for p in teacher.parameters(): |
| p.requires_grad = False |
|
|
| |
| print("[3/5] Loading student model with LoRA...") |
| student = AutoModelForCausalLM.from_pretrained( |
| args.student_model, |
| quantization_config=bnb_config, |
| device_map="auto", |
| trust_remote_code=True, |
| torch_dtype=torch.bfloat16, |
| ) |
| student = prepare_model_for_kbit_training(student, use_gradient_checkpointing=True) |
|
|
| lora_config = LoraConfig( |
| r=LORA_R, |
| lora_alpha=LORA_ALPHA, |
| lora_dropout=0.05, |
| target_modules="all-linear", |
| task_type=TaskType.CAUSAL_LM, |
| bias="none", |
| ) |
| student = get_peft_model(student, lora_config) |
| student.print_trainable_parameters() |
|
|
| |
| print("[4/5] Loading and tokenizing data...") |
| dataset = load_data(args.data_path, tokenizer, MAX_SEQ_LEN) |
| dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) |
| print(f" Examples: {len(dataset)}, Batches/epoch: {len(dataloader)}") |
|
|
| |
| print("[5/5] Starting distillation...\n") |
|
|
| optimizer = torch.optim.AdamW(student.parameters(), lr=args.lr, weight_decay=0.01) |
|
|
| total_steps = len(dataloader) * args.epochs // GRAD_ACCUM |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps) |
|
|
| global_step = 0 |
| best_loss = float("inf") |
|
|
| for epoch in range(args.epochs): |
| student.train() |
| epoch_loss = 0.0 |
| accum_loss = 0.0 |
|
|
| pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{args.epochs}") |
| for step, batch in enumerate(pbar): |
| input_ids = batch["input_ids"].to(device) |
| attention_mask = batch["attention_mask"].to(device) |
|
|
| |
| with torch.no_grad(): |
| teacher_outputs = teacher( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| ) |
| teacher_logits = teacher_outputs.logits |
|
|
| |
| student_outputs = student( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| labels=input_ids, |
| ) |
| student_logits = student_outputs.logits |
| hard_loss = student_outputs.loss |
|
|
| |
| T = args.temperature |
| teacher_probs = F.log_softmax(teacher_logits / T, dim=-1) |
| student_log_probs = F.log_softmax(student_logits / T, dim=-1) |
|
|
| |
| mask = attention_mask.unsqueeze(-1).float() |
| kl_loss = F.kl_div( |
| student_log_probs * mask, |
| teacher_probs * mask, |
| log_target=True, |
| reduction="batchmean", |
| ) * (T ** 2) |
|
|
| |
| loss = args.alpha_kl * kl_loss + (1 - args.alpha_kl) * hard_loss |
| loss = loss / GRAD_ACCUM |
|
|
| loss.backward() |
| accum_loss += loss.item() |
|
|
| if (step + 1) % GRAD_ACCUM == 0: |
| torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0) |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad() |
| global_step += 1 |
|
|
| pbar.set_postfix({ |
| "loss": f"{accum_loss:.4f}", |
| "kl": f"{kl_loss.item():.4f}", |
| "hard": f"{hard_loss.item():.4f}", |
| "lr": f"{scheduler.get_last_lr()[0]:.2e}", |
| }) |
| epoch_loss += accum_loss |
| accum_loss = 0.0 |
|
|
| |
| if global_step % SAVE_STEPS == 0: |
| ckpt_dir = os.path.join(args.output_dir, f"checkpoint-{global_step}") |
| student.save_pretrained(ckpt_dir) |
| tokenizer.save_pretrained(ckpt_dir) |
| print(f"\n Checkpoint saved: {ckpt_dir}") |
|
|
| avg_loss = epoch_loss / max(1, global_step) |
| print(f"\nEpoch {epoch+1} β avg loss: {avg_loss:.4f}") |
|
|
| if avg_loss < best_loss: |
| best_loss = avg_loss |
| best_dir = os.path.join(args.output_dir, "best") |
| student.save_pretrained(best_dir) |
| tokenizer.save_pretrained(best_dir) |
| print(f" Best model saved: {best_dir}") |
|
|
| |
| print("\nSaving final distilled model...") |
| student.save_pretrained(args.output_dir) |
| tokenizer.save_pretrained(args.output_dir) |
|
|
| |
| merged_dir = f"{args.output_dir}-merged" |
| print(f"Merging LoRA β {merged_dir}") |
| merged = student.merge_and_unload() |
| merged.save_pretrained(merged_dir) |
| tokenizer.save_pretrained(merged_dir) |
|
|
| print(f"\nβ Distillation complete!") |
| print(f" Student (LoRA): {args.output_dir}") |
| print(f" Student (merged): {merged_dir}") |
| print(f" Best loss: {best_loss:.4f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|