| # ztrain/stats.py | |
| # Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted | |
| import os | |
| import torch | |
| from typing import Optional | |
| def gen_stats(delta : torch.Tensor, base : Optional[torch.Tensor]) -> tuple[float, float, float, float]: | |
| if base is None: | |
| rebuilt = delta | |
| else: | |
| rebuilt = base + delta | |
| norm = rebuilt.norm().item() | |
| if base is None: | |
| cosine = 0 | |
| else: | |
| cosine = torch.nn.functional.cosine_similarity(rebuilt, base, dim=0).mean().item() | |
| min = delta.min().item() | |
| max = delta.max().item() | |
| del rebuilt | |
| return norm, cosine, min, max | |
| def get_report(m0: torch.Tensor, stack : torch.Tensor, model_list : list[str]): | |
| norm, cosine, min, max = gen_stats(m0, None) | |
| print(f"Base Model {norm} {min} {max}") | |
| for i, s in enumerate(stack): | |
| model_name = os.path.basename(model_list[i]) | |
| norm, cosine, min, max = gen_stats(s, m0) | |
| print(f"{model_name} {norm} {cosine} {min} {max}") | |