Spaces:
Running
on
Zero
Running
on
Zero
| from tqdm import tqdm | |
| from PIL import Image | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from accelerate import init_empty_weights | |
| from safetensors.torch import load_file | |
| from .denoiser import JiT | |
| from .class_encoder import ClassEncoder | |
| from .config import JiTConfig, ClassContextConfig | |
| # from .text_encoder import TextEncoder | |
| # from ...modules.quant import replace_by_prequantized_weights | |
| # from ...utils import tensor as tensor_utils | |
| def tensor_to_images( | |
| tensor: torch.Tensor, | |
| ) -> list[Image.Image]: | |
| # -1~1 -> 0~255 | |
| # denormalize | |
| tensor = tensor.clamp(-1.0, 1.0) | |
| tensor = (tensor + 1.0) / 2.0 * 255.0 | |
| # permute | |
| tensor = tensor.permute(0, 2, 3, 1) # [B, C, H, W] -> [B, H, W, C] | |
| # convert to numpy array | |
| image_array = tensor.cpu().float().numpy().astype(np.uint8) | |
| return [Image.fromarray(image) for image in image_array] | |
| class JiTModel(nn.Module): | |
| denoiser: JiT | |
| denoiser_class: type[JiT] = JiT | |
| class_encoder: ClassEncoder | |
| def __init__( | |
| self, | |
| config: JiTConfig, | |
| ): | |
| super().__init__() | |
| self.config = config | |
| self.denoiser = self.denoiser_class(config.denoiser) | |
| if isinstance(config.context_encoder, ClassContextConfig): | |
| self.class_encoder = ClassEncoder( | |
| label2id=config.context_encoder.label2id, | |
| embedding_dim=config.denoiser.context_dim, | |
| ) | |
| else: | |
| raise NotImplementedError( | |
| "Only ClassContextConfig is supported in this version." | |
| ) | |
| self.progress_bar = tqdm | |
| def _load_checkpoint( | |
| self, | |
| checkpoint_path: str, | |
| strict: bool = True, | |
| ): | |
| state_dict = load_file(checkpoint_path) | |
| # replace_by_prequantized_weights(self, state_dict) | |
| self.denoiser.load_state_dict( | |
| { | |
| key[len("denoiser.") :]: value | |
| for key, value in state_dict.items() | |
| if key.startswith("denoiser.") | |
| }, | |
| strict=strict, | |
| assign=True, | |
| ) | |
| if self.class_encoder is not None: | |
| self.class_encoder.load_state_dict( | |
| { | |
| key[len("class_encoder.") :]: value | |
| for key, value in state_dict.items() | |
| if key.startswith("class_encoder.") | |
| }, | |
| strict=strict, | |
| assign=True, | |
| ) | |
| # if self.text_encoder is not None: | |
| # self.text_encoder.model.load_state_dict( | |
| # { | |
| # key[len("text_encoder.") :]: value | |
| # for key, value in state_dict.items() | |
| # if key.startswith("text_encoder.") | |
| # }, | |
| # strict=strict, | |
| # assign=True, | |
| # ) | |
| def from_pretrained( | |
| cls, | |
| config: JiTConfig, | |
| checkpoint_path: str, | |
| ) -> "JiTModel": | |
| with init_empty_weights(): | |
| model = cls(config) | |
| model._load_checkpoint(checkpoint_path) | |
| return model | |
| def new_with_config( | |
| cls, | |
| config: JiTConfig, | |
| ) -> "JiTModel": | |
| with init_empty_weights(): | |
| model = cls(config) | |
| model.denoiser.to_empty(device="cpu") | |
| model.denoiser.initialize_weights() | |
| if isinstance(config.context_encoder, ClassContextConfig): | |
| model.class_encoder.to_empty(device="cpu") | |
| model.class_encoder.initialize_weights() | |
| else: | |
| # model.text_encoder = TextEncoder.from_remote( | |
| # repo_id=config.context_encoder.pretrained_model, | |
| # ) | |
| raise NotImplementedError( | |
| "Only ClassContextConfig is supported in this version." | |
| ) | |
| return model | |
| def prepare_noisy_image( | |
| self, | |
| batch_size: int, | |
| height: int, | |
| width: int, | |
| dtype: torch.dtype, | |
| device: torch.device, | |
| seed: int | None = None, | |
| ): | |
| if seed is not None: | |
| generator = torch.Generator(device=device) | |
| generator.manual_seed(seed) | |
| noise = torch.randn( | |
| (batch_size, 3, height, width), | |
| dtype=dtype, | |
| device=device, | |
| generator=generator, | |
| ) | |
| else: | |
| noise = torch.randn( | |
| (batch_size, 3, height, width), | |
| dtype=dtype, | |
| device=device, | |
| ) | |
| return noise | |
| def prepare_timesteps( | |
| self, | |
| num_inference_steps: int, | |
| device: torch.device, | |
| ): | |
| timesteps = torch.linspace( | |
| 0.0, | |
| 1.0, | |
| num_inference_steps + 1, | |
| device=device, | |
| ) | |
| return timesteps | |
| def prepare_context_embeddings( | |
| self, | |
| prompts: str | list[str], | |
| negative_prompt: str | list[str], | |
| max_token_length: int = 64, | |
| do_cfg: bool = False, | |
| ): | |
| # if self.text_encoder is not None: | |
| # encoder_output = self.text_encoder.encode_prompts( | |
| # prompts, | |
| # negative_prompts=negative_prompt, | |
| # use_negative_prompts=do_cfg, | |
| # max_token_length=max_token_length, | |
| # ) | |
| # if do_cfg: | |
| # prompt_embeddings = torch.cat( | |
| # [ | |
| # encoder_output.positive_embeddings, | |
| # encoder_output.negative_embeddings, | |
| # ] | |
| # ) | |
| # attention_mask = torch.cat( | |
| # [ | |
| # encoder_output.positive_attention_mask, | |
| # encoder_output.negative_attention_mask, | |
| # ] | |
| # ) | |
| # else: | |
| # prompt_embeddings = encoder_output.positive_embeddings | |
| # attention_mask = encoder_output.positive_attention_mask | |
| if self.class_encoder is not None: | |
| embeddings, attention_mask = self.class_encoder.encode_prompts( | |
| prompts, | |
| max_token_length=max_token_length, | |
| ) | |
| negative_embeddings, _ = self.class_encoder.encode_prompts( | |
| negative_prompt, | |
| max_token_length=max_token_length, | |
| ) | |
| if do_cfg: | |
| prompt_embeddings = torch.cat( | |
| [ | |
| embeddings, | |
| negative_embeddings, | |
| ], | |
| dim=0, | |
| ) | |
| attention_mask = torch.cat( | |
| [ | |
| attention_mask, | |
| attention_mask, | |
| ], | |
| dim=0, | |
| ) | |
| else: | |
| prompt_embeddings = embeddings | |
| else: | |
| raise NotImplementedError("Only ClassEncoder is supported in this version.") | |
| return prompt_embeddings, attention_mask | |
| def to_pil_images(self, tensor: torch.Tensor) -> list[Image.Image]: | |
| return tensor_to_images(tensor) | |
| def image_to_velocity( | |
| self, | |
| image: torch.Tensor, | |
| noisy: torch.Tensor, | |
| timestep: torch.Tensor, | |
| clamp_eps: float = 1e-5, | |
| ): | |
| return (image - noisy) / (1 - timestep.view(-1, 1, 1, 1)).clamp_min_(clamp_eps) | |
| def renorm_cfg( | |
| self, | |
| positive_velocity: torch.Tensor, | |
| cfg_velocity: torch.Tensor, | |
| ) -> torch.Tensor: | |
| positive_norm = torch.norm(positive_velocity, dim=-1, keepdim=True) | |
| cfg_norm = torch.norm(cfg_velocity, dim=-1, keepdim=True) | |
| new_cfg_velocity = cfg_velocity * (positive_norm / cfg_norm) | |
| return new_cfg_velocity | |
| def dynamic_thresholding( | |
| self, | |
| images: torch.Tensor, | |
| percentile: float = 0.995, | |
| ) -> torch.Tensor: | |
| """ | |
| Apply dynamic thresholding to the images. | |
| Args: | |
| images (torch.Tensor): The input images tensor. | |
| percentile (float): The percentile value for thresholding. | |
| Returns: | |
| torch.Tensor: The thresholded images tensor. | |
| """ | |
| batch_size = images.shape[0] | |
| flattened_images = images.view(batch_size, -1) | |
| abs_images = torch.abs(flattened_images) | |
| s = torch.quantile(abs_images, percentile, dim=1, keepdim=True) | |
| s = torch.clamp(s, min=1.0).view(batch_size, 1, 1, 1) | |
| thresholded_images = torch.clamp(images, -s, s) / s | |
| return thresholded_images | |
| def normalize_prompts( | |
| self, | |
| prompt: str | list[str], | |
| ) -> list[str]: | |
| return prompt if isinstance(prompt, list) else [prompt] | |
| def generate( | |
| self, | |
| prompt: str | list[str], | |
| negative_prompt: str | list[str] | None = None, | |
| width: int = 256, | |
| height: int = 256, | |
| num_inference_steps: int = 20, | |
| cfg_scale: float = 2.0, | |
| max_token_length: int = 64, | |
| seed: int | None = None, | |
| execution_dtype: torch.dtype = torch.bfloat16, | |
| device: torch.device | str = torch.device("cuda"), | |
| do_cfg_renorm: bool = False, | |
| do_dynamic_thresholding: bool = False, | |
| cfg_time_range: list[float] = [0.0, 1.0], | |
| # do_offloading: bool = False, | |
| ): | |
| # 1. Prepare args | |
| execution_device: torch.device = ( | |
| torch.device(device) if isinstance(device, str) else device | |
| ) | |
| do_cfg = cfg_scale > 1.0 | |
| timesteps = self.prepare_timesteps( | |
| num_inference_steps=num_inference_steps, | |
| device=execution_device, | |
| ) | |
| batch_size = len(prompt) if isinstance(prompt, list) else 1 | |
| # 3. prepare noise | |
| noisy_image = self.prepare_noisy_image( | |
| batch_size=batch_size, | |
| height=height, | |
| width=width, | |
| dtype=execution_dtype, | |
| device=execution_device, | |
| seed=seed, | |
| ) | |
| negative_prompts = [""] if negative_prompt is None else negative_prompt | |
| negative_prompts = self.normalize_prompts(negative_prompts) | |
| if len(negative_prompts) != batch_size and len(negative_prompts) == 1: | |
| negative_prompts = negative_prompts * batch_size | |
| prompt_embeddings, attention_mask = self.prepare_context_embeddings( | |
| prompts=prompt, | |
| negative_prompt=negative_prompts, | |
| max_token_length=max_token_length, | |
| do_cfg=do_cfg, | |
| ) | |
| # 4. Denoising loop | |
| with self.progress_bar(total=num_inference_steps) as pbar: | |
| for i, timestep in enumerate(timesteps[:-1]): | |
| image_input = torch.cat([noisy_image] * 2) if do_cfg else noisy_image | |
| batch_timestep = timestep.expand(image_input.shape[0]) | |
| model_pred = self.denoiser( | |
| image=image_input, | |
| timestep=batch_timestep, | |
| context=prompt_embeddings, | |
| context_mask=attention_mask, | |
| ) | |
| if do_cfg and cfg_time_range[0] <= float(timestep) <= cfg_time_range[1]: | |
| image_pred_positive, image_pred_negative = model_pred.chunk(2) | |
| v_pred_positive = self.image_to_velocity( | |
| image=image_pred_positive, | |
| noisy=noisy_image, | |
| timestep=timestep.expand(batch_size), | |
| ) | |
| v_pred_negative = self.image_to_velocity( | |
| image=image_pred_negative, | |
| noisy=noisy_image, | |
| timestep=timestep.expand(batch_size), | |
| ) | |
| velocity = v_pred_positive + cfg_scale * ( | |
| v_pred_positive - v_pred_negative | |
| ) | |
| if do_cfg_renorm: | |
| velocity = self.renorm_cfg( | |
| positive_velocity=v_pred_positive, | |
| cfg_velocity=velocity, | |
| ) | |
| if do_dynamic_thresholding: | |
| # re-calculate the image prediction after cfg | |
| image_pred = noisy_image + velocity * (1 - timestep) | |
| image_pred = self.dynamic_thresholding(image_pred) | |
| velocity = self.image_to_velocity( | |
| image=image_pred, | |
| noisy=noisy_image, | |
| timestep=timestep.expand(batch_size), | |
| ) | |
| else: | |
| velocity = self.image_to_velocity( | |
| image=model_pred[:batch_size], | |
| noisy=noisy_image, | |
| timestep=timestep.expand(batch_size), | |
| ) | |
| # new noisy image | |
| noisy_image = noisy_image + velocity * (timesteps[i + 1] - timestep) | |
| pbar.update() | |
| # now it should be clean | |
| clean_image = noisy_image | |
| # to PIL images | |
| pil_images = self.to_pil_images(clean_image.cpu()) | |
| return pil_images | |