""" QLoRA Healing Fine-Tune — repairs damage from merging. After each merge (or after all merges), the model may have rough edges. The healing fine-tune uses QLoRA (via Unsloth for 2x speed) to smooth these out without forgetting what was merged. Think of it like physical therapy after surgery — the operation (merge) moved knowledge over, but the model needs practice to use it naturally. Config notes: - r=32, alpha=64, dropout=0.0 (must be 0 for Unsloth speed) - transformers >= 4.51.3 (NOT 4.51.0, NOT 4.52.0-4.55.1) - bfloat16 end-to-end - DDP across dual 4090 Findings: #12, #16, #20 """ import os import sys import time import torch from pathlib import Path from typing import Optional from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments from datasets import load_dataset from .config import MergeConfig def _load_model_smart(checkpoint, **kwargs): """Load model — auto-detects Qwen3-VL and uses the correct class.""" from transformers import AutoConfig try: config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True) model_type = getattr(config, 'model_type', '') config_class = type(config).__name__.lower() if 'qwen3_vl' in model_type or 'qwen3vl' in config_class: from transformers import Qwen3VLForConditionalGeneration print(f'[heal] Loading as Qwen3-VL model: {checkpoint}') return Qwen3VLForConditionalGeneration.from_pretrained(checkpoint, **kwargs) except Exception as e: print(f'[heal] Auto-detect failed ({e}), using AutoModelForCausalLM') return AutoModelForCausalLM.from_pretrained(checkpoint, **kwargs) def check_unsloth_available() -> bool: """Check if Unsloth is installed and working.""" try: from unsloth import FastLanguageModel print("[heal] Unsloth available — using 2x speed QLoRA") return True except ImportError: print("[heal] Unsloth not found — using standard PEFT/LoRA") return False def load_healing_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list: """ Load data for healing fine-tune. Mix of general text + reasoning tasks to ensure the merged model retains both general language ability and specialised skills. """ print("[heal] Loading healing fine-tune data...") # Merge-specific: use diverse data that exercises all merged capabilities # Each entry: (dataset_id, config_name_or_None, split, count, text_field) datasets_to_load = [ # General language — same calibration data source that works reliably ("neuralmagic/LLM_compression_calibration", None, "train", 1500, "text"), # Math reasoning (exercises DeepSeek/MiMo contributions) ("openai/gsm8k", "main", "train", 1000, "question"), # Code — bigcode/starcoderdata is a modern alternative ("sahil2801/CodeAlpaca-20k", None, "train", 500, "output"), ] all_texts = [] for entry in datasets_to_load: dataset_id, config_name, split, count, text_field = entry try: if config_name: ds = load_dataset(dataset_id, config_name, split=split, streaming=True) else: ds = load_dataset(dataset_id, split=split, streaming=True) loaded = 0 for example in ds: if loaded >= count: break text = example.get(text_field, "") if len(str(text)) > 50: all_texts.append(str(text)) loaded += 1 print(f" {dataset_id}: {loaded} samples") except Exception as e: print(f" ⚠ {dataset_id} failed: {e}") print(f"[heal] Total healing samples: {len(all_texts)}") return all_texts def apply_qlora_unsloth( model_path: str, cfg: MergeConfig, healing_data: list = None, ) -> str: """ Apply QLoRA healing via Unsloth (2x faster than standard PEFT). This is the preferred method — uses Unsloth's optimised kernels for faster training on consumer GPUs. Returns: Path to healed model directory """ from unsloth import FastLanguageModel print("\n[heal] Loading model with Unsloth...") model, tokenizer = FastLanguageModel.from_pretrained( model_name=model_path, dtype=getattr(torch, cfg.dtype), max_seq_length=cfg.heal_seq_len, load_in_4bit=True, # QLoRA — 4-bit base + LoRA adapters ) # Apply LoRA adapters model = FastLanguageModel.get_peft_model( model, r=cfg.heal_lora_r, # 32 — higher rank for healing lora_alpha=cfg.heal_lora_alpha, # 64 — 2x rank lora_dropout=cfg.heal_lora_dropout, # 0.0 — MUST be 0 for Unsloth speed target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], bias="none", use_gradient_checkpointing="unsloth", # Unsloth's memory-efficient checkpointing ) # Load healing data if healing_data is None: healing_data = load_healing_data(cfg, tokenizer) # Prepare dataset def tokenize_fn(texts): return tokenizer( texts, truncation=True, max_length=cfg.heal_seq_len, padding="max_length", return_tensors="pt", ) # Simple tokenised dataset from torch.utils.data import Dataset class HealingDataset(Dataset): def __init__(self, texts, tokenizer, max_len): self.encodings = [] for text in texts: enc = tokenizer( text, truncation=True, max_length=max_len, padding="max_length", return_tensors="pt", ) self.encodings.append({ "input_ids": enc["input_ids"].squeeze(), "attention_mask": enc["attention_mask"].squeeze(), "labels": enc["input_ids"].squeeze(), }) def __len__(self): return len(self.encodings) def __getitem__(self, idx): return self.encodings[idx] dataset = HealingDataset(healing_data, tokenizer, cfg.heal_seq_len) # Training arguments output_dir = Path(cfg.output_dir) / "heal_output" output_dir.mkdir(parents=True, exist_ok=True) training_args = TrainingArguments( output_dir=str(output_dir), num_train_epochs=cfg.heal_epochs, per_device_train_batch_size=cfg.heal_batch_size, gradient_accumulation_steps=cfg.heal_grad_accum, learning_rate=cfg.heal_learning_rate, bf16=True, logging_steps=10, save_strategy="steps", save_steps=50, save_total_limit=2, max_steps=50, # Don't save intermediate checkpoints — saves ~17GB disk warmup_ratio=0.05, lr_scheduler_type="cosine", optim="adamw_8bit", # Memory-efficient optimiser report_to="none", ) # Use Unsloth's trainer from trl import SFTTrainer trainer = SFTTrainer( model=model, processing_class=tokenizer, train_dataset=dataset, args=training_args, max_seq_length=cfg.heal_seq_len, ) print("\n[heal] Starting QLoRA healing fine-tune...") trainer.train() # Save healed model (merge LoRA back into base) healed_dir = Path(cfg.output_dir) / "healed" healed_dir.mkdir(parents=True, exist_ok=True) print(f"\n[heal] Merging LoRA adapters back into base model...") model.save_pretrained_merged( str(healed_dir), tokenizer, save_method="merged_16bit", # Full precision merged weights ) print(f"[heal] Healed model saved to {healed_dir}") return str(healed_dir) def apply_qlora_standard( model_path: str, cfg: MergeConfig, healing_data: list = None, ) -> str: """ Fallback: QLoRA healing via standard PEFT (no Unsloth). Slower but works without Unsloth installed. Returns: Path to healed model directory """ import os healed_check = os.path.join('td_fuse_outputs', 'healed', 'model.safetensors') if os.path.exists(healed_check): print('[heal] Found existing healed model — SKIPPING healing!') return 'td_fuse_outputs/healed' import torch from peft import LoraConfig, get_peft_model, TaskType from transformers import AutoModelForCausalLM, AutoTokenizer print("\n[heal] Loading model with standard PEFT...") tokenizer = AutoTokenizer.from_pretrained(model_path) model = _load_model_smart( model_path, device_map="auto", torch_dtype=torch.bfloat16, ) # LoRA config lora_config = LoraConfig( r=cfg.heal_lora_r, lora_alpha=cfg.heal_lora_alpha, lora_dropout=cfg.heal_lora_dropout, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], bias="none", task_type=TaskType.CAUSAL_LM, ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() # Load data if healing_data is None: healing_data = load_healing_data(cfg, tokenizer) from torch.utils.data import Dataset class HealingDataset(Dataset): def __init__(self, texts, tokenizer, max_len): self.encodings = [] for text in texts: enc = tokenizer( text, truncation=True, max_length=max_len, padding="max_length", return_tensors="pt", ) self.encodings.append({ "input_ids": enc["input_ids"].squeeze(), "attention_mask": enc["attention_mask"].squeeze(), "labels": enc["input_ids"].squeeze(), }) def __len__(self): return len(self.encodings) def __getitem__(self, idx): return self.encodings[idx] dataset = HealingDataset(healing_data, tokenizer, cfg.heal_seq_len) # Training output_dir = Path(cfg.output_dir) / "heal_output" output_dir.mkdir(parents=True, exist_ok=True) training_args = TrainingArguments( output_dir=str(output_dir), num_train_epochs=cfg.heal_epochs, per_device_train_batch_size=cfg.heal_batch_size, gradient_accumulation_steps=cfg.heal_grad_accum, learning_rate=cfg.heal_learning_rate, bf16=True, logging_steps=10, save_strategy="steps", save_steps=50, save_total_limit=2, max_steps=50, # Don't save intermediate checkpoints — saves ~17GB disk warmup_ratio=0.05, lr_scheduler_type="cosine", optim="adamw_torch", report_to="none", ) from transformers import Trainer trainer = Trainer( model=model, processing_class=tokenizer, train_dataset=dataset, args=training_args, ) print("\n[heal] Starting standard QLoRA healing fine-tune...") trainer.train() # Free disk space: delete training checkpoints (epoch saves) before saving final model # These are ~17GB and we need room for the healed model import shutil, gc heal_output_dir = Path(cfg.output_dir) / "heal_output" if heal_output_dir.exists(): print(f"[heal] Cleaning up training checkpoints to free disk space...") shutil.rmtree(str(heal_output_dir), ignore_errors=True) print(f"[heal] Freed ~17GB from {heal_output_dir}") # Save — merge LoRA adapters healed_dir = Path(cfg.output_dir) / "healed" healed_dir.mkdir(parents=True, exist_ok=True) print(f"\n[heal] Merging LoRA adapters...") merged_model = model.merge_and_unload() gc.collect() # bf16 model — save_pretrained works correctly, no dequantize needed merged_model.save_pretrained(str(healed_dir), safe_serialization=True) tokenizer.save_pretrained(str(healed_dir)) print(f"[heal] SAVED OK: {healed_dir}") # Verify the save actually worked before cleaning up ANYTHING saved_model = healed_dir / "model.safetensors" if not saved_model.exists() or saved_model.stat().st_size < 1_000_000: print(f"[heal] WARNING: Save may have failed — NOT deleting any backups!") else: save_size = saved_model.stat().st_size / 1e9 print(f"[heal] Verified: {saved_model} ({save_size:.1f} GB)") # NOW safe to clean up old stuff cleanup_targets = [ "td_fuse_outputs/final", ] for target in cleanup_targets: target_path = Path(target) if target_path.exists() and target_path.is_dir(): shutil.rmtree(str(target_path)) print(f"[heal] Freed space: removed {target_path}") gc.collect() print(f"[heal] Healed model saved to {healed_dir}") return str(healed_dir) def heal_model( model_path: str, cfg: MergeConfig = None, healing_data: list = None, ) -> str: """ Main entry point for healing. Tries Unsloth first, falls back to PEFT. Args: model_path: Path to the merged model checkpoint cfg: Merge configuration healing_data: Optional pre-loaded training data Returns: Path to healed model directory """ if cfg is None: cfg = MergeConfig() # Skip healing if already done (saves ~45 min on re-runs) import os healed_check = os.path.join('td_fuse_outputs', 'healed', 'model.safetensors') if os.path.exists(healed_check): print('[heal] Found existing healed model — SKIPPING healing!') return 'td_fuse_outputs/healed' heal_start = time.time() print("\n" + "=" * 60) print("HEALING FINE-TUNE") print(f"Model: {model_path}") print(f"LoRA r={cfg.heal_lora_r}, alpha={cfg.heal_lora_alpha}") print(f"Epochs: {cfg.heal_epochs}, LR: {cfg.heal_learning_rate}") print(f"Started at: {time.strftime('%H:%M:%S')}") print("=" * 60) sys.stdout.flush() if check_unsloth_available(): result = apply_qlora_unsloth(model_path, cfg, healing_data) else: result = apply_qlora_standard(model_path, cfg, healing_data) print(f"[heal] Total healing time: {(time.time()-heal_start)/60:.1f} min") sys.stdout.flush() return result