| import os |
| import sys |
| from collections import defaultdict |
|
|
| import click |
| import numpy as np |
| import safetensors.torch |
| import scipy |
| import torch |
|
|
| from mergekit.architecture import ArchitectureInfoUtils, _template_substitution |
| from mergekit.common import ModelReference |
|
|
|
|
| def calc_correlation_matrix(feats): |
| feats = feats.view(-1, feats.shape[-1]) |
|
|
| return torch.corrcoef(feats.T) |
|
|
|
|
| def match_tensors_permute( |
| absval=False, |
| correlation_matrix=None, |
| ): |
| """ |
| This function is adapted from ZipIt! (https://github.com/gstoica27/ZipIt) |
| """ |
|
|
| Om = correlation_matrix.shape[0] // 2 |
| device = correlation_matrix.device |
|
|
| mats = [torch.eye(Om, device=device)] |
|
|
| corr_submatrix = correlation_matrix[:Om, Om:].cpu().numpy() |
| if absval: |
| corr_submatrix = np.absolute(corr_submatrix) |
| _, col_ind = scipy.optimize.linear_sum_assignment(corr_submatrix, maximize=True) |
|
|
| new_mat = torch.eye(Om, device=device)[torch.tensor(col_ind).long().to(device)] |
| mats.append(new_mat.T) |
|
|
| unmerge_mats = mats |
|
|
| unmerge = torch.cat(unmerge_mats, dim=0) |
|
|
| merge = torch.cat(mats, dim=0) |
| merge = merge / (merge.sum(dim=0, keepdim=True) + 1e-5) |
|
|
| return merge.T, unmerge |
|
|
|
|
| def match_tensors_permute_MHA( |
| n_heads=32, |
| absval=False, |
| correlation_matrix=None, |
| ): |
| """ |
| Handles different head permutations in attention. |
| Modified version of the function here: https://github.com/nverma1/merging-text-transformers/blob/main/matching_functions.py#L76 |
| """ |
|
|
| Om = correlation_matrix.shape[0] // 2 |
| device = correlation_matrix.device |
| query_size = Om // n_heads |
|
|
| mats = [torch.eye(Om, device=device)] |
| head_perms = [] |
|
|
| costs = np.ones((n_heads, n_heads)) * -sys.maxsize |
|
|
| col_inds_storage = defaultdict(lambda: defaultdict(int)) |
|
|
| for j in range(n_heads): |
| for k in range(n_heads): |
| head1_idx = [query_size * j, query_size * (j + 1)] |
| head2_idx = [query_size * k, query_size * (k + 1)] |
|
|
| corr_submatrix = ( |
| correlation_matrix[ |
| head1_idx[0] : head1_idx[1], |
| (Om + head2_idx[0]) : (Om + head2_idx[1]), |
| ] |
| .cpu() |
| .numpy() |
| ) |
| if absval: |
| corr_submatrix = np.absolute(corr_submatrix) |
|
|
| |
| row_ind, col_ind = scipy.optimize.linear_sum_assignment( |
| corr_submatrix, maximize=True |
| ) |
|
|
| costs[j, k] = corr_submatrix[row_ind, col_ind].sum() |
|
|
| col_inds_storage[j][k] = col_ind |
|
|
| outer_row_ind, outer_col_ind = scipy.optimize.linear_sum_assignment( |
| costs, maximize=True |
| ) |
|
|
| for j in range(n_heads): |
| head_1 = outer_row_ind[j] |
| head_2 = outer_col_ind[j] |
|
|
| head_perm = col_inds_storage[head_1][head_2] |
| head_perms.append(torch.tensor(head_perm + query_size * head_2)) |
|
|
| new_mat = torch.eye(Om, device=device)[ |
| torch.cat(head_perms).clone().detach().long().to(device) |
| ] |
| mats.append(new_mat.T) |
|
|
| unmerge_mats = mats |
|
|
| unmerge = torch.cat(unmerge_mats, dim=0) |
| merge = torch.cat(mats, dim=0) |
| merge = merge / (merge.sum(dim=0, keepdim=True) + 1e-5) |
|
|
| return merge.T, unmerge |
|
|
|
|
| @click.command("mergekit-abm-extract-permutations") |
| @click.argument("model1-ft", type=str, required=True) |
| @click.argument("model2-ft", type=str, required=True) |
| @click.option("--model_path", type=str, required=True, help="Model information") |
| @click.option( |
| "--out_path", required=True, type=str, help="Output path for metric tensors" |
| ) |
| @click.option( |
| "--absval/--no-absval", |
| required=False, |
| default=False, |
| help="Use absolute value on correlation matrices/submatrices while calculating merge/unmerge matrices", |
| ) |
| @click.option( |
| "--device", |
| "-d", |
| type=str, |
| default="cpu", |
| help="Device to compute on (default: cpu)", |
| ) |
| def main(model1_ft, model2_ft, model_path, out_path, absval, device): |
| os.makedirs(out_path, exist_ok=True) |
|
|
| model = ModelReference.model_validate(model_path) |
|
|
| model_config = model.config() |
|
|
| model_arch_info = ArchitectureInfoUtils.get_architecture_info(model_config) |
|
|
| _json = model_arch_info.definition |
|
|
| residual_space = None |
| kq_space = None |
| v_space = None |
|
|
| |
| for weight in _json.layer_templates.weights: |
| if weight.is_kq: |
| kq_space = weight.output_space |
| residual_space = weight.input_space |
| continue |
|
|
| |
| if ( |
| not weight.is_kq |
| and weight.head_split |
| and (weight.input_space == residual_space) |
| ): |
| v_space = weight.output_space |
| continue |
|
|
| num_layers = model_arch_info.num_layers(model_config) |
|
|
| kq_spaces = [] |
| v_spaces = [] |
| for j in range(num_layers): |
| kq_spaces.append( |
| _template_substitution(kq_space, num_layers=num_layers, layer_idx=j) |
| ) |
| v_spaces.append( |
| _template_substitution(v_space, num_layers=num_layers, layer_idx=j) |
| ) |
|
|
| model1_features = safetensors.torch.load_file(model1_ft, device=device) |
| model2_features = safetensors.torch.load_file(model2_ft, device=device) |
|
|
| model1_features.pop("attention_mask") |
| model2_features.pop("attention_mask") |
|
|
| for feature_space in model1_features.keys(): |
| concatenated_feature = torch.cat( |
| (model1_features[feature_space], model2_features[feature_space]), dim=-1 |
| ) |
|
|
| correlation_matrix = calc_correlation_matrix(concatenated_feature) |
|
|
| if feature_space in (kq_spaces + v_spaces): |
| merge, unmerge = match_tensors_permute_MHA( |
| correlation_matrix=correlation_matrix, |
| n_heads=model_config.num_attention_heads, |
| absval=absval, |
| ) |
|
|
| else: |
| merge, unmerge = match_tensors_permute( |
| correlation_matrix=correlation_matrix, |
| absval=absval, |
| ) |
|
|
| safetensors.torch.save_file( |
| {feature_space: merge.contiguous()}, |
| f"{out_path}/{feature_space}_merge.safetensor", |
| ) |
|
|
| safetensors.torch.save_file( |
| {feature_space: unmerge.contiguous()}, |
| f"{out_path}/{feature_space}_unmerge.safetensor", |
| ) |
|
|
| del merge, unmerge, correlation_matrix, concatenated_feature |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|