|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import time |
|
|
from typing import Dict, List, Optional, Union |
|
|
|
|
|
import torch |
|
|
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL |
|
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler |
|
|
from diffusers.utils.torch_utils import randn_tensor |
|
|
from PIL import Image |
|
|
from tqdm import tqdm |
|
|
|
|
|
from .unet_2d_condition import UNet2DConditionModel |
|
|
from .utils import ( |
|
|
numpy_to_pil, |
|
|
paste_image_back_with_feathering, |
|
|
prepare_extra_step_kwargs, |
|
|
prepare_image, |
|
|
prepare_mask_image, |
|
|
adjust_input_image, |
|
|
) |
|
|
|
|
|
REF_LABEL_MAP = { |
|
|
"upper": 0, |
|
|
"lower": 1, |
|
|
"overall": 2, |
|
|
"shoe": 3, |
|
|
"bag": 4 |
|
|
} |
|
|
|
|
|
class FastFitPipeline: |
|
|
def __init__( |
|
|
self, |
|
|
base_model_path: str, |
|
|
device: str = None, |
|
|
mixed_precision: str = "fp16", |
|
|
allow_tf32: bool = False, |
|
|
): |
|
|
""" |
|
|
Initialize FastFit inference pipeline |
|
|
|
|
|
Args: |
|
|
base_model_path: Path to pretrained model |
|
|
device: Device to run on, e.g. 'cuda' or 'cpu' |
|
|
mixed_precision: Mixed precision type, options: 'fp16', 'bf16', 'fp32' |
|
|
allow_tf32: Whether to allow TF32 on Ampere GPUs |
|
|
""" |
|
|
|
|
|
self.device = device if device is not None else "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
if mixed_precision == "fp16": |
|
|
self.weight_dtype = torch.float16 |
|
|
elif mixed_precision == "bf16": |
|
|
self.weight_dtype = torch.bfloat16 |
|
|
else: |
|
|
self.weight_dtype = torch.float32 |
|
|
|
|
|
|
|
|
self.noise_scheduler = DDPMScheduler.from_pretrained(base_model_path, subfolder="scheduler") |
|
|
|
|
|
self.vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae") |
|
|
self.vae.to(self.device, dtype=self.weight_dtype) |
|
|
|
|
|
self.unet = UNet2DConditionModel.from_pretrained(base_model_path, subfolder="unet") |
|
|
self.unet.to(self.device, dtype=self.weight_dtype) |
|
|
|
|
|
|
|
|
|
|
|
self.vae = torch.compile(self.vae, mode="max-autotune") |
|
|
self.unet.fuse_qkv_projections() |
|
|
|
|
|
|
|
|
if allow_tf32 and torch.cuda.is_available(): |
|
|
torch.set_float32_matmul_precision("high") |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__( |
|
|
self, |
|
|
person: Union[torch.Tensor, Image.Image], |
|
|
mask: Union[torch.Tensor, Image.Image], |
|
|
ref_images: List[Union[torch.Tensor, Image.Image]], |
|
|
ref_labels: List[Union[int,str]], |
|
|
ref_attention_masks: List[int], |
|
|
pose: Union[torch.Tensor, Image.Image] = None, |
|
|
num_inference_steps: int = 50, |
|
|
guidance_scale: float = 2.5, |
|
|
generator: Optional[torch.Generator] = None, |
|
|
cross_attention_kwargs: Optional[Dict] = None, |
|
|
eta: float = 1.0, |
|
|
return_pil: bool = True, |
|
|
do_adjust_input_image: bool = False, |
|
|
): |
|
|
""" |
|
|
Execute FastFit inference |
|
|
|
|
|
Args: |
|
|
person: Input person image |
|
|
mask: Mask image |
|
|
ref_images: List of reference images |
|
|
ref_labels: List of reference image labels |
|
|
ref_attention_masks: List of reference image attention masks |
|
|
pose: Pose image (optional) |
|
|
num_inference_steps: Number of inference steps |
|
|
guidance_scale: Classifier-free guidance scale |
|
|
generator: Random number generator |
|
|
cross_attention_kwargs: Cross attention parameters |
|
|
eta: Eta parameter |
|
|
return_pil: Whether to return PIL image |
|
|
adjust_input_image: Whether to adjust the input image |
|
|
Returns: |
|
|
Generated image |
|
|
""" |
|
|
|
|
|
if do_adjust_input_image: |
|
|
background_img, person_img, mask_img, crop_box = adjust_input_image( |
|
|
person, mask, (768, 1024), 0.05 |
|
|
) |
|
|
else: |
|
|
background_img = person |
|
|
person_img = person |
|
|
mask_img = mask |
|
|
crop_box = None |
|
|
|
|
|
|
|
|
if isinstance(ref_labels[0], str): |
|
|
ref_labels = [REF_LABEL_MAP[label] for label in ref_labels] |
|
|
|
|
|
|
|
|
person = prepare_image(person_img, self.device, self.weight_dtype) |
|
|
mask = prepare_mask_image(mask_img, self.device, self.weight_dtype) |
|
|
ref_images = [prepare_image(image, self.device, self.weight_dtype) for image in ref_images] |
|
|
if pose is not None: |
|
|
pose = prepare_image(pose, self.device, self.weight_dtype, do_normalize=False) |
|
|
if pose.shape[-2:] != (person_img.size[1], person_img.size[0]): |
|
|
pose = torch.nn.functional.interpolate( |
|
|
pose.unsqueeze(0), |
|
|
size=(person_img.size[1], person_img.size[0]), |
|
|
mode="bilinear", |
|
|
align_corners=False, |
|
|
).squeeze(0) |
|
|
masked_person = person * (1 - mask) + mask * pose if pose is not None else person * (1 - mask) |
|
|
|
|
|
if ref_attention_masks is not None: |
|
|
if isinstance(ref_attention_masks[0], int): |
|
|
ref_attention_masks = [torch.tensor([ref_attn_mask]).to(self.device) for ref_attn_mask in ref_attention_masks] |
|
|
else: |
|
|
ref_attention_masks = [ref_attn_mask.to(self.device) for ref_attn_mask in ref_attention_masks] |
|
|
if ref_labels is not None: |
|
|
if isinstance(ref_labels[0], int): |
|
|
ref_labels = [torch.tensor([ref_label]).to(self.device) for ref_label in ref_labels] |
|
|
else: |
|
|
ref_labels = [ref_label.to(self.device) for ref_label in ref_labels] |
|
|
|
|
|
|
|
|
masked_person_latent = self.vae.encode(masked_person).latent_dist.sample() * self.vae.config.scaling_factor |
|
|
ref_images_latent = [self.vae.encode(image).latent_dist.sample() * self.vae.config.scaling_factor for image in ref_images] |
|
|
mask_latent = torch.nn.functional.interpolate( |
|
|
mask.to(dtype=torch.float32), |
|
|
size=masked_person_latent.shape[-2:], |
|
|
mode="nearest", |
|
|
).to(self.weight_dtype) |
|
|
|
|
|
|
|
|
noise = randn_tensor( |
|
|
masked_person_latent.shape, |
|
|
generator=generator, |
|
|
device=torch.device(self.device), |
|
|
dtype=self.weight_dtype, |
|
|
) |
|
|
self.noise_scheduler.set_timesteps(num_inference_steps, device=self.device) |
|
|
timesteps = self.noise_scheduler.timesteps |
|
|
noise = noise * self.noise_scheduler.init_noise_sigma |
|
|
latents = noise |
|
|
|
|
|
|
|
|
do_classifier_free_guidance = guidance_scale > 1.0 |
|
|
if do_classifier_free_guidance: |
|
|
mask_latent = torch.cat([mask_latent] * 2) |
|
|
masked_person_latent = torch.cat([masked_person_latent] * 2) |
|
|
ref_images_latent = [torch.cat([torch.zeros_like(image), image]) for image in ref_images_latent] |
|
|
if ref_attention_masks is not None: |
|
|
ref_attention_masks = [torch.cat([ref_attn_mask] * 2) for ref_attn_mask in ref_attention_masks] |
|
|
if ref_labels is not None: |
|
|
ref_labels = [torch.cat([ref_label] * 2) for ref_label in ref_labels] |
|
|
else: |
|
|
if ref_attention_masks is not None: |
|
|
ref_attention_masks = [torch.tensor([ref_attn_mask]).to(self.device) for ref_attn_mask in ref_attention_masks] |
|
|
if ref_labels is not None: |
|
|
ref_labels = [torch.tensor([ref_label]).to(self.device) for ref_label in ref_labels] |
|
|
|
|
|
|
|
|
extra_step_kwargs = prepare_extra_step_kwargs(self.noise_scheduler, generator, eta) |
|
|
num_warmup_steps = ( |
|
|
len(timesteps) - num_inference_steps * self.noise_scheduler.order |
|
|
) |
|
|
|
|
|
|
|
|
for ref_image_latent, ref_label, ref_attention_mask in zip(ref_images_latent, ref_labels, ref_attention_masks): |
|
|
|
|
|
if ref_attention_mask.sum() == 0: |
|
|
continue |
|
|
self.unet( |
|
|
sample=ref_image_latent, |
|
|
timestep=None, |
|
|
class_labels=ref_label, |
|
|
attention_mask=ref_attention_mask, |
|
|
return_dict=False, |
|
|
cache_kv=True, |
|
|
) |
|
|
|
|
|
with tqdm(total=num_inference_steps) as progress_bar: |
|
|
for i, t in enumerate(timesteps): |
|
|
t = t.to(self.device) |
|
|
|
|
|
|
|
|
non_inpainting_latent_model_input = ( |
|
|
torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
|
|
) |
|
|
non_inpainting_latent_model_input = ( |
|
|
self.noise_scheduler.scale_model_input( |
|
|
non_inpainting_latent_model_input, t |
|
|
) |
|
|
) |
|
|
|
|
|
inpainting_latent_model_input = torch.cat( |
|
|
[ |
|
|
non_inpainting_latent_model_input, |
|
|
mask_latent, |
|
|
masked_person_latent, |
|
|
], |
|
|
dim=1 |
|
|
) |
|
|
|
|
|
|
|
|
noise_pred = self.unet( |
|
|
inpainting_latent_model_input, |
|
|
t, |
|
|
cross_attention_kwargs=cross_attention_kwargs, |
|
|
return_dict=False, |
|
|
attention_mask=torch.tensor([1] * inpainting_latent_model_input.shape[0]).to(self.device), |
|
|
)[0] |
|
|
|
|
|
|
|
|
if do_classifier_free_guidance: |
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
|
noise_pred = noise_pred_uncond + guidance_scale * ( |
|
|
noise_pred_text - noise_pred_uncond |
|
|
) |
|
|
|
|
|
|
|
|
latents = self.noise_scheduler.step( |
|
|
noise_pred, t, latents, **extra_step_kwargs |
|
|
).prev_sample |
|
|
|
|
|
|
|
|
if i == len(timesteps) - 1 or ( |
|
|
(i + 1) > num_warmup_steps |
|
|
and (i + 1) % self.noise_scheduler.order == 0 |
|
|
): |
|
|
progress_bar.update() |
|
|
|
|
|
|
|
|
self.unet.clear_kv_cache() |
|
|
|
|
|
|
|
|
latents = (1 / self.vae.config.scaling_factor * latents).to(self.vae.device, dtype=self.vae.dtype) |
|
|
image = self.vae.decode(latents).sample |
|
|
|
|
|
|
|
|
image = image * mask + (1 - mask) * person |
|
|
|
|
|
if return_pil: |
|
|
image = (image / 2 + 0.5).clamp(0, 1) |
|
|
image = image.cpu().permute(0, 2, 3, 1).float().numpy() |
|
|
image = numpy_to_pil(image) |
|
|
|
|
|
if do_adjust_input_image: |
|
|
image, _ = paste_image_back_with_feathering( |
|
|
background_img, image[0], crop_box |
|
|
) |
|
|
image = [image] |
|
|
|
|
|
return image |
|
|
|