model_tools / Audits /audit_della.py
Naphula's picture
Upload 12 files
8c4f85f verified
import yaml
import torch
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from safetensors import safe_open
from sklearn.decomposition import PCA
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
import argparse
# --- CONFIGURATION ---
PROBE_LAYERS = [
"model.layers.12.mlp.down_proj.weight", # Mid-model logic
"lm_head.weight" # Output semantics
]
LOG_FILENAME = "della_scan.log"
# ---------------------
class Logger:
def __init__(self, filename):
self.terminal = sys.stdout
self.log = open(filename, "w", encoding="utf-8")
def write(self, message):
self.terminal.write(message)
self.log.write(message)
self.log.flush()
def flush(self):
self.terminal.flush()
self.log.flush()
def close(self):
self.log.close()
def load_yaml_config(config_path):
print(f"Loading config: {config_path}")
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
models = []
base_model = None
# Extract base model
if 'base_model' in config:
base_model = config['base_model']
# Extract models list
if 'models' in config:
for m in config['models']:
models.append(m['model'])
return base_model, models
def get_model_fingerprint(model_path, probe_layers):
tensors = []
if os.path.exists(model_path):
files = [f for f in os.listdir(model_path) if f.endswith('.safetensors')]
files.sort()
found_layers = 0
for file in files:
full_path = os.path.join(model_path, file)
try:
with safe_open(full_path, framework="pt", device="cpu") as f:
keys = f.keys()
for layer in probe_layers:
if layer in keys:
t = f.get_tensor(layer).float().view(-1)
t = t[::10] # Downsample
tensors.append(t)
found_layers += 1
except Exception as e:
print(f"Error reading {file}: {e}")
if found_layers == 0:
return None
else:
return None
if not tensors:
return None
return torch.cat(tensors)
def analyze_task_vectors(base_fp, donor_fps):
# 0. Handle size mismatches (Manifold Alignment)
base_size = base_fp.numel()
donor_sizes = [f.numel() for f in donor_fps]
min_size = min([base_size] + donor_sizes)
if any(s != min_size for s in donor_sizes) or base_size != min_size:
print(f"\n[!] SIZE MISMATCH DETECTED")
print(f" Base Size: {base_size}")
print(f" Min Donor: {min(donor_sizes)}")
print(f" Action: Truncating all models to {min_size} for audit.")
# Align fingerprints
aligned_base = base_fp[:min_size]
aligned_donors = [f[:min_size] for f in donor_fps]
# 1. Calculate Task Vectors (Delta = Donor - Base)
task_vectors = []
for d_fp in aligned_donors:
task_vectors.append(d_fp - aligned_base)
# Stack into matrix [N_donors, N_features]
data_matrix = torch.stack(task_vectors).numpy()
# 2. Norm Analysis (Magnitude of the Delta)
norms = np.linalg.norm(data_matrix, axis=1)
# 3. Cosine Similarity Matrix (Directional Alignment)
cos_sim = cosine_similarity(data_matrix)
# 4. PCA Projection (2D)
# Center the task vectors
centered_data = data_matrix - np.mean(data_matrix, axis=0)
if len(donor_fps) > 1:
pca = PCA(n_components=2)
coords = pca.fit_transform(centered_data)
var_ratio = pca.explained_variance_ratio_
else:
coords = np.zeros((1, 2))
var_ratio = [1.0, 0.0]
return norms, cos_sim, coords, var_ratio, donor_sizes
def plot_results(model_ids, norms, cos_sim, coords, var_ratio):
labels = [str(mid) for mid in model_ids]
fig = plt.figure(figsize=(20, 12))
fig.suptitle(f"DELLA/Task Arithmetic Compatibility Audit ({len(model_ids)} Donors)\nRefer to della_scan.log for ID Key", fontsize=16)
# --- Plot 1: Task Vector Manifold (PCA) ---
ax1 = fig.add_subplot(2, 2, 1)
ax1.scatter(coords[:, 0], coords[:, 1], c='purple', s=80, alpha=0.6)
for i, txt in enumerate(labels):
ax1.annotate(txt, (coords[i, 0], coords[i, 1]), xytext=(3, 3), textcoords='offset points', fontsize=8, fontweight='bold')
ax1.set_title(f"Task Vector Map (PCA of Deltas)\nClusters = Redundant Skills")
ax1.set_xlabel(f"PC1 ({var_ratio[0]:.1%} variance)")
ax1.set_ylabel(f"PC2 ({var_ratio[1]:.1%} variance)")
ax1.grid(True, alpha=0.3)
# Plot Origin (Base Model reference relative to centered data)
center_offset = -np.mean(coords, axis=0)
ax1.scatter(center_offset[0], center_offset[1], c='red', marker='x', s=100, label='Base Model (Ref)')
ax1.legend()
# --- Plot 2: Cosine Similarity Heatmap ---
ax2 = fig.add_subplot(2, 2, 2)
# For Task Vectors, negative similarity is common (conflicting directions)
im = ax2.imshow(cos_sim, cmap='coolwarm', vmin=-1.0, vmax=1.0)
ax2.set_xticks(np.arange(len(labels)))
ax2.set_yticks(np.arange(len(labels)))
ax2.set_xticklabels(labels, rotation=90, fontsize=6)
ax2.set_yticklabels(labels, fontsize=6)
ax2.set_title("Task Vector Alignment (Blue=Opposed, Red=Aligned)")
plt.colorbar(im, ax=ax2)
# --- Plot 3: Delta Magnitude (L2 Norm) ---
ax3 = fig.add_subplot(2, 1, 2)
bars = ax3.bar(labels, norms, color='orange', alpha=0.6)
ax3.set_title("Task Vector Magnitude (L2 Norm)\nHigh bars = Drastic deviation from Base Model")
ax3.set_ylabel("Delta L2 Norm")
ax3.set_xlabel("Donor ID")
ax3.grid(axis='y', alpha=0.3)
for bar in bars:
height = bar.get_height()
ax3.text(bar.get_x() + bar.get_width()/2., height,
f'{height:.1f}', ha='center', va='bottom', fontsize=6, rotation=90)
plt.tight_layout()
plt.show()
def main():
# Hook stdout to log file
sys.stdout = Logger(LOG_FILENAME)
parser = argparse.ArgumentParser(description="Audit MergeKit models for DELLA/Task Arithmetic compatibility.")
parser.add_argument("config", help="Path to the mergekit yaml config file")
args = parser.parse_args()
print(f"--- DELLA AUDIT V2 START ---")
base_model_path, donor_paths = load_yaml_config(args.config)
if not base_model_path:
print("Error: No 'base_model' found in config. DELLA requires a base model.")
return
print(f"Base Model: {base_model_path}")
print(f"Donors: {len(donor_paths)}")
print("\nExtracting BASE MODEL fingerprint...")
base_fp = get_model_fingerprint(base_model_path, PROBE_LAYERS)
if base_fp is None:
print("Failed to load base model. Exiting.")
return
donor_fps = []
valid_donors = []
valid_ids = []
print("\nExtracting DONOR fingerprints...")
for i, path in enumerate(tqdm(donor_paths)):
fp = get_model_fingerprint(path, PROBE_LAYERS)
if fp is not None:
donor_fps.append(fp)
valid_donors.append(path)
valid_ids.append(i + 1)
else:
print(f"Skipping {path} (failed to load)")
if len(valid_donors) < 1:
print("Need at least 1 valid donor.")
return
print("\nComputing Task Vector geometry...")
norms, cos_sim, coords, var_ratio, sizes = analyze_task_vectors(base_fp, donor_fps)
# --- LOGGING THE KEY ---
print("\n" + "="*80)
print(f"{'ID':<5} | {'Model Name'}")
print("-" * 80)
for i, path in enumerate(valid_donors):
name = os.path.basename(path).replace("!models--", "")
print(f"#{valid_ids[i]:<4} | {name}")
print("="*80 + "\n")
# --- MAGNITUDE ANALYSIS ---
print("--- MAGNITUDE ANALYSIS & DATA POINTS ---")
print(f"{'ID':<5} | {'Status':<10} | {'Delta Norm':<12} | {'Orig Size':<12} | {'Model Name'}")
print("-" * 100)
mean_norm = np.mean(norms)
std_norm = np.std(norms)
for i, model in enumerate(valid_donors):
name = os.path.basename(model).replace("!models--", "")
# Check if norm is significantly higher than average (potential destroyer of weights)
z_score = (norms[i] - mean_norm) / (std_norm + 1e-8)
status = "HIGH MAG" if z_score > 1.5 else "OK"
print(f"#{valid_ids[i]:<4} | {status:<10} | {norms[i]:<12.4f} | {sizes[i]:<12} | {name}")
print("\nLog saved to: " + LOG_FILENAME)
print("Displaying charts...")
# Reset stdout
sys.stdout.terminal.flush()
plot_results(valid_ids, norms, cos_sim, coords, var_ratio)
# Close log
sys.stdout.close()
if __name__ == "__main__":
main()