Spaces:
Runtime error
Runtime error
| import torch | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import random | |
| import os | |
| from scipy.stats import chi2 | |
| def manual_seed(seed, fix_cudnn=True): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| os.environ["PYTHONHASHSEED"] = str(seed) | |
| if fix_cudnn: | |
| torch.backends.cudnn.deterministic = True # noqa | |
| torch.backends.cudnn.benchmark = False # noqa | |
| def spcor(x, y): | |
| n = len(x) | |
| with torch.no_grad(): | |
| r = 1 - torch.sum(6 * torch.square(x - y)) / (n * (n**2 - 1)) | |
| return r | |
| def pdists(x, y): | |
| x = x.to("cuda") | |
| y = y.to("cuda") | |
| with torch.no_grad(): | |
| xsum = torch.sum(torch.square(x), axis=-1) | |
| ysum = torch.sum(torch.square(y), axis=-1) | |
| dists = xsum.view(-1, 1) + ysum.view(1, -1) - 2 * x @ y.T | |
| return dists.cpu() | |
| def cossim(x, y): | |
| x = x.to("cuda") | |
| y = y.to("cuda") | |
| with torch.no_grad(): | |
| similarities = ( | |
| x | |
| / ( | |
| torch.linalg.norm(x, axis=-1).view(-1, 1) | |
| * torch.linalg.norm(y, axis=-1).view(1, -1) | |
| ) | |
| ) | |
| return similarities.cpu() | |
| def fisher(p): | |
| count = 0 | |
| chi_2 = 0 | |
| for pvalue in p: | |
| if not np.isnan(pvalue): | |
| chi_2 -= 2 * np.log(pvalue) | |
| count += 1 | |
| return chi2.sf(chi_2, df=2 * count) | |
| def normalize_mc_midpoint(mid, base, ft): | |
| slope = ft - base | |
| mid -= slope * 0.5 | |
| mid -= base | |
| return mid | |
| def normalize_trace(trace, alphas): | |
| slope = trace[-1] - trace[0] | |
| start = trace[0] | |
| for i in range(len(trace)): | |
| trace[i] -= slope * alphas[i] | |
| trace[i] -= start | |
| return trace | |
| def output_hook(m, inp, op, name, feats): | |
| feats[name] = op.detach() | |
| def get_submodule(module, submodule_string): | |
| attributes = submodule_string.split(".") | |
| for attr in attributes: | |
| module = getattr(module, attr) | |
| return module | |
| def plot_trace(losses, alphas, normalize, model_a_name, model_b_name, plot_path): | |
| plt.figure(figsize=(8, 6)) | |
| if normalize: | |
| losses = normalize_trace(losses, alphas) | |
| plt.plot(alphas, losses, "o-") | |
| plt.xlabel("Alpha") | |
| plt.ylabel("Loss") | |
| plt.title(f"{model_a_name} (Left) vs {model_b_name} (Right)") | |
| plot_filename = f"{plot_path}.png" | |
| plt.savefig(plot_filename, dpi=300, bbox_inches="tight") | |
| plt.close() | |