| import numpy as np |
| import torch |
| from torchvision.utils import save_image, make_grid |
| import matplotlib.pyplot as plt |
| from matplotlib.animation import FuncAnimation, PillowWriter |
| from torchvision import transforms |
| from torch.utils.data import Dataset |
|
|
|
|
|
|
| def unorm(x): |
| |
| |
| xmax = x.max((0,1)) |
| xmin = x.min((0,1)) |
| return(x - xmin)/(xmax - xmin) |
|
|
| def norm_all(store, n_t, n_s): |
| |
| nstore = np.zeros_like(store) |
| for t in range(n_t): |
| for s in range(n_s): |
| nstore[t,s] = unorm(store[t,s]) |
| return nstore |
|
|
| def norm_torch(x_all): |
| |
| |
| x = x_all.cpu().numpy() |
| xmax = x.max((2,3)) |
| xmin = x.min((2,3)) |
| xmax = np.expand_dims(xmax,(2,3)) |
| xmin = np.expand_dims(xmin,(2,3)) |
| nstore = (x - xmin)/(xmax - xmin) |
| return torch.from_numpy(nstore) |
|
|
|
|
| def plot_grid(x,n_sample,n_rows,save_dir,w): |
| |
| ncols = n_sample//n_rows |
| grid = make_grid(norm_torch(x), nrow=ncols) |
| save_image(grid, save_dir + f"run_image_w{w}.png") |
| print('saved image at ' + save_dir + f"run_image_w{w}.png") |
| return grid |
|
|
| def plot_sample(x_gen_store,n_sample,nrows,save_dir, fn, w, save=False): |
| ncols = n_sample//nrows |
| sx_gen_store = np.moveaxis(x_gen_store,2,4) |
| nsx_gen_store = norm_all(sx_gen_store, sx_gen_store.shape[0], n_sample) |
| fig, axs = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey=True,figsize=(ncols,nrows)) |
| def animate_diff(i, store): |
| print(f'gif animating frame {i} of {store.shape[0]}', end='\r') |
| plots = [] |
| for row in range(nrows): |
| for col in range(ncols): |
| axs[row, col].clear() |
| axs[row, col].set_xticks([]) |
| axs[row, col].set_yticks([]) |
| plots.append(axs[row, col].imshow(store[i,(row*ncols)+col])) |
| return plots |
| ani = FuncAnimation(fig, animate_diff, fargs=[nsx_gen_store], interval=200, blit=False, repeat=True, frames=nsx_gen_store.shape[0]) |
| plt.close() |
| if save: |
| ani.save(save_dir + f"{fn}_w{w}.gif", dpi=100, writer=PillowWriter(fps=5)) |
| print('saved gif at ' + save_dir + f"{fn}_w{w}.gif") |
| return ani |
|
|
|
|
| class CustomDataset(Dataset): |
| def __init__(self, sfilename, lfilename, transform, null_context=False): |
| self.sprites = np.load(sfilename,allow_pickle=True,fix_imports=True,encoding='latin1') |
| self.slabels = np.load(lfilename,allow_pickle=True,fix_imports=True,encoding='latin1') |
| print(f"sprite shape: {self.sprites.shape}") |
| print(f"labels shape: {self.slabels.shape}") |
| self.transform = transform |
| self.null_context = null_context |
| self.sprites_shape = self.sprites.shape |
| self.slabel_shape = self.slabels.shape |
| |
| |
| def __len__(self): |
| return len(self.sprites) |
| |
| |
| def __getitem__(self, idx): |
| |
| if self.transform: |
| image = self.transform(self.sprites[idx]) |
| if self.null_context: |
| label = torch.tensor(0).to(torch.int64) |
| else: |
| label = torch.tensor(self.slabels[idx]).to(torch.int64) |
| return (image, label) |
|
|
| def getshapes(self): |
| |
| return self.sprites_shape, self.slabel_shape |
|
|