File size: 4,601 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import math
import torch
from PIL import Image
from torchvision.transforms.functional import to_pil_image, to_tensor as tv_to_tensor


def get_tiled_scale_steps(width: int, height: int, tile_x: int, tile_y: int, overlap: int) -> int:
    """Calculate steps required for tiled scaling."""
    return math.ceil(height / (tile_y - overlap)) * math.ceil(width / (tile_x - overlap))


@torch.inference_mode()
def tiled_scale(samples: torch.Tensor, function: callable, tile_x: int = 64, tile_y: int = 64,
                overlap: int = 8, upscale_amount: float = 4, out_channels: int = 3, pbar=None) -> torch.Tensor:
    """Perform tiled upscaling on samples."""
    h_up, w_up = round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)
    output = torch.empty((samples.shape[0], out_channels, h_up, w_up), device="cpu")

    for b in range(samples.shape[0]):
        s = samples[b:b + 1]
        out = torch.zeros((s.shape[0], out_channels, h_up, w_up), device="cpu")
        out_div = torch.zeros_like(out)

        for y in range(0, s.shape[2], tile_y - overlap):
            for x in range(0, s.shape[3], tile_x - overlap):
                ps = function(s[:, :, y:y + tile_y, x:x + tile_x]).cpu()
                mask = torch.ones_like(ps)
                feather = round(overlap * upscale_amount)
                for t in range(feather):
                    f = (1.0 / feather) * (t + 1)
                    mask[:, :, t:1 + t, :] *= f
                    mask[:, :, -1 - t:-t or None, :] *= f
                    mask[:, :, :, t:1 + t] *= f
                    mask[:, :, :, -1 - t:-t or None] *= f

                y_start, y_end = round(y * upscale_amount), round((y + tile_y) * upscale_amount)
                x_start, x_end = round(x * upscale_amount), round((x + tile_x) * upscale_amount)
                out[:, :, y_start:y_end, x_start:x_end] += ps * mask
                out_div[:, :, y_start:y_end, x_start:x_end] += mask

        output[b:b + 1] = out / out_div
    return output


def flatten(img: Image.Image, bgcolor: str) -> Image.Image:
    """Replace transparency with background color."""
    if img.mode == "RGB":
        return img
    return Image.alpha_composite(Image.new("RGBA", img.size, bgcolor), img).convert("RGB")


BLUR_KERNEL_SIZE = 15


def tensor_to_pil(img_tensor: torch.Tensor, batch_index: int = 0) -> Image.Image:
    """Convert tensor to PIL image using torchvision."""
    tensor = img_tensor[batch_index]
    if tensor.dim() == 3 and tensor.shape[-1] in [1, 3, 4]:
        tensor = tensor.permute(2, 0, 1)
    return to_pil_image(torch.clamp(tensor, 0, 1))


def pil_to_tensor(image: Image.Image) -> torch.Tensor:
    """Convert PIL image to tensor using torchvision."""
    if image.mode == 'RGBA':
        background = Image.new('RGB', image.size, (255, 255, 255))
        background.paste(image, mask=image.split()[-1])
        image = background
    elif image.mode != 'RGB':
        image = image.convert('RGB')
    return tv_to_tensor(image).unsqueeze(0).permute(0, 2, 3, 1)


def get_crop_region(mask: Image.Image, pad: int = 0) -> tuple:
    """Get crop region from mask bounding box."""
    bbox = mask.getbbox()
    x1, y1, x2, y2 = bbox if bbox else (mask.width, mask.height, 0, 0)
    x1, y1 = max(x1 - pad, 0), max(y1 - pad, 0)
    x2, y2 = min(x2 + pad, mask.width), min(y2 + pad, mask.height)
    return fix_crop_region((x1, y1, x2, y2), (mask.width, mask.height))


def fix_crop_region(region: tuple, image_size: tuple) -> tuple:
    """Fix crop region by removing extra pixel."""
    w, h = image_size
    x1, y1, x2, y2 = region
    return x1, y1, x2 - 1 if x2 < w else x2, y2 - 1 if y2 < h else y2


def expand_crop(region: tuple, width: int, height: int, target_width: int, target_height: int) -> tuple:
    """Expand crop region to target size."""
    x1, y1, x2, y2 = region
    # Expand horizontally
    diff = target_width - (x2 - x1)
    x2 = min(x2 + diff // 2, width)
    diff = target_width - (x2 - x1)
    x1 = max(x1 - diff, 0)
    x2 = min(x2 + target_width - (x2 - x1), width)
    # Expand vertically
    diff = target_height - (y2 - y1)
    y2 = min(y2 + diff // 2, height)
    diff = target_height - (y2 - y1)
    y1 = max(y1 - diff, 0)
    y2 = min(y2 + target_height - (y2 - y1), height)
    return (x1, y1, x2, y2), (target_width, target_height)


def crop_cond(cond: list, region: tuple, init_size: tuple, canvas_size: tuple,
              tile_size: tuple, w_pad: int = 0, h_pad: int = 0) -> list:
    """Crop conditioning data to match region."""
    return [[emb, x.copy()] for emb, x in cond]