td-toolkit / td_fuse /heal.py
td-builder's picture
Current td_fuse code with all fixes
bc446a5 verified
"""
QLoRA Healing Fine-Tune — repairs damage from merging.
After each merge (or after all merges), the model may have rough edges.
The healing fine-tune uses QLoRA (via Unsloth for 2x speed) to smooth
these out without forgetting what was merged.
Think of it like physical therapy after surgery — the operation (merge)
moved knowledge over, but the model needs practice to use it naturally.
Config notes:
- r=32, alpha=64, dropout=0.0 (must be 0 for Unsloth speed)
- transformers >= 4.51.3 (NOT 4.51.0, NOT 4.52.0-4.55.1)
- bfloat16 end-to-end
- DDP across dual 4090
Findings: #12, #16, #20
"""
import os
import sys
import time
import torch
from pathlib import Path
from typing import Optional
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from datasets import load_dataset
from .config import MergeConfig
def _load_model_smart(checkpoint, **kwargs):
"""Load model — auto-detects Qwen3-VL and uses the correct class."""
from transformers import AutoConfig
try:
config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
model_type = getattr(config, 'model_type', '')
config_class = type(config).__name__.lower()
if 'qwen3_vl' in model_type or 'qwen3vl' in config_class:
from transformers import Qwen3VLForConditionalGeneration
print(f'[heal] Loading as Qwen3-VL model: {checkpoint}')
return Qwen3VLForConditionalGeneration.from_pretrained(checkpoint, **kwargs)
except Exception as e:
print(f'[heal] Auto-detect failed ({e}), using AutoModelForCausalLM')
return AutoModelForCausalLM.from_pretrained(checkpoint, **kwargs)
def check_unsloth_available() -> bool:
"""Check if Unsloth is installed and working."""
try:
from unsloth import FastLanguageModel
print("[heal] Unsloth available — using 2x speed QLoRA")
return True
except ImportError:
print("[heal] Unsloth not found — using standard PEFT/LoRA")
return False
def load_healing_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list:
"""
Load data for healing fine-tune.
Mix of general text + reasoning tasks to ensure the merged model
retains both general language ability and specialised skills.
"""
print("[heal] Loading healing fine-tune data...")
# Merge-specific: use diverse data that exercises all merged capabilities
# Each entry: (dataset_id, config_name_or_None, split, count, text_field)
datasets_to_load = [
# General language — same calibration data source that works reliably
("neuralmagic/LLM_compression_calibration", None, "train", 1500, "text"),
# Math reasoning (exercises DeepSeek/MiMo contributions)
("openai/gsm8k", "main", "train", 1000, "question"),
# Code — bigcode/starcoderdata is a modern alternative
("sahil2801/CodeAlpaca-20k", None, "train", 500, "output"),
]
all_texts = []
for entry in datasets_to_load:
dataset_id, config_name, split, count, text_field = entry
try:
if config_name:
ds = load_dataset(dataset_id, config_name, split=split, streaming=True)
else:
ds = load_dataset(dataset_id, split=split, streaming=True)
loaded = 0
for example in ds:
if loaded >= count:
break
text = example.get(text_field, "")
if len(str(text)) > 50:
all_texts.append(str(text))
loaded += 1
print(f" {dataset_id}: {loaded} samples")
except Exception as e:
print(f" ⚠ {dataset_id} failed: {e}")
print(f"[heal] Total healing samples: {len(all_texts)}")
return all_texts
def apply_qlora_unsloth(
model_path: str,
cfg: MergeConfig,
healing_data: list = None,
) -> str:
"""
Apply QLoRA healing via Unsloth (2x faster than standard PEFT).
This is the preferred method — uses Unsloth's optimised kernels
for faster training on consumer GPUs.
Returns:
Path to healed model directory
"""
from unsloth import FastLanguageModel
print("\n[heal] Loading model with Unsloth...")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_path,
dtype=getattr(torch, cfg.dtype),
max_seq_length=cfg.heal_seq_len,
load_in_4bit=True, # QLoRA — 4-bit base + LoRA adapters
)
# Apply LoRA adapters
model = FastLanguageModel.get_peft_model(
model,
r=cfg.heal_lora_r, # 32 — higher rank for healing
lora_alpha=cfg.heal_lora_alpha, # 64 — 2x rank
lora_dropout=cfg.heal_lora_dropout, # 0.0 — MUST be 0 for Unsloth speed
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
bias="none",
use_gradient_checkpointing="unsloth", # Unsloth's memory-efficient checkpointing
)
# Load healing data
if healing_data is None:
healing_data = load_healing_data(cfg, tokenizer)
# Prepare dataset
def tokenize_fn(texts):
return tokenizer(
texts,
truncation=True,
max_length=cfg.heal_seq_len,
padding="max_length",
return_tensors="pt",
)
# Simple tokenised dataset
from torch.utils.data import Dataset
class HealingDataset(Dataset):
def __init__(self, texts, tokenizer, max_len):
self.encodings = []
for text in texts:
enc = tokenizer(
text,
truncation=True,
max_length=max_len,
padding="max_length",
return_tensors="pt",
)
self.encodings.append({
"input_ids": enc["input_ids"].squeeze(),
"attention_mask": enc["attention_mask"].squeeze(),
"labels": enc["input_ids"].squeeze(),
})
def __len__(self):
return len(self.encodings)
def __getitem__(self, idx):
return self.encodings[idx]
dataset = HealingDataset(healing_data, tokenizer, cfg.heal_seq_len)
# Training arguments
output_dir = Path(cfg.output_dir) / "heal_output"
output_dir.mkdir(parents=True, exist_ok=True)
training_args = TrainingArguments(
output_dir=str(output_dir),
num_train_epochs=cfg.heal_epochs,
per_device_train_batch_size=cfg.heal_batch_size,
gradient_accumulation_steps=cfg.heal_grad_accum,
learning_rate=cfg.heal_learning_rate,
bf16=True,
logging_steps=10,
save_strategy="steps",
save_steps=50,
save_total_limit=2, max_steps=50, # Don't save intermediate checkpoints — saves ~17GB disk
warmup_ratio=0.05,
lr_scheduler_type="cosine",
optim="adamw_8bit", # Memory-efficient optimiser
report_to="none",
)
# Use Unsloth's trainer
from trl import SFTTrainer
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
train_dataset=dataset,
args=training_args,
max_seq_length=cfg.heal_seq_len,
)
print("\n[heal] Starting QLoRA healing fine-tune...")
trainer.train()
# Save healed model (merge LoRA back into base)
healed_dir = Path(cfg.output_dir) / "healed"
healed_dir.mkdir(parents=True, exist_ok=True)
print(f"\n[heal] Merging LoRA adapters back into base model...")
model.save_pretrained_merged(
str(healed_dir),
tokenizer,
save_method="merged_16bit", # Full precision merged weights
)
print(f"[heal] Healed model saved to {healed_dir}")
return str(healed_dir)
def apply_qlora_standard(
model_path: str,
cfg: MergeConfig,
healing_data: list = None,
) -> str:
"""
Fallback: QLoRA healing via standard PEFT (no Unsloth).
Slower but works without Unsloth installed.
Returns:
Path to healed model directory
"""
import os
healed_check = os.path.join('td_fuse_outputs', 'healed', 'model.safetensors')
if os.path.exists(healed_check):
print('[heal] Found existing healed model — SKIPPING healing!')
return 'td_fuse_outputs/healed'
import torch
from peft import LoraConfig, get_peft_model, TaskType
from transformers import AutoModelForCausalLM, AutoTokenizer
print("\n[heal] Loading model with standard PEFT...")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = _load_model_smart(
model_path,
device_map="auto",
torch_dtype=torch.bfloat16,
)
# LoRA config
lora_config = LoraConfig(
r=cfg.heal_lora_r,
lora_alpha=cfg.heal_lora_alpha,
lora_dropout=cfg.heal_lora_dropout,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
bias="none",
task_type=TaskType.CAUSAL_LM,
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Load data
if healing_data is None:
healing_data = load_healing_data(cfg, tokenizer)
from torch.utils.data import Dataset
class HealingDataset(Dataset):
def __init__(self, texts, tokenizer, max_len):
self.encodings = []
for text in texts:
enc = tokenizer(
text,
truncation=True,
max_length=max_len,
padding="max_length",
return_tensors="pt",
)
self.encodings.append({
"input_ids": enc["input_ids"].squeeze(),
"attention_mask": enc["attention_mask"].squeeze(),
"labels": enc["input_ids"].squeeze(),
})
def __len__(self):
return len(self.encodings)
def __getitem__(self, idx):
return self.encodings[idx]
dataset = HealingDataset(healing_data, tokenizer, cfg.heal_seq_len)
# Training
output_dir = Path(cfg.output_dir) / "heal_output"
output_dir.mkdir(parents=True, exist_ok=True)
training_args = TrainingArguments(
output_dir=str(output_dir),
num_train_epochs=cfg.heal_epochs,
per_device_train_batch_size=cfg.heal_batch_size,
gradient_accumulation_steps=cfg.heal_grad_accum,
learning_rate=cfg.heal_learning_rate,
bf16=True,
logging_steps=10,
save_strategy="steps",
save_steps=50,
save_total_limit=2, max_steps=50, # Don't save intermediate checkpoints — saves ~17GB disk
warmup_ratio=0.05,
lr_scheduler_type="cosine",
optim="adamw_torch",
report_to="none",
)
from transformers import Trainer
trainer = Trainer(
model=model,
processing_class=tokenizer,
train_dataset=dataset,
args=training_args,
)
print("\n[heal] Starting standard QLoRA healing fine-tune...")
trainer.train()
# Free disk space: delete training checkpoints (epoch saves) before saving final model
# These are ~17GB and we need room for the healed model
import shutil, gc
heal_output_dir = Path(cfg.output_dir) / "heal_output"
if heal_output_dir.exists():
print(f"[heal] Cleaning up training checkpoints to free disk space...")
shutil.rmtree(str(heal_output_dir), ignore_errors=True)
print(f"[heal] Freed ~17GB from {heal_output_dir}")
# Save — merge LoRA adapters
healed_dir = Path(cfg.output_dir) / "healed"
healed_dir.mkdir(parents=True, exist_ok=True)
print(f"\n[heal] Merging LoRA adapters...")
merged_model = model.merge_and_unload()
gc.collect()
# bf16 model — save_pretrained works correctly, no dequantize needed
merged_model.save_pretrained(str(healed_dir), safe_serialization=True)
tokenizer.save_pretrained(str(healed_dir))
print(f"[heal] SAVED OK: {healed_dir}")
# Verify the save actually worked before cleaning up ANYTHING
saved_model = healed_dir / "model.safetensors"
if not saved_model.exists() or saved_model.stat().st_size < 1_000_000:
print(f"[heal] WARNING: Save may have failed — NOT deleting any backups!")
else:
save_size = saved_model.stat().st_size / 1e9
print(f"[heal] Verified: {saved_model} ({save_size:.1f} GB)")
# NOW safe to clean up old stuff
cleanup_targets = [
"td_fuse_outputs/final",
]
for target in cleanup_targets:
target_path = Path(target)
if target_path.exists() and target_path.is_dir():
shutil.rmtree(str(target_path))
print(f"[heal] Freed space: removed {target_path}")
gc.collect()
print(f"[heal] Healed model saved to {healed_dir}")
return str(healed_dir)
def heal_model(
model_path: str,
cfg: MergeConfig = None,
healing_data: list = None,
) -> str:
"""
Main entry point for healing. Tries Unsloth first, falls back to PEFT.
Args:
model_path: Path to the merged model checkpoint
cfg: Merge configuration
healing_data: Optional pre-loaded training data
Returns:
Path to healed model directory
"""
if cfg is None:
cfg = MergeConfig()
# Skip healing if already done (saves ~45 min on re-runs)
import os
healed_check = os.path.join('td_fuse_outputs', 'healed', 'model.safetensors')
if os.path.exists(healed_check):
print('[heal] Found existing healed model — SKIPPING healing!')
return 'td_fuse_outputs/healed'
heal_start = time.time()
print("\n" + "=" * 60)
print("HEALING FINE-TUNE")
print(f"Model: {model_path}")
print(f"LoRA r={cfg.heal_lora_r}, alpha={cfg.heal_lora_alpha}")
print(f"Epochs: {cfg.heal_epochs}, LR: {cfg.heal_learning_rate}")
print(f"Started at: {time.strftime('%H:%M:%S')}")
print("=" * 60)
sys.stdout.flush()
if check_unsloth_available():
result = apply_qlora_unsloth(model_path, cfg, healing_data)
else:
result = apply_qlora_standard(model_path, cfg, healing_data)
print(f"[heal] Total healing time: {(time.time()-heal_start)/60:.1f} min")
sys.stdout.flush()
return result