X-Portrait45 / utils /utils.py
fffiloni's picture
Migrated from GitHub
bfed184 verified
Raw
History Blame Contribute Delete
2.22 kB
import random
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
def merge_lists_by_index(list1, list2):
# Check if both lists have the same number of elements
if len(list1) != len(list2):
raise ValueError("Both lists should have the same number of elements.")
# Merge the lists by concatenating strings at the same index
merged_list = [string1 + '. ' + string2 for string1, string2 in zip(list1, list2)]
return merged_list
def plot_x_y(x, y, x_label, y_label, save_path, **kwargs):
plt.plot( x , y, **kwargs)
plt.xlabel(x_label)
plt.ylabel(y_label)
plt.legend()
plt.savefig(save_path)
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# count # of param for a list of module
def count_param(module_list):
return sum(x.numel() for module in module_list for x in module.parameters()) / 10**6
# display the peak memory of cuda
def print_peak_memory(prefix, device):
if device == 0:
print(f"{prefix}: {torch.cuda.max_memory_allocated(device) // 1e6}MB ")
def anal_tensor(tensor, name):
sent = f" name: {name} mean: {tensor.mean().item()} std: {tensor.std().item()} min: {tensor.min().item()} max: {tensor.max().item()}"
print(sent)
def split(lst, split_nbr):
div = len(lst) // split_nbr
rest = len(lst) % split_nbr
results = []
start, end = 0, div
while start < len(lst):
if rest >= 1:
end += 1
rest -= 1
results.append(lst[start:end])
start, end = end, end+div
return results
def chunk(iterable, chunk_size):
ret = []
for record in iterable:
ret.append(record)
if len(ret) == chunk_size:
yield ret
ret = []
if ret:
yield ret
def image_concat_h(im1, im2):
dst = Image.new('RGB', (im1.width + im2.width, im1.height))
dst.paste(im1, (0, 0))
dst.paste(im2, (im1.width, 0))
return dst
def image_concat_v(im1, im2):
dst = Image.new('RGB', (im1.width, im1.height + im2.height))
dst.paste(im1, (0, 0))
dst.paste(im2, (0, im1.height))
return dst