|
|
import os |
|
|
from accelerate import Accelerator, DistributedDataParallelKwargs |
|
|
from accelerate.utils import ProjectConfiguration |
|
|
import torch |
|
|
import numpy as np |
|
|
from PIL import Image, ImageFilter |
|
|
import PIL |
|
|
import inspect |
|
|
import math |
|
|
from typing import Optional, Tuple, Set, List |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
def paste_image_back_with_feathering( |
|
|
resized_background_image: Image.Image, |
|
|
image_to_paste: Image.Image, |
|
|
crop_box: Tuple[int, int, int, int], |
|
|
feather_radius: int = 50, |
|
|
) -> Tuple[Image.Image, Image.Image]: |
|
|
""" |
|
|
将一个图像粘贴回背景图的指定位置,并对矩形边缘进行羽化处理以实现平滑融合。 |
|
|
|
|
|
此版本在内部创建一个矩形遮罩进行羽化,不依赖外部传入的遮罩形状。 |
|
|
|
|
|
Args: |
|
|
resized_background_image (Image.Image): |
|
|
调整尺寸后的背景图。 |
|
|
image_to_paste (Image.Image): |
|
|
需要被粘贴回去的图像。 |
|
|
crop_box (Tuple[int, int, int, int]): |
|
|
定义了粘贴区域的坐标 (左, 上, 右, 下),用于确定粘贴位置和遮罩范围。 |
|
|
feather_radius (int, optional): |
|
|
高斯模糊的半径,用于控制羽化边缘的宽度和柔和度。默认为 50。 |
|
|
|
|
|
Returns: |
|
|
Tuple[Image.Image, Image.Image]: |
|
|
一个元组,包含: |
|
|
- final_image (Image.Image): 经过边缘融合处理后,粘贴了新图像的完整背景图。 |
|
|
- feather_mask (Image.Image): 用于合成的全尺寸羽化遮罩。 |
|
|
""" |
|
|
|
|
|
|
|
|
mask = Image.new("L", resized_background_image.size, 0) |
|
|
|
|
|
|
|
|
|
|
|
mask.paste(255, crop_box) |
|
|
|
|
|
|
|
|
feather_mask = mask.filter(ImageFilter.GaussianBlur(radius=feather_radius)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_with_paste = resized_background_image.copy() |
|
|
paste_position = (crop_box[0], crop_box[1]) |
|
|
image_with_paste.paste(image_to_paste, paste_position) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
final_image = Image.composite( |
|
|
image_with_paste, resized_background_image, feather_mask |
|
|
) |
|
|
|
|
|
return final_image, feather_mask |
|
|
|
|
|
|
|
|
def get_bounding_box(mask_pil: Image.Image) -> Optional[Tuple[int, int, int, int]]: |
|
|
""" |
|
|
根据Mask PIL图像获取非零区域的最小外接矩形。 |
|
|
|
|
|
Args: |
|
|
mask_pil (Image.Image): 输入的单通道或多通道遮罩图像。 |
|
|
|
|
|
Returns: |
|
|
Optional[Tuple[int, int, int, int]]: |
|
|
如果遮罩不为空,返回一个元组 (xmin, ymin, xmax, ymax), |
|
|
代表左上角和右下角的坐标。注意,xmax和ymax是开区间, |
|
|
符合PIL crop等操作的习惯 (即宽度 = xmax - xmin)。 |
|
|
如果遮罩为空,则返回 None。 |
|
|
""" |
|
|
|
|
|
if mask_pil.mode != "L": |
|
|
mask_pil = mask_pil.convert("L") |
|
|
|
|
|
mask_np = np.array(mask_pil) |
|
|
|
|
|
|
|
|
if not np.any(mask_np > 0): |
|
|
return None |
|
|
|
|
|
|
|
|
rows = np.any(mask_np > 0, axis=1) |
|
|
cols = np.any(mask_np > 0, axis=0) |
|
|
|
|
|
|
|
|
ymin, ymax = np.where(rows)[0][[0, -1]] |
|
|
xmin, xmax = np.where(cols)[0][[0, -1]] |
|
|
|
|
|
|
|
|
return (int(xmin), int(ymin), int(xmax + 1), int(ymax + 1)) |
|
|
|
|
|
|
|
|
def adjust_input_image( |
|
|
image: Image.Image, |
|
|
mask: Image.Image, |
|
|
target_size: Tuple[int, int] = (768, 1024), |
|
|
padding_ratio: float = 0.05, |
|
|
) -> Tuple[int, int, Image.Image, Image.Image, Tuple[int, int, int, int]]: |
|
|
""" |
|
|
将图像和遮罩根据目标尺寸的宽高比进行调整和裁剪。 |
|
|
该函数首先围绕遮罩内容生成一个符合目标宽高比的框,然后添加一些内边距(padding), |
|
|
最后将整个图像缩放并裁剪出这个区域。 |
|
|
|
|
|
Args: |
|
|
image (Image.Image): 原始图像。 |
|
|
mask (Image.Image): 原始遮罩。 |
|
|
target_size (Tuple[int, int], optional): (宽度, 高度) 目标输出尺寸。 |
|
|
默认为 (768, 1024)。 |
|
|
padding_ratio (float, optional): 在调整宽高比后的框周围添加的内边距比例。 |
|
|
默认为 0.1。 |
|
|
|
|
|
Returns: |
|
|
Tuple[Image.Image, Image.Image, Image.Image, Tuple[int, int, int, int]]: |
|
|
- image_new (Image.Image): 缩放后完整图像。 |
|
|
- cropped_image (Image.Image): 最终裁剪出的图像。 |
|
|
- cropped_mask (Image.Image): 最终裁剪出的遮罩。 |
|
|
- crop_box (Tuple[int, int, int, int]): 在缩放后图像上进行裁剪的坐标框。 |
|
|
""" |
|
|
|
|
|
img_w, img_h = image.size |
|
|
target_w, target_h = target_size |
|
|
target_ratio = target_w / target_h |
|
|
|
|
|
|
|
|
bbox = get_bounding_box(mask) |
|
|
if bbox is None: |
|
|
raise ValueError("输入遮罩为空,无法进行调整。") |
|
|
x_min, y_min, x_max, y_max = bbox |
|
|
box_w = x_max - x_min |
|
|
box_h = y_max - y_min |
|
|
|
|
|
|
|
|
|
|
|
ideal_w = max(box_h * target_ratio, box_w) |
|
|
ideal_h = max(box_w / target_ratio, box_h) |
|
|
|
|
|
|
|
|
x_center = (x_min + x_max) / 2 |
|
|
y_center = (y_min + y_max) / 2 |
|
|
x_min = x_center - ideal_w / 2 |
|
|
y_min = y_center - ideal_h / 2 |
|
|
x_max = x_center + ideal_w / 2 |
|
|
y_max = y_center + ideal_h / 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
max_padding_ratio = min(padding_ratio, (x_min + img_w - x_max) / (ideal_w * 2), (y_min + img_h - y_max) / (ideal_h * 2)) |
|
|
x_padding = int(ideal_w * max_padding_ratio) |
|
|
y_padding = int(ideal_h * max_padding_ratio) |
|
|
x_min = x_min - x_padding |
|
|
y_min = y_min - y_padding |
|
|
x_max = x_max + x_padding |
|
|
y_max = y_max + y_padding |
|
|
|
|
|
|
|
|
|
|
|
if x_min < 0: |
|
|
x_max -= x_min |
|
|
x_min = 0 |
|
|
if y_min < 0: |
|
|
y_max -= y_min |
|
|
y_min = 0 |
|
|
if x_max > img_w: |
|
|
x_min -= x_max - img_w |
|
|
x_max = img_w |
|
|
if y_max > img_h: |
|
|
y_min -= y_max - img_h |
|
|
y_max = img_h |
|
|
|
|
|
|
|
|
|
|
|
scale = target_w / (x_max - x_min) |
|
|
img_new_w = int(img_w * scale) |
|
|
img_new_h = int(img_h * scale) |
|
|
|
|
|
image_new = image.resize((img_new_w, img_new_h), Image.Resampling.LANCZOS) |
|
|
mask_new = mask.resize((img_new_w, img_new_h), Image.Resampling.NEAREST) |
|
|
|
|
|
|
|
|
crop_x_start = int(x_min * scale) |
|
|
crop_y_start = int(y_min * scale) |
|
|
crop_box = ( |
|
|
crop_x_start, |
|
|
crop_y_start, |
|
|
crop_x_start + target_w, |
|
|
crop_y_start + target_h, |
|
|
) |
|
|
|
|
|
cropped_image = image_new.crop(crop_box) |
|
|
cropped_mask = mask_new.crop(crop_box) |
|
|
|
|
|
return image_new, cropped_image, cropped_mask, crop_box |
|
|
|
|
|
|
|
|
|
|
|
def prepare_extra_step_kwargs(noise_scheduler, generator, eta): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
accepts_eta = "eta" in set( |
|
|
inspect.signature(noise_scheduler.step).parameters.keys() |
|
|
) |
|
|
extra_step_kwargs = {} |
|
|
if accepts_eta: |
|
|
extra_step_kwargs["eta"] = eta |
|
|
|
|
|
|
|
|
accepts_generator = "generator" in set( |
|
|
inspect.signature(noise_scheduler.step).parameters.keys() |
|
|
) |
|
|
if accepts_generator: |
|
|
extra_step_kwargs["generator"] = generator |
|
|
return extra_step_kwargs |
|
|
|
|
|
|
|
|
def init_accelerator(config): |
|
|
accelerator_project_config = ProjectConfiguration( |
|
|
project_dir=config.project_name, |
|
|
logging_dir=os.path.join(config.project_name, "logs"), |
|
|
) |
|
|
accelerator_ddp_config = DistributedDataParallelKwargs(find_unused_parameters=True) |
|
|
accelerator = Accelerator( |
|
|
mixed_precision=config.mixed_precision, |
|
|
log_with=config.report_to, |
|
|
project_config=accelerator_project_config, |
|
|
kwargs_handlers=[accelerator_ddp_config], |
|
|
gradient_accumulation_steps=config.gradient_accumulation_steps, |
|
|
) |
|
|
if accelerator.is_main_process: |
|
|
accelerator.init_trackers( |
|
|
project_name=config.project_name, |
|
|
config={ |
|
|
"learning_rate": config.learning_rate, |
|
|
"train_batch_size": config.train_batch_size, |
|
|
"image_size": f"{config.width}x{config.height}", |
|
|
}, |
|
|
) |
|
|
return accelerator |
|
|
|
|
|
def init_weight_dtype(wight_dtype): |
|
|
return { |
|
|
"no": torch.float32, |
|
|
"fp16": torch.float16, |
|
|
"bf16": torch.bfloat16, |
|
|
}[wight_dtype] |
|
|
|
|
|
|
|
|
def prepare_image(image, device='cuda', dtype=torch.float32, do_normalize=True): |
|
|
if isinstance(image, torch.Tensor): |
|
|
|
|
|
if image.ndim == 3: |
|
|
image = image.unsqueeze(0) |
|
|
image = image.to(dtype=torch.float32) |
|
|
else: |
|
|
|
|
|
if isinstance(image, (Image.Image, np.ndarray)): |
|
|
image = [image] |
|
|
if isinstance(image, list) and isinstance(image[0], Image.Image): |
|
|
image = [np.array(i.convert("RGB"))[None, :] for i in image] |
|
|
image = np.concatenate(image, axis=0) |
|
|
elif isinstance(image, list) and isinstance(image[0], np.ndarray): |
|
|
image = np.concatenate([i[None, :] for i in image], axis=0) |
|
|
image = image.transpose(0, 3, 1, 2) |
|
|
if do_normalize: |
|
|
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 |
|
|
else: |
|
|
image = torch.from_numpy(image).to(dtype=torch.float32) / 255.0 |
|
|
return image.to(device, dtype=dtype) |
|
|
|
|
|
def prepare_mask_image(mask_image, device='cuda', dtype=torch.float32): |
|
|
if isinstance(mask_image, torch.Tensor): |
|
|
if mask_image.ndim == 2: |
|
|
|
|
|
mask_image = mask_image.unsqueeze(0).unsqueeze(0) |
|
|
elif mask_image.ndim == 3 and mask_image.shape[0] == 1: |
|
|
|
|
|
|
|
|
mask_image = mask_image.unsqueeze(0) |
|
|
elif mask_image.ndim == 3 and mask_image.shape[0] != 1: |
|
|
|
|
|
|
|
|
mask_image = mask_image.unsqueeze(1) |
|
|
|
|
|
|
|
|
mask_image[mask_image < 0.5] = 0 |
|
|
mask_image[mask_image >= 0.5] = 1 |
|
|
else: |
|
|
|
|
|
if isinstance(mask_image, (Image.Image, np.ndarray)): |
|
|
mask_image = [mask_image] |
|
|
|
|
|
if isinstance(mask_image, list) and isinstance(mask_image[0], Image.Image): |
|
|
mask_image = np.concatenate( |
|
|
[np.array(m.convert("L"))[None, None, :] for m in mask_image], axis=0 |
|
|
) |
|
|
mask_image = mask_image.astype(np.float32) / 255.0 |
|
|
elif isinstance(mask_image, list) and isinstance(mask_image[0], np.ndarray): |
|
|
mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0) |
|
|
|
|
|
mask_image[mask_image < 0.5] = 0 |
|
|
mask_image[mask_image >= 0.5] = 1 |
|
|
mask_image = torch.from_numpy(mask_image) |
|
|
|
|
|
return mask_image.to(device, dtype=dtype) |
|
|
|
|
|
def numpy_to_pil(images): |
|
|
""" |
|
|
Convert a numpy image or a batch of images to a PIL image. |
|
|
""" |
|
|
if images.ndim == 3: |
|
|
images = images[None, ...] |
|
|
images = (images * 255).round().astype("uint8") |
|
|
if images.shape[-1] == 1: |
|
|
|
|
|
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] |
|
|
else: |
|
|
pil_images = [Image.fromarray(image) for image in images] |
|
|
|
|
|
return pil_images |
|
|
|
|
|
def scan_files_in_dir(directory, postfix: Set[str] = None, progress_bar: tqdm = None) -> list: |
|
|
file_list = [] |
|
|
progress_bar = tqdm(total=0, desc="Scanning", ncols=100) if progress_bar is None else progress_bar |
|
|
for entry in os.scandir(directory): |
|
|
if entry.is_file(): |
|
|
if postfix is None or os.path.splitext(entry.path)[1] in postfix: |
|
|
file_list.append(entry) |
|
|
progress_bar.total += 1 |
|
|
progress_bar.update(1) |
|
|
elif entry.is_dir(): |
|
|
file_list += scan_files_in_dir(entry.path, postfix=postfix, progress_bar=progress_bar) |
|
|
return file_list |
|
|
|
|
|
def compute_dream_and_update_latents( |
|
|
unet, |
|
|
noise_scheduler, |
|
|
timesteps: torch.Tensor, |
|
|
noise: torch.Tensor, |
|
|
noisy_latents: torch.Tensor, |
|
|
mask_latent: torch.Tensor, |
|
|
masked_target_latent: torch.Tensor, |
|
|
target: torch.Tensor, |
|
|
attention_mask: torch.Tensor = None, |
|
|
encoder_hidden_states: torch.Tensor = None, |
|
|
dream_detail_preservation: float = 1.0, |
|
|
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: |
|
|
""" |
|
|
Implements "DREAM (Diffusion Rectification and Estimation-Adaptive Models)" from |
|
|
https://huggingface.co/papers/2312.00210. DREAM helps align training with sampling to help training be more |
|
|
efficient and accurate at the cost of an extra forward step without gradients. |
|
|
|
|
|
Args: |
|
|
`unet`: The state unet to use to make a prediction. |
|
|
`noise_scheduler`: The noise scheduler used to add noise for the given timestep. |
|
|
`timesteps`: The timesteps for the noise_scheduler to user. |
|
|
`noise`: A tensor of noise in the shape of noisy_latents. |
|
|
`noisy_latents`: Previously noise latents from the training loop. |
|
|
`target`: The ground-truth tensor to predict after eps is removed. |
|
|
`encoder_hidden_states`: Text embeddings from the text model. |
|
|
`dream_detail_preservation`: A float value that indicates detail preservation level. |
|
|
See reference. |
|
|
|
|
|
Returns: |
|
|
`tuple[torch.Tensor, torch.Tensor]`: Adjusted noisy_latents and target. |
|
|
""" |
|
|
alphas_cumprod = noise_scheduler.alphas_cumprod.to(timesteps.device)[timesteps, None, None, None] |
|
|
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 |
|
|
|
|
|
|
|
|
dream_lambda = sqrt_one_minus_alphas_cumprod**dream_detail_preservation |
|
|
|
|
|
pred = None |
|
|
with torch.no_grad(): |
|
|
|
|
|
input_noisy_latents = torch.cat( |
|
|
[noisy_latents, mask_latent, masked_target_latent], dim=1 |
|
|
) |
|
|
pred = unet( |
|
|
input_noisy_latents, |
|
|
timesteps, |
|
|
attention_mask=attention_mask, |
|
|
encoder_hidden_states=encoder_hidden_states).sample |
|
|
|
|
|
_noisy_latents, _target = (None, None) |
|
|
if noise_scheduler.config.prediction_type == "epsilon": |
|
|
predicted_noise = pred |
|
|
delta_noise = (noise - predicted_noise).detach() |
|
|
delta_noise.mul_(dream_lambda) |
|
|
_noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise) |
|
|
_target = target.add(delta_noise) |
|
|
elif noise_scheduler.config.prediction_type == "v_prediction": |
|
|
raise NotImplementedError("DREAM has not been implemented for v-prediction") |
|
|
else: |
|
|
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
|
|
|
|
|
return _noisy_latents, _target |
|
|
|
|
|
def tensor_to_image(tensor: torch.Tensor): |
|
|
""" |
|
|
Converts a torch tensor to PIL Image. |
|
|
""" |
|
|
assert tensor.dim() == 3, "Input tensor should be 3-dimensional." |
|
|
assert tensor.dtype == torch.float32, "Input tensor should be float32." |
|
|
assert ( |
|
|
tensor.min() >= 0 and tensor.max() <= 1 |
|
|
), "Input tensor should be in range [0, 1]." |
|
|
tensor = tensor.cpu() |
|
|
tensor = tensor * 255 |
|
|
tensor = tensor.permute(1, 2, 0) |
|
|
tensor = tensor.numpy().astype(np.uint8) |
|
|
image = Image.fromarray(tensor) |
|
|
return image |
|
|
|
|
|
|
|
|
def concat_images(images: List[Image.Image], divider: int = 4, cols: int = 4): |
|
|
""" |
|
|
Concatenates images horizontally and with |
|
|
""" |
|
|
widths = [image.size[0] for image in images] |
|
|
heights = [image.size[1] for image in images] |
|
|
total_width = cols * max(widths) |
|
|
total_width += divider * (cols - 1) |
|
|
|
|
|
rows = math.ceil(len(images) / cols) |
|
|
total_height = max(heights) * rows |
|
|
|
|
|
total_height += divider * (len(heights) // cols - 1) |
|
|
|
|
|
|
|
|
concat_image = Image.new("RGB", (total_width, total_height), (0, 0, 0)) |
|
|
|
|
|
x_offset = 0 |
|
|
y_offset = 0 |
|
|
for i, image in enumerate(images): |
|
|
concat_image.paste(image, (x_offset, y_offset)) |
|
|
x_offset += image.size[0] + divider |
|
|
if (i + 1) % cols == 0: |
|
|
x_offset = 0 |
|
|
y_offset += image.size[1] + divider |
|
|
|
|
|
return concat_image |
|
|
|
|
|
|
|
|
def save_tensors_to_npz(tensors: torch.Tensor, paths: List[str]): |
|
|
assert len(tensors) == len(paths), "Length of tensors and paths should be the same!" |
|
|
for tensor, path in zip(tensors, paths): |
|
|
np.savez_compressed(path, latent=tensor.cpu().numpy()) |
|
|
|
|
|
|
|
|
|
|
|
def resize_and_crop(image, size=None): |
|
|
w, h = image.size |
|
|
if size is not None: |
|
|
|
|
|
target_w, target_h = size |
|
|
if w / h < target_w / target_h: |
|
|
new_w = w |
|
|
new_h = w * target_h // target_w |
|
|
else: |
|
|
new_h = h |
|
|
new_w = h * target_w // target_h |
|
|
image = image.crop( |
|
|
((w - new_w) // 2, (h - new_h) // 2, (w + new_w) // 2, (h + new_h) // 2) |
|
|
) |
|
|
|
|
|
image = image.resize(size, Image.LANCZOS) |
|
|
else: |
|
|
|
|
|
|
|
|
new_w = (w // 16) * 16 |
|
|
new_h = (h // 16) * 16 |
|
|
|
|
|
if new_w == 0 or new_h == 0: |
|
|
raise ValueError( |
|
|
f"Image dimensions ({w}x{h}) are too small to be cropped to a multiple of 16. " |
|
|
"Minimum size is 16x16." |
|
|
) |
|
|
|
|
|
|
|
|
left = (w - new_w) // 2 |
|
|
top = (h - new_h) // 2 |
|
|
right = left + new_w |
|
|
bottom = top + new_h |
|
|
|
|
|
|
|
|
image = image.crop((left, top, right, bottom)) |
|
|
return image |
|
|
|
|
|
|
|
|
def resize_and_padding(image, size): |
|
|
|
|
|
w, h = image.size |
|
|
target_w, target_h = size |
|
|
if w / h < target_w / target_h: |
|
|
new_h = target_h |
|
|
new_w = w * target_h // h |
|
|
else: |
|
|
new_w = target_w |
|
|
new_h = h * target_w // w |
|
|
image = image.resize((new_w, new_h), Image.LANCZOS) |
|
|
|
|
|
padding = Image.new("RGB", size, (255, 255, 255)) |
|
|
padding.paste(image, ((target_w - new_w) // 2, (target_h - new_h) // 2)) |
|
|
return padding |
|
|
|
|
|
|