Spaces:
Running
Running
| import yaml | |
| import torch | |
| import os | |
| import sys | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from safetensors import safe_open | |
| from sklearn.decomposition import PCA | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from tqdm import tqdm | |
| import argparse | |
| # --- CONFIGURATION --- | |
| PROBE_LAYERS = [ | |
| "model.layers.12.mlp.down_proj.weight", # Mid-model logic | |
| "lm_head.weight" # Output semantics | |
| ] | |
| LOG_FILENAME = "della_scan.log" | |
| # --------------------- | |
| class Logger: | |
| def __init__(self, filename): | |
| self.terminal = sys.stdout | |
| self.log = open(filename, "w", encoding="utf-8") | |
| def write(self, message): | |
| self.terminal.write(message) | |
| self.log.write(message) | |
| self.log.flush() | |
| def flush(self): | |
| self.terminal.flush() | |
| self.log.flush() | |
| def close(self): | |
| self.log.close() | |
| def load_yaml_config(config_path): | |
| print(f"Loading config: {config_path}") | |
| with open(config_path, 'r', encoding='utf-8') as f: | |
| config = yaml.safe_load(f) | |
| models = [] | |
| base_model = None | |
| # Extract base model | |
| if 'base_model' in config: | |
| base_model = config['base_model'] | |
| # Extract models list | |
| if 'models' in config: | |
| for m in config['models']: | |
| models.append(m['model']) | |
| return base_model, models | |
| def get_model_fingerprint(model_path, probe_layers): | |
| tensors = [] | |
| if os.path.exists(model_path): | |
| files = [f for f in os.listdir(model_path) if f.endswith('.safetensors')] | |
| files.sort() | |
| found_layers = 0 | |
| for file in files: | |
| full_path = os.path.join(model_path, file) | |
| try: | |
| with safe_open(full_path, framework="pt", device="cpu") as f: | |
| keys = f.keys() | |
| for layer in probe_layers: | |
| if layer in keys: | |
| t = f.get_tensor(layer).float().view(-1) | |
| t = t[::10] # Downsample | |
| tensors.append(t) | |
| found_layers += 1 | |
| except Exception as e: | |
| print(f"Error reading {file}: {e}") | |
| if found_layers == 0: | |
| return None | |
| else: | |
| return None | |
| if not tensors: | |
| return None | |
| return torch.cat(tensors) | |
| def analyze_task_vectors(base_fp, donor_fps): | |
| # 0. Handle size mismatches (Manifold Alignment) | |
| base_size = base_fp.numel() | |
| donor_sizes = [f.numel() for f in donor_fps] | |
| min_size = min([base_size] + donor_sizes) | |
| if any(s != min_size for s in donor_sizes) or base_size != min_size: | |
| print(f"\n[!] SIZE MISMATCH DETECTED") | |
| print(f" Base Size: {base_size}") | |
| print(f" Min Donor: {min(donor_sizes)}") | |
| print(f" Action: Truncating all models to {min_size} for audit.") | |
| # Align fingerprints | |
| aligned_base = base_fp[:min_size] | |
| aligned_donors = [f[:min_size] for f in donor_fps] | |
| # 1. Calculate Task Vectors (Delta = Donor - Base) | |
| task_vectors = [] | |
| for d_fp in aligned_donors: | |
| task_vectors.append(d_fp - aligned_base) | |
| # Stack into matrix [N_donors, N_features] | |
| data_matrix = torch.stack(task_vectors).numpy() | |
| # 2. Norm Analysis (Magnitude of the Delta) | |
| norms = np.linalg.norm(data_matrix, axis=1) | |
| # 3. Cosine Similarity Matrix (Directional Alignment) | |
| cos_sim = cosine_similarity(data_matrix) | |
| # 4. PCA Projection (2D) | |
| # Center the task vectors | |
| centered_data = data_matrix - np.mean(data_matrix, axis=0) | |
| if len(donor_fps) > 1: | |
| pca = PCA(n_components=2) | |
| coords = pca.fit_transform(centered_data) | |
| var_ratio = pca.explained_variance_ratio_ | |
| else: | |
| coords = np.zeros((1, 2)) | |
| var_ratio = [1.0, 0.0] | |
| return norms, cos_sim, coords, var_ratio, donor_sizes | |
| def plot_results(model_ids, norms, cos_sim, coords, var_ratio): | |
| labels = [str(mid) for mid in model_ids] | |
| fig = plt.figure(figsize=(20, 12)) | |
| fig.suptitle(f"DELLA/Task Arithmetic Compatibility Audit ({len(model_ids)} Donors)\nRefer to della_scan.log for ID Key", fontsize=16) | |
| # --- Plot 1: Task Vector Manifold (PCA) --- | |
| ax1 = fig.add_subplot(2, 2, 1) | |
| ax1.scatter(coords[:, 0], coords[:, 1], c='purple', s=80, alpha=0.6) | |
| for i, txt in enumerate(labels): | |
| ax1.annotate(txt, (coords[i, 0], coords[i, 1]), xytext=(3, 3), textcoords='offset points', fontsize=8, fontweight='bold') | |
| ax1.set_title(f"Task Vector Map (PCA of Deltas)\nClusters = Redundant Skills") | |
| ax1.set_xlabel(f"PC1 ({var_ratio[0]:.1%} variance)") | |
| ax1.set_ylabel(f"PC2 ({var_ratio[1]:.1%} variance)") | |
| ax1.grid(True, alpha=0.3) | |
| # Plot Origin (Base Model reference relative to centered data) | |
| center_offset = -np.mean(coords, axis=0) | |
| ax1.scatter(center_offset[0], center_offset[1], c='red', marker='x', s=100, label='Base Model (Ref)') | |
| ax1.legend() | |
| # --- Plot 2: Cosine Similarity Heatmap --- | |
| ax2 = fig.add_subplot(2, 2, 2) | |
| # For Task Vectors, negative similarity is common (conflicting directions) | |
| im = ax2.imshow(cos_sim, cmap='coolwarm', vmin=-1.0, vmax=1.0) | |
| ax2.set_xticks(np.arange(len(labels))) | |
| ax2.set_yticks(np.arange(len(labels))) | |
| ax2.set_xticklabels(labels, rotation=90, fontsize=6) | |
| ax2.set_yticklabels(labels, fontsize=6) | |
| ax2.set_title("Task Vector Alignment (Blue=Opposed, Red=Aligned)") | |
| plt.colorbar(im, ax=ax2) | |
| # --- Plot 3: Delta Magnitude (L2 Norm) --- | |
| ax3 = fig.add_subplot(2, 1, 2) | |
| bars = ax3.bar(labels, norms, color='orange', alpha=0.6) | |
| ax3.set_title("Task Vector Magnitude (L2 Norm)\nHigh bars = Drastic deviation from Base Model") | |
| ax3.set_ylabel("Delta L2 Norm") | |
| ax3.set_xlabel("Donor ID") | |
| ax3.grid(axis='y', alpha=0.3) | |
| for bar in bars: | |
| height = bar.get_height() | |
| ax3.text(bar.get_x() + bar.get_width()/2., height, | |
| f'{height:.1f}', ha='center', va='bottom', fontsize=6, rotation=90) | |
| plt.tight_layout() | |
| plt.show() | |
| def main(): | |
| # Hook stdout to log file | |
| sys.stdout = Logger(LOG_FILENAME) | |
| parser = argparse.ArgumentParser(description="Audit MergeKit models for DELLA/Task Arithmetic compatibility.") | |
| parser.add_argument("config", help="Path to the mergekit yaml config file") | |
| args = parser.parse_args() | |
| print(f"--- DELLA AUDIT V2 START ---") | |
| base_model_path, donor_paths = load_yaml_config(args.config) | |
| if not base_model_path: | |
| print("Error: No 'base_model' found in config. DELLA requires a base model.") | |
| return | |
| print(f"Base Model: {base_model_path}") | |
| print(f"Donors: {len(donor_paths)}") | |
| print("\nExtracting BASE MODEL fingerprint...") | |
| base_fp = get_model_fingerprint(base_model_path, PROBE_LAYERS) | |
| if base_fp is None: | |
| print("Failed to load base model. Exiting.") | |
| return | |
| donor_fps = [] | |
| valid_donors = [] | |
| valid_ids = [] | |
| print("\nExtracting DONOR fingerprints...") | |
| for i, path in enumerate(tqdm(donor_paths)): | |
| fp = get_model_fingerprint(path, PROBE_LAYERS) | |
| if fp is not None: | |
| donor_fps.append(fp) | |
| valid_donors.append(path) | |
| valid_ids.append(i + 1) | |
| else: | |
| print(f"Skipping {path} (failed to load)") | |
| if len(valid_donors) < 1: | |
| print("Need at least 1 valid donor.") | |
| return | |
| print("\nComputing Task Vector geometry...") | |
| norms, cos_sim, coords, var_ratio, sizes = analyze_task_vectors(base_fp, donor_fps) | |
| # --- LOGGING THE KEY --- | |
| print("\n" + "="*80) | |
| print(f"{'ID':<5} | {'Model Name'}") | |
| print("-" * 80) | |
| for i, path in enumerate(valid_donors): | |
| name = os.path.basename(path).replace("!models--", "") | |
| print(f"#{valid_ids[i]:<4} | {name}") | |
| print("="*80 + "\n") | |
| # --- MAGNITUDE ANALYSIS --- | |
| print("--- MAGNITUDE ANALYSIS & DATA POINTS ---") | |
| print(f"{'ID':<5} | {'Status':<10} | {'Delta Norm':<12} | {'Orig Size':<12} | {'Model Name'}") | |
| print("-" * 100) | |
| mean_norm = np.mean(norms) | |
| std_norm = np.std(norms) | |
| for i, model in enumerate(valid_donors): | |
| name = os.path.basename(model).replace("!models--", "") | |
| # Check if norm is significantly higher than average (potential destroyer of weights) | |
| z_score = (norms[i] - mean_norm) / (std_norm + 1e-8) | |
| status = "HIGH MAG" if z_score > 1.5 else "OK" | |
| print(f"#{valid_ids[i]:<4} | {status:<10} | {norms[i]:<12.4f} | {sizes[i]:<12} | {name}") | |
| print("\nLog saved to: " + LOG_FILENAME) | |
| print("Displaying charts...") | |
| # Reset stdout | |
| sys.stdout.terminal.flush() | |
| plot_results(valid_ids, norms, cos_sim, coords, var_ratio) | |
| # Close log | |
| sys.stdout.close() | |
| if __name__ == "__main__": | |
| main() |