Upload folder using huggingface_hub
Browse files- ztrain/__init__.py +0 -0
- ztrain/io.py +30 -0
- ztrain/model.py +39 -0
- ztrain/signal.py +79 -0
- ztrain/stats.py +30 -0
- ztrain/tensors.py +258 -0
- ztrain/util.py +37 -0
ztrain/__init__.py
ADDED
|
File without changes
|
ztrain/io.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ztrain/io.py
|
| 2 |
+
# Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
from glob import glob
|
| 6 |
+
|
| 7 |
+
def flatten_index(model_paths : list[str], allow_list : list[str]):
|
| 8 |
+
flat = []
|
| 9 |
+
subtype = []
|
| 10 |
+
index = {}
|
| 11 |
+
ix = 0
|
| 12 |
+
for g in sorted(model_paths):
|
| 13 |
+
name = os.path.basename(g)
|
| 14 |
+
if name in allow_list:
|
| 15 |
+
index[name] = ix
|
| 16 |
+
flat.append(name)
|
| 17 |
+
if 'base' in g:
|
| 18 |
+
subtype.append('base')
|
| 19 |
+
elif 'instruct' in g:
|
| 20 |
+
subtype.append('instruct')
|
| 21 |
+
else:
|
| 22 |
+
subtype.append('other')
|
| 23 |
+
ix += 1
|
| 24 |
+
return index, flat, subtype
|
| 25 |
+
|
| 26 |
+
def list_for_path(path: str, include_folders: list[str], search: str = "/**/*") -> tuple[list[str], list[str], list[str], dict[str, int]]:
|
| 27 |
+
model_list = sorted([*[ f for f in glob(path + search)]])
|
| 28 |
+
group_idx, model_names, subtypes = flatten_index(model_list, include_folders)
|
| 29 |
+
groups = [[m] for m in model_names]
|
| 30 |
+
return model_names, subtypes, model_list, group_idx
|
ztrain/model.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ztrain/model.py
|
| 2 |
+
# Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted
|
| 3 |
+
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
import re
|
| 6 |
+
|
| 7 |
+
def generate_merge_group(group_data : list, parents : list[int] = []):
|
| 8 |
+
# drill down until we find a list of strings, then yield it with a parent tree index
|
| 9 |
+
for i, g in enumerate(group_data):
|
| 10 |
+
if isinstance(g, list):
|
| 11 |
+
yield from generate_merge_group(g, parents + [i])
|
| 12 |
+
else:
|
| 13 |
+
yield g, parents + [i]
|
| 14 |
+
|
| 15 |
+
def merge_groups(group_data : list):
|
| 16 |
+
results = defaultdict(list)
|
| 17 |
+
for g, k in generate_merge_group(group_data):
|
| 18 |
+
key = tuple(k[:-1])
|
| 19 |
+
results[key].append(g)
|
| 20 |
+
return results
|
| 21 |
+
|
| 22 |
+
def get_layer_type(k : str) -> tuple[int, str, str, str]:
|
| 23 |
+
matcher = re.compile(r"model.layers.(\d+)\.(.+)\.(.+)\.(.+)")
|
| 24 |
+
|
| 25 |
+
m = matcher.match(k)
|
| 26 |
+
if m is not None:
|
| 27 |
+
return int(m.group(1)), m.group(2), m.group(3), m.group(4)
|
| 28 |
+
matcher = re.compile(r"model.layers.(\d+)\.(.+)\.(.+)")
|
| 29 |
+
if m is not None:
|
| 30 |
+
return int(m.group(1)), m.group(2), "", m.group(3)
|
| 31 |
+
|
| 32 |
+
if "model.norm.weight" == k:
|
| 33 |
+
return -1, "norm", "", "weight"
|
| 34 |
+
if "model.embed_tokens.weight" == k:
|
| 35 |
+
return -1, "embed_tokens", "", "weight"
|
| 36 |
+
if "lm_head.weight" == k:
|
| 37 |
+
return -1, "lm_head", "", "weight"
|
| 38 |
+
print(f"Unknown key {k}")
|
| 39 |
+
return -1, "unknown", "unknown", "unknown"
|
ztrain/signal.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ztrain/signal.py
|
| 2 |
+
# Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
def gaussian_kernel(size, sigma=1.0):
|
| 7 |
+
"""
|
| 8 |
+
Generates a 2D Gaussian kernel using PyTorch.
|
| 9 |
+
|
| 10 |
+
Parameters:
|
| 11 |
+
- size: The size of the kernel (an integer). It's recommended to use an odd number
|
| 12 |
+
to have a central pixel.
|
| 13 |
+
- sigma: The standard deviation of the Gaussian distribution.
|
| 14 |
+
|
| 15 |
+
Returns:
|
| 16 |
+
- A 2D PyTorch tensor representing the Gaussian kernel.
|
| 17 |
+
"""
|
| 18 |
+
size = int(size) // 2
|
| 19 |
+
x, y = torch.meshgrid(torch.arange(-size, size+1), torch.arange(-size, size+1))
|
| 20 |
+
g = torch.exp(-(x**2 + y**2) / (2 * sigma**2))
|
| 21 |
+
return g / g.sum()
|
| 22 |
+
|
| 23 |
+
def laplacian_kernel(size, scale=1.0):
|
| 24 |
+
"""
|
| 25 |
+
Creates a Laplacian kernel for edge detection with an adjustable size and scale factor.
|
| 26 |
+
|
| 27 |
+
Parameters:
|
| 28 |
+
- size: The size of the kernel (an integer). It's recommended to use an odd number
|
| 29 |
+
to ensure a central pixel.
|
| 30 |
+
- scale: A float that adjusts the intensity of the edge detection effect.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
- A 2D PyTorch tensor representing the scaled Laplacian kernel.
|
| 34 |
+
"""
|
| 35 |
+
if size % 2 == 0:
|
| 36 |
+
raise ValueError("Size must be odd.")
|
| 37 |
+
|
| 38 |
+
# Initialize the kernel with zeros
|
| 39 |
+
kernel = torch.zeros((size, size), dtype=torch.float32)
|
| 40 |
+
|
| 41 |
+
# Set the center pixel
|
| 42 |
+
kernel[size // 2, size // 2] = -4.0
|
| 43 |
+
|
| 44 |
+
# Set the immediate neighbors
|
| 45 |
+
kernel[size // 2, size // 2 - 1] = kernel[size // 2, size // 2 + 1] = 1.0
|
| 46 |
+
kernel[size // 2 - 1, size // 2] = kernel[size // 2 + 1, size // 2] = 1.0
|
| 47 |
+
|
| 48 |
+
# For larger kernels, adjust the outer pixels (this simplistic approach might need refinement for larger sizes)
|
| 49 |
+
if size > 3:
|
| 50 |
+
for i in range(size):
|
| 51 |
+
for j in range(size):
|
| 52 |
+
if i == 0 or i == size - 1 or j == 0 or j == size - 1:
|
| 53 |
+
kernel[i, j] = 1.0
|
| 54 |
+
|
| 55 |
+
# Apply the scale factor
|
| 56 |
+
kernel *= scale
|
| 57 |
+
|
| 58 |
+
# Adjust the kernel so that its sum is 0
|
| 59 |
+
center = size // 2
|
| 60 |
+
kernel[center, center] = -torch.sum(kernel) + kernel[center, center]
|
| 61 |
+
|
| 62 |
+
return kernel
|
| 63 |
+
|
| 64 |
+
def fftshift(input):
|
| 65 |
+
"""
|
| 66 |
+
Reorients the FFT output so the zero-frequency component is at the center.
|
| 67 |
+
|
| 68 |
+
Parameters:
|
| 69 |
+
- input: A 2D tensor representing the FFT output.
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
- A 2D tensor with the zero-frequency component shifted to the center.
|
| 73 |
+
"""
|
| 74 |
+
# For even dimensions, we split at dim_size // 2. For odd dimensions, we need to do (dim_size + 1) // 2
|
| 75 |
+
for dim in range(2): # assuming input is 2D
|
| 76 |
+
n = input.shape[dim]
|
| 77 |
+
half = (n + 1) // 2
|
| 78 |
+
input = torch.roll(input, shifts=half, dims=dim)
|
| 79 |
+
return input
|
ztrain/stats.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ztrain/stats.py
|
| 2 |
+
# Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
def gen_stats(delta : torch.Tensor, base : Optional[torch.Tensor]) -> tuple[float, float, float, float]:
|
| 9 |
+
if base is None:
|
| 10 |
+
rebuilt = delta
|
| 11 |
+
else:
|
| 12 |
+
rebuilt = base + delta
|
| 13 |
+
norm = rebuilt.norm().item()
|
| 14 |
+
if base is None:
|
| 15 |
+
cosine = 0
|
| 16 |
+
else:
|
| 17 |
+
cosine = torch.nn.functional.cosine_similarity(rebuilt, base, dim=0).mean().item()
|
| 18 |
+
min = delta.min().item()
|
| 19 |
+
max = delta.max().item()
|
| 20 |
+
del rebuilt
|
| 21 |
+
return norm, cosine, min, max
|
| 22 |
+
|
| 23 |
+
def get_report(m0: torch.Tensor, stack : torch.Tensor, model_list : list[str]):
|
| 24 |
+
norm, cosine, min, max = gen_stats(m0, None)
|
| 25 |
+
print(f"Base Model {norm} {min} {max}")
|
| 26 |
+
|
| 27 |
+
for i, s in enumerate(stack):
|
| 28 |
+
model_name = os.path.basename(model_list[i])
|
| 29 |
+
norm, cosine, min, max = gen_stats(s, m0)
|
| 30 |
+
print(f"{model_name} {norm} {cosine} {min} {max}")
|
ztrain/tensors.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ztrain/tensors.py
|
| 2 |
+
# Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from typing import Generator, Tuple
|
| 6 |
+
|
| 7 |
+
def normalize_to(m1 : torch.Tensor, norm : torch.float32) -> tuple[torch.Tensor, torch.float32, torch.float32]:
|
| 8 |
+
m1 = m1.to(torch.float32)
|
| 9 |
+
m1_norm = torch.norm(m1)
|
| 10 |
+
ratio = (norm / m1_norm).item()
|
| 11 |
+
m1 = m1 * ratio
|
| 12 |
+
return m1, norm.item(), ratio
|
| 13 |
+
|
| 14 |
+
def norm_ratio(m1 : torch.Tensor, m2 : torch.Tensor) -> float:
|
| 15 |
+
m1_norm = torch.norm(m1)
|
| 16 |
+
m2_norm = torch.norm(m2)
|
| 17 |
+
ratio = (m1_norm / m2_norm).item()
|
| 18 |
+
print(f"Norms {m1_norm} {m2_norm} {ratio}")
|
| 19 |
+
return ratio
|
| 20 |
+
|
| 21 |
+
def merge_tensors_fft2(v0: torch.Tensor, v1: torch.Tensor, t: float) -> torch.Tensor:
|
| 22 |
+
"""
|
| 23 |
+
Merges two tensors using 2D Fourier transform interpolation.
|
| 24 |
+
|
| 25 |
+
Parameters:
|
| 26 |
+
- v0 (torch.Tensor): The first input tensor.
|
| 27 |
+
- v1 (torch.Tensor): The second input tensor.
|
| 28 |
+
- t (float): Interpolation parameter (0 <= t <= 1).
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
- torch.Tensor: The tensor resulting from the interpolated inverse FFT.
|
| 32 |
+
"""
|
| 33 |
+
v0 = v0.to("cuda:0")
|
| 34 |
+
v1 = v1.to("cuda:0")
|
| 35 |
+
|
| 36 |
+
# Ensure the input tensors are on the same device and dtype
|
| 37 |
+
if len(v0.shape) == 1:
|
| 38 |
+
fft_v0 = torch.fft.fft(v0)
|
| 39 |
+
fft_v1 = torch.fft.fft(v1)
|
| 40 |
+
result_fft = torch.zeros_like(fft_v0)
|
| 41 |
+
|
| 42 |
+
real_v0 = fft_v0.real
|
| 43 |
+
real_v1 = fft_v1.real
|
| 44 |
+
abs_real_v0 = real_v0.abs()
|
| 45 |
+
abs_real_v1 = real_v1.abs()
|
| 46 |
+
|
| 47 |
+
sign_mask = real_v0.sign() == real_v1.sign()
|
| 48 |
+
larger_values_mask = abs_real_v0 > abs_real_v1
|
| 49 |
+
|
| 50 |
+
result_fft.real[sign_mask] = (1 - t) * real_v0[sign_mask] + t * real_v1[sign_mask]
|
| 51 |
+
result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])
|
| 52 |
+
|
| 53 |
+
imag_v0 = fft_v0.imag
|
| 54 |
+
imag_v1 = fft_v1.imag
|
| 55 |
+
abs_imag_v0 = imag_v0.abs()
|
| 56 |
+
abs_imag_v1 = imag_v1.abs()
|
| 57 |
+
larger_values_mask_imag = abs_imag_v0 > abs_imag_v1
|
| 58 |
+
|
| 59 |
+
result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask]
|
| 60 |
+
result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])
|
| 61 |
+
|
| 62 |
+
merged_tensor = torch.fft.ifft(result_fft).real # Taking the real part
|
| 63 |
+
del v0, v1, fft_v0, fft_v1, result_fft
|
| 64 |
+
return merged_tensor
|
| 65 |
+
|
| 66 |
+
# Perform the 2D FFT on both tensors
|
| 67 |
+
fft_v0 = torch.fft.fftn(v0, dim=(-2, -1))
|
| 68 |
+
fft_v1 = torch.fft.fftn(v1, dim=(-2, -1))
|
| 69 |
+
|
| 70 |
+
# Initialize the result FFT tensor
|
| 71 |
+
result_fft = torch.zeros_like(fft_v0)
|
| 72 |
+
|
| 73 |
+
# Compare real parts of the coefficients
|
| 74 |
+
real_v0 = fft_v0.real
|
| 75 |
+
real_v1 = fft_v1.real
|
| 76 |
+
abs_real_v0 = real_v0.abs()
|
| 77 |
+
abs_real_v1 = real_v1.abs()
|
| 78 |
+
|
| 79 |
+
# Create masks for where signs match and where they do not
|
| 80 |
+
sign_mask = real_v0.sign() == real_v1.sign()
|
| 81 |
+
larger_values_mask = abs_real_v0 > abs_real_v1
|
| 82 |
+
|
| 83 |
+
# Where signs match, interpolate; where signs do not match, take the larger by magnitude
|
| 84 |
+
result_fft.real[sign_mask] = (1 - t) * real_v0[sign_mask] + t * real_v1[sign_mask]
|
| 85 |
+
result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])
|
| 86 |
+
|
| 87 |
+
del real_v0, real_v1, abs_real_v0, abs_real_v1, larger_values_mask
|
| 88 |
+
|
| 89 |
+
# Assuming the imaginary part should be treated similarly, adjust this if not
|
| 90 |
+
imag_v0 = fft_v0.imag
|
| 91 |
+
imag_v1 = fft_v1.imag
|
| 92 |
+
abs_imag_v0 = imag_v0.abs()
|
| 93 |
+
abs_imag_v1 = imag_v1.abs()
|
| 94 |
+
larger_values_mask_imag = abs_imag_v0 > abs_imag_v1
|
| 95 |
+
|
| 96 |
+
result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask]
|
| 97 |
+
result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])
|
| 98 |
+
|
| 99 |
+
del imag_v0, imag_v1, abs_imag_v0, abs_imag_v1, larger_values_mask_imag, sign_mask
|
| 100 |
+
|
| 101 |
+
# Perform the inverse FFT to go back to the spatial domain
|
| 102 |
+
merged_tensor = torch.fft.ifftn(result_fft, dim=(-2, -1)).real # Taking the real part
|
| 103 |
+
|
| 104 |
+
del fft_v0, fft_v1, result_fft
|
| 105 |
+
|
| 106 |
+
return merged_tensor
|
| 107 |
+
|
| 108 |
+
def correlate_pairs(tensors : torch.Tensor, work_device : str = "cuda:0", store_device : str = "cpu") -> torch.Tensor:
|
| 109 |
+
n = tensors.shape[0]
|
| 110 |
+
matrix = torch.zeros(n, n).to(store_device)
|
| 111 |
+
for i in range(n):
|
| 112 |
+
a = tensors[i].to(work_device)
|
| 113 |
+
for j in range(i + 1, n):
|
| 114 |
+
b = tensors[j].to(work_device)
|
| 115 |
+
matrix[i, j] = matrix[j, i] = torch.nn.functional.cosine_similarity(a, b, dim=0).nan_to_num(0).mean().item()
|
| 116 |
+
b.to(store_device)
|
| 117 |
+
a.to(store_device)
|
| 118 |
+
return matrix
|
| 119 |
+
|
| 120 |
+
def least_correlated_pairs(correlation_tensor: torch.Tensor) -> Generator[Tuple[int, int, float], None, None]:
|
| 121 |
+
"""
|
| 122 |
+
Generates tuples of indices and their corresponding least correlation coefficient
|
| 123 |
+
from a given correlation matrix, ensuring that once an index is used, it is no longer
|
| 124 |
+
considered in future tuples.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
correlation_tensor (torch.Tensor): A 2D square tensor representing the correlation matrix.
|
| 128 |
+
|
| 129 |
+
Yields:
|
| 130 |
+
Tuple[int, int, float]: A tuple containing the x-index, y-index, and the correlation coefficient
|
| 131 |
+
of the least correlated pairs in the matrix.
|
| 132 |
+
"""
|
| 133 |
+
n = correlation_tensor.size(0)
|
| 134 |
+
# Create a mask to exclude diagonal and already processed elements
|
| 135 |
+
mask = torch.triu(torch.ones(n, n, dtype=torch.bool), diagonal=1)
|
| 136 |
+
|
| 137 |
+
while torch.any(mask):
|
| 138 |
+
# Apply mask to get relevant correlations
|
| 139 |
+
valid_correlation = torch.where(mask, correlation_tensor, torch.tensor(float('inf')))
|
| 140 |
+
|
| 141 |
+
# Find the minimum non-zero absolute correlation
|
| 142 |
+
min_val = torch.min(torch.abs(valid_correlation[valid_correlation != float('inf')]))
|
| 143 |
+
|
| 144 |
+
# Locate the indices with the minimum correlation
|
| 145 |
+
min_indices = torch.nonzero(torch.abs(valid_correlation) == min_val, as_tuple=True)
|
| 146 |
+
if len(min_indices[0]) == 0:
|
| 147 |
+
break
|
| 148 |
+
|
| 149 |
+
# Yield the first index pair (greedy approach) along with the correlation coefficient
|
| 150 |
+
x, y = min_indices[0][0].item(), min_indices[1][0].item()
|
| 151 |
+
coefficient = correlation_tensor[x, y].item() # Extract the actual correlation value
|
| 152 |
+
yield (x, y, coefficient)
|
| 153 |
+
|
| 154 |
+
# Mask out the entire row and column for both indices
|
| 155 |
+
mask[x, :] = False
|
| 156 |
+
mask[:, x] = False
|
| 157 |
+
mask[y, :] = False
|
| 158 |
+
mask[:, y] = False
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def merge_tensors_fft2_autoscale(v0: torch.Tensor, v1: torch.Tensor, t: float) -> tuple[torch.Tensor, float, float]:
|
| 162 |
+
"""
|
| 163 |
+
Merges two tensors using 2D Fourier transform interpolation.
|
| 164 |
+
|
| 165 |
+
Parameters:
|
| 166 |
+
- v0 (torch.Tensor): The first input tensor.
|
| 167 |
+
- v1 (torch.Tensor): The second input tensor.
|
| 168 |
+
- t (float): Interpolation parameter (0 <= t <= 1).
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
- torch.Tensor: The tensor resulting from the interpolated inverse FFT.
|
| 172 |
+
"""
|
| 173 |
+
v0 = v0.to("cuda:0")
|
| 174 |
+
v1 = v1.to("cuda:0")
|
| 175 |
+
|
| 176 |
+
# Calculate norms of each tensor
|
| 177 |
+
norm_v0_t = v0.norm()
|
| 178 |
+
norm_v1_t = v1.norm()
|
| 179 |
+
|
| 180 |
+
# Scale tensors by their norms
|
| 181 |
+
v0 = v0 / norm_v0_t if norm_v0_t != 0 else v0
|
| 182 |
+
v1 = v1 / norm_v1_t if norm_v1_t != 0 else v1
|
| 183 |
+
|
| 184 |
+
norm_v0 = norm_v0_t.item()
|
| 185 |
+
norm_v1 = norm_v1_t.item()
|
| 186 |
+
del norm_v0_t, norm_v1_t
|
| 187 |
+
|
| 188 |
+
# Ensure the input tensors are on the same device and dtype
|
| 189 |
+
if len(v0.shape) == 1:
|
| 190 |
+
fft_v0 = torch.fft.fft(v0)
|
| 191 |
+
fft_v1 = torch.fft.fft(v1)
|
| 192 |
+
result_fft = torch.zeros_like(fft_v0)
|
| 193 |
+
|
| 194 |
+
real_v0 = fft_v0.real
|
| 195 |
+
real_v1 = fft_v1.real
|
| 196 |
+
abs_real_v0 = real_v0.abs()
|
| 197 |
+
abs_real_v1 = real_v1.abs()
|
| 198 |
+
|
| 199 |
+
sign_mask = real_v0.sign() == real_v1.sign()
|
| 200 |
+
larger_values_mask = abs_real_v0 > abs_real_v1
|
| 201 |
+
|
| 202 |
+
result_fft.real[sign_mask] = (1 - t) * real_v0[sign_mask] + t * real_v1[sign_mask]
|
| 203 |
+
result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])
|
| 204 |
+
|
| 205 |
+
imag_v0 = fft_v0.imag
|
| 206 |
+
imag_v1 = fft_v1.imag
|
| 207 |
+
abs_imag_v0 = imag_v0.abs()
|
| 208 |
+
abs_imag_v1 = imag_v1.abs()
|
| 209 |
+
larger_values_mask_imag = abs_imag_v0 > abs_imag_v1
|
| 210 |
+
|
| 211 |
+
result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask]
|
| 212 |
+
result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])
|
| 213 |
+
|
| 214 |
+
merged_tensor = torch.fft.ifft(result_fft).real # Taking the real part
|
| 215 |
+
del v0, v1, fft_v0, fft_v1, result_fft
|
| 216 |
+
return merged_tensor, norm_v0, norm_v1
|
| 217 |
+
|
| 218 |
+
# Perform the 2D FFT on both tensors
|
| 219 |
+
fft_v0 = torch.fft.fftn(v0, dim=(-2, -1))
|
| 220 |
+
fft_v1 = torch.fft.fftn(v1, dim=(-2, -1))
|
| 221 |
+
|
| 222 |
+
# Initialize the result FFT tensor
|
| 223 |
+
result_fft = torch.zeros_like(fft_v0)
|
| 224 |
+
|
| 225 |
+
# Compare real parts of the coefficients
|
| 226 |
+
real_v0 = fft_v0.real
|
| 227 |
+
real_v1 = fft_v1.real
|
| 228 |
+
abs_real_v0 = real_v0.abs()
|
| 229 |
+
abs_real_v1 = real_v1.abs()
|
| 230 |
+
|
| 231 |
+
# Create masks for where signs match and where they do not
|
| 232 |
+
sign_mask = real_v0.sign() == real_v1.sign()
|
| 233 |
+
larger_values_mask = abs_real_v0 > abs_real_v1
|
| 234 |
+
|
| 235 |
+
# Where signs match, interpolate; where signs do not match, take the larger by magnitude
|
| 236 |
+
result_fft.real[sign_mask] = (1 - t) * real_v0[sign_mask] + t * real_v1[sign_mask]
|
| 237 |
+
result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask])
|
| 238 |
+
|
| 239 |
+
del real_v0, real_v1, abs_real_v0, abs_real_v1, larger_values_mask
|
| 240 |
+
|
| 241 |
+
# Assuming the imaginary part should be treated similarly, adjust this if not
|
| 242 |
+
imag_v0 = fft_v0.imag
|
| 243 |
+
imag_v1 = fft_v1.imag
|
| 244 |
+
abs_imag_v0 = imag_v0.abs()
|
| 245 |
+
abs_imag_v1 = imag_v1.abs()
|
| 246 |
+
larger_values_mask_imag = abs_imag_v0 > abs_imag_v1
|
| 247 |
+
|
| 248 |
+
result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask]
|
| 249 |
+
result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask])
|
| 250 |
+
|
| 251 |
+
del imag_v0, imag_v1, abs_imag_v0, abs_imag_v1, larger_values_mask_imag, sign_mask
|
| 252 |
+
|
| 253 |
+
# Perform the inverse FFT to go back to the spatial domain
|
| 254 |
+
merged_tensor = torch.fft.ifftn(result_fft, dim=(-2, -1)).real # Taking the real part
|
| 255 |
+
|
| 256 |
+
del fft_v0, fft_v1, result_fft
|
| 257 |
+
|
| 258 |
+
return merged_tensor, norm_v0, norm_v1
|
ztrain/util.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ztrain/util.py
|
| 2 |
+
# Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted
|
| 3 |
+
|
| 4 |
+
import contextlib
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@contextlib.contextmanager
|
| 9 |
+
def cuda_memory_profiler(display : str = True):
|
| 10 |
+
"""
|
| 11 |
+
A context manager for profiling CUDA memory usage in PyTorch.
|
| 12 |
+
"""
|
| 13 |
+
if display is False:
|
| 14 |
+
yield
|
| 15 |
+
return
|
| 16 |
+
|
| 17 |
+
if not torch.cuda.is_available():
|
| 18 |
+
print("CUDA is not available, skipping memory profiling")
|
| 19 |
+
yield
|
| 20 |
+
return
|
| 21 |
+
|
| 22 |
+
torch.cuda.reset_peak_memory_stats()
|
| 23 |
+
torch.cuda.synchronize()
|
| 24 |
+
start_memory = torch.cuda.memory_allocated()
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
yield
|
| 28 |
+
finally:
|
| 29 |
+
torch.cuda.synchronize()
|
| 30 |
+
end_memory = torch.cuda.memory_allocated()
|
| 31 |
+
print(f"Peak memory usage: {torch.cuda.max_memory_allocated() / (1024 ** 2):.2f} MB")
|
| 32 |
+
print(f"Memory allocated at start: {start_memory / (1024 ** 2):.2f} MB")
|
| 33 |
+
print(f"Memory allocated at end: {end_memory / (1024 ** 2):.2f} MB")
|
| 34 |
+
print(f"Net memory change: {(end_memory - start_memory) / (1024 ** 2):.2f} MB")
|
| 35 |
+
|
| 36 |
+
def get_device():
|
| 37 |
+
return torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
|