Spaces:
Sleeping
Sleeping
| import cv2 | |
| import torch | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from matplotlib import rc, colors | |
| rc("font", **{"family": "serif", "serif": ["Roman"]}) | |
| # rc("text", usetex=True) | |
| rc("image", interpolation="none") | |
| # rc("text.latex", preamble=r"\usepackage{amsmath} \usepackage{amssymb}") | |
| class MidpointNormalize(colors.Normalize): | |
| def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False): | |
| self.midpoint = midpoint | |
| colors.Normalize.__init__(self, vmin, vmax, clip) | |
| def __call__(self, value, clip=None): | |
| v_ext = np.max([np.abs(self.vmin), np.abs(self.vmax)]) | |
| x, y = [-v_ext, self.midpoint, v_ext], [0, 0.5, 1] | |
| return np.ma.masked_array(np.interp(value, x, y)) | |
| def postprocess(x): | |
| return ((x + 1.0) * 127.5).squeeze().detach().cpu().numpy() | |
| def vae_preprocess(args, pa): | |
| # concatenate parents expand to input res for conditioning the vae | |
| pa = torch.cat( | |
| [pa[k] if len(pa[k].shape) > 1 else pa[k][..., None] for k in args.parents_x], | |
| dim=1, | |
| ) | |
| pa = ( | |
| pa[..., None, None].repeat(1, 1, *(args.input_res,) * 2).to(args.device).float() | |
| ) | |
| return pa | |
| def get_fig_arr(x, width=4.2, height=4.2, dpi=100, cmap="Greys_r", norm=None): | |
| fig = plt.figure(figsize=(width, height), dpi=dpi) | |
| ax = plt.axes([0, 0, 1, 1], frameon=False) | |
| x = cv2.resize(x, (420, 420), interpolation=cv2.INTER_CUBIC) | |
| if cmap == "Greys_r": | |
| ax.imshow(x, cmap=cmap, vmin=0, vmax=255) | |
| else: | |
| ax.imshow(x, cmap=cmap, norm=norm) | |
| ax.axis("off") | |
| fig.canvas.draw() | |
| img = np.array(fig.canvas.renderer.buffer_rgba()) | |
| plt.close(fig) | |
| return img | |