import numpy as np import torch from torchvision.utils import save_image, make_grid import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation, PillowWriter from torchvision import transforms from torch.utils.data import Dataset def unorm(x): # unity norm. results in range of [0,1] # assume x (h,w,3) xmax = x.max((0,1)) xmin = x.min((0,1)) return(x - xmin)/(xmax - xmin) def norm_all(store, n_t, n_s): # runs unity norm on all timesteps of all samples nstore = np.zeros_like(store) for t in range(n_t): for s in range(n_s): nstore[t,s] = unorm(store[t,s]) return nstore def norm_torch(x_all): # runs unity norm on all timesteps of all samples # input is (n_samples, 3,h,w), the torch image format x = x_all.cpu().numpy() xmax = x.max((2,3)) xmin = x.min((2,3)) xmax = np.expand_dims(xmax,(2,3)) xmin = np.expand_dims(xmin,(2,3)) nstore = (x - xmin)/(xmax - xmin) return torch.from_numpy(nstore) def plot_grid(x,n_sample,n_rows,save_dir,w): # x:(n_sample, 3, h, w) ncols = n_sample//n_rows grid = make_grid(norm_torch(x), nrow=ncols) # curiously, nrow is number of columns.. or number of items in the row. save_image(grid, save_dir + f"run_image_w{w}.png") print('saved image at ' + save_dir + f"run_image_w{w}.png") return grid def plot_sample(x_gen_store,n_sample,nrows,save_dir, fn, w, save=False): ncols = n_sample//nrows sx_gen_store = np.moveaxis(x_gen_store,2,4) nsx_gen_store = norm_all(sx_gen_store, sx_gen_store.shape[0], n_sample) fig, axs = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey=True,figsize=(ncols,nrows)) def animate_diff(i, store): print(f'gif animating frame {i} of {store.shape[0]}', end='\r') plots = [] for row in range(nrows): for col in range(ncols): axs[row, col].clear() axs[row, col].set_xticks([]) axs[row, col].set_yticks([]) plots.append(axs[row, col].imshow(store[i,(row*ncols)+col])) return plots ani = FuncAnimation(fig, animate_diff, fargs=[nsx_gen_store], interval=200, blit=False, repeat=True, frames=nsx_gen_store.shape[0]) plt.close() if save: ani.save(save_dir + f"{fn}_w{w}.gif", dpi=100, writer=PillowWriter(fps=5)) print('saved gif at ' + save_dir + f"{fn}_w{w}.gif") return ani class CustomDataset(Dataset): def __init__(self, sfilename, lfilename, transform, null_context=False): self.sprites = np.load(sfilename,allow_pickle=True,fix_imports=True,encoding='latin1') self.slabels = np.load(lfilename,allow_pickle=True,fix_imports=True,encoding='latin1') print(f"sprite shape: {self.sprites.shape}") print(f"labels shape: {self.slabels.shape}") self.transform = transform self.null_context = null_context self.sprites_shape = self.sprites.shape self.slabel_shape = self.slabels.shape # Return the number of images in the dataset def __len__(self): return len(self.sprites) # Get the image and label at a given index def __getitem__(self, idx): # Return the image and label as a tuple if self.transform: image = self.transform(self.sprites[idx]) if self.null_context: label = torch.tensor(0).to(torch.int64) else: label = torch.tensor(self.slabels[idx]).to(torch.int64) return (image, label) def getshapes(self): # return shapes of data and labels return self.sprites_shape, self.slabel_shape