Spaces:
Running
on
Zero
Running
on
Zero
| import math | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| def get_tiled_scale_steps(width: int, height: int, tile_x: int, tile_y: int, overlap: int) -> int: | |
| """#### Calculate the number of steps required for tiled scaling. | |
| #### Args: | |
| - `width` (int): The width of the image. | |
| - `height` (int): The height of the image. | |
| - `tile_x` (int): The width of each tile. | |
| - `tile_y` (int): The height of each tile. | |
| - `overlap` (int): The overlap between tiles. | |
| #### Returns: | |
| - `int`: The number of steps required for tiled scaling. | |
| """ | |
| return math.ceil((height / (tile_y - overlap))) * math.ceil( | |
| (width / (tile_x - overlap)) | |
| ) | |
| def tiled_scale( | |
| samples: torch.Tensor, | |
| function: callable, | |
| tile_x: int = 64, | |
| tile_y: int = 64, | |
| overlap: int = 8, | |
| upscale_amount: float = 4, | |
| out_channels: int = 3, | |
| pbar: any = None, | |
| ) -> torch.Tensor: | |
| """#### Perform tiled scaling on a batch of samples. | |
| #### Args: | |
| - `samples` (torch.Tensor): The input samples. | |
| - `function` (callable): The function to apply to each tile. | |
| - `tile_x` (int, optional): The width of each tile. Defaults to 64. | |
| - `tile_y` (int, optional): The height of each tile. Defaults to 64. | |
| - `overlap` (int, optional): The overlap between tiles. Defaults to 8. | |
| - `upscale_amount` (float, optional): The upscale amount. Defaults to 4. | |
| - `out_channels` (int, optional): The number of output channels. Defaults to 3. | |
| - `pbar` (any, optional): The progress bar. Defaults to None. | |
| #### Returns: | |
| - `torch.Tensor`: The scaled output tensor. | |
| """ | |
| output = torch.empty( | |
| ( | |
| samples.shape[0], | |
| out_channels, | |
| round(samples.shape[2] * upscale_amount), | |
| round(samples.shape[3] * upscale_amount), | |
| ), | |
| device="cpu", | |
| ) | |
| for b in range(samples.shape[0]): | |
| s = samples[b : b + 1] | |
| out = torch.zeros( | |
| ( | |
| s.shape[0], | |
| out_channels, | |
| round(s.shape[2] * upscale_amount), | |
| round(s.shape[3] * upscale_amount), | |
| ), | |
| device="cpu", | |
| ) | |
| out_div = torch.zeros( | |
| ( | |
| s.shape[0], | |
| out_channels, | |
| round(s.shape[2] * upscale_amount), | |
| round(s.shape[3] * upscale_amount), | |
| ), | |
| device="cpu", | |
| ) | |
| for y in range(0, s.shape[2], tile_y - overlap): | |
| for x in range(0, s.shape[3], tile_x - overlap): | |
| s_in = s[:, :, y : y + tile_y, x : x + tile_x] | |
| ps = function(s_in).cpu() | |
| mask = torch.ones_like(ps) | |
| feather = round(overlap * upscale_amount) | |
| for t in range(feather): | |
| mask[:, :, t : 1 + t, :] *= (1.0 / feather) * (t + 1) | |
| mask[:, :, mask.shape[2] - 1 - t : mask.shape[2] - t, :] *= ( | |
| 1.0 / feather | |
| ) * (t + 1) | |
| mask[:, :, :, t : 1 + t] *= (1.0 / feather) * (t + 1) | |
| mask[:, :, :, mask.shape[3] - 1 - t : mask.shape[3] - t] *= ( | |
| 1.0 / feather | |
| ) * (t + 1) | |
| out[ | |
| :, | |
| :, | |
| round(y * upscale_amount) : round((y + tile_y) * upscale_amount), | |
| round(x * upscale_amount) : round((x + tile_x) * upscale_amount), | |
| ] += ps * mask | |
| out_div[ | |
| :, | |
| :, | |
| round(y * upscale_amount) : round((y + tile_y) * upscale_amount), | |
| round(x * upscale_amount) : round((x + tile_x) * upscale_amount), | |
| ] += mask | |
| output[b : b + 1] = out / out_div | |
| return output | |
| def flatten(img: Image.Image, bgcolor: str) -> Image.Image: | |
| """#### Replace transparency with a background color. | |
| #### Args: | |
| - `img` (Image.Image): The input image. | |
| - `bgcolor` (str): The background color. | |
| #### Returns: | |
| - `Image.Image`: The image with transparency replaced by the background color. | |
| """ | |
| if img.mode in ("RGB"): | |
| return img | |
| return Image.alpha_composite(Image.new("RGBA", img.size, bgcolor), img).convert( | |
| "RGB" | |
| ) | |
| BLUR_KERNEL_SIZE = 15 | |
| def tensor_to_pil(img_tensor: torch.Tensor, batch_index: int = 0) -> Image.Image: | |
| """#### Convert a tensor to a PIL image. | |
| #### Args: | |
| - `img_tensor` (torch.Tensor): The input tensor. | |
| - `batch_index` (int, optional): The batch index. Defaults to 0. | |
| #### Returns: | |
| - `Image.Image`: The converted PIL image. | |
| """ | |
| img_tensor = img_tensor[batch_index].unsqueeze(0) | |
| i = 255.0 * img_tensor.cpu().numpy() | |
| img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8).squeeze()) | |
| return img | |
| def pil_to_tensor(image: Image.Image) -> torch.Tensor: | |
| """#### Convert a PIL image to a tensor. | |
| #### Args: | |
| - `image` (Image.Image): The input PIL image. | |
| #### Returns: | |
| - `torch.Tensor`: The converted tensor. | |
| """ | |
| image = np.array(image).astype(np.float32) / 255.0 | |
| image = torch.from_numpy(image).unsqueeze(0) | |
| return image | |
| def get_crop_region(mask: Image.Image, pad: int = 0) -> tuple: | |
| """#### Get the coordinates of the white rectangular mask region. | |
| #### Args: | |
| - `mask` (Image.Image): The input mask image in 'L' mode. | |
| - `pad` (int, optional): The padding to apply. Defaults to 0. | |
| #### Returns: | |
| - `tuple`: The coordinates of the crop region. | |
| """ | |
| coordinates = mask.getbbox() | |
| if coordinates is not None: | |
| x1, y1, x2, y2 = coordinates | |
| else: | |
| x1, y1, x2, y2 = mask.width, mask.height, 0, 0 | |
| # Apply padding | |
| x1 = max(x1 - pad, 0) | |
| y1 = max(y1 - pad, 0) | |
| x2 = min(x2 + pad, mask.width) | |
| y2 = min(y2 + pad, mask.height) | |
| return fix_crop_region((x1, y1, x2, y2), (mask.width, mask.height)) | |
| def fix_crop_region(region: tuple, image_size: tuple) -> tuple: | |
| """#### Remove the extra pixel added by the get_crop_region function. | |
| #### Args: | |
| - `region` (tuple): The crop region coordinates. | |
| - `image_size` (tuple): The size of the image. | |
| #### Returns: | |
| - `tuple`: The fixed crop region coordinates. | |
| """ | |
| image_width, image_height = image_size | |
| x1, y1, x2, y2 = region | |
| if x2 < image_width: | |
| x2 -= 1 | |
| if y2 < image_height: | |
| y2 -= 1 | |
| return x1, y1, x2, y2 | |
| def expand_crop(region: tuple, width: int, height: int, target_width: int, target_height: int) -> tuple: | |
| """#### Expand a crop region to a specified target size. | |
| #### Args: | |
| - `region` (tuple): The crop region coordinates. | |
| - `width` (int): The width of the image. | |
| - `height` (int): The height of the image. | |
| - `target_width` (int): The desired width of the crop region. | |
| - `target_height` (int): The desired height of the crop region. | |
| #### Returns: | |
| - `tuple`: The expanded crop region coordinates and the target size. | |
| """ | |
| x1, y1, x2, y2 = region | |
| actual_width = x2 - x1 | |
| actual_height = y2 - y1 | |
| # Try to expand region to the right of half the difference | |
| width_diff = target_width - actual_width | |
| x2 = min(x2 + width_diff // 2, width) | |
| # Expand region to the left of the difference including the pixels that could not be expanded to the right | |
| width_diff = target_width - (x2 - x1) | |
| x1 = max(x1 - width_diff, 0) | |
| # Try the right again | |
| width_diff = target_width - (x2 - x1) | |
| x2 = min(x2 + width_diff, width) | |
| # Try to expand region to the bottom of half the difference | |
| height_diff = target_height - actual_height | |
| y2 = min(y2 + height_diff // 2, height) | |
| # Expand region to the top of the difference including the pixels that could not be expanded to the bottom | |
| height_diff = target_height - (y2 - y1) | |
| y1 = max(y1 - height_diff, 0) | |
| # Try the bottom again | |
| height_diff = target_height - (y2 - y1) | |
| y2 = min(y2 + height_diff, height) | |
| return (x1, y1, x2, y2), (target_width, target_height) | |
| def crop_cond(cond: list, region: tuple, init_size: tuple, canvas_size: tuple, tile_size: tuple, w_pad: int = 0, h_pad: int = 0) -> list: | |
| """#### Crop conditioning data to match a specific region. | |
| #### Args: | |
| - `cond` (list): The conditioning data. | |
| - `region` (tuple): The crop region coordinates. | |
| - `init_size` (tuple): The initial size of the image. | |
| - `canvas_size` (tuple): The size of the canvas. | |
| - `tile_size` (tuple): The size of the tile. | |
| - `w_pad` (int, optional): The width padding. Defaults to 0. | |
| - `h_pad` (int, optional): The height padding. Defaults to 0. | |
| #### Returns: | |
| - `list`: The cropped conditioning data. | |
| """ | |
| cropped = [] | |
| for emb, x in cond: | |
| cond_dict = x.copy() | |
| n = [emb, cond_dict] | |
| cropped.append(n) | |
| return cropped |