| from typing import Optional | |
| import diffusers | |
| import torch | |
| def patch_time_text_image_embedding_forward() -> None: | |
| _patch_time_text_image_embedding_forward() | |
| def _patch_time_text_image_embedding_forward() -> None: | |
| diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding.forward = ( | |
| _patched_WanTimeTextImageEmbedding_forward | |
| ) | |
| def _patched_WanTimeTextImageEmbedding_forward( | |
| self, | |
| timestep: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor, | |
| encoder_hidden_states_image: Optional[torch.Tensor] = None, | |
| ): | |
| # Some code has been removed compared to original implementation in Diffusers | |
| # Also, timestep is typed as that of encoder_hidden_states | |
| timestep = self.timesteps_proj(timestep).type_as(encoder_hidden_states) | |
| temb = self.time_embedder(timestep).type_as(encoder_hidden_states) | |
| timestep_proj = self.time_proj(self.act_fn(temb)) | |
| encoder_hidden_states = self.text_embedder(encoder_hidden_states) | |
| if encoder_hidden_states_image is not None: | |
| encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image) | |
| return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image | |