prestigeAI-7b / scripts /ABM /extract_permutation_matrices.py
Danaasa's picture
Upload folder using huggingface_hub
656b04b verified
import os
import sys
from collections import defaultdict
import click
import numpy as np
import safetensors.torch
import scipy
import torch
from mergekit.architecture import ArchitectureInfoUtils, _template_substitution
from mergekit.common import ModelReference
def calc_correlation_matrix(feats):
feats = feats.view(-1, feats.shape[-1])
return torch.corrcoef(feats.T)
def match_tensors_permute(
absval=False,
correlation_matrix=None,
):
"""
This function is adapted from ZipIt! (https://github.com/gstoica27/ZipIt)
"""
Om = correlation_matrix.shape[0] // 2
device = correlation_matrix.device
mats = [torch.eye(Om, device=device)]
corr_submatrix = correlation_matrix[:Om, Om:].cpu().numpy()
if absval:
corr_submatrix = np.absolute(corr_submatrix)
_, col_ind = scipy.optimize.linear_sum_assignment(corr_submatrix, maximize=True)
new_mat = torch.eye(Om, device=device)[torch.tensor(col_ind).long().to(device)]
mats.append(new_mat.T)
unmerge_mats = mats
unmerge = torch.cat(unmerge_mats, dim=0)
merge = torch.cat(mats, dim=0)
merge = merge / (merge.sum(dim=0, keepdim=True) + 1e-5)
return merge.T, unmerge
def match_tensors_permute_MHA(
n_heads=32,
absval=False,
correlation_matrix=None,
):
"""
Handles different head permutations in attention.
Modified version of the function here: https://github.com/nverma1/merging-text-transformers/blob/main/matching_functions.py#L76
"""
Om = correlation_matrix.shape[0] // 2
device = correlation_matrix.device
query_size = Om // n_heads
mats = [torch.eye(Om, device=device)]
head_perms = []
costs = np.ones((n_heads, n_heads)) * -sys.maxsize
col_inds_storage = defaultdict(lambda: defaultdict(int))
for j in range(n_heads):
for k in range(n_heads):
head1_idx = [query_size * j, query_size * (j + 1)]
head2_idx = [query_size * k, query_size * (k + 1)]
corr_submatrix = (
correlation_matrix[
head1_idx[0] : head1_idx[1],
(Om + head2_idx[0]) : (Om + head2_idx[1]),
]
.cpu()
.numpy()
)
if absval:
corr_submatrix = np.absolute(corr_submatrix)
# compute perm for head j & head k
row_ind, col_ind = scipy.optimize.linear_sum_assignment(
corr_submatrix, maximize=True
)
costs[j, k] = corr_submatrix[row_ind, col_ind].sum()
col_inds_storage[j][k] = col_ind
outer_row_ind, outer_col_ind = scipy.optimize.linear_sum_assignment(
costs, maximize=True
)
for j in range(n_heads):
head_1 = outer_row_ind[j]
head_2 = outer_col_ind[j]
head_perm = col_inds_storage[head_1][head_2]
head_perms.append(torch.tensor(head_perm + query_size * head_2))
new_mat = torch.eye(Om, device=device)[
torch.cat(head_perms).clone().detach().long().to(device)
]
mats.append(new_mat.T)
unmerge_mats = mats
unmerge = torch.cat(unmerge_mats, dim=0)
merge = torch.cat(mats, dim=0)
merge = merge / (merge.sum(dim=0, keepdim=True) + 1e-5)
return merge.T, unmerge
@click.command("mergekit-abm-extract-permutations")
@click.argument("model1-ft", type=str, required=True)
@click.argument("model2-ft", type=str, required=True)
@click.option("--model_path", type=str, required=True, help="Model information")
@click.option(
"--out_path", required=True, type=str, help="Output path for metric tensors"
)
@click.option(
"--absval/--no-absval",
required=False,
default=False,
help="Use absolute value on correlation matrices/submatrices while calculating merge/unmerge matrices",
)
@click.option(
"--device",
"-d",
type=str,
default="cpu",
help="Device to compute on (default: cpu)",
)
def main(model1_ft, model2_ft, model_path, out_path, absval, device):
os.makedirs(out_path, exist_ok=True)
model = ModelReference.model_validate(model_path)
model_config = model.config()
model_arch_info = ArchitectureInfoUtils.get_architecture_info(model_config)
_json = model_arch_info.definition
residual_space = None
kq_space = None
v_space = None
# extract the residual, attention related spaces
for weight in _json.layer_templates.weights:
if weight.is_kq:
kq_space = weight.output_space
residual_space = weight.input_space
continue
# assuming order is observed
if (
not weight.is_kq
and weight.head_split
and (weight.input_space == residual_space)
):
v_space = weight.output_space
continue
num_layers = model_arch_info.num_layers(model_config)
kq_spaces = []
v_spaces = []
for j in range(num_layers):
kq_spaces.append(
_template_substitution(kq_space, num_layers=num_layers, layer_idx=j)
)
v_spaces.append(
_template_substitution(v_space, num_layers=num_layers, layer_idx=j)
)
model1_features = safetensors.torch.load_file(model1_ft, device=device)
model2_features = safetensors.torch.load_file(model2_ft, device=device)
model1_features.pop("attention_mask")
model2_features.pop("attention_mask")
for feature_space in model1_features.keys():
concatenated_feature = torch.cat(
(model1_features[feature_space], model2_features[feature_space]), dim=-1
)
correlation_matrix = calc_correlation_matrix(concatenated_feature)
if feature_space in (kq_spaces + v_spaces):
merge, unmerge = match_tensors_permute_MHA(
correlation_matrix=correlation_matrix,
n_heads=model_config.num_attention_heads,
absval=absval,
)
else:
merge, unmerge = match_tensors_permute(
correlation_matrix=correlation_matrix,
absval=absval,
)
safetensors.torch.save_file(
{feature_space: merge.contiguous()},
f"{out_path}/{feature_space}_merge.safetensor",
)
safetensors.torch.save_file(
{feature_space: unmerge.contiguous()},
f"{out_path}/{feature_space}_unmerge.safetensor",
)
del merge, unmerge, correlation_matrix, concatenated_feature
if __name__ == "__main__":
main()