import yaml import torch import os import sys import numpy as np import matplotlib.pyplot as plt from safetensors import safe_open from sklearn.decomposition import PCA from sklearn.metrics.pairwise import cosine_similarity from tqdm import tqdm import argparse # --- CONFIGURATION --- PROBE_LAYERS = [ "model.layers.12.mlp.down_proj.weight", # Mid-model logic "lm_head.weight" # Output semantics ] LOG_FILENAME = "della_scan.log" # --------------------- class Logger: def __init__(self, filename): self.terminal = sys.stdout self.log = open(filename, "w", encoding="utf-8") def write(self, message): self.terminal.write(message) self.log.write(message) self.log.flush() def flush(self): self.terminal.flush() self.log.flush() def close(self): self.log.close() def load_yaml_config(config_path): print(f"Loading config: {config_path}") with open(config_path, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) models = [] base_model = None # Extract base model if 'base_model' in config: base_model = config['base_model'] # Extract models list if 'models' in config: for m in config['models']: models.append(m['model']) return base_model, models def get_model_fingerprint(model_path, probe_layers): tensors = [] if os.path.exists(model_path): files = [f for f in os.listdir(model_path) if f.endswith('.safetensors')] files.sort() found_layers = 0 for file in files: full_path = os.path.join(model_path, file) try: with safe_open(full_path, framework="pt", device="cpu") as f: keys = f.keys() for layer in probe_layers: if layer in keys: t = f.get_tensor(layer).float().view(-1) t = t[::10] # Downsample tensors.append(t) found_layers += 1 except Exception as e: print(f"Error reading {file}: {e}") if found_layers == 0: return None else: return None if not tensors: return None return torch.cat(tensors) def analyze_task_vectors(base_fp, donor_fps): # 0. Handle size mismatches (Manifold Alignment) base_size = base_fp.numel() donor_sizes = [f.numel() for f in donor_fps] min_size = min([base_size] + donor_sizes) if any(s != min_size for s in donor_sizes) or base_size != min_size: print(f"\n[!] SIZE MISMATCH DETECTED") print(f" Base Size: {base_size}") print(f" Min Donor: {min(donor_sizes)}") print(f" Action: Truncating all models to {min_size} for audit.") # Align fingerprints aligned_base = base_fp[:min_size] aligned_donors = [f[:min_size] for f in donor_fps] # 1. Calculate Task Vectors (Delta = Donor - Base) task_vectors = [] for d_fp in aligned_donors: task_vectors.append(d_fp - aligned_base) # Stack into matrix [N_donors, N_features] data_matrix = torch.stack(task_vectors).numpy() # 2. Norm Analysis (Magnitude of the Delta) norms = np.linalg.norm(data_matrix, axis=1) # 3. Cosine Similarity Matrix (Directional Alignment) cos_sim = cosine_similarity(data_matrix) # 4. PCA Projection (2D) # Center the task vectors centered_data = data_matrix - np.mean(data_matrix, axis=0) if len(donor_fps) > 1: pca = PCA(n_components=2) coords = pca.fit_transform(centered_data) var_ratio = pca.explained_variance_ratio_ else: coords = np.zeros((1, 2)) var_ratio = [1.0, 0.0] return norms, cos_sim, coords, var_ratio, donor_sizes def plot_results(model_ids, norms, cos_sim, coords, var_ratio): labels = [str(mid) for mid in model_ids] fig = plt.figure(figsize=(20, 12)) fig.suptitle(f"DELLA/Task Arithmetic Compatibility Audit ({len(model_ids)} Donors)\nRefer to della_scan.log for ID Key", fontsize=16) # --- Plot 1: Task Vector Manifold (PCA) --- ax1 = fig.add_subplot(2, 2, 1) ax1.scatter(coords[:, 0], coords[:, 1], c='purple', s=80, alpha=0.6) for i, txt in enumerate(labels): ax1.annotate(txt, (coords[i, 0], coords[i, 1]), xytext=(3, 3), textcoords='offset points', fontsize=8, fontweight='bold') ax1.set_title(f"Task Vector Map (PCA of Deltas)\nClusters = Redundant Skills") ax1.set_xlabel(f"PC1 ({var_ratio[0]:.1%} variance)") ax1.set_ylabel(f"PC2 ({var_ratio[1]:.1%} variance)") ax1.grid(True, alpha=0.3) # Plot Origin (Base Model reference relative to centered data) center_offset = -np.mean(coords, axis=0) ax1.scatter(center_offset[0], center_offset[1], c='red', marker='x', s=100, label='Base Model (Ref)') ax1.legend() # --- Plot 2: Cosine Similarity Heatmap --- ax2 = fig.add_subplot(2, 2, 2) # For Task Vectors, negative similarity is common (conflicting directions) im = ax2.imshow(cos_sim, cmap='coolwarm', vmin=-1.0, vmax=1.0) ax2.set_xticks(np.arange(len(labels))) ax2.set_yticks(np.arange(len(labels))) ax2.set_xticklabels(labels, rotation=90, fontsize=6) ax2.set_yticklabels(labels, fontsize=6) ax2.set_title("Task Vector Alignment (Blue=Opposed, Red=Aligned)") plt.colorbar(im, ax=ax2) # --- Plot 3: Delta Magnitude (L2 Norm) --- ax3 = fig.add_subplot(2, 1, 2) bars = ax3.bar(labels, norms, color='orange', alpha=0.6) ax3.set_title("Task Vector Magnitude (L2 Norm)\nHigh bars = Drastic deviation from Base Model") ax3.set_ylabel("Delta L2 Norm") ax3.set_xlabel("Donor ID") ax3.grid(axis='y', alpha=0.3) for bar in bars: height = bar.get_height() ax3.text(bar.get_x() + bar.get_width()/2., height, f'{height:.1f}', ha='center', va='bottom', fontsize=6, rotation=90) plt.tight_layout() plt.show() def main(): # Hook stdout to log file sys.stdout = Logger(LOG_FILENAME) parser = argparse.ArgumentParser(description="Audit MergeKit models for DELLA/Task Arithmetic compatibility.") parser.add_argument("config", help="Path to the mergekit yaml config file") args = parser.parse_args() print(f"--- DELLA AUDIT V2 START ---") base_model_path, donor_paths = load_yaml_config(args.config) if not base_model_path: print("Error: No 'base_model' found in config. DELLA requires a base model.") return print(f"Base Model: {base_model_path}") print(f"Donors: {len(donor_paths)}") print("\nExtracting BASE MODEL fingerprint...") base_fp = get_model_fingerprint(base_model_path, PROBE_LAYERS) if base_fp is None: print("Failed to load base model. Exiting.") return donor_fps = [] valid_donors = [] valid_ids = [] print("\nExtracting DONOR fingerprints...") for i, path in enumerate(tqdm(donor_paths)): fp = get_model_fingerprint(path, PROBE_LAYERS) if fp is not None: donor_fps.append(fp) valid_donors.append(path) valid_ids.append(i + 1) else: print(f"Skipping {path} (failed to load)") if len(valid_donors) < 1: print("Need at least 1 valid donor.") return print("\nComputing Task Vector geometry...") norms, cos_sim, coords, var_ratio, sizes = analyze_task_vectors(base_fp, donor_fps) # --- LOGGING THE KEY --- print("\n" + "="*80) print(f"{'ID':<5} | {'Model Name'}") print("-" * 80) for i, path in enumerate(valid_donors): name = os.path.basename(path).replace("!models--", "") print(f"#{valid_ids[i]:<4} | {name}") print("="*80 + "\n") # --- MAGNITUDE ANALYSIS --- print("--- MAGNITUDE ANALYSIS & DATA POINTS ---") print(f"{'ID':<5} | {'Status':<10} | {'Delta Norm':<12} | {'Orig Size':<12} | {'Model Name'}") print("-" * 100) mean_norm = np.mean(norms) std_norm = np.std(norms) for i, model in enumerate(valid_donors): name = os.path.basename(model).replace("!models--", "") # Check if norm is significantly higher than average (potential destroyer of weights) z_score = (norms[i] - mean_norm) / (std_norm + 1e-8) status = "HIGH MAG" if z_score > 1.5 else "OK" print(f"#{valid_ids[i]:<4} | {status:<10} | {norms[i]:<12.4f} | {sizes[i]:<12} | {name}") print("\nLog saved to: " + LOG_FILENAME) print("Displaying charts...") # Reset stdout sys.stdout.terminal.flush() plot_results(valid_ids, norms, cos_sim, coords, var_ratio) # Close log sys.stdout.close() if __name__ == "__main__": main()