Spaces:
Runtime error
Runtime error
| """ | |
| Implementation of activation matching algorithms for comparing neural network models. | |
| This module provides functions for matching neurons between two models based on | |
| their activation patterns, helping identify corresponding functional units despite | |
| permutation differences. | |
| """ | |
| import torch | |
| from collections import defaultdict | |
| import scipy | |
| import numpy as np | |
| from tracing.utils.evaluate import evaluate | |
| from tracing.utils.llama.matching import match_wmats | |
| def hook_in(m, inp, op, feats, name): | |
| """ | |
| Forward hook to capture input activations to model layers. | |
| Args: | |
| m: Module being hooked | |
| inp: Input to the module (tuple) | |
| op: Output from the module | |
| feats: Dictionary to store activations | |
| name: Key to store the activations under | |
| """ | |
| feats[name].append(inp[0].detach().cpu()) | |
| def hook_out(m, inp, op, feats, name): | |
| """ | |
| Forward hook to capture output activations from model layers. | |
| Args: | |
| m: Module being hooked | |
| inp: Input to the module | |
| op: Output from the module | |
| feats: Dictionary to store activations | |
| name: Key to store the activations under | |
| """ | |
| feats[name].append(op.detach().cpu()) | |
| def statistic(base_model, ft_model, dataloader, n_blocks=32): | |
| """ | |
| Compute neuron matching statistics across all transformer blocks. | |
| For each block, compares the gate and up projections to determine if | |
| the permutation patterns are consistent, which would indicate functionally | |
| corresponding neurons. | |
| Args: | |
| base_model: Base model to compare | |
| ft_model: Fine-tuned or target model to compare against the base model | |
| dataloader: DataLoader providing input data for activation collection | |
| n_blocks: Number of transformer blocks to analyze (default: 32) | |
| Returns: | |
| list: Spearman correlation p-values for each block | |
| """ | |
| stats = [] | |
| for i in range(n_blocks): | |
| gate_match = mlp_matching_gate(base_model, ft_model, dataloader, i=i) | |
| up_match = mlp_matching_up(base_model, ft_model, dataloader, i=i) | |
| cor, pvalue = scipy.stats.spearmanr(gate_match.tolist(), up_match.tolist()) | |
| print(i, pvalue, len(gate_match)) | |
| stats.append(pvalue) | |
| return stats | |
| def statistic_layer(base_model, ft_model, dataloader, i=0): | |
| """ | |
| Compute neuron matching statistics for a specific transformer block. | |
| Args: | |
| base_model: Base model to compare | |
| ft_model: Fine-tuned or target model to compare against the base model | |
| dataloader: DataLoader providing input data for activation collection | |
| i: Block index to analyze (default: 0) | |
| Returns: | |
| float: Spearman correlation p-value for the specified block | |
| """ | |
| gate_perm = mlp_matching_gate(base_model, ft_model, dataloader, i=i) | |
| up_perm = mlp_matching_up(base_model, ft_model, dataloader, i=i) | |
| cor, pvalue = scipy.stats.spearmanr(gate_perm.tolist(), up_perm.tolist()) | |
| return pvalue | |
| def mlp_matching_gate(base_model, ft_model, dataloader, i=0): | |
| """ | |
| Match neurons between models by comparing gate projection activations. | |
| Collects activations from the gate projection layer for both models | |
| and computes a permutation that would align corresponding neurons. | |
| Args: | |
| base_model: Base model to compare | |
| ft_model: Fine-tuned or target model to compare against the base model | |
| dataloader: DataLoader providing input data for activation collection | |
| i: Block index to analyze (default: 0) | |
| Returns: | |
| torch.Tensor: Permutation indices that match neurons between models | |
| """ | |
| feats = defaultdict(list) | |
| base_hook = lambda *args: hook_out(*args, feats, "base") | |
| base_handle = base_model.model.layers[i].mlp.gate_proj.register_forward_hook(base_hook) | |
| ft_hook = lambda *args: hook_out(*args, feats, "ft") | |
| ft_handle = ft_model.model.layers[i].mlp.gate_proj.register_forward_hook(ft_hook) | |
| evaluate(base_model, dataloader) | |
| evaluate(ft_model, dataloader) | |
| base_mat = torch.vstack(feats["base"]) | |
| ft_mat = torch.vstack(feats["ft"]) | |
| base_mat.to("cuda") | |
| ft_mat.to("cuda") | |
| base_mat = base_mat.view(-1, base_mat.shape[-1]).T | |
| ft_mat = ft_mat.view(-1, ft_mat.shape[-1]).T | |
| base_handle.remove() | |
| ft_handle.remove() | |
| perm = match_wmats(base_mat, ft_mat) | |
| return perm | |
| def mlp_matching_up(base_model, ft_model, dataloader, i=0): | |
| """ | |
| Match neurons between models by comparing up projection activations. | |
| Collects activations from the up projection layer for both models | |
| and computes a permutation that would align corresponding neurons. | |
| Args: | |
| base_model: Base model to compare | |
| ft_model: Fine-tuned or target model to compare against the base model | |
| dataloader: DataLoader providing input data for activation collection | |
| i: Block index to analyze (default: 0) | |
| Returns: | |
| torch.Tensor: Permutation indices that match neurons between models | |
| """ | |
| feats = defaultdict(list) | |
| base_hook = lambda *args: hook_out(*args, feats, "base") | |
| base_handle = base_model.model.layers[i].mlp.up_proj.register_forward_hook(base_hook) | |
| ft_hook = lambda *args: hook_out(*args, feats, "ft") | |
| ft_handle = ft_model.model.layers[i].mlp.up_proj.register_forward_hook(ft_hook) | |
| evaluate(base_model, dataloader) | |
| evaluate(ft_model, dataloader) | |
| base_mat = torch.vstack(feats["base"]) | |
| ft_mat = torch.vstack(feats["ft"]) | |
| base_mat.to("cuda") | |
| ft_mat.to("cuda") | |
| base_mat = base_mat.view(-1, base_mat.shape[-1]).T | |
| ft_mat = ft_mat.view(-1, ft_mat.shape[-1]).T | |
| base_handle.remove() | |
| ft_handle.remove() | |
| perm = match_wmats(base_mat, ft_mat) | |
| return perm | |
| def mlp_layers(base_model, ft_model, dataloader, i, j): | |
| """ | |
| Compare gate and up projections between specific layers of two models. | |
| Useful for comparing non-corresponding layers to find functional similarities. | |
| Args: | |
| base_model: Base model to compare | |
| ft_model: Fine-tuned or target model to compare against the base model | |
| dataloader: DataLoader providing input data for activation collection | |
| i: Layer index in the base model | |
| j: Layer index in the fine-tuned model | |
| Returns: | |
| float: Spearman correlation p-value between gate and up projections | |
| """ | |
| gate_match = mlp_matching_gate(base_model, ft_model, dataloader, i, j) | |
| up_match = mlp_matching_up(base_model, ft_model, dataloader, i, j) | |
| cor, pvalue = scipy.stats.spearmanr(gate_match.tolist(), up_match.tolist()) | |
| return pvalue | |
| def statistic_all(model_1, model_2, dataloader): | |
| """ | |
| Perform comprehensive layer matching between two models. | |
| Tests all combinations of layers between the models to identify corresponding | |
| functional units, regardless of their position in the network architecture. | |
| Args: | |
| model_1: First model to compare | |
| model_2: Second model to compare | |
| dataloader: DataLoader providing input data for activation collection | |
| Returns: | |
| None: Prints matching results during execution | |
| """ | |
| model_1_matched = np.zeros(model_1.config.num_hidden_layers) | |
| model_2_matched = np.zeros(model_2.config.num_hidden_layers) | |
| for i in range(model_1.config.num_hidden_layers): | |
| for j in range(model_2.config.num_hidden_layers): | |
| if model_1_matched[i] == 1 or model_2_matched[j] == 1: | |
| continue | |
| stat = mlp_layers(model_1, model_2, dataloader, i, j) | |
| print(i, j, stat) | |
| if stat < 0.000001: | |
| model_1_matched[i] = 1 | |
| model_2_matched[j] = 1 | |
| break | |