| import math |
| import torch |
| import comfy.ldm.common_dit |
| import comfy.model_management as mm |
|
|
| from torch import Tensor |
| from einops import repeat |
| from typing import Optional |
| from unittest.mock import patch |
|
|
| from comfy.ldm.flux.layers import timestep_embedding, apply_mod |
| from comfy.ldm.lightricks.model import precompute_freqs_cis |
| from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords |
| from comfy.ldm.wan.model import sinusoidal_embedding_1d |
|
|
|
|
| SUPPORTED_MODELS_COEFFICIENTS = { |
| "flux": [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01], |
| "flux-kontext": [-1.04655119e+03, 3.12563399e+02, -1.69500694e+01, 4.10995971e-01, 3.74537863e-02], |
| "ltxv": [2.14700694e+01, -1.28016453e+01, 2.31279151e+00, 7.92487521e-01, 9.69274326e-03], |
| "lumina_2": [-8.74643948e+02, 4.66059906e+02, -7.51559762e+01, 5.32836175e+00, -3.27258296e-02], |
| "hunyuan_video": [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02], |
| "hidream_i1_full": [-3.13605009e+04, -7.12425503e+02, 4.91363285e+01, 8.26515490e+00, 1.08053901e-01], |
| "hidream_i1_dev": [1.39997273, -4.30130469, 5.01534416, -2.20504164, 0.93942874], |
| "hidream_i1_fast": [2.26509623, -6.88864563, 7.61123826, -3.10849353, 0.99927602], |
| "wan2.1_t2v_1.3B": [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01], |
| "wan2.1_t2v_14B": [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404], |
| "wan2.1_i2v_480p_14B": [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01], |
| "wan2.1_i2v_720p_14B": [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683], |
| "wan2.1_t2v_1.3B_ret_mode": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], |
| "wan2.1_t2v_14B_ret_mode": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01], |
| "wan2.1_i2v_480p_14B_ret_mode": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01], |
| "wan2.1_i2v_720p_14B_ret_mode": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], |
| } |
|
|
| def poly1d(coefficients, x): |
| result = torch.zeros_like(x) |
| for i, coeff in enumerate(coefficients): |
| result += coeff * (x ** (len(coefficients) - 1 - i)) |
| return result |
|
|
| def teacache_flux_forward( |
| self, |
| img: Tensor, |
| img_ids: Tensor, |
| txt: Tensor, |
| txt_ids: Tensor, |
| timesteps: Tensor, |
| y: Tensor, |
| guidance: Tensor = None, |
| control = None, |
| transformer_options={}, |
| attn_mask: Tensor = None, |
| ) -> Tensor: |
| patches_replace = transformer_options.get("patches_replace", {}) |
| rel_l1_thresh = transformer_options.get("rel_l1_thresh") |
| coefficients = transformer_options.get("coefficients") |
| enable_teacache = transformer_options.get("enable_teacache", True) |
| cache_device = transformer_options.get("cache_device") |
|
|
| if y is None: |
| y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) |
| |
| if img.ndim != 3 or txt.ndim != 3: |
| raise ValueError("Input img and txt tensors must have 3 dimensions.") |
|
|
| |
| img = self.img_in(img) |
| vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype)) |
| if self.params.guidance_embed: |
| if guidance is not None: |
| vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) |
|
|
| vec = vec + self.vector_in(y[:,:self.params.vec_in_dim]) |
| txt = self.txt_in(txt) |
|
|
| if img_ids is not None: |
| ids = torch.cat((txt_ids, img_ids), dim=1) |
| pe = self.pe_embedder(ids) |
| else: |
| pe = None |
|
|
| blocks_replace = patches_replace.get("dit", {}) |
|
|
| |
| img_mod1, _ = self.double_blocks[0].img_mod(vec) |
| modulated_inp = self.double_blocks[0].img_norm1(img) |
| modulated_inp = apply_mod(modulated_inp, (1 + img_mod1.scale), img_mod1.shift).to(cache_device) |
| ca_idx = 0 |
|
|
| if not hasattr(self, 'accumulated_rel_l1_distance'): |
| should_calc = True |
| self.accumulated_rel_l1_distance = 0 |
| else: |
| try: |
| self.accumulated_rel_l1_distance += poly1d(coefficients, ((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean())).abs() |
| if self.accumulated_rel_l1_distance < rel_l1_thresh: |
| should_calc = False |
| else: |
| should_calc = True |
| self.accumulated_rel_l1_distance = 0 |
| except: |
| should_calc = True |
| self.accumulated_rel_l1_distance = 0 |
|
|
| self.previous_modulated_input = modulated_inp |
|
|
| if not enable_teacache: |
| should_calc = True |
|
|
| if not should_calc: |
| img += self.previous_residual.to(img.device) |
| else: |
| ori_img = img.to(cache_device) |
| for i, block in enumerate(self.double_blocks): |
| if ("double_block", i) in blocks_replace: |
| def block_wrap(args): |
| out = {} |
| out["img"], out["txt"] = block(img=args["img"], |
| txt=args["txt"], |
| vec=args["vec"], |
| pe=args["pe"], |
| attn_mask=args.get("attn_mask")) |
| return out |
|
|
| out = blocks_replace[("double_block", i)]({"img": img, |
| "txt": txt, |
| "vec": vec, |
| "pe": pe, |
| "attn_mask": attn_mask}, |
| {"original_block": block_wrap}) |
| txt = out["txt"] |
| img = out["img"] |
| else: |
| img, txt = block(img=img, |
| txt=txt, |
| vec=vec, |
| pe=pe, |
| attn_mask=attn_mask) |
|
|
| if control is not None: |
| control_i = control.get("input") |
| if i < len(control_i): |
| add = control_i[i] |
| if add is not None: |
| img += add |
|
|
| |
| if getattr(self, "pulid_data", {}): |
| if i % self.pulid_double_interval == 0: |
| |
| for _, node_data in self.pulid_data.items(): |
| if torch.any((node_data['sigma_start'] >= timesteps) |
| & (timesteps >= node_data['sigma_end'])): |
| img = img + node_data['weight'] * self.pulid_ca[ca_idx](node_data['embedding'], img) |
| ca_idx += 1 |
|
|
| if img.dtype == torch.float16: |
| img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504) |
|
|
| img = torch.cat((txt, img), 1) |
|
|
| for i, block in enumerate(self.single_blocks): |
| if ("single_block", i) in blocks_replace: |
| def block_wrap(args): |
| out = {} |
| out["img"] = block(args["img"], |
| vec=args["vec"], |
| pe=args["pe"], |
| attn_mask=args.get("attn_mask")) |
| return out |
|
|
| out = blocks_replace[("single_block", i)]({"img": img, |
| "vec": vec, |
| "pe": pe, |
| "attn_mask": attn_mask}, |
| {"original_block": block_wrap}) |
| img = out["img"] |
| else: |
| img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) |
|
|
| if control is not None: |
| control_o = control.get("output") |
| if i < len(control_o): |
| add = control_o[i] |
| if add is not None: |
| img[:, txt.shape[1] :, ...] += add |
|
|
| |
| if getattr(self, "pulid_data", {}): |
| real_img, txt = img[:, txt.shape[1]:, ...], img[:, :txt.shape[1], ...] |
| if i % self.pulid_single_interval == 0: |
| |
| for _, node_data in self.pulid_data.items(): |
| if torch.any((node_data['sigma_start'] >= timesteps) |
| & (timesteps >= node_data['sigma_end'])): |
| real_img = real_img + node_data['weight'] * self.pulid_ca[ca_idx](node_data['embedding'], real_img) |
| ca_idx += 1 |
| img = torch.cat((txt, real_img), 1) |
|
|
| img = img[:, txt.shape[1] :, ...] |
| self.previous_residual = img.to(cache_device) - ori_img |
|
|
| img = self.final_layer(img, vec) |
| |
| return img |
|
|
| def teacache_hidream_forward( |
| self, |
| x: torch.Tensor, |
| t: torch.Tensor, |
| y: Optional[torch.Tensor] = None, |
| context: Optional[torch.Tensor] = None, |
| encoder_hidden_states_llama3=None, |
| image_cond=None, |
| control = None, |
| transformer_options = {}, |
| ) -> torch.Tensor: |
| rel_l1_thresh = transformer_options.get("rel_l1_thresh") |
| coefficients = transformer_options.get("coefficients") |
| cond_or_uncond = transformer_options.get("cond_or_uncond") |
| model_type = transformer_options.get("model_type") |
| enable_teacache = transformer_options.get("enable_teacache", True) |
| cache_device = transformer_options.get("cache_device") |
|
|
| bs, c, h, w = x.shape |
| if image_cond is not None: |
| x = torch.cat([x, image_cond], dim=-1) |
| hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) |
| timesteps = t |
| pooled_embeds = y |
| T5_encoder_hidden_states = context |
|
|
| img_sizes = None |
|
|
| |
| batch_size = hidden_states.shape[0] |
| hidden_states_type = hidden_states.dtype |
|
|
| |
| timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device) |
| timesteps = self.t_embedder(timesteps, hidden_states_type) |
| p_embedder = self.p_embedder(pooled_embeds) |
| adaln_input = timesteps + p_embedder |
|
|
| hidden_states, image_tokens_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes) |
| if image_tokens_masks is None: |
| pH, pW = img_sizes[0] |
| img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device) |
| img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None] |
| img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :] |
| img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size) |
| hidden_states = self.x_embedder(hidden_states) |
|
|
| |
| encoder_hidden_states = encoder_hidden_states_llama3.movedim(1, 0) |
| encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers] |
|
|
| if self.caption_projection is not None: |
| new_encoder_hidden_states = [] |
| for i, enc_hidden_state in enumerate(encoder_hidden_states): |
| enc_hidden_state = self.caption_projection[i](enc_hidden_state) |
| enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1]) |
| new_encoder_hidden_states.append(enc_hidden_state) |
| encoder_hidden_states = new_encoder_hidden_states |
| T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states) |
| T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) |
| encoder_hidden_states.append(T5_encoder_hidden_states) |
|
|
| txt_ids = torch.zeros( |
| batch_size, |
| encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1], |
| 3, |
| device=img_ids.device, dtype=img_ids.dtype |
| ) |
| ids = torch.cat((img_ids, txt_ids), dim=1) |
| rope = self.pe_embedder(ids) |
|
|
| |
| modulated_inp = timesteps.to(cache_device) if "full" in model_type else hidden_states.to(cache_device) |
| if not hasattr(self, 'teacache_state'): |
| self.teacache_state = { |
| 0: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None}, |
| 1: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None} |
| } |
|
|
| def update_cache_state(cache, modulated_inp): |
| if cache['previous_modulated_input'] is not None: |
| try: |
| cache['accumulated_rel_l1_distance'] += poly1d(coefficients, ((modulated_inp-cache['previous_modulated_input']).abs().mean() / cache['previous_modulated_input'].abs().mean())) |
| if cache['accumulated_rel_l1_distance'] < rel_l1_thresh: |
| cache['should_calc'] = False |
| else: |
| cache['should_calc'] = True |
| cache['accumulated_rel_l1_distance'] = 0 |
| except: |
| cache['should_calc'] = True |
| cache['accumulated_rel_l1_distance'] = 0 |
| cache['previous_modulated_input'] = modulated_inp |
| |
| b = int(len(hidden_states) / len(cond_or_uncond)) |
|
|
| for i, k in enumerate(cond_or_uncond): |
| update_cache_state(self.teacache_state[k], modulated_inp[i*b:(i+1)*b]) |
|
|
| if enable_teacache: |
| should_calc = False |
| for k in cond_or_uncond: |
| should_calc = (should_calc or self.teacache_state[k]['should_calc']) |
| else: |
| should_calc = True |
|
|
| if not should_calc: |
| for i, k in enumerate(cond_or_uncond): |
| hidden_states[i*b:(i+1)*b] += self.teacache_state[k]['previous_residual'].to(hidden_states.device) |
| else: |
| |
| ori_hidden_states = hidden_states.to(cache_device) |
| block_id = 0 |
| initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1) |
| initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1] |
| for bid, block in enumerate(self.double_stream_blocks): |
| cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] |
| cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1) |
| hidden_states, initial_encoder_hidden_states = block( |
| image_tokens = hidden_states, |
| image_tokens_masks = image_tokens_masks, |
| text_tokens = cur_encoder_hidden_states, |
| adaln_input = adaln_input, |
| rope = rope, |
| ) |
| initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len] |
| block_id += 1 |
|
|
| image_tokens_seq_len = hidden_states.shape[1] |
| hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1) |
| hidden_states_seq_len = hidden_states.shape[1] |
| if image_tokens_masks is not None: |
| encoder_attention_mask_ones = torch.ones( |
| (batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]), |
| device=image_tokens_masks.device, dtype=image_tokens_masks.dtype |
| ) |
| image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1) |
|
|
| for bid, block in enumerate(self.single_stream_blocks): |
| cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] |
| hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1) |
| hidden_states = block( |
| image_tokens=hidden_states, |
| image_tokens_masks=image_tokens_masks, |
| text_tokens=None, |
| adaln_input=adaln_input, |
| rope=rope, |
| ) |
| hidden_states = hidden_states[:, :hidden_states_seq_len] |
| block_id += 1 |
|
|
| hidden_states = hidden_states[:, :image_tokens_seq_len, ...] |
| for i, k in enumerate(cond_or_uncond): |
| self.teacache_state[k]['previous_residual'] = (hidden_states.to(cache_device) - ori_hidden_states)[i*b:(i+1)*b] |
|
|
| output = self.final_layer(hidden_states, adaln_input) |
| output = self.unpatchify(output, img_sizes) |
| return -output[:, :, :h, :w] |
|
|
| def teacache_lumina_forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs): |
| rel_l1_thresh = transformer_options.get("rel_l1_thresh") |
| coefficients = transformer_options.get("coefficients") |
| cond_or_uncond = transformer_options.get("cond_or_uncond") |
| enable_teacache = transformer_options.get("enable_teacache", True) |
| cache_device = transformer_options.get("cache_device") |
|
|
| t = 1.0 - timesteps |
| cap_feats = context |
| cap_mask = attention_mask |
| bs, c, h, w = x.shape |
| x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) |
| |
| t = self.t_embedder(t, dtype=x.dtype) |
| adaln_input = t |
|
|
| cap_feats = self.cap_embedder(cap_feats) |
|
|
| x_is_tensor = isinstance(x, torch.Tensor) |
| x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens) |
| freqs_cis = freqs_cis.to(x.device) |
|
|
| |
| modulated_inp = t.to(cache_device) |
| if not hasattr(self, 'teacache_state'): |
| self.teacache_state = { |
| 0: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None}, |
| 1: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None} |
| } |
|
|
| def update_cache_state(cache, modulated_inp): |
| if cache['previous_modulated_input'] is not None: |
| try: |
| cache['accumulated_rel_l1_distance'] += poly1d(coefficients, ((modulated_inp-cache['previous_modulated_input']).abs().mean() / cache['previous_modulated_input'].abs().mean())) |
| if cache['accumulated_rel_l1_distance'] < rel_l1_thresh: |
| cache['should_calc'] = False |
| else: |
| cache['should_calc'] = True |
| cache['accumulated_rel_l1_distance'] = 0 |
| except: |
| cache['should_calc'] = True |
| cache['accumulated_rel_l1_distance'] = 0 |
| cache['previous_modulated_input'] = modulated_inp |
|
|
| b = int(len(x) / len(cond_or_uncond)) |
|
|
| for i, k in enumerate(cond_or_uncond): |
| update_cache_state(self.teacache_state[k], modulated_inp[i*b:(i+1)*b]) |
|
|
| if enable_teacache: |
| should_calc = False |
| for k in cond_or_uncond: |
| should_calc = (should_calc or self.teacache_state[k]['should_calc']) |
| else: |
| should_calc = True |
|
|
| if not should_calc: |
| for i, k in enumerate(cond_or_uncond): |
| x[i*b:(i+1)*b] += self.teacache_state[k]['previous_residual'].to(x.device) |
| else: |
| ori_x = x.to(cache_device) |
| |
| for layer in self.layers: |
| x = layer(x, mask, freqs_cis, adaln_input) |
| for i, k in enumerate(cond_or_uncond): |
| self.teacache_state[k]['previous_residual'] = (x.to(cache_device) - ori_x)[i*b:(i+1)*b] |
| |
| x = self.final_layer(x, adaln_input) |
| x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w] |
|
|
| return -x |
|
|
| def teacache_hunyuanvideo_forward( |
| self, |
| img: Tensor, |
| img_ids: Tensor, |
| txt: Tensor, |
| txt_ids: Tensor, |
| txt_mask: Tensor, |
| timesteps: Tensor, |
| y: Tensor, |
| guidance: Tensor = None, |
| guiding_frame_index=None, |
| ref_latent=None, |
| control=None, |
| transformer_options={}, |
| ) -> Tensor: |
| patches_replace = transformer_options.get("patches_replace", {}) |
| rel_l1_thresh = transformer_options.get("rel_l1_thresh") |
| coefficients = transformer_options.get("coefficients") |
| enable_teacache = transformer_options.get("enable_teacache", True) |
| cache_device = transformer_options.get("cache_device") |
|
|
| initial_shape = list(img.shape) |
| |
| img = self.img_in(img) |
| vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype)) |
|
|
| if ref_latent is not None: |
| ref_latent_ids = self.img_ids(ref_latent) |
| ref_latent = self.img_in(ref_latent) |
| img = torch.cat([ref_latent, img], dim=-2) |
| ref_latent_ids[..., 0] = -1 |
| ref_latent_ids[..., 2] += (initial_shape[-1] // self.patch_size[-1]) |
| img_ids = torch.cat([ref_latent_ids, img_ids], dim=-2) |
|
|
| if guiding_frame_index is not None: |
| token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0)) |
| vec_ = self.vector_in(y[:, :self.params.vec_in_dim]) |
| vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1) |
| frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2]) |
| modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)] |
| modulation_dims_txt = [(0, None, 1)] |
| else: |
| vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) |
| modulation_dims = None |
| modulation_dims_txt = None |
|
|
| if self.params.guidance_embed: |
| if guidance is not None: |
| vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) |
|
|
| if txt_mask is not None and not torch.is_floating_point(txt_mask): |
| txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max |
|
|
| txt = self.txt_in(txt, timesteps, txt_mask) |
|
|
| ids = torch.cat((img_ids, txt_ids), dim=1) |
| pe = self.pe_embedder(ids) |
|
|
| img_len = img.shape[1] |
| if txt_mask is not None: |
| attn_mask_len = img_len + txt.shape[1] |
| attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device) |
| attn_mask[:, 0, img_len:] = txt_mask |
| else: |
| attn_mask = None |
|
|
| blocks_replace = patches_replace.get("dit", {}) |
|
|
| |
| img_mod1, _ = self.double_blocks[0].img_mod(vec) |
| modulated_inp = self.double_blocks[0].img_norm1(img) |
| modulated_inp = apply_mod(modulated_inp, (1 + img_mod1.scale), img_mod1.shift, modulation_dims).to(cache_device) |
|
|
| if not hasattr(self, 'accumulated_rel_l1_distance'): |
| should_calc = True |
| self.accumulated_rel_l1_distance = 0 |
| else: |
| try: |
| self.accumulated_rel_l1_distance += poly1d(coefficients, ((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean())) |
| if self.accumulated_rel_l1_distance < rel_l1_thresh: |
| should_calc = False |
| else: |
| should_calc = True |
| self.accumulated_rel_l1_distance = 0 |
| except: |
| should_calc = True |
| self.accumulated_rel_l1_distance = 0 |
|
|
| self.previous_modulated_input = modulated_inp |
|
|
| if not enable_teacache: |
| should_calc = True |
|
|
| if not should_calc: |
| img += self.previous_residual.to(img.device) |
| else: |
| ori_img = img.to(cache_device) |
| for i, block in enumerate(self.double_blocks): |
| if ("double_block", i) in blocks_replace: |
| def block_wrap(args): |
| out = {} |
| out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"]) |
| return out |
|
|
| out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap}) |
| txt = out["txt"] |
| img = out["img"] |
| else: |
| img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt) |
|
|
| if control is not None: |
| control_i = control.get("input") |
| if i < len(control_i): |
| add = control_i[i] |
| if add is not None: |
| img += add |
|
|
| img = torch.cat((img, txt), 1) |
|
|
| for i, block in enumerate(self.single_blocks): |
| if ("single_block", i) in blocks_replace: |
| def block_wrap(args): |
| out = {} |
| out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"]) |
| return out |
|
|
| out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap}) |
| img = out["img"] |
| else: |
| img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims) |
|
|
| if control is not None: |
| control_o = control.get("output") |
| if i < len(control_o): |
| add = control_o[i] |
| if add is not None: |
| img[:, : img_len] += add |
|
|
| img = img[:, : img_len] |
| self.previous_residual = (img.to(cache_device) - ori_img) |
|
|
| if ref_latent is not None: |
| img = img[:, ref_latent.shape[1]:] |
| |
| img = self.final_layer(img, vec, modulation_dims=modulation_dims) |
|
|
| shape = initial_shape[-3:] |
| for i in range(len(shape)): |
| shape[i] = shape[i] // self.patch_size[i] |
| img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size) |
| img = img.permute(0, 4, 1, 5, 2, 6, 3, 7) |
| img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4]) |
| return img |
|
|
| def teacache_ltxvmodel_forward( |
| self, |
| x, |
| timestep, |
| context, |
| attention_mask, |
| frame_rate=25, |
| transformer_options={}, |
| keyframe_idxs=None, |
| **kwargs |
| ): |
| patches_replace = transformer_options.get("patches_replace", {}) |
| rel_l1_thresh = transformer_options.get("rel_l1_thresh") |
| coefficients = transformer_options.get("coefficients") |
| cond_or_uncond = transformer_options.get("cond_or_uncond") |
| enable_teacache = transformer_options.get("enable_teacache", True) |
| cache_device = transformer_options.get("cache_device") |
|
|
| orig_shape = list(x.shape) |
|
|
| x, latent_coords = self.patchifier.patchify(x) |
| pixel_coords = latent_to_pixel_coords( |
| latent_coords=latent_coords, |
| scale_factors=self.vae_scale_factors, |
| causal_fix=self.causal_temporal_positioning, |
| ) |
|
|
| if keyframe_idxs is not None: |
| pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs |
|
|
| fractional_coords = pixel_coords.to(torch.float32) |
| fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) |
|
|
| x = self.patchify_proj(x) |
| timestep = timestep * 1000.0 |
|
|
| if attention_mask is not None and not torch.is_floating_point(attention_mask): |
| attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max |
|
|
| pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype) |
|
|
| batch_size = x.shape[0] |
| timestep, embedded_timestep = self.adaln_single( |
| timestep.flatten(), |
| {"resolution": None, "aspect_ratio": None}, |
| batch_size=batch_size, |
| hidden_dtype=x.dtype, |
| ) |
| |
| timestep = timestep.view(batch_size, -1, timestep.shape[-1]) |
| embedded_timestep = embedded_timestep.view( |
| batch_size, -1, embedded_timestep.shape[-1] |
| ) |
|
|
| |
| if self.caption_projection is not None: |
| batch_size = x.shape[0] |
| context = self.caption_projection(context) |
| context = context.view( |
| batch_size, -1, x.shape[-1] |
| ) |
|
|
| blocks_replace = patches_replace.get("dit", {}) |
|
|
| |
| inp = x.to(cache_device) |
| timestep_ = timestep.to(cache_device) |
| num_ada_params = self.transformer_blocks[0].scale_shift_table.shape[0] |
| ada_values = self.transformer_blocks[0].scale_shift_table[None, None].to(timestep_.device) + timestep_.reshape(batch_size, timestep_.size(1), num_ada_params, -1) |
| shift_msa, scale_msa, _, _, _, _ = ada_values.unbind(dim=2) |
| modulated_inp = comfy.ldm.common_dit.rms_norm(inp) |
| modulated_inp = modulated_inp * (1 + scale_msa) + shift_msa |
|
|
| if not hasattr(self, 'teacache_state'): |
| self.teacache_state = { |
| 0: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None}, |
| 1: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None} |
| } |
|
|
| def update_cache_state(cache, modulated_inp): |
| if cache['previous_modulated_input'] is not None: |
| try: |
| cache['accumulated_rel_l1_distance'] += poly1d(coefficients, ((modulated_inp-cache['previous_modulated_input']).abs().mean() / cache['previous_modulated_input'].abs().mean())) |
| if cache['accumulated_rel_l1_distance'] < rel_l1_thresh: |
| cache['should_calc'] = False |
| else: |
| cache['should_calc'] = True |
| cache['accumulated_rel_l1_distance'] = 0 |
| except: |
| cache['should_calc'] = True |
| cache['accumulated_rel_l1_distance'] = 0 |
| cache['previous_modulated_input'] = modulated_inp |
|
|
| b = int(len(x) / len(cond_or_uncond)) |
| |
| for i, k in enumerate(cond_or_uncond): |
| update_cache_state(self.teacache_state[k], modulated_inp[i*b:(i+1)*b]) |
|
|
| if enable_teacache: |
| should_calc = False |
| for k in cond_or_uncond: |
| should_calc = (should_calc or self.teacache_state[k]['should_calc']) |
| else: |
| should_calc = True |
| |
| if not should_calc: |
| for i, k in enumerate(cond_or_uncond): |
| x[i*b:(i+1)*b] += self.teacache_state[k]['previous_residual'].to(x.device) |
| else: |
| ori_x = x.to(cache_device) |
| for i, block in enumerate(self.transformer_blocks): |
| if ("double_block", i) in blocks_replace: |
| def block_wrap(args): |
| out = {} |
| out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"]) |
| return out |
|
|
| out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap}) |
| x = out["img"] |
| else: |
| x = block( |
| x, |
| context=context, |
| attention_mask=attention_mask, |
| timestep=timestep, |
| pe=pe |
| ) |
|
|
| |
| scale_shift_values = ( |
| self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None] |
| ) |
| shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] |
| x = self.norm_out(x) |
| |
| x = x * (1 + scale) + shift |
| for i, k in enumerate(cond_or_uncond): |
| self.teacache_state[k]['previous_residual'] = (x.to(cache_device) - ori_x)[i*b:(i+1)*b] |
|
|
| x = self.proj_out(x) |
|
|
| x = self.patchifier.unpatchify( |
| latents=x, |
| output_height=orig_shape[3], |
| output_width=orig_shape[4], |
| output_num_frames=orig_shape[2], |
| out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size), |
| ) |
|
|
| return x |
|
|
| def teacache_wanmodel_forward( |
| self, |
| x, |
| t, |
| context, |
| clip_fea=None, |
| freqs=None, |
| transformer_options={}, |
| **kwargs, |
| ): |
| patches_replace = transformer_options.get("patches_replace", {}) |
| rel_l1_thresh = transformer_options.get("rel_l1_thresh") |
| coefficients = transformer_options.get("coefficients") |
| cond_or_uncond = transformer_options.get("cond_or_uncond") |
| model_type = transformer_options.get("model_type") |
| enable_teacache = transformer_options.get("enable_teacache", True) |
| cache_device = transformer_options.get("cache_device") |
|
|
| |
| x = self.patch_embedding(x.float()).to(x.dtype) |
| grid_sizes = x.shape[2:] |
| x = x.flatten(2).transpose(1, 2) |
|
|
| |
| e = self.time_embedding( |
| sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype)) |
| e0 = self.time_projection(e).unflatten(1, (6, self.dim)) |
|
|
| |
| context = self.text_embedding(context) |
|
|
| context_img_len = None |
| if clip_fea is not None: |
| if self.img_emb is not None: |
| context_clip = self.img_emb(clip_fea) |
| context = torch.concat([context_clip, context], dim=1) |
| context_img_len = clip_fea.shape[-2] |
|
|
| blocks_replace = patches_replace.get("dit", {}) |
|
|
| |
| modulated_inp = e0.to(cache_device) if "ret_mode" in model_type else e.to(cache_device) |
| if not hasattr(self, 'teacache_state'): |
| self.teacache_state = { |
| 0: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None}, |
| 1: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None} |
| } |
|
|
| def update_cache_state(cache, modulated_inp): |
| if cache['previous_modulated_input'] is not None: |
| try: |
| cache['accumulated_rel_l1_distance'] += poly1d(coefficients, ((modulated_inp-cache['previous_modulated_input']).abs().mean() / cache['previous_modulated_input'].abs().mean())) |
| if cache['accumulated_rel_l1_distance'] < rel_l1_thresh: |
| cache['should_calc'] = False |
| else: |
| cache['should_calc'] = True |
| cache['accumulated_rel_l1_distance'] = 0 |
| except: |
| cache['should_calc'] = True |
| cache['accumulated_rel_l1_distance'] = 0 |
| cache['previous_modulated_input'] = modulated_inp |
| |
| b = int(len(x) / len(cond_or_uncond)) |
|
|
| for i, k in enumerate(cond_or_uncond): |
| update_cache_state(self.teacache_state[k], modulated_inp[i*b:(i+1)*b]) |
|
|
| if enable_teacache: |
| should_calc = False |
| for k in cond_or_uncond: |
| should_calc = (should_calc or self.teacache_state[k]['should_calc']) |
| else: |
| should_calc = True |
|
|
| if not should_calc: |
| for i, k in enumerate(cond_or_uncond): |
| x[i*b:(i+1)*b] += self.teacache_state[k]['previous_residual'].to(x.device) |
| else: |
| ori_x = x.to(cache_device) |
| for i, block in enumerate(self.blocks): |
| if ("double_block", i) in blocks_replace: |
| def block_wrap(args): |
| out = {} |
| out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) |
| return out |
| out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap, "transformer_options": transformer_options}) |
| x = out["img"] |
| else: |
| x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) |
| for i, k in enumerate(cond_or_uncond): |
| self.teacache_state[k]['previous_residual'] = (x.to(cache_device) - ori_x)[i*b:(i+1)*b] |
|
|
| |
| x = self.head(x, e) |
|
|
| |
| x = self.unpatchify(x, grid_sizes) |
| return x |
|
|
| class TeaCache: |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "model": ("MODEL", {"tooltip": "The diffusion model the TeaCache will be applied to."}), |
| "model_type": (["flux", "flux-kontext", "ltxv", "lumina_2", "hunyuan_video", "hidream_i1_full", "hidream_i1_dev", "hidream_i1_fast", "wan2.1_t2v_1.3B", "wan2.1_t2v_14B", "wan2.1_i2v_480p_14B", "wan2.1_i2v_720p_14B", "wan2.1_t2v_1.3B_ret_mode", "wan2.1_t2v_14B_ret_mode", "wan2.1_i2v_480p_14B_ret_mode", "wan2.1_i2v_720p_14B_ret_mode"], {"default": "flux", "tooltip": "Supported diffusion model."}), |
| "rel_l1_thresh": ("FLOAT", {"default": 0.4, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "How strongly to cache the output of diffusion model. This value must be non-negative."}), |
| "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The start percentage of the steps that will apply TeaCache."}), |
| "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The end percentage of the steps that will apply TeaCache."}), |
| "cache_device": (["cuda", "cpu"], {"default": "cuda", "tooltip": "Device where the cache will reside."}), |
| } |
| } |
| |
| RETURN_TYPES = ("MODEL",) |
| RETURN_NAMES = ("model",) |
| FUNCTION = "apply_teacache" |
| CATEGORY = "TeaCache" |
| TITLE = "TeaCache" |
| |
| def apply_teacache(self, model, model_type: str, rel_l1_thresh: float, start_percent: float, end_percent: float, cache_device: str): |
| if rel_l1_thresh == 0: |
| return (model,) |
|
|
| new_model = model.clone() |
| if 'transformer_options' not in new_model.model_options: |
| new_model.model_options['transformer_options'] = {} |
| new_model.model_options["transformer_options"]["rel_l1_thresh"] = rel_l1_thresh |
| new_model.model_options["transformer_options"]["coefficients"] = SUPPORTED_MODELS_COEFFICIENTS[model_type] |
| new_model.model_options["transformer_options"]["model_type"] = model_type |
| new_model.model_options["transformer_options"]["cache_device"] = mm.get_torch_device() if cache_device == "cuda" else torch.device("cpu") |
| |
| diffusion_model = new_model.get_model_object("diffusion_model") |
|
|
| if "flux" in model_type: |
| is_cfg = False |
| context = patch.multiple( |
| diffusion_model, |
| forward_orig=teacache_flux_forward.__get__(diffusion_model, diffusion_model.__class__) |
| ) |
| elif "lumina_2" in model_type: |
| is_cfg = True |
| context = patch.multiple( |
| diffusion_model, |
| forward=teacache_lumina_forward.__get__(diffusion_model, diffusion_model.__class__) |
| ) |
| elif "hidream_i1" in model_type: |
| is_cfg = True if "full" in model_type else False |
| context = patch.multiple( |
| diffusion_model, |
| forward=teacache_hidream_forward.__get__(diffusion_model, diffusion_model.__class__) |
| ) |
| elif "ltxv" in model_type: |
| is_cfg = True |
| context = patch.multiple( |
| diffusion_model, |
| forward=teacache_ltxvmodel_forward.__get__(diffusion_model, diffusion_model.__class__) |
| ) |
| elif "hunyuan_video" in model_type: |
| is_cfg = False |
| context = patch.multiple( |
| diffusion_model, |
| forward_orig=teacache_hunyuanvideo_forward.__get__(diffusion_model, diffusion_model.__class__) |
| ) |
| elif "wan2.1" in model_type: |
| is_cfg = True |
| context = patch.multiple( |
| diffusion_model, |
| forward_orig=teacache_wanmodel_forward.__get__(diffusion_model, diffusion_model.__class__) |
| ) |
| else: |
| raise ValueError(f"Unknown type {model_type}") |
| |
| def unet_wrapper_function(model_function, kwargs): |
| input = kwargs["input"] |
| timestep = kwargs["timestep"] |
| c = kwargs["c"] |
| |
| sigmas = c["transformer_options"]["sample_sigmas"] |
| matched_step_index = (sigmas == timestep[0]).nonzero() |
| if len(matched_step_index) > 0: |
| current_step_index = matched_step_index.item() |
| else: |
| current_step_index = 0 |
| for i in range(len(sigmas) - 1): |
| |
| if (sigmas[i] - timestep[0]) * (sigmas[i + 1] - timestep[0]) <= 0: |
| current_step_index = i |
| break |
| |
| if current_step_index == 0: |
| if is_cfg: |
| |
| if hasattr(diffusion_model, 'teacache_state') and \ |
| diffusion_model.teacache_state[0]['previous_modulated_input'] is not None and \ |
| diffusion_model.teacache_state[1]['previous_modulated_input'] is not None: |
| delattr(diffusion_model, 'teacache_state') |
| else: |
| if hasattr(diffusion_model, 'teacache_state'): |
| delattr(diffusion_model, 'teacache_state') |
| if hasattr(diffusion_model, 'accumulated_rel_l1_distance'): |
| delattr(diffusion_model, 'accumulated_rel_l1_distance') |
| |
| current_percent = current_step_index / (len(sigmas) - 1) |
| c["transformer_options"]["current_percent"] = current_percent |
| if start_percent <= current_percent <= end_percent: |
| c["transformer_options"]["enable_teacache"] = True |
| else: |
| c["transformer_options"]["enable_teacache"] = False |
| |
| with context: |
| return model_function(input, timestep, **c) |
|
|
| new_model.set_model_unet_function_wrapper(unet_wrapper_function) |
|
|
| return (new_model,) |
| |
| def patch_optimized_module(): |
| try: |
| from torch._dynamo.eval_frame import OptimizedModule |
| except ImportError: |
| return |
|
|
| if getattr(OptimizedModule, "_patched", False): |
| return |
|
|
| def __getattribute__(self, name): |
| if name == "_orig_mod": |
| return object.__getattribute__(self, "_modules")[name] |
| if name in ( |
| "__class__", |
| "_modules", |
| "state_dict", |
| "load_state_dict", |
| "parameters", |
| "named_parameters", |
| "buffers", |
| "named_buffers", |
| "children", |
| "named_children", |
| "modules", |
| "named_modules", |
| ): |
| return getattr(object.__getattribute__(self, "_orig_mod"), name) |
| return object.__getattribute__(self, name) |
|
|
| def __delattr__(self, name): |
| return delattr(self._orig_mod, name) |
|
|
| @classmethod |
| def __instancecheck__(cls, instance): |
| return isinstance(instance, OptimizedModule) or issubclass( |
| object.__getattribute__(instance, "__class__"), cls |
| ) |
|
|
| OptimizedModule.__getattribute__ = __getattribute__ |
| OptimizedModule.__delattr__ = __delattr__ |
| OptimizedModule.__instancecheck__ = __instancecheck__ |
| OptimizedModule._patched = True |
|
|
| def patch_same_meta(): |
| try: |
| from torch._inductor.fx_passes import post_grad |
| except ImportError: |
| return |
|
|
| same_meta = getattr(post_grad, "same_meta", None) |
| if same_meta is None: |
| return |
|
|
| if getattr(same_meta, "_patched", False): |
| return |
|
|
| def new_same_meta(a, b): |
| try: |
| return same_meta(a, b) |
| except Exception: |
| return False |
|
|
| post_grad.same_meta = new_same_meta |
| new_same_meta._patched = True |
|
|
| class CompileModel: |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "model": ("MODEL", {"tooltip": "The diffusion model the torch.compile will be applied to."}), |
| "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), |
| "backend": (["inductor","cudagraphs", "eager", "aot_eager"], {"default": "inductor"}), |
| "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), |
| "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), |
| } |
| } |
| |
| RETURN_TYPES = ("MODEL",) |
| RETURN_NAMES = ("model",) |
| FUNCTION = "apply_compile" |
| CATEGORY = "TeaCache" |
| TITLE = "Compile Model" |
| |
| def apply_compile(self, model, mode: str, backend: str, fullgraph: bool, dynamic: bool): |
| patch_optimized_module() |
| patch_same_meta() |
| torch._dynamo.config.suppress_errors = True |
| |
| new_model = model.clone() |
| new_model.add_object_patch( |
| "diffusion_model", |
| torch.compile( |
| new_model.get_model_object("diffusion_model"), |
| mode=mode, |
| backend=backend, |
| fullgraph=fullgraph, |
| dynamic=dynamic |
| ) |
| ) |
| |
| return (new_model,) |
| |
|
|
| NODE_CLASS_MAPPINGS = { |
| "TeaCache": TeaCache, |
| "CompileModel": CompileModel |
| } |
|
|
| NODE_DISPLAY_NAME_MAPPINGS = {k: v.TITLE for k, v in NODE_CLASS_MAPPINGS.items()} |
|
|