| import numpy as np | |
| import math | |
| import PIL | |
| def postprocess(x): | |
| """[0,1] to uint8.""" | |
| x = np.clip(255 * x, 0, 255) | |
| x = np.cast[np.uint8](x) | |
| return x | |
| def tile(X, rows, cols): | |
| """Tile images for display.""" | |
| tiling = np.zeros((rows * X.shape[1], cols * X.shape[2], X.shape[3]), dtype = X.dtype) | |
| for i in range(rows): | |
| for j in range(cols): | |
| idx = i * cols + j | |
| if idx < X.shape[0]: | |
| img = X[idx,...] | |
| tiling[ | |
| i*X.shape[1]:(i+1)*X.shape[1], | |
| j*X.shape[2]:(j+1)*X.shape[2], | |
| :] = img | |
| return tiling | |
| def plot_batch(X, out_path): | |
| """Save batch of images tiled.""" | |
| n_channels = X.shape[3] | |
| if n_channels > 3: | |
| X = X[:,:,:,np.random.choice(n_channels, size = 3)] | |
| X = postprocess(X) | |
| rc = math.sqrt(X.shape[0]) | |
| rows = cols = math.ceil(rc) | |
| canvas = tile(X, rows, cols) | |
| canvas = np.squeeze(canvas) | |
| PIL.Image.fromarray(canvas).save(out_path) |