| import torch |
| import numpy as np |
|
|
| from einops import rearrange |
| from typing import Optional, Tuple, Union |
| from diffusers.models.modeling_outputs import Transformer2DModelOutput |
|
|
| from .models.cogvideox.custom_cogvideox_transformer_3d import CogVideoXTransformer3DModel |
| from .models.cogvideox.enhance_a_video.globals import set_num_frames |
|
|
| def poly1d(coefficients, x): |
| result = torch.zeros_like(x) |
| for i, coeff in enumerate(coefficients): |
| result += coeff * (x ** (len(coefficients) - 1 - i)) |
| return result.abs() |
|
|
| def fft(tensor): |
| tensor_fft = torch.fft.fft2(tensor) |
| tensor_fft_shifted = torch.fft.fftshift(tensor_fft) |
| B, C, H, W = tensor.size() |
| radius = min(H, W) // 5 |
| |
| Y, X = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij') |
| center_x, center_y = W // 2, H // 2 |
| mask = (X - center_x) ** 2 + (Y - center_y) ** 2 <= radius ** 2 |
| low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(tensor.device) |
| high_freq_mask = ~low_freq_mask |
| |
| low_freq_fft = tensor_fft_shifted * low_freq_mask |
| high_freq_fft = tensor_fft_shifted * high_freq_mask |
|
|
| return low_freq_fft, high_freq_fft |
|
|
| def teacache_cogvideox_forward( |
| self, |
| hidden_states: torch.Tensor, |
| encoder_hidden_states: torch.Tensor, |
| timestep: Union[int, float, torch.LongTensor], |
| timestep_cond: Optional[torch.Tensor] = None, |
| ofs: Optional[Union[int, float, torch.LongTensor]] = None, |
| image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| controlnet_states: torch.Tensor = None, |
| controlnet_weights: Optional[Union[float, int, list, np.ndarray, torch.FloatTensor]] = 1.0, |
| video_flow_features: Optional[torch.Tensor] = None, |
| return_dict: bool = True, |
| ): |
| batch_size, num_frames, channels, height, width = hidden_states.shape |
|
|
| set_num_frames(num_frames) |
| |
| |
| timesteps = timestep |
| t_emb = self.time_proj(timesteps) |
|
|
| |
| |
| |
| t_emb = t_emb.to(dtype=hidden_states.dtype) |
| emb = self.time_embedding(t_emb, timestep_cond) |
| if self.ofs_embedding is not None: |
| ofs_emb = self.ofs_proj(ofs) |
| ofs_emb = ofs_emb.to(dtype=hidden_states.dtype) |
| ofs_emb = self.ofs_embedding(ofs_emb) |
| emb = emb + ofs_emb |
|
|
| |
| p = self.config.patch_size |
| p_t = self.config.patch_size_t |
|
|
| hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) |
| hidden_states = self.embedding_dropout(hidden_states) |
|
|
| text_seq_length = encoder_hidden_states.shape[1] |
| encoder_hidden_states = hidden_states[:, :text_seq_length] |
| hidden_states = hidden_states[:, text_seq_length:] |
|
|
| |
| if not hasattr(self, 'accumulated_rel_l1_distance'): |
| should_calc = True |
| self.accumulated_rel_l1_distance = 0 |
| else: |
| try: |
| if not self.config.use_rotary_positional_embeddings: |
| |
| coefficients = [-3.10658903e+01, 2.54732368e+01, -5.92380459e+00, 1.75769064e+00, -3.61568434e-03] |
| else: |
| |
| coefficients = [-1.53880483e+03, 8.43202495e+02, -1.34363087e+02, 7.97131516e+00, -5.23162339e-02] |
| |
| self.accumulated_rel_l1_distance += poly1d(coefficients, ((emb-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean())) |
| if self.accumulated_rel_l1_distance < self.rel_l1_thresh: |
| should_calc = False |
| else: |
| should_calc = True |
| self.accumulated_rel_l1_distance = 0 |
| except: |
| should_calc = True |
| self.accumulated_rel_l1_distance = 0 |
|
|
| self.previous_modulated_input = emb |
|
|
| if self.use_fastercache: |
| self.fastercache_counter += 1 |
| |
| if not hasattr(self, 'delta_lf'): |
| self.delta_lf = torch.tensor(0.0, device=hidden_states.device, dtype=torch.float32) |
| if not hasattr(self, 'delta_hf'): |
| self.delta_hf = torch.tensor(0.0, device=hidden_states.device, dtype=torch.float32) |
|
|
| if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 != 0: |
| if not should_calc: |
| hidden_states += self.previous_residual |
| encoder_hidden_states += self.previous_residual_encoder |
| else: |
| ori_hidden_states = hidden_states.clone() |
| ori_encoder_hidden_states = encoder_hidden_states.clone() |
| |
| for i, block in enumerate(self.transformer_blocks): |
| hidden_states, encoder_hidden_states = block( |
| hidden_states=hidden_states[:1], |
| encoder_hidden_states=encoder_hidden_states[:1], |
| temb=emb[:1], |
| image_rotary_emb=image_rotary_emb, |
| video_flow_feature=video_flow_features[i][:1] if video_flow_features is not None else None, |
| fuser = self.fuser_list[i] if self.fuser_list is not None else None, |
| block_use_fastercache = i <= self.fastercache_num_blocks_to_cache, |
| fastercache_counter = self.fastercache_counter, |
| fastercache_start_step = self.fastercache_start_step, |
| fastercache_device = self.fastercache_device |
| ) |
|
|
| if (controlnet_states is not None) and (i < len(controlnet_states)): |
| controlnet_states_block = controlnet_states[i] |
| controlnet_block_weight = 1.0 |
| if isinstance(controlnet_weights, (list, np.ndarray)) or torch.is_tensor(controlnet_weights): |
| controlnet_block_weight = controlnet_weights[i] |
| elif isinstance(controlnet_weights, (float, int)): |
| controlnet_block_weight = controlnet_weights |
| |
| hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight |
| self.previous_residual = hidden_states - ori_hidden_states |
| self.previous_residual_encoder = encoder_hidden_states - ori_encoder_hidden_states |
| |
| if not self.config.use_rotary_positional_embeddings: |
| |
| hidden_states = self.norm_final(hidden_states) |
| else: |
| |
| hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) |
| hidden_states = self.norm_final(hidden_states) |
| hidden_states = hidden_states[:, text_seq_length:] |
|
|
| |
| hidden_states = self.norm_out(hidden_states, temb=emb[:1]) |
| hidden_states = self.proj_out(hidden_states) |
|
|
| |
| |
| |
| |
| |
| if p_t is None: |
| output = hidden_states.reshape(1, num_frames, height // p, width // p, -1, p, p) |
| output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) |
| else: |
| output = hidden_states.reshape( |
| 1, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p |
| ) |
| output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) |
| |
| (bb, tt, cc, hh, ww) = output.shape |
| cond = rearrange(output, "B T C H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww) |
| lf_c, hf_c = fft(cond.float()) |
|
|
| if self.fastercache_counter <= self.fastercache_lf_step: |
| self.delta_lf = self.delta_lf * 1.1 |
| if self.fastercache_counter >= self.fastercache_hf_step: |
| self.delta_hf = self.delta_hf * 1.1 |
|
|
| new_hf_uc = self.delta_hf + hf_c |
| new_lf_uc = self.delta_lf + lf_c |
|
|
| combine_uc = new_lf_uc + new_hf_uc |
| combined_fft = torch.fft.ifftshift(combine_uc) |
| recovered_uncond = torch.fft.ifft2(combined_fft).real |
| recovered_uncond = rearrange(recovered_uncond.to(output.dtype), "(B T) C H W -> B T C H W", B=bb, C=cc, T=tt, H=hh, W=ww) |
| output = torch.cat([output, recovered_uncond]) |
| else: |
| if not should_calc: |
| hidden_states += self.previous_residual |
| encoder_hidden_states += self.previous_residual_encoder |
| else: |
| ori_hidden_states = hidden_states.clone() |
| ori_encoder_hidden_states = encoder_hidden_states.clone() |
| for i, block in enumerate(self.transformer_blocks): |
| hidden_states, encoder_hidden_states = block( |
| hidden_states=hidden_states, |
| encoder_hidden_states=encoder_hidden_states, |
| temb=emb, |
| image_rotary_emb=image_rotary_emb, |
| video_flow_feature=video_flow_features[i] if video_flow_features is not None else None, |
| fuser = self.fuser_list[i] if self.fuser_list is not None else None, |
| block_use_fastercache = i <= self.fastercache_num_blocks_to_cache, |
| fastercache_counter = self.fastercache_counter, |
| fastercache_start_step = self.fastercache_start_step, |
| fastercache_device = self.fastercache_device |
| ) |
|
|
| |
| if (controlnet_states is not None) and (i < len(controlnet_states)): |
| controlnet_states_block = controlnet_states[i] |
| controlnet_block_weight = 1.0 |
| if isinstance(controlnet_weights, (list, np.ndarray)) or torch.is_tensor(controlnet_weights): |
| controlnet_block_weight = controlnet_weights[i] |
| elif isinstance(controlnet_weights, (float, int)): |
| controlnet_block_weight = controlnet_weights |
| hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight |
| self.previous_residual = hidden_states - ori_hidden_states |
| self.previous_residual_encoder = encoder_hidden_states - ori_encoder_hidden_states |
| |
| if not self.config.use_rotary_positional_embeddings: |
| |
| hidden_states = self.norm_final(hidden_states) |
| else: |
| |
| hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) |
| hidden_states = self.norm_final(hidden_states) |
| hidden_states = hidden_states[:, text_seq_length:] |
|
|
| |
| hidden_states = self.norm_out(hidden_states, temb=emb) |
| hidden_states = self.proj_out(hidden_states) |
|
|
| |
| |
| |
| |
| |
| if p_t is None: |
| output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) |
| output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) |
| else: |
| output = hidden_states.reshape( |
| batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p |
| ) |
| output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) |
|
|
| if self.fastercache_counter >= self.fastercache_start_step + 1: |
| (bb, tt, cc, hh, ww) = output.shape |
| cond = rearrange(output[0:1].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww) |
| uncond = rearrange(output[1:2].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww) |
|
|
| lf_c, hf_c = fft(cond) |
| lf_uc, hf_uc = fft(uncond) |
|
|
| self.delta_lf = lf_uc - lf_c |
| self.delta_hf = hf_uc - hf_c |
|
|
| if not return_dict: |
| return (output,) |
| return Transformer2DModelOutput(sample=output) |
|
|
| class TeaCacheForCogVideoX: |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "model": ("COGVIDEOMODEL", {"tooltip": "The CogVideoX model the TeaCache will be applied to."}), |
| "enable_teacache": ("BOOLEAN", {"default": True, "tooltip": "Enable teacache will speed up inference but may lose visual quality."}), |
| "rel_l1_thresh": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "How strongly to cache the output of diffusion model. This value must be non-negative."}) |
| } |
| } |
| |
| RETURN_TYPES = ("COGVIDEOMODEL",) |
| RETURN_NAMES = ("model",) |
| FUNCTION = "apply_teacache" |
| CATEGORY = "TeaCache" |
| TITLE = "TeaCache For CogVideoX" |
| |
| def apply_teacache(self, model, enable_teacache: bool, rel_l1_thresh: float): |
| if enable_teacache: |
| transformer = model["pipe"].transformer |
| transformer.rel_l1_thresh = rel_l1_thresh |
| transformer.forward = teacache_cogvideox_forward.__get__( |
| transformer, |
| transformer.__class__ |
| ) |
| else: |
| transformer = model["pipe"].transformer |
| transformer.forward = CogVideoXTransformer3DModel.forward.__get__( |
| transformer, |
| transformer.__class__ |
| ) |
| |
| return (model,) |
| |
|
|
| NODE_CLASS_MAPPINGS = { |
| "TeaCacheForCogVideoX": TeaCacheForCogVideoX |
| } |
|
|
| NODE_DISPLAY_NAME_MAPPINGS = {k: v.TITLE for k, v in NODE_CLASS_MAPPINGS.items()} |
|
|