JiT-AnimeFace-Demo / model /pipeline.py
Plat
init
4b08319
from tqdm import tqdm
from PIL import Image
import torch
import torch.nn as nn
import numpy as np
from accelerate import init_empty_weights
from safetensors.torch import load_file
from .denoiser import JiT
from .class_encoder import ClassEncoder
from .config import JiTConfig, ClassContextConfig
# from .text_encoder import TextEncoder
# from ...modules.quant import replace_by_prequantized_weights
# from ...utils import tensor as tensor_utils
def tensor_to_images(
tensor: torch.Tensor,
) -> list[Image.Image]:
# -1~1 -> 0~255
# denormalize
tensor = tensor.clamp(-1.0, 1.0)
tensor = (tensor + 1.0) / 2.0 * 255.0
# permute
tensor = tensor.permute(0, 2, 3, 1) # [B, C, H, W] -> [B, H, W, C]
# convert to numpy array
image_array = tensor.cpu().float().numpy().astype(np.uint8)
return [Image.fromarray(image) for image in image_array]
class JiTModel(nn.Module):
denoiser: JiT
denoiser_class: type[JiT] = JiT
class_encoder: ClassEncoder
def __init__(
self,
config: JiTConfig,
):
super().__init__()
self.config = config
self.denoiser = self.denoiser_class(config.denoiser)
if isinstance(config.context_encoder, ClassContextConfig):
self.class_encoder = ClassEncoder(
label2id=config.context_encoder.label2id,
embedding_dim=config.denoiser.context_dim,
)
else:
raise NotImplementedError(
"Only ClassContextConfig is supported in this version."
)
self.progress_bar = tqdm
def _load_checkpoint(
self,
checkpoint_path: str,
strict: bool = True,
):
state_dict = load_file(checkpoint_path)
# replace_by_prequantized_weights(self, state_dict)
self.denoiser.load_state_dict(
{
key[len("denoiser.") :]: value
for key, value in state_dict.items()
if key.startswith("denoiser.")
},
strict=strict,
assign=True,
)
if self.class_encoder is not None:
self.class_encoder.load_state_dict(
{
key[len("class_encoder.") :]: value
for key, value in state_dict.items()
if key.startswith("class_encoder.")
},
strict=strict,
assign=True,
)
# if self.text_encoder is not None:
# self.text_encoder.model.load_state_dict(
# {
# key[len("text_encoder.") :]: value
# for key, value in state_dict.items()
# if key.startswith("text_encoder.")
# },
# strict=strict,
# assign=True,
# )
@classmethod
def from_pretrained(
cls,
config: JiTConfig,
checkpoint_path: str,
) -> "JiTModel":
with init_empty_weights():
model = cls(config)
model._load_checkpoint(checkpoint_path)
return model
@classmethod
def new_with_config(
cls,
config: JiTConfig,
) -> "JiTModel":
with init_empty_weights():
model = cls(config)
model.denoiser.to_empty(device="cpu")
model.denoiser.initialize_weights()
if isinstance(config.context_encoder, ClassContextConfig):
model.class_encoder.to_empty(device="cpu")
model.class_encoder.initialize_weights()
else:
# model.text_encoder = TextEncoder.from_remote(
# repo_id=config.context_encoder.pretrained_model,
# )
raise NotImplementedError(
"Only ClassContextConfig is supported in this version."
)
return model
def prepare_noisy_image(
self,
batch_size: int,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
seed: int | None = None,
):
if seed is not None:
generator = torch.Generator(device=device)
generator.manual_seed(seed)
noise = torch.randn(
(batch_size, 3, height, width),
dtype=dtype,
device=device,
generator=generator,
)
else:
noise = torch.randn(
(batch_size, 3, height, width),
dtype=dtype,
device=device,
)
return noise
def prepare_timesteps(
self,
num_inference_steps: int,
device: torch.device,
):
timesteps = torch.linspace(
0.0,
1.0,
num_inference_steps + 1,
device=device,
)
return timesteps
def prepare_context_embeddings(
self,
prompts: str | list[str],
negative_prompt: str | list[str],
max_token_length: int = 64,
do_cfg: bool = False,
):
# if self.text_encoder is not None:
# encoder_output = self.text_encoder.encode_prompts(
# prompts,
# negative_prompts=negative_prompt,
# use_negative_prompts=do_cfg,
# max_token_length=max_token_length,
# )
# if do_cfg:
# prompt_embeddings = torch.cat(
# [
# encoder_output.positive_embeddings,
# encoder_output.negative_embeddings,
# ]
# )
# attention_mask = torch.cat(
# [
# encoder_output.positive_attention_mask,
# encoder_output.negative_attention_mask,
# ]
# )
# else:
# prompt_embeddings = encoder_output.positive_embeddings
# attention_mask = encoder_output.positive_attention_mask
if self.class_encoder is not None:
embeddings, attention_mask = self.class_encoder.encode_prompts(
prompts,
max_token_length=max_token_length,
)
negative_embeddings, _ = self.class_encoder.encode_prompts(
negative_prompt,
max_token_length=max_token_length,
)
if do_cfg:
prompt_embeddings = torch.cat(
[
embeddings,
negative_embeddings,
],
dim=0,
)
attention_mask = torch.cat(
[
attention_mask,
attention_mask,
],
dim=0,
)
else:
prompt_embeddings = embeddings
else:
raise NotImplementedError("Only ClassEncoder is supported in this version.")
return prompt_embeddings, attention_mask
def to_pil_images(self, tensor: torch.Tensor) -> list[Image.Image]:
return tensor_to_images(tensor)
def image_to_velocity(
self,
image: torch.Tensor,
noisy: torch.Tensor,
timestep: torch.Tensor,
clamp_eps: float = 1e-5,
):
return (image - noisy) / (1 - timestep.view(-1, 1, 1, 1)).clamp_min_(clamp_eps)
def renorm_cfg(
self,
positive_velocity: torch.Tensor,
cfg_velocity: torch.Tensor,
) -> torch.Tensor:
positive_norm = torch.norm(positive_velocity, dim=-1, keepdim=True)
cfg_norm = torch.norm(cfg_velocity, dim=-1, keepdim=True)
new_cfg_velocity = cfg_velocity * (positive_norm / cfg_norm)
return new_cfg_velocity
def dynamic_thresholding(
self,
images: torch.Tensor,
percentile: float = 0.995,
) -> torch.Tensor:
"""
Apply dynamic thresholding to the images.
Args:
images (torch.Tensor): The input images tensor.
percentile (float): The percentile value for thresholding.
Returns:
torch.Tensor: The thresholded images tensor.
"""
batch_size = images.shape[0]
flattened_images = images.view(batch_size, -1)
abs_images = torch.abs(flattened_images)
s = torch.quantile(abs_images, percentile, dim=1, keepdim=True)
s = torch.clamp(s, min=1.0).view(batch_size, 1, 1, 1)
thresholded_images = torch.clamp(images, -s, s) / s
return thresholded_images
def normalize_prompts(
self,
prompt: str | list[str],
) -> list[str]:
return prompt if isinstance(prompt, list) else [prompt]
@torch.inference_mode()
def generate(
self,
prompt: str | list[str],
negative_prompt: str | list[str] | None = None,
width: int = 256,
height: int = 256,
num_inference_steps: int = 20,
cfg_scale: float = 2.0,
max_token_length: int = 64,
seed: int | None = None,
execution_dtype: torch.dtype = torch.bfloat16,
device: torch.device | str = torch.device("cuda"),
do_cfg_renorm: bool = False,
do_dynamic_thresholding: bool = False,
cfg_time_range: list[float] = [0.0, 1.0],
# do_offloading: bool = False,
):
# 1. Prepare args
execution_device: torch.device = (
torch.device(device) if isinstance(device, str) else device
)
do_cfg = cfg_scale > 1.0
timesteps = self.prepare_timesteps(
num_inference_steps=num_inference_steps,
device=execution_device,
)
batch_size = len(prompt) if isinstance(prompt, list) else 1
# 3. prepare noise
noisy_image = self.prepare_noisy_image(
batch_size=batch_size,
height=height,
width=width,
dtype=execution_dtype,
device=execution_device,
seed=seed,
)
negative_prompts = [""] if negative_prompt is None else negative_prompt
negative_prompts = self.normalize_prompts(negative_prompts)
if len(negative_prompts) != batch_size and len(negative_prompts) == 1:
negative_prompts = negative_prompts * batch_size
prompt_embeddings, attention_mask = self.prepare_context_embeddings(
prompts=prompt,
negative_prompt=negative_prompts,
max_token_length=max_token_length,
do_cfg=do_cfg,
)
# 4. Denoising loop
with self.progress_bar(total=num_inference_steps) as pbar:
for i, timestep in enumerate(timesteps[:-1]):
image_input = torch.cat([noisy_image] * 2) if do_cfg else noisy_image
batch_timestep = timestep.expand(image_input.shape[0])
model_pred = self.denoiser(
image=image_input,
timestep=batch_timestep,
context=prompt_embeddings,
context_mask=attention_mask,
)
if do_cfg and cfg_time_range[0] <= float(timestep) <= cfg_time_range[1]:
image_pred_positive, image_pred_negative = model_pred.chunk(2)
v_pred_positive = self.image_to_velocity(
image=image_pred_positive,
noisy=noisy_image,
timestep=timestep.expand(batch_size),
)
v_pred_negative = self.image_to_velocity(
image=image_pred_negative,
noisy=noisy_image,
timestep=timestep.expand(batch_size),
)
velocity = v_pred_positive + cfg_scale * (
v_pred_positive - v_pred_negative
)
if do_cfg_renorm:
velocity = self.renorm_cfg(
positive_velocity=v_pred_positive,
cfg_velocity=velocity,
)
if do_dynamic_thresholding:
# re-calculate the image prediction after cfg
image_pred = noisy_image + velocity * (1 - timestep)
image_pred = self.dynamic_thresholding(image_pred)
velocity = self.image_to_velocity(
image=image_pred,
noisy=noisy_image,
timestep=timestep.expand(batch_size),
)
else:
velocity = self.image_to_velocity(
image=model_pred[:batch_size],
noisy=noisy_image,
timestep=timestep.expand(batch_size),
)
# new noisy image
noisy_image = noisy_image + velocity * (timesteps[i + 1] - timestep)
pbar.update()
# now it should be clean
clean_image = noisy_image
# to PIL images
pil_images = self.to_pil_images(clean_image.cpu())
return pil_images