Spaces:
Running on Zero
Running on Zero
File size: 3,231 Bytes
b701455 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 | import torch
def bislerp(samples: torch.Tensor, width: int, height: int) -> torch.Tensor:
"""Bilinear spherical interpolation for latent upscaling."""
def slerp(b1, b2, r):
"""Spherical linear interpolation."""
c = b1.shape[-1]
b1_n = torch.norm(b1, dim=-1, keepdim=True)
b2_n = torch.norm(b2, dim=-1, keepdim=True)
b1_norm = b1 / b1_n
b2_norm = b2 / b2_n
b1_norm[b1_n.expand(-1, c) == 0.0] = 0.0
b2_norm[b2_n.expand(-1, c) == 0.0] = 0.0
dot = (b1_norm * b2_norm).sum(1)
omega = torch.acos(dot)
so = torch.sin(omega)
res = (torch.sin((1.0 - r.squeeze(1)) * omega) / so).unsqueeze(1) * b1_norm + \
(torch.sin(r.squeeze(1) * omega) / so).unsqueeze(1) * b2_norm
res *= (b1_n * (1.0 - r) + b2_n * r).expand(-1, c)
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
res[dot < 1e-5 - 1] = (b1 * (1.0 - r) + b2 * r)[dot < 1e-5 - 1]
return res
def gen_bilinear(length_old, length_new, device):
"""Generate bilinear interpolation data."""
c1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape(1, 1, 1, -1)
c1 = torch.nn.functional.interpolate(c1, size=(1, length_new), mode="bilinear")
ratios = c1 - c1.floor()
c1 = c1.to(torch.int64)
c2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape(1, 1, 1, -1) + 1
c2[:, :, :, -1] -= 1
c2 = torch.nn.functional.interpolate(c2, size=(1, length_new), mode="bilinear").to(torch.int64)
return ratios, c1, c2
orig_dtype = samples.dtype
samples = samples.float()
n, c, h, w = samples.shape
# Width interpolation
ratios, c1, c2 = gen_bilinear(w, width, samples.device)
p1 = samples.gather(-1, c1.expand(n, c, h, -1)).movedim(1, -1).reshape(-1, c)
p2 = samples.gather(-1, c2.expand(n, c, h, -1)).movedim(1, -1).reshape(-1, c)
result = slerp(p1, p2, ratios.expand(n, 1, h, -1).movedim(1, -1).reshape(-1, 1))
result = result.reshape(n, h, width, c).movedim(-1, 1)
# Height interpolation
ratios, c1, c2 = gen_bilinear(h, height, samples.device)
p1 = result.gather(-2, c1.reshape(1, 1, -1, 1).expand(n, c, -1, width)).movedim(1, -1).reshape(-1, c)
p2 = result.gather(-2, c2.reshape(1, 1, -1, 1).expand(n, c, -1, width)).movedim(1, -1).reshape(-1, c)
result = slerp(p1, p2, ratios.reshape(1, 1, -1, 1).expand(n, 1, -1, width).movedim(1, -1).reshape(-1, 1))
return result.reshape(n, height, width, c).movedim(-1, 1).to(orig_dtype)
def common_upscale(samples: torch.Tensor, width: int, height: int) -> torch.Tensor:
"""Upscale samples using bislerp."""
return bislerp(samples, width, height)
class LatentUpscale:
"""Upscale latent codes."""
def upscale(self, samples: dict, width: int, height: int, upscale_method: str = "bislerp",
downscale_factor: int = 8) -> tuple:
if width == 0 and height == 0:
return (samples,)
s = samples.copy()
s["samples"] = common_upscale(samples["samples"], max(64, width) // downscale_factor,
max(64, height) // downscale_factor)
return (s,)
|