Spaces:
Running
Running
| # Copyright (C) 2025 Arcee AI & Kraken Architect | |
| # SPDX-License-Identifier: BUSL-1.1 | |
| import logging | |
| import os | |
| import sys | |
| from typing import List, Optional | |
| import click | |
| import torch | |
| import yaml | |
| from tqdm import tqdm | |
| from mergekit.common import ModelReference | |
| from mergekit.config import MergeConfiguration | |
| from mergekit.io.lazy_tensor_loader import LazyTensorLoader, ShardedTensorIndex | |
| from mergekit.merge_methods.easy_define import merge_method | |
| logging.basicConfig(level=logging.INFO, format="%(message)s") | |
| LOG = logging.getLogger("donor_audit") | |
| def _donor_audit_registration(tensors: List[torch.Tensor]) -> torch.Tensor: | |
| """Placeholder to register the method name.""" | |
| return tensors[0] | |
| def rsce_weight(tvs: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Calculates matrix-level weights based on the energy of the task vectors. | |
| (Copied from RSCE v3) | |
| """ | |
| # Mean square energy | |
| weights = torch.mean(tvs**2, dim=list(range(1, tvs.dim()))) | |
| weight_sum = torch.sum(weights).item() | |
| if abs(weight_sum) < 1e-8: | |
| return torch.ones_like(weights) / weights.shape[0] | |
| return weights / weight_sum | |
| def log_rsce_audit(layer_name: str, weights: torch.Tensor, names: List[str]): | |
| """Prints and saves a bar chart of donor influence.""" | |
| w_list = weights.tolist() | |
| bar_char = "█" | |
| # Header | |
| print(f"\n{'='*60}") | |
| print(f"RSCE DONOR AUDIT REPORT") | |
| print(f"Target Tensor: {layer_name}") | |
| print(f"{'='*60}") | |
| lines = [] | |
| for name, w in zip(names, w_list): | |
| pct = w * 100 | |
| # Scale bar: 50 chars = 100% influence (which is huge/impossible usually) | |
| # Let's scale it so the max value fills the bar for better visibility | |
| max_val = max(w_list) if max(w_list) > 0 else 1.0 | |
| # Relative bar length (relative to the loudest model) | |
| bar_len = int((w / max_val) * 40) | |
| bar = bar_char * bar_len | |
| # Truncate name for clean display | |
| clean_name = os.path.basename(name) | |
| if len(clean_name) > 30: | |
| clean_name = clean_name[:27] + "..." | |
| lines.append(f"{clean_name:<30} | {bar:<40} | {pct:6.2f}% (Raw: {w:.4f})") | |
| log_entry = "\n".join(lines) | |
| print(log_entry) | |
| print(f"{'='*60}\n") | |
| # Append to file | |
| with open("rsce_audit.log", "a", encoding="utf-8") as f: | |
| f.write(f"\n[Audit {layer_name}]\n" + log_entry + "\n") | |
| def find_layer0_tensor(loader: LazyTensorLoader) -> str: | |
| """ | |
| Scans a model loader to find a suitable Layer 0 tensor for auditing. | |
| Prioritizes self_attn projections as they are usually dense and representative. | |
| """ | |
| candidates = [] | |
| for key in loader.index.tensor_paths.keys(): | |
| # Look for Layer 0 | |
| if ".layers.0." in key or ".h.0." in key or ".blocks.0." in key: | |
| # Look for weights (not bias) | |
| if key.endswith(".weight"): | |
| candidates.append(key) | |
| # Priority sort: q_proj > gate_proj > dense > others | |
| for c in candidates: | |
| if "down_proj" in c: return c | |
| for c in candidates: | |
| if "gate_proj" in c: return c | |
| for c in candidates: | |
| if "c_attn" in c: return c # GPT-NeoX / Qwen | |
| if not candidates: | |
| raise RuntimeError("Could not find any Layer 0 weights in the base model.") | |
| return candidates[0] | |
| def load_tensor_safe(model_path: str, tensor_name: str, device="cpu") -> torch.Tensor: | |
| """Loads a single tensor from a model path.""" | |
| try: | |
| # We use ShardedTensorIndex directly to avoid caching overhead of LoaderCache for this simple script | |
| if os.path.isfile(model_path): | |
| index = ShardedTensorIndex.from_file(model_path) | |
| else: | |
| index = ShardedTensorIndex.from_disk(model_path) | |
| loader = LazyTensorLoader(index, lazy_unpickle=True) | |
| # Handle potential naming mismatches (simple check) | |
| if tensor_name not in index.tensor_paths: | |
| # Try to find a fuzzy match if exact name fails (e.g. if models have slightly different archs) | |
| # This is a basic fallback | |
| suffix = tensor_name.split("layers.0.")[-1] | |
| for k in index.tensor_paths.keys(): | |
| if k.endswith(suffix) and ("layers.0." in k or "h.0." in k): | |
| tensor_name = k | |
| break | |
| t = loader.get_tensor(tensor_name, device=device) | |
| return t.float() # Convert to float32 for math | |
| except Exception as e: | |
| LOG.error(f"Failed to load {tensor_name} from {model_path}: {e}") | |
| sys.exit(1) | |
| def main(config_file, lora_merge_cache, cuda): | |
| """ | |
| RSCE Donor Audit Tool V3. | |
| Loads Layer 0 from all models in the config and calculates their | |
| Task Vector magnitude/energy contribution relative to the base model. | |
| """ | |
| device = "cuda" if cuda and torch.cuda.is_available() else "cpu" | |
| LOG.info(f"Running audit on {device}...") | |
| # 1. Parse Config | |
| with open(config_file, "r", encoding="utf-8") as f: | |
| config_data = yaml.safe_load(f) | |
| config = MergeConfiguration.model_validate(config_data) | |
| # 2. Identify Models | |
| base_model_ref = config.base_model | |
| if not base_model_ref: | |
| LOG.error("Config must specify a `base_model` for RSCE auditing.") | |
| sys.exit(1) | |
| # Extract donor models from slices or models list | |
| donor_refs = [] | |
| if config.models: | |
| donor_refs = [m.model for m in config.models] | |
| elif config.slices: | |
| # Flatten slices to get unique models | |
| seen = set() | |
| for s in config.slices: | |
| for source in s.sources: | |
| if source.model != base_model_ref and source.model not in seen: | |
| donor_refs.append(source.model) | |
| seen.add(source.model) | |
| # Filter out base model if it appeared in donors | |
| donor_refs = [d for d in donor_refs if d != base_model_ref] | |
| LOG.info(f"Base Model: {base_model_ref.model.path}") | |
| LOG.info(f"Found {len(donor_refs)} donor models.") | |
| # 3. Resolve Paths (Handle LoRAs if necessary) | |
| def resolve_path(ref: ModelReference): | |
| if ref.lora: | |
| if not lora_merge_cache: | |
| LOG.warning("LoRA detected but --lora-merge-cache not set. This might fail.") | |
| return ref.merged(cache_dir=lora_merge_cache).model.path | |
| if not os.path.exists(ref.model.path): | |
| try: | |
| from huggingface_hub import snapshot_download | |
| return snapshot_download(ref.model.path, allow_patterns=["*.safetensors", "*.bin", "*.json"]) | |
| except: | |
| return ref.model.path | |
| return ref.model.path | |
| base_path = resolve_path(base_model_ref) | |
| donor_paths = [resolve_path(d) for d in donor_refs] | |
| # 4. Identify Target Tensor (Layer 0) | |
| base_index = ShardedTensorIndex.from_disk(base_path) | |
| base_loader = LazyTensorLoader(base_index, lazy_unpickle=True) | |
| target_tensor_name = find_layer0_tensor(base_loader) | |
| LOG.info(f"Selected audit tensor: {target_tensor_name}") | |
| LOG.info("Loading tensors into memory...") | |
| # 5. Load All Tensors | |
| base_tensor = load_tensor_safe(base_path, target_tensor_name, device) | |
| donor_tensors = [] | |
| valid_donor_refs = [] | |
| for d_path, d_ref in zip(tqdm(donor_paths, desc="Loading Donors"), donor_refs): | |
| dt = load_tensor_safe(d_path, target_tensor_name, device) | |
| # V3: Catch shape mismatches (e.g. a 7B model mixed into a 12B merge) | |
| if dt.shape != base_tensor.shape: | |
| LOG.warning(f"\n[!] Shape mismatch for {d_ref.model.path}: expected {base_tensor.shape}, got {dt.shape}. Skipping this model.") | |
| continue | |
| donor_tensors.append(dt) | |
| valid_donor_refs.append(d_ref) | |
| if not donor_tensors: | |
| LOG.error("No valid donor tensors found with matching shapes. Exiting.") | |
| sys.exit(1) | |
| # 6. Perform RSCE Audit Math | |
| LOG.info("Calculating Task Vector Energy...") | |
| base_tv = torch.zeros_like(base_tensor) | |
| donor_tvs = [dt - base_tensor for dt in donor_tensors] | |
| all_tvs = torch.stack([base_tv] + donor_tvs, dim=0) | |
| raw_weights = rsce_weight(all_tvs) | |
| display_names = ["Base Model (Anchor)"] + [d.model.path for d in valid_donor_refs] | |
| # 7. Output | |
| log_rsce_audit(target_tensor_name, raw_weights, display_names) | |
| LOG.info("Audit complete.") | |
| if __name__ == "__main__": | |
| main() |