imaginable-imaging / app_utils.py
fabio-deep
first commit
3275658
import cv2
import torch
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib import rc, colors
rc("font", **{"family": "serif", "serif": ["Roman"]})
# rc("text", usetex=True)
rc("image", interpolation="none")
# rc("text.latex", preamble=r"\usepackage{amsmath} \usepackage{amssymb}")
class MidpointNormalize(colors.Normalize):
def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
self.midpoint = midpoint
colors.Normalize.__init__(self, vmin, vmax, clip)
def __call__(self, value, clip=None):
v_ext = np.max([np.abs(self.vmin), np.abs(self.vmax)])
x, y = [-v_ext, self.midpoint, v_ext], [0, 0.5, 1]
return np.ma.masked_array(np.interp(value, x, y))
def postprocess(x):
return ((x + 1.0) * 127.5).squeeze().detach().cpu().numpy()
def vae_preprocess(args, pa):
# concatenate parents expand to input res for conditioning the vae
pa = torch.cat(
[pa[k] if len(pa[k].shape) > 1 else pa[k][..., None] for k in args.parents_x],
dim=1,
)
pa = (
pa[..., None, None].repeat(1, 1, *(args.input_res,) * 2).to(args.device).float()
)
return pa
def get_fig_arr(x, width=4.2, height=4.2, dpi=100, cmap="Greys_r", norm=None):
fig = plt.figure(figsize=(width, height), dpi=dpi)
ax = plt.axes([0, 0, 1, 1], frameon=False)
x = cv2.resize(x, (420, 420), interpolation=cv2.INTER_CUBIC)
if cmap == "Greys_r":
ax.imshow(x, cmap=cmap, vmin=0, vmax=255)
else:
ax.imshow(x, cmap=cmap, norm=norm)
ax.axis("off")
fig.canvas.draw()
img = np.array(fig.canvas.renderer.buffer_rgba())
plt.close(fig)
return img