File size: 3,730 Bytes
35839a1 8a6ed33 35839a1 8a6ed33 35839a1 8a6ed33 35839a1 8a6ed33 35839a1 8a6ed33 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 | import numpy as np
import torch
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
from torchvision import transforms
from torch.utils.data import Dataset
def unorm(x):
# unity norm. results in range of [0,1]
# assume x (h,w,3)
xmax = x.max((0,1))
xmin = x.min((0,1))
return(x - xmin)/(xmax - xmin)
def norm_all(store, n_t, n_s):
# runs unity norm on all timesteps of all samples
nstore = np.zeros_like(store)
for t in range(n_t):
for s in range(n_s):
nstore[t,s] = unorm(store[t,s])
return nstore
def norm_torch(x_all):
# runs unity norm on all timesteps of all samples
# input is (n_samples, 3,h,w), the torch image format
x = x_all.cpu().numpy()
xmax = x.max((2,3))
xmin = x.min((2,3))
xmax = np.expand_dims(xmax,(2,3))
xmin = np.expand_dims(xmin,(2,3))
nstore = (x - xmin)/(xmax - xmin)
return torch.from_numpy(nstore)
def plot_grid(x,n_sample,n_rows,save_dir,w):
# x:(n_sample, 3, h, w)
ncols = n_sample//n_rows
grid = make_grid(norm_torch(x), nrow=ncols) # curiously, nrow is number of columns.. or number of items in the row.
save_image(grid, save_dir + f"run_image_w{w}.png")
print('saved image at ' + save_dir + f"run_image_w{w}.png")
return grid
def plot_sample(x_gen_store,n_sample,nrows,save_dir, fn, w, save=False):
ncols = n_sample//nrows
sx_gen_store = np.moveaxis(x_gen_store,2,4)
nsx_gen_store = norm_all(sx_gen_store, sx_gen_store.shape[0], n_sample)
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey=True,figsize=(ncols,nrows))
def animate_diff(i, store):
print(f'gif animating frame {i} of {store.shape[0]}', end='\r')
plots = []
for row in range(nrows):
for col in range(ncols):
axs[row, col].clear()
axs[row, col].set_xticks([])
axs[row, col].set_yticks([])
plots.append(axs[row, col].imshow(store[i,(row*ncols)+col]))
return plots
ani = FuncAnimation(fig, animate_diff, fargs=[nsx_gen_store], interval=200, blit=False, repeat=True, frames=nsx_gen_store.shape[0])
plt.close()
if save:
ani.save(save_dir + f"{fn}_w{w}.gif", dpi=100, writer=PillowWriter(fps=5))
print('saved gif at ' + save_dir + f"{fn}_w{w}.gif")
return ani
class CustomDataset(Dataset):
def __init__(self, sfilename, lfilename, transform, null_context=False):
self.sprites = np.load(sfilename,allow_pickle=True,fix_imports=True,encoding='latin1')
self.slabels = np.load(lfilename,allow_pickle=True,fix_imports=True,encoding='latin1')
print(f"sprite shape: {self.sprites.shape}")
print(f"labels shape: {self.slabels.shape}")
self.transform = transform
self.null_context = null_context
self.sprites_shape = self.sprites.shape
self.slabel_shape = self.slabels.shape
# Return the number of images in the dataset
def __len__(self):
return len(self.sprites)
# Get the image and label at a given index
def __getitem__(self, idx):
# Return the image and label as a tuple
if self.transform:
image = self.transform(self.sprites[idx])
if self.null_context:
label = torch.tensor(0).to(torch.int64)
else:
label = torch.tensor(self.slabels[idx]).to(torch.int64)
return (image, label)
def getshapes(self):
# return shapes of data and labels
return self.sprites_shape, self.slabel_shape
|