Spaces:
Running on Zero
Running on Zero
| import math | |
| import torch | |
| from PIL import Image | |
| from torchvision.transforms.functional import to_pil_image, to_tensor as tv_to_tensor | |
| def get_tiled_scale_steps(width: int, height: int, tile_x: int, tile_y: int, overlap: int) -> int: | |
| """Calculate steps required for tiled scaling.""" | |
| return math.ceil(height / (tile_y - overlap)) * math.ceil(width / (tile_x - overlap)) | |
| def tiled_scale(samples: torch.Tensor, function: callable, tile_x: int = 64, tile_y: int = 64, | |
| overlap: int = 8, upscale_amount: float = 4, out_channels: int = 3, pbar=None) -> torch.Tensor: | |
| """Perform tiled upscaling on samples.""" | |
| h_up, w_up = round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount) | |
| output = torch.empty((samples.shape[0], out_channels, h_up, w_up), device="cpu") | |
| for b in range(samples.shape[0]): | |
| s = samples[b:b + 1] | |
| out = torch.zeros((s.shape[0], out_channels, h_up, w_up), device="cpu") | |
| out_div = torch.zeros_like(out) | |
| for y in range(0, s.shape[2], tile_y - overlap): | |
| for x in range(0, s.shape[3], tile_x - overlap): | |
| ps = function(s[:, :, y:y + tile_y, x:x + tile_x]).cpu() | |
| mask = torch.ones_like(ps) | |
| feather = round(overlap * upscale_amount) | |
| for t in range(feather): | |
| f = (1.0 / feather) * (t + 1) | |
| mask[:, :, t:1 + t, :] *= f | |
| mask[:, :, -1 - t:-t or None, :] *= f | |
| mask[:, :, :, t:1 + t] *= f | |
| mask[:, :, :, -1 - t:-t or None] *= f | |
| y_start, y_end = round(y * upscale_amount), round((y + tile_y) * upscale_amount) | |
| x_start, x_end = round(x * upscale_amount), round((x + tile_x) * upscale_amount) | |
| out[:, :, y_start:y_end, x_start:x_end] += ps * mask | |
| out_div[:, :, y_start:y_end, x_start:x_end] += mask | |
| output[b:b + 1] = out / out_div | |
| return output | |
| def flatten(img: Image.Image, bgcolor: str) -> Image.Image: | |
| """Replace transparency with background color.""" | |
| if img.mode == "RGB": | |
| return img | |
| return Image.alpha_composite(Image.new("RGBA", img.size, bgcolor), img).convert("RGB") | |
| BLUR_KERNEL_SIZE = 15 | |
| def tensor_to_pil(img_tensor: torch.Tensor, batch_index: int = 0) -> Image.Image: | |
| """Convert tensor to PIL image using torchvision.""" | |
| tensor = img_tensor[batch_index] | |
| if tensor.dim() == 3 and tensor.shape[-1] in [1, 3, 4]: | |
| tensor = tensor.permute(2, 0, 1) | |
| return to_pil_image(torch.clamp(tensor, 0, 1)) | |
| def pil_to_tensor(image: Image.Image) -> torch.Tensor: | |
| """Convert PIL image to tensor using torchvision.""" | |
| if image.mode == 'RGBA': | |
| background = Image.new('RGB', image.size, (255, 255, 255)) | |
| background.paste(image, mask=image.split()[-1]) | |
| image = background | |
| elif image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| return tv_to_tensor(image).unsqueeze(0).permute(0, 2, 3, 1) | |
| def get_crop_region(mask: Image.Image, pad: int = 0) -> tuple: | |
| """Get crop region from mask bounding box.""" | |
| bbox = mask.getbbox() | |
| x1, y1, x2, y2 = bbox if bbox else (mask.width, mask.height, 0, 0) | |
| x1, y1 = max(x1 - pad, 0), max(y1 - pad, 0) | |
| x2, y2 = min(x2 + pad, mask.width), min(y2 + pad, mask.height) | |
| return fix_crop_region((x1, y1, x2, y2), (mask.width, mask.height)) | |
| def fix_crop_region(region: tuple, image_size: tuple) -> tuple: | |
| """Fix crop region by removing extra pixel.""" | |
| w, h = image_size | |
| x1, y1, x2, y2 = region | |
| return x1, y1, x2 - 1 if x2 < w else x2, y2 - 1 if y2 < h else y2 | |
| def expand_crop(region: tuple, width: int, height: int, target_width: int, target_height: int) -> tuple: | |
| """Expand crop region to target size.""" | |
| x1, y1, x2, y2 = region | |
| # Expand horizontally | |
| diff = target_width - (x2 - x1) | |
| x2 = min(x2 + diff // 2, width) | |
| diff = target_width - (x2 - x1) | |
| x1 = max(x1 - diff, 0) | |
| x2 = min(x2 + target_width - (x2 - x1), width) | |
| # Expand vertically | |
| diff = target_height - (y2 - y1) | |
| y2 = min(y2 + diff // 2, height) | |
| diff = target_height - (y2 - y1) | |
| y1 = max(y1 - diff, 0) | |
| y2 = min(y2 + target_height - (y2 - y1), height) | |
| return (x1, y1, x2, y2), (target_width, target_height) | |
| def crop_cond(cond: list, region: tuple, init_size: tuple, canvas_size: tuple, | |
| tile_size: tuple, w_pad: int = 0, h_pad: int = 0) -> list: | |
| """Crop conditioning data to match region.""" | |
| return [[emb, x.copy()] for emb, x in cond] | |