File size: 976 Bytes
b59223f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# ztrain/stats.py
# Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted

import os
import torch
from typing import Optional

def gen_stats(delta : torch.Tensor, base : Optional[torch.Tensor]) -> tuple[float, float, float, float]:
    if base is None:
        rebuilt = delta
    else:
        rebuilt = base + delta
    norm = rebuilt.norm().item()
    if base is None:
        cosine = 0
    else:
        cosine = torch.nn.functional.cosine_similarity(rebuilt, base, dim=0).mean().item()
    min = delta.min().item()
    max = delta.max().item()
    del rebuilt
    return norm, cosine, min, max

def get_report(m0: torch.Tensor, stack : torch.Tensor, model_list : list[str]):
    norm, cosine, min, max = gen_stats(m0, None)
    print(f"Base Model {norm} {min} {max}")

    for i, s in enumerate(stack):
        model_name = os.path.basename(model_list[i])
        norm, cosine, min, max = gen_stats(s, m0)
        print(f"{model_name} {norm} {cosine} {min} {max}")