| |
| |
| import os |
| import logging |
|
|
| logger = logging.getLogger(__name__) |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| def cal_cross_attn(to_q, to_k, to_v, rand_input): |
| hidden_dim, embed_dim = to_q.shape |
| attn_to_q = nn.Linear(hidden_dim, embed_dim, bias=False) |
| attn_to_k = nn.Linear(hidden_dim, embed_dim, bias=False) |
| attn_to_v = nn.Linear(hidden_dim, embed_dim, bias=False) |
| attn_to_q.load_state_dict({"weight": to_q}) |
| attn_to_k.load_state_dict({"weight": to_k}) |
| attn_to_v.load_state_dict({"weight": to_v}) |
|
|
| return torch.einsum( |
| "ik, jk -> ik", |
| F.softmax( |
| torch.einsum("ij, kj -> ik", attn_to_q(rand_input), attn_to_k(rand_input)), |
| dim=-1, |
| ), |
| attn_to_v(rand_input), |
| ) |
|
|
|
|
| def model_hash(filename): |
| try: |
| with open(filename, "rb") as file: |
| import hashlib |
|
|
| m = hashlib.sha256() |
|
|
| file.seek(0x100000) |
| m.update(file.read(0x10000)) |
| return m.hexdigest()[0:8] |
| except FileNotFoundError: |
| return "NOFILE" |
|
|
|
|
| def eval(model, n, input): |
| qk = f"enc_p.encoder.attn_layers.{n}.conv_q.weight" |
| uk = f"enc_p.encoder.attn_layers.{n}.conv_k.weight" |
| vk = f"enc_p.encoder.attn_layers.{n}.conv_v.weight" |
| atoq, atok, atov = model[qk][:, :, 0], model[uk][:, :, 0], model[vk][:, :, 0] |
|
|
| attn = cal_cross_attn(atoq, atok, atov, input) |
| return attn |
|
|
|
|
| def main(path, root): |
| torch.manual_seed(114514) |
| model_a = torch.load(path, map_location="cpu")["weight"] |
|
|
| logger.info("Query:\t\t%s\t%s" % (path, model_hash(path))) |
|
|
| map_attn_a = {} |
| map_rand_input = {} |
| for n in range(6): |
| hidden_dim, embed_dim, _ = model_a[ |
| f"enc_p.encoder.attn_layers.{n}.conv_v.weight" |
| ].shape |
| rand_input = torch.randn([embed_dim, hidden_dim]) |
|
|
| map_attn_a[n] = eval(model_a, n, rand_input) |
| map_rand_input[n] = rand_input |
|
|
| del model_a |
|
|
| for name in sorted(list(os.listdir(root))): |
| path = "%s/%s" % (root, name) |
| model_b = torch.load(path, map_location="cpu")["weight"] |
|
|
| sims = [] |
| for n in range(6): |
| attn_a = map_attn_a[n] |
| attn_b = eval(model_b, n, map_rand_input[n]) |
|
|
| sim = torch.mean(torch.cosine_similarity(attn_a, attn_b)) |
| sims.append(sim) |
|
|
| logger.info( |
| "Reference:\t%s\t%s\t%s" |
| % (path, model_hash(path), f"{torch.mean(torch.stack(sims)) * 1e2:.2f}%") |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| query_path = r"assets\weights\mi v3.pth" |
| reference_root = r"assets\weights" |
| main(query_path, reference_root) |
|
|