model_tools / donor_audit_v3.py
Naphula's picture
Upload 8 files
5f463e1 verified
# Copyright (C) 2025 Arcee AI & Kraken Architect
# SPDX-License-Identifier: BUSL-1.1
import logging
import os
import sys
from typing import List, Optional
import click
import torch
import yaml
from tqdm import tqdm
from mergekit.common import ModelReference
from mergekit.config import MergeConfiguration
from mergekit.io.lazy_tensor_loader import LazyTensorLoader, ShardedTensorIndex
from mergekit.merge_methods.easy_define import merge_method
logging.basicConfig(level=logging.INFO, format="%(message)s")
LOG = logging.getLogger("donor_audit")
@merge_method(
name="donor_audit",
pretty_name="Donor Audit",
reference_url="https://arxiv.org/abs/2408.07990",
)
def _donor_audit_registration(tensors: List[torch.Tensor]) -> torch.Tensor:
"""Placeholder to register the method name."""
return tensors[0]
def rsce_weight(tvs: torch.Tensor) -> torch.Tensor:
"""
Calculates matrix-level weights based on the energy of the task vectors.
(Copied from RSCE v3)
"""
# Mean square energy
weights = torch.mean(tvs**2, dim=list(range(1, tvs.dim())))
weight_sum = torch.sum(weights).item()
if abs(weight_sum) < 1e-8:
return torch.ones_like(weights) / weights.shape[0]
return weights / weight_sum
def log_rsce_audit(layer_name: str, weights: torch.Tensor, names: List[str]):
"""Prints and saves a bar chart of donor influence."""
w_list = weights.tolist()
bar_char = "█"
# Header
print(f"\n{'='*60}")
print(f"RSCE DONOR AUDIT REPORT")
print(f"Target Tensor: {layer_name}")
print(f"{'='*60}")
lines = []
for name, w in zip(names, w_list):
pct = w * 100
# Scale bar: 50 chars = 100% influence (which is huge/impossible usually)
# Let's scale it so the max value fills the bar for better visibility
max_val = max(w_list) if max(w_list) > 0 else 1.0
# Relative bar length (relative to the loudest model)
bar_len = int((w / max_val) * 40)
bar = bar_char * bar_len
# Truncate name for clean display
clean_name = os.path.basename(name)
if len(clean_name) > 30:
clean_name = clean_name[:27] + "..."
lines.append(f"{clean_name:<30} | {bar:<40} | {pct:6.2f}% (Raw: {w:.4f})")
log_entry = "\n".join(lines)
print(log_entry)
print(f"{'='*60}\n")
# Append to file
with open("rsce_audit.log", "a", encoding="utf-8") as f:
f.write(f"\n[Audit {layer_name}]\n" + log_entry + "\n")
def find_layer0_tensor(loader: LazyTensorLoader) -> str:
"""
Scans a model loader to find a suitable Layer 0 tensor for auditing.
Prioritizes self_attn projections as they are usually dense and representative.
"""
candidates = []
for key in loader.index.tensor_paths.keys():
# Look for Layer 0
if ".layers.0." in key or ".h.0." in key or ".blocks.0." in key:
# Look for weights (not bias)
if key.endswith(".weight"):
candidates.append(key)
# Priority sort: q_proj > gate_proj > dense > others
for c in candidates:
if "down_proj" in c: return c
for c in candidates:
if "gate_proj" in c: return c
for c in candidates:
if "c_attn" in c: return c # GPT-NeoX / Qwen
if not candidates:
raise RuntimeError("Could not find any Layer 0 weights in the base model.")
return candidates[0]
def load_tensor_safe(model_path: str, tensor_name: str, device="cpu") -> torch.Tensor:
"""Loads a single tensor from a model path."""
try:
# We use ShardedTensorIndex directly to avoid caching overhead of LoaderCache for this simple script
if os.path.isfile(model_path):
index = ShardedTensorIndex.from_file(model_path)
else:
index = ShardedTensorIndex.from_disk(model_path)
loader = LazyTensorLoader(index, lazy_unpickle=True)
# Handle potential naming mismatches (simple check)
if tensor_name not in index.tensor_paths:
# Try to find a fuzzy match if exact name fails (e.g. if models have slightly different archs)
# This is a basic fallback
suffix = tensor_name.split("layers.0.")[-1]
for k in index.tensor_paths.keys():
if k.endswith(suffix) and ("layers.0." in k or "h.0." in k):
tensor_name = k
break
t = loader.get_tensor(tensor_name, device=device)
return t.float() # Convert to float32 for math
except Exception as e:
LOG.error(f"Failed to load {tensor_name} from {model_path}: {e}")
sys.exit(1)
@click.command()
@click.argument("config_file", type=click.Path(exists=True))
@click.option("--lora-merge-cache", default=None, help="Cache directory for merged LoRAs")
@click.option("--cuda/--no-cuda", default=False, help="Use GPU for calculation (faster math, higher VRAM)")
def main(config_file, lora_merge_cache, cuda):
"""
RSCE Donor Audit Tool V3.
Loads Layer 0 from all models in the config and calculates their
Task Vector magnitude/energy contribution relative to the base model.
"""
device = "cuda" if cuda and torch.cuda.is_available() else "cpu"
LOG.info(f"Running audit on {device}...")
# 1. Parse Config
with open(config_file, "r", encoding="utf-8") as f:
config_data = yaml.safe_load(f)
config = MergeConfiguration.model_validate(config_data)
# 2. Identify Models
base_model_ref = config.base_model
if not base_model_ref:
LOG.error("Config must specify a `base_model` for RSCE auditing.")
sys.exit(1)
# Extract donor models from slices or models list
donor_refs = []
if config.models:
donor_refs = [m.model for m in config.models]
elif config.slices:
# Flatten slices to get unique models
seen = set()
for s in config.slices:
for source in s.sources:
if source.model != base_model_ref and source.model not in seen:
donor_refs.append(source.model)
seen.add(source.model)
# Filter out base model if it appeared in donors
donor_refs = [d for d in donor_refs if d != base_model_ref]
LOG.info(f"Base Model: {base_model_ref.model.path}")
LOG.info(f"Found {len(donor_refs)} donor models.")
# 3. Resolve Paths (Handle LoRAs if necessary)
def resolve_path(ref: ModelReference):
if ref.lora:
if not lora_merge_cache:
LOG.warning("LoRA detected but --lora-merge-cache not set. This might fail.")
return ref.merged(cache_dir=lora_merge_cache).model.path
if not os.path.exists(ref.model.path):
try:
from huggingface_hub import snapshot_download
return snapshot_download(ref.model.path, allow_patterns=["*.safetensors", "*.bin", "*.json"])
except:
return ref.model.path
return ref.model.path
base_path = resolve_path(base_model_ref)
donor_paths = [resolve_path(d) for d in donor_refs]
# 4. Identify Target Tensor (Layer 0)
base_index = ShardedTensorIndex.from_disk(base_path)
base_loader = LazyTensorLoader(base_index, lazy_unpickle=True)
target_tensor_name = find_layer0_tensor(base_loader)
LOG.info(f"Selected audit tensor: {target_tensor_name}")
LOG.info("Loading tensors into memory...")
# 5. Load All Tensors
base_tensor = load_tensor_safe(base_path, target_tensor_name, device)
donor_tensors = []
valid_donor_refs = []
for d_path, d_ref in zip(tqdm(donor_paths, desc="Loading Donors"), donor_refs):
dt = load_tensor_safe(d_path, target_tensor_name, device)
# V3: Catch shape mismatches (e.g. a 7B model mixed into a 12B merge)
if dt.shape != base_tensor.shape:
LOG.warning(f"\n[!] Shape mismatch for {d_ref.model.path}: expected {base_tensor.shape}, got {dt.shape}. Skipping this model.")
continue
donor_tensors.append(dt)
valid_donor_refs.append(d_ref)
if not donor_tensors:
LOG.error("No valid donor tensors found with matching shapes. Exiting.")
sys.exit(1)
# 6. Perform RSCE Audit Math
LOG.info("Calculating Task Vector Energy...")
base_tv = torch.zeros_like(base_tensor)
donor_tvs = [dt - base_tensor for dt in donor_tensors]
all_tvs = torch.stack([base_tv] + donor_tvs, dim=0)
raw_weights = rsce_weight(all_tvs)
display_names = ["Base Model (Anchor)"] + [d.model.path for d in valid_donor_refs]
# 7. Output
log_rsce_audit(target_tensor_name, raw_weights, display_names)
LOG.info("Audit complete.")
if __name__ == "__main__":
main()