| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Any, Dict, Optional, Tuple, Union |
|
|
| import torch |
| from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput |
|
|
|
|
| def cachecrossattnupblock2d_forward( |
| self, |
| hidden_states: torch.FloatTensor, |
| res_hidden_states_0: torch.FloatTensor, |
| res_hidden_states_1: torch.FloatTensor, |
| res_hidden_states_2: torch.FloatTensor, |
| temb: Optional[torch.FloatTensor] = None, |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| upsample_size: Optional[int] = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, |
| ) -> torch.FloatTensor: |
| res_hidden_states_tuple = (res_hidden_states_0, res_hidden_states_1, res_hidden_states_2) |
| for resnet, attn in zip(self.resnets, self.attentions): |
| |
| res_hidden_states = res_hidden_states_tuple[-1] |
| res_hidden_states_tuple = res_hidden_states_tuple[:-1] |
|
|
| hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) |
|
|
| hidden_states = resnet(hidden_states, temb) |
| hidden_states = attn( |
| hidden_states, |
| encoder_hidden_states=encoder_hidden_states, |
| cross_attention_kwargs=cross_attention_kwargs, |
| attention_mask=attention_mask, |
| encoder_attention_mask=encoder_attention_mask, |
| return_dict=False, |
| )[0] |
|
|
| if self.upsamplers is not None: |
| for upsampler in self.upsamplers: |
| hidden_states = upsampler(hidden_states, upsample_size) |
|
|
| return hidden_states |
|
|
|
|
| def cacheupblock2d_forward( |
| self, |
| hidden_states: torch.FloatTensor, |
| res_hidden_states_0: torch.FloatTensor, |
| res_hidden_states_1: torch.FloatTensor, |
| res_hidden_states_2: torch.FloatTensor, |
| temb: Optional[torch.FloatTensor] = None, |
| upsample_size: Optional[int] = None, |
| ) -> torch.FloatTensor: |
| res_hidden_states_tuple = (res_hidden_states_0, res_hidden_states_1, res_hidden_states_2) |
| for resnet in self.resnets: |
| |
| res_hidden_states = res_hidden_states_tuple[-1] |
| res_hidden_states_tuple = res_hidden_states_tuple[:-1] |
|
|
| hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) |
|
|
| hidden_states = resnet(hidden_states, temb) |
|
|
| if self.upsamplers is not None: |
| for upsampler in self.upsamplers: |
| hidden_states = upsampler(hidden_states, upsample_size) |
|
|
| return hidden_states |
|
|
|
|
| def cacheunet_forward( |
| self, |
| sample: torch.FloatTensor, |
| timestep: Union[torch.Tensor, float, int], |
| encoder_hidden_states: torch.Tensor, |
| class_labels: Optional[torch.Tensor] = None, |
| timestep_cond: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, |
| down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, |
| mid_block_additional_residual: Optional[torch.Tensor] = None, |
| down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, |
| encoder_attention_mask: Optional[torch.Tensor] = None, |
| return_dict: bool = True, |
| ) -> Union[UNet2DConditionOutput, Tuple]: |
| |
| t_emb = self.get_time_embed(sample=sample, timestep=timestep) |
| emb = self.time_embedding(t_emb, timestep_cond) |
| aug_emb = None |
|
|
| aug_emb = self.get_aug_embed( |
| emb=emb, |
| encoder_hidden_states=encoder_hidden_states, |
| added_cond_kwargs=added_cond_kwargs, |
| ) |
|
|
| emb = emb + aug_emb if aug_emb is not None else emb |
|
|
| encoder_hidden_states = self.process_encoder_hidden_states( |
| encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs |
| ) |
|
|
| |
| sample = self.conv_in(sample) |
|
|
| if hasattr(self, "_export_precess_onnx") and self._export_precess_onnx: |
| return ( |
| sample, |
| encoder_hidden_states, |
| emb, |
| ) |
|
|
| down_block_res_samples = (sample,) |
| for i, downsample_block in enumerate(self.down_blocks): |
| if ( |
| hasattr(downsample_block, "has_cross_attention") |
| and downsample_block.has_cross_attention |
| ): |
| if hasattr(self, "use_trt_infer") and self.use_trt_infer: |
| feed_dict = { |
| "hidden_states": sample, |
| "temb": emb, |
| "encoder_hidden_states": encoder_hidden_states, |
| } |
| down_results = self.engines[f"down_blocks.{i}"](feed_dict, self.cuda_stream) |
| sample = down_results["sample"] |
| res_samples_0 = down_results["res_samples_0"] |
| res_samples_1 = down_results["res_samples_1"] |
| if "res_samples_2" in down_results.keys(): |
| res_samples_2 = down_results["res_samples_2"] |
| else: |
| |
| additional_residuals = {} |
|
|
| sample, res_samples = downsample_block( |
| hidden_states=sample, |
| temb=emb, |
| encoder_hidden_states=encoder_hidden_states, |
| attention_mask=attention_mask, |
| cross_attention_kwargs=cross_attention_kwargs, |
| encoder_attention_mask=encoder_attention_mask, |
| **additional_residuals, |
| ) |
| else: |
| if hasattr(self, "use_trt_infer") and self.use_trt_infer: |
| feed_dict = {"hidden_states": sample, "temb": emb} |
| down_results = self.engines[f"down_blocks.{i}"](feed_dict, self.cuda_stream) |
| sample = down_results["sample"] |
| res_samples_0 = down_results["res_samples_0"] |
| res_samples_1 = down_results["res_samples_1"] |
| if "res_samples_2" in down_results.keys(): |
| res_samples_2 = down_results["res_samples_2"] |
| else: |
| sample, res_samples = downsample_block(hidden_states=sample, temb=emb) |
|
|
| if hasattr(self, "use_trt_infer") and self.use_trt_infer: |
| down_block_res_samples += ( |
| res_samples_0, |
| res_samples_1, |
| ) |
| if "res_samples_2" in down_results.keys(): |
| down_block_res_samples += (res_samples_2,) |
| else: |
| down_block_res_samples += res_samples |
|
|
| if hasattr(self, "use_trt_infer") and self.use_trt_infer: |
| feed_dict = { |
| "hidden_states": sample, |
| "temb": emb, |
| "encoder_hidden_states": encoder_hidden_states, |
| } |
| mid_results = self.engines["mid_block"](feed_dict, self.cuda_stream) |
| sample = mid_results["sample"] |
| else: |
| sample = self.mid_block( |
| sample, |
| emb, |
| encoder_hidden_states=encoder_hidden_states, |
| attention_mask=attention_mask, |
| cross_attention_kwargs=cross_attention_kwargs, |
| encoder_attention_mask=encoder_attention_mask, |
| ) |
|
|
| |
| for i, upsample_block in enumerate(self.up_blocks): |
| res_samples = down_block_res_samples[-len(upsample_block.resnets) :] |
| down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] |
|
|
| if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: |
| if hasattr(self, "use_trt_infer") and self.use_trt_infer: |
| feed_dict = { |
| "hidden_states": sample, |
| "res_hidden_states_0": res_samples[0], |
| "res_hidden_states_1": res_samples[1], |
| "res_hidden_states_2": res_samples[2], |
| "temb": emb, |
| "encoder_hidden_states": encoder_hidden_states, |
| } |
| up_results = self.engines[f"up_blocks.{i}"](feed_dict, self.cuda_stream) |
| sample = up_results["sample"] |
| else: |
| sample = upsample_block( |
| hidden_states=sample, |
| temb=emb, |
| res_hidden_states_0=res_samples[0], |
| res_hidden_states_1=res_samples[1], |
| res_hidden_states_2=res_samples[2], |
| encoder_hidden_states=encoder_hidden_states, |
| cross_attention_kwargs=cross_attention_kwargs, |
| attention_mask=attention_mask, |
| encoder_attention_mask=encoder_attention_mask, |
| ) |
| else: |
| if hasattr(self, "use_trt_infer") and self.use_trt_infer: |
| feed_dict = { |
| "hidden_states": sample, |
| "res_hidden_states_0": res_samples[0], |
| "res_hidden_states_1": res_samples[1], |
| "res_hidden_states_2": res_samples[2], |
| "temb": emb, |
| } |
| up_results = self.engines[f"up_blocks.{i}"](feed_dict, self.cuda_stream) |
| sample = up_results["sample"] |
| else: |
| sample = upsample_block( |
| hidden_states=sample, |
| temb=emb, |
| res_hidden_states_0=res_samples[0], |
| res_hidden_states_1=res_samples[1], |
| res_hidden_states_2=res_samples[2], |
| ) |
|
|
| |
| if self.conv_norm_out: |
| sample = self.conv_norm_out(sample) |
| sample = self.conv_act(sample) |
| sample = self.conv_out(sample) |
|
|
| if not return_dict: |
| return (sample,) |
|
|
| return UNet2DConditionOutput(sample=sample) |
|
|