File size: 976 Bytes
b59223f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
# ztrain/stats.py
# Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted
import os
import torch
from typing import Optional
def gen_stats(delta : torch.Tensor, base : Optional[torch.Tensor]) -> tuple[float, float, float, float]:
if base is None:
rebuilt = delta
else:
rebuilt = base + delta
norm = rebuilt.norm().item()
if base is None:
cosine = 0
else:
cosine = torch.nn.functional.cosine_similarity(rebuilt, base, dim=0).mean().item()
min = delta.min().item()
max = delta.max().item()
del rebuilt
return norm, cosine, min, max
def get_report(m0: torch.Tensor, stack : torch.Tensor, model_list : list[str]):
norm, cosine, min, max = gen_stats(m0, None)
print(f"Base Model {norm} {min} {max}")
for i, s in enumerate(stack):
model_name = os.path.basename(model_list[i])
norm, cosine, min, max = gen_stats(s, m0)
print(f"{model_name} {norm} {cosine} {min} {max}")
|