| """ |
| QLoRA Healing Fine-Tune — repairs damage from merging. |
| |
| After each merge (or after all merges), the model may have rough edges. |
| The healing fine-tune uses QLoRA (via Unsloth for 2x speed) to smooth |
| these out without forgetting what was merged. |
| |
| Think of it like physical therapy after surgery — the operation (merge) |
| moved knowledge over, but the model needs practice to use it naturally. |
| |
| Config notes: |
| - r=32, alpha=64, dropout=0.0 (must be 0 for Unsloth speed) |
| - transformers >= 4.51.3 (NOT 4.51.0, NOT 4.52.0-4.55.1) |
| - bfloat16 end-to-end |
| - DDP across dual 4090 |
| |
| Findings: #12, #16, #20 |
| """ |
|
|
| import os |
| import sys |
| import time |
| import torch |
| from pathlib import Path |
| from typing import Optional |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments |
| from datasets import load_dataset |
|
|
| from .config import MergeConfig |
|
|
|
|
| def _load_model_smart(checkpoint, **kwargs): |
| """Load model — auto-detects Qwen3-VL and uses the correct class.""" |
| from transformers import AutoConfig |
| try: |
| config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True) |
| model_type = getattr(config, 'model_type', '') |
| config_class = type(config).__name__.lower() |
| if 'qwen3_vl' in model_type or 'qwen3vl' in config_class: |
| from transformers import Qwen3VLForConditionalGeneration |
| print(f'[heal] Loading as Qwen3-VL model: {checkpoint}') |
| return Qwen3VLForConditionalGeneration.from_pretrained(checkpoint, **kwargs) |
| except Exception as e: |
| print(f'[heal] Auto-detect failed ({e}), using AutoModelForCausalLM') |
| return AutoModelForCausalLM.from_pretrained(checkpoint, **kwargs) |
|
|
|
|
| def check_unsloth_available() -> bool: |
| """Check if Unsloth is installed and working.""" |
| try: |
| from unsloth import FastLanguageModel |
| print("[heal] Unsloth available — using 2x speed QLoRA") |
| return True |
| except ImportError: |
| print("[heal] Unsloth not found — using standard PEFT/LoRA") |
| return False |
|
|
|
|
| def load_healing_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list: |
| """ |
| Load data for healing fine-tune. |
| |
| Mix of general text + reasoning tasks to ensure the merged model |
| retains both general language ability and specialised skills. |
| """ |
| print("[heal] Loading healing fine-tune data...") |
|
|
| |
| |
| datasets_to_load = [ |
| |
| ("neuralmagic/LLM_compression_calibration", None, "train", 1500, "text"), |
| |
| ("openai/gsm8k", "main", "train", 1000, "question"), |
| |
| ("sahil2801/CodeAlpaca-20k", None, "train", 500, "output"), |
| ] |
|
|
| all_texts = [] |
|
|
| for entry in datasets_to_load: |
| dataset_id, config_name, split, count, text_field = entry |
| try: |
| if config_name: |
| ds = load_dataset(dataset_id, config_name, split=split, streaming=True) |
| else: |
| ds = load_dataset(dataset_id, split=split, streaming=True) |
| loaded = 0 |
| for example in ds: |
| if loaded >= count: |
| break |
| text = example.get(text_field, "") |
| if len(str(text)) > 50: |
| all_texts.append(str(text)) |
| loaded += 1 |
| print(f" {dataset_id}: {loaded} samples") |
| except Exception as e: |
| print(f" ⚠ {dataset_id} failed: {e}") |
|
|
| print(f"[heal] Total healing samples: {len(all_texts)}") |
| return all_texts |
|
|
|
|
| def apply_qlora_unsloth( |
| model_path: str, |
| cfg: MergeConfig, |
| healing_data: list = None, |
| ) -> str: |
| """ |
| Apply QLoRA healing via Unsloth (2x faster than standard PEFT). |
| |
| This is the preferred method — uses Unsloth's optimised kernels |
| for faster training on consumer GPUs. |
| |
| Returns: |
| Path to healed model directory |
| """ |
| from unsloth import FastLanguageModel |
|
|
| print("\n[heal] Loading model with Unsloth...") |
| model, tokenizer = FastLanguageModel.from_pretrained( |
| model_name=model_path, |
| dtype=getattr(torch, cfg.dtype), |
| max_seq_length=cfg.heal_seq_len, |
| load_in_4bit=True, |
| ) |
|
|
| |
| model = FastLanguageModel.get_peft_model( |
| model, |
| r=cfg.heal_lora_r, |
| lora_alpha=cfg.heal_lora_alpha, |
| lora_dropout=cfg.heal_lora_dropout, |
| target_modules=[ |
| "q_proj", "k_proj", "v_proj", "o_proj", |
| "gate_proj", "up_proj", "down_proj", |
| ], |
| bias="none", |
| use_gradient_checkpointing="unsloth", |
| ) |
|
|
| |
| if healing_data is None: |
| healing_data = load_healing_data(cfg, tokenizer) |
|
|
| |
| def tokenize_fn(texts): |
| return tokenizer( |
| texts, |
| truncation=True, |
| max_length=cfg.heal_seq_len, |
| padding="max_length", |
| return_tensors="pt", |
| ) |
|
|
| |
| from torch.utils.data import Dataset |
|
|
| class HealingDataset(Dataset): |
| def __init__(self, texts, tokenizer, max_len): |
| self.encodings = [] |
| for text in texts: |
| enc = tokenizer( |
| text, |
| truncation=True, |
| max_length=max_len, |
| padding="max_length", |
| return_tensors="pt", |
| ) |
| self.encodings.append({ |
| "input_ids": enc["input_ids"].squeeze(), |
| "attention_mask": enc["attention_mask"].squeeze(), |
| "labels": enc["input_ids"].squeeze(), |
| }) |
|
|
| def __len__(self): |
| return len(self.encodings) |
|
|
| def __getitem__(self, idx): |
| return self.encodings[idx] |
|
|
| dataset = HealingDataset(healing_data, tokenizer, cfg.heal_seq_len) |
|
|
| |
| output_dir = Path(cfg.output_dir) / "heal_output" |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| training_args = TrainingArguments( |
| output_dir=str(output_dir), |
| num_train_epochs=cfg.heal_epochs, |
| per_device_train_batch_size=cfg.heal_batch_size, |
| gradient_accumulation_steps=cfg.heal_grad_accum, |
| learning_rate=cfg.heal_learning_rate, |
| bf16=True, |
| logging_steps=10, |
| save_strategy="steps", |
| save_steps=50, |
| save_total_limit=2, max_steps=50, |
| warmup_ratio=0.05, |
| lr_scheduler_type="cosine", |
| optim="adamw_8bit", |
| report_to="none", |
| ) |
|
|
| |
| from trl import SFTTrainer |
|
|
| trainer = SFTTrainer( |
| model=model, |
| processing_class=tokenizer, |
| train_dataset=dataset, |
| args=training_args, |
| max_seq_length=cfg.heal_seq_len, |
| ) |
|
|
| print("\n[heal] Starting QLoRA healing fine-tune...") |
| trainer.train() |
|
|
| |
| healed_dir = Path(cfg.output_dir) / "healed" |
| healed_dir.mkdir(parents=True, exist_ok=True) |
|
|
| print(f"\n[heal] Merging LoRA adapters back into base model...") |
| model.save_pretrained_merged( |
| str(healed_dir), |
| tokenizer, |
| save_method="merged_16bit", |
| ) |
|
|
| print(f"[heal] Healed model saved to {healed_dir}") |
| return str(healed_dir) |
|
|
|
|
| def apply_qlora_standard( |
| model_path: str, |
| cfg: MergeConfig, |
| healing_data: list = None, |
| ) -> str: |
| """ |
| Fallback: QLoRA healing via standard PEFT (no Unsloth). |
| |
| Slower but works without Unsloth installed. |
| |
| Returns: |
| Path to healed model directory |
| """ |
| import os |
| healed_check = os.path.join('td_fuse_outputs', 'healed', 'model.safetensors') |
| if os.path.exists(healed_check): |
| print('[heal] Found existing healed model — SKIPPING healing!') |
| return 'td_fuse_outputs/healed' |
| import torch |
| from peft import LoraConfig, get_peft_model, TaskType |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| print("\n[heal] Loading model with standard PEFT...") |
|
|
| tokenizer = AutoTokenizer.from_pretrained(model_path) |
| model = _load_model_smart( |
| model_path, |
| device_map="auto", |
| torch_dtype=torch.bfloat16, |
| ) |
|
|
| |
| lora_config = LoraConfig( |
| r=cfg.heal_lora_r, |
| lora_alpha=cfg.heal_lora_alpha, |
| lora_dropout=cfg.heal_lora_dropout, |
| target_modules=[ |
| "q_proj", "k_proj", "v_proj", "o_proj", |
| "gate_proj", "up_proj", "down_proj", |
| ], |
| bias="none", |
| task_type=TaskType.CAUSAL_LM, |
| ) |
|
|
| model = get_peft_model(model, lora_config) |
| model.print_trainable_parameters() |
|
|
| |
| if healing_data is None: |
| healing_data = load_healing_data(cfg, tokenizer) |
|
|
| from torch.utils.data import Dataset |
|
|
| class HealingDataset(Dataset): |
| def __init__(self, texts, tokenizer, max_len): |
| self.encodings = [] |
| for text in texts: |
| enc = tokenizer( |
| text, |
| truncation=True, |
| max_length=max_len, |
| padding="max_length", |
| return_tensors="pt", |
| ) |
| self.encodings.append({ |
| "input_ids": enc["input_ids"].squeeze(), |
| "attention_mask": enc["attention_mask"].squeeze(), |
| "labels": enc["input_ids"].squeeze(), |
| }) |
|
|
| def __len__(self): |
| return len(self.encodings) |
|
|
| def __getitem__(self, idx): |
| return self.encodings[idx] |
|
|
| dataset = HealingDataset(healing_data, tokenizer, cfg.heal_seq_len) |
|
|
| |
| output_dir = Path(cfg.output_dir) / "heal_output" |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| training_args = TrainingArguments( |
| output_dir=str(output_dir), |
| num_train_epochs=cfg.heal_epochs, |
| per_device_train_batch_size=cfg.heal_batch_size, |
| gradient_accumulation_steps=cfg.heal_grad_accum, |
| learning_rate=cfg.heal_learning_rate, |
| bf16=True, |
| logging_steps=10, |
| save_strategy="steps", |
| save_steps=50, |
| save_total_limit=2, max_steps=50, |
| warmup_ratio=0.05, |
| lr_scheduler_type="cosine", |
| optim="adamw_torch", |
| report_to="none", |
| ) |
|
|
| from transformers import Trainer |
|
|
| trainer = Trainer( |
| model=model, |
| processing_class=tokenizer, |
| train_dataset=dataset, |
| args=training_args, |
| ) |
|
|
| print("\n[heal] Starting standard QLoRA healing fine-tune...") |
| trainer.train() |
|
|
| |
| |
| import shutil, gc |
| heal_output_dir = Path(cfg.output_dir) / "heal_output" |
| if heal_output_dir.exists(): |
| print(f"[heal] Cleaning up training checkpoints to free disk space...") |
| shutil.rmtree(str(heal_output_dir), ignore_errors=True) |
| print(f"[heal] Freed ~17GB from {heal_output_dir}") |
|
|
| |
| healed_dir = Path(cfg.output_dir) / "healed" |
| healed_dir.mkdir(parents=True, exist_ok=True) |
|
|
| print(f"\n[heal] Merging LoRA adapters...") |
| merged_model = model.merge_and_unload() |
|
|
| gc.collect() |
|
|
| |
| merged_model.save_pretrained(str(healed_dir), safe_serialization=True) |
| tokenizer.save_pretrained(str(healed_dir)) |
| print(f"[heal] SAVED OK: {healed_dir}") |
|
|
| |
| saved_model = healed_dir / "model.safetensors" |
| if not saved_model.exists() or saved_model.stat().st_size < 1_000_000: |
| print(f"[heal] WARNING: Save may have failed — NOT deleting any backups!") |
| else: |
| save_size = saved_model.stat().st_size / 1e9 |
| print(f"[heal] Verified: {saved_model} ({save_size:.1f} GB)") |
| |
| cleanup_targets = [ |
| "td_fuse_outputs/final", |
| ] |
| for target in cleanup_targets: |
| target_path = Path(target) |
| if target_path.exists() and target_path.is_dir(): |
| shutil.rmtree(str(target_path)) |
| print(f"[heal] Freed space: removed {target_path}") |
|
|
| gc.collect() |
|
|
| print(f"[heal] Healed model saved to {healed_dir}") |
| return str(healed_dir) |
|
|
|
|
| def heal_model( |
| model_path: str, |
| cfg: MergeConfig = None, |
| healing_data: list = None, |
| ) -> str: |
| """ |
| Main entry point for healing. Tries Unsloth first, falls back to PEFT. |
| |
| Args: |
| model_path: Path to the merged model checkpoint |
| cfg: Merge configuration |
| healing_data: Optional pre-loaded training data |
| |
| Returns: |
| Path to healed model directory |
| """ |
| if cfg is None: |
| cfg = MergeConfig() |
|
|
| |
| import os |
| healed_check = os.path.join('td_fuse_outputs', 'healed', 'model.safetensors') |
| if os.path.exists(healed_check): |
| print('[heal] Found existing healed model — SKIPPING healing!') |
| return 'td_fuse_outputs/healed' |
|
|
| heal_start = time.time() |
| print("\n" + "=" * 60) |
| print("HEALING FINE-TUNE") |
| print(f"Model: {model_path}") |
| print(f"LoRA r={cfg.heal_lora_r}, alpha={cfg.heal_lora_alpha}") |
| print(f"Epochs: {cfg.heal_epochs}, LR: {cfg.heal_learning_rate}") |
| print(f"Started at: {time.strftime('%H:%M:%S')}") |
| print("=" * 60) |
| sys.stdout.flush() |
|
|
| if check_unsloth_available(): |
| result = apply_qlora_unsloth(model_path, cfg, healing_data) |
| else: |
| result = apply_qlora_standard(model_path, cfg, healing_data) |
| print(f"[heal] Total healing time: {(time.time()-heal_start)/60:.1f} min") |
| sys.stdout.flush() |
| return result |
|
|