File size: 3,231 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch


def bislerp(samples: torch.Tensor, width: int, height: int) -> torch.Tensor:
    """Bilinear spherical interpolation for latent upscaling."""

    def slerp(b1, b2, r):
        """Spherical linear interpolation."""
        c = b1.shape[-1]
        b1_n = torch.norm(b1, dim=-1, keepdim=True)
        b2_n = torch.norm(b2, dim=-1, keepdim=True)
        b1_norm = b1 / b1_n
        b2_norm = b2 / b2_n
        b1_norm[b1_n.expand(-1, c) == 0.0] = 0.0
        b2_norm[b2_n.expand(-1, c) == 0.0] = 0.0
        dot = (b1_norm * b2_norm).sum(1)
        omega = torch.acos(dot)
        so = torch.sin(omega)
        res = (torch.sin((1.0 - r.squeeze(1)) * omega) / so).unsqueeze(1) * b1_norm + \
              (torch.sin(r.squeeze(1) * omega) / so).unsqueeze(1) * b2_norm
        res *= (b1_n * (1.0 - r) + b2_n * r).expand(-1, c)
        res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
        res[dot < 1e-5 - 1] = (b1 * (1.0 - r) + b2 * r)[dot < 1e-5 - 1]
        return res

    def gen_bilinear(length_old, length_new, device):
        """Generate bilinear interpolation data."""
        c1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape(1, 1, 1, -1)
        c1 = torch.nn.functional.interpolate(c1, size=(1, length_new), mode="bilinear")
        ratios = c1 - c1.floor()
        c1 = c1.to(torch.int64)
        c2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape(1, 1, 1, -1) + 1
        c2[:, :, :, -1] -= 1
        c2 = torch.nn.functional.interpolate(c2, size=(1, length_new), mode="bilinear").to(torch.int64)
        return ratios, c1, c2

    orig_dtype = samples.dtype
    samples = samples.float()
    n, c, h, w = samples.shape

    # Width interpolation
    ratios, c1, c2 = gen_bilinear(w, width, samples.device)
    p1 = samples.gather(-1, c1.expand(n, c, h, -1)).movedim(1, -1).reshape(-1, c)
    p2 = samples.gather(-1, c2.expand(n, c, h, -1)).movedim(1, -1).reshape(-1, c)
    result = slerp(p1, p2, ratios.expand(n, 1, h, -1).movedim(1, -1).reshape(-1, 1))
    result = result.reshape(n, h, width, c).movedim(-1, 1)

    # Height interpolation
    ratios, c1, c2 = gen_bilinear(h, height, samples.device)
    p1 = result.gather(-2, c1.reshape(1, 1, -1, 1).expand(n, c, -1, width)).movedim(1, -1).reshape(-1, c)
    p2 = result.gather(-2, c2.reshape(1, 1, -1, 1).expand(n, c, -1, width)).movedim(1, -1).reshape(-1, c)
    result = slerp(p1, p2, ratios.reshape(1, 1, -1, 1).expand(n, 1, -1, width).movedim(1, -1).reshape(-1, 1))
    return result.reshape(n, height, width, c).movedim(-1, 1).to(orig_dtype)


def common_upscale(samples: torch.Tensor, width: int, height: int) -> torch.Tensor:
    """Upscale samples using bislerp."""
    return bislerp(samples, width, height)


class LatentUpscale:
    """Upscale latent codes."""
    def upscale(self, samples: dict, width: int, height: int, upscale_method: str = "bislerp", 
                downscale_factor: int = 8) -> tuple:
        if width == 0 and height == 0:
            return (samples,)
        s = samples.copy()
        s["samples"] = common_upscale(samples["samples"], max(64, width) // downscale_factor, 
                                     max(64, height) // downscale_factor)
        return (s,)