import contextlib import dataclasses import unittest from collections import defaultdict from typing import DefaultDict, Dict import torch from src.AutoEncoders.ResBlock import forward_timestep_embed1 from src.NeuralNetwork.unet import apply_control1 from src.sample.sampling_util import timestep_embedding _current_cache_context = None @dataclasses.dataclass class CacheContext: buffers: Dict[str, torch.Tensor] = dataclasses.field(default_factory=dict) incremental_name_counters: DefaultDict[str, int] = dataclasses.field(default_factory=lambda: defaultdict(int)) def get_incremental_name(self, name=None): name = name or "default" idx = self.incremental_name_counters[name] self.incremental_name_counters[name] += 1 return f"{name}_{idx}" def reset_incremental_names(self): self.incremental_name_counters.clear() @torch.compiler.disable() def get_buffer(self, name): return self.buffers.get(name) @torch.compiler.disable() def set_buffer(self, name, buffer): self.buffers[name] = buffer def clear_buffers(self): self.buffers.clear() def create_cache_context(): return CacheContext() def get_current_cache_context(): return _current_cache_context def set_current_cache_context(cache_context=None): global _current_cache_context _current_cache_context = cache_context @contextlib.contextmanager def cache_context(ctx): global _current_cache_context old = _current_cache_context _current_cache_context = ctx try: yield finally: _current_cache_context = old @torch.compiler.disable() def get_buffer(name): ctx = get_current_cache_context() assert ctx is not None return ctx.get_buffer(name) @torch.compiler.disable() def set_buffer(name, buffer): ctx = get_current_cache_context() assert ctx is not None ctx.set_buffer(name, buffer) @torch.compiler.disable() def are_two_tensors_similar(t1, t2, *, threshold): if t1.shape != t2.shape: return False return ((t1 - t2).abs().mean() / t1.abs().mean()).item() < threshold @torch.compiler.disable() def apply_prev_hidden_states_residual(hidden_states, encoder_hidden_states=None): hidden_states = (get_buffer("hidden_states_residual") + hidden_states).contiguous() if encoder_hidden_states is None: return hidden_states enc_res = get_buffer("encoder_hidden_states_residual") if enc_res is None: return hidden_states, None return hidden_states, (enc_res + encoder_hidden_states).contiguous() @torch.compiler.disable() def get_can_use_cache(first_hidden_states_residual, threshold, parallelized=False): prev = get_buffer("first_hidden_states_residual") return prev is not None and are_two_tensors_similar(prev, first_hidden_states_residual, threshold=threshold) class CachedTransformerBlocks(torch.nn.Module): def __init__(self, transformer_blocks, single_transformer_blocks=None, *, residual_diff_threshold, validate_can_use_cache_function=None, return_hidden_states_first=True, accept_hidden_states_first=True, cat_hidden_states_first=False, return_hidden_states_only=False, clone_original_hidden_states=False): super().__init__() self.transformer_blocks = transformer_blocks self.single_transformer_blocks = single_transformer_blocks self.residual_diff_threshold = residual_diff_threshold self.validate_can_use_cache_function = validate_can_use_cache_function self.return_hidden_states_first = return_hidden_states_first self.accept_hidden_states_first = accept_hidden_states_first self.cat_hidden_states_first = cat_hidden_states_first self.return_hidden_states_only = return_hidden_states_only self.clone_original_hidden_states = clone_original_hidden_states def _extract_args(self, args, kwargs): img_key = "img" if "img" in kwargs else "hidden_states" if "hidden_states" in kwargs else None txt_key = "txt" if "txt" in kwargs else "context" if "context" in kwargs else "encoder_hidden_states" if "encoder_hidden_states" in kwargs else None args = list(args) if self.accept_hidden_states_first: img = args.pop(0) if args else kwargs.pop(img_key) txt = args.pop(0) if args else kwargs.pop(txt_key) else: txt = args.pop(0) if args else kwargs.pop(txt_key) img = args.pop(0) if args else kwargs.pop(img_key) return img, txt, txt_key, args, kwargs def _call_block(self, block, img, txt, txt_key, args, kwargs): if txt_key == "encoder_hidden_states": out = block(img, *args, encoder_hidden_states=txt, **kwargs) elif self.accept_hidden_states_first: out = block(img, txt, *args, **kwargs) else: out = block(txt, img, *args, **kwargs) if not self.return_hidden_states_only: img, txt = out if not self.return_hidden_states_first: img, txt = txt, img else: img = out return img, txt def _process_single_blocks(self, img, txt, args, kwargs): if self.single_transformer_blocks is None: return img, txt img = torch.cat([img, txt] if self.cat_hidden_states_first else [txt, img], dim=1) for block in self.single_transformer_blocks: img = block(img, *args, **kwargs) return img[:, txt.shape[1]:] if self.cat_hidden_states_first else img[:, txt.shape[1]:], txt def _format_output(self, img, txt): if self.return_hidden_states_only: return img return (img, txt) if self.return_hidden_states_first else (txt, img) def forward(self, *args, **kwargs): img, txt, txt_key, args, kwargs = self._extract_args(args, kwargs) if self.residual_diff_threshold <= 0.0: for block in self.transformer_blocks: img, txt = self._call_block(block, img, txt, txt_key, args, kwargs) img, txt = self._process_single_blocks(img, txt, args, kwargs) return self._format_output(img, txt) original_img = img.clone() if self.clone_original_hidden_states else img img, txt = self._call_block(self.transformer_blocks[0], img, txt, txt_key, args, kwargs) first_residual = img - original_img can_use_cache = get_can_use_cache(first_residual, threshold=self.residual_diff_threshold) if self.validate_can_use_cache_function: can_use_cache = self.validate_can_use_cache_function(can_use_cache) torch._dynamo.graph_break() if can_use_cache: result = apply_prev_hidden_states_residual(img, txt) img, txt = (result, txt) if isinstance(result, torch.Tensor) else result else: set_buffer("first_hidden_states_residual", first_residual) img, txt, img_res, txt_res = self._call_remaining(img, txt, txt_key, args, kwargs) set_buffer("hidden_states_residual", img_res) if txt_res is not None: set_buffer("encoder_hidden_states_residual", txt_res) torch._dynamo.graph_break() return self._format_output(img, txt) def _call_remaining(self, img, txt, txt_key, args, kwargs): orig_img = img.clone() if self.clone_original_hidden_states else img orig_txt = txt.clone() if self.clone_original_hidden_states and txt is not None else txt for block in self.transformer_blocks[1:]: img, txt = self._call_block(block, img, txt, txt_key, args, kwargs) if self.single_transformer_blocks: img = torch.cat([img, txt] if self.cat_hidden_states_first else [txt, img], dim=1) for block in self.single_transformer_blocks: img = block(img, *args, **kwargs) if self.cat_hidden_states_first: img, txt = img.split([img.shape[1] - txt.shape[1], txt.shape[1]], dim=1) else: txt, img = img.split([txt.shape[1], img.shape[1] - txt.shape[1]], dim=1) img = img.flatten().contiguous().reshape(img.shape) if txt is not None: txt = txt.flatten().contiguous().reshape(txt.shape) return img, txt, img - orig_img, (txt - orig_txt if txt is not None else None) def create_patch_unet_model__forward(model, *, residual_diff_threshold, validate_can_use_cache_function=None): def call_remaining_blocks(self, transformer_options, control, transformer_patches, hs, h, *args, **kwargs): original_h = h for id, module in enumerate(self.input_blocks): if id < 2: continue transformer_options["block"] = ("input", id) h = forward_timestep_embed1(module, h, *args, **kwargs) h = apply_control1(h, control, 'input') for p in transformer_patches.get("input_block_patch", []): h = p(h, transformer_options) hs.append(h) for p in transformer_patches.get("input_block_patch_after_skip", []): h = p(h, transformer_options) transformer_options["block"] = ("middle", 0) if self.middle_block is not None: h = forward_timestep_embed1(self.middle_block, h, *args, **kwargs) h = apply_control1(h, control, 'middle') for id, module in enumerate(self.output_blocks): transformer_options["block"] = ("output", id) hsp = apply_control1(hs.pop(), control, 'output') for p in transformer_patches.get("output_block_patch", []): h, hsp = p(h, hsp, transformer_options) h = torch.cat([h, hsp], dim=1) del hsp h = forward_timestep_embed1(module, h, *args, hs[-1].shape if hs else None, **kwargs) return h, h - original_h def unet_forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): transformer_options["original_shape"], transformer_options["transformer_index"] = list(x.shape), 0 transformer_patches = transformer_options.get("patches", {}) num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames) image_only_indicator, time_context = kwargs.get("image_only_indicator"), kwargs.get("time_context") assert (y is not None) == (self.num_classes is not None) emb = self.time_embed(timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)) for p in transformer_patches.get("emb_patch", []): emb = p(emb, self.model_channels, transformer_options) if self.num_classes is not None: emb = emb + self.label_emb(y) hs, h = [], x for id, module in enumerate(self.input_blocks): if id >= 2: break transformer_options["block"] = ("input", id) if id == 1: original_h = h h = forward_timestep_embed1(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) h = apply_control1(h, control, 'input') for p in transformer_patches.get("input_block_patch", []): h = p(h, transformer_options) hs.append(h) for p in transformer_patches.get("input_block_patch_after_skip", []): h = p(h, transformer_options) if id == 1: first_residual = h - original_h can_use_cache = get_can_use_cache(first_residual, threshold=residual_diff_threshold) if validate_can_use_cache_function: can_use_cache = validate_can_use_cache_function(can_use_cache) if not can_use_cache: set_buffer("first_hidden_states_residual", first_residual) torch._dynamo.graph_break() if can_use_cache: h = apply_prev_hidden_states_residual(h) else: h, hidden_states_residual = call_remaining_blocks(self, transformer_options, control, transformer_patches, hs, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) set_buffer("hidden_states_residual", hidden_states_residual) torch._dynamo.graph_break() return self.id_predictor(h) if self.predict_codebook_ids else self.out(h.type(x.dtype)) new_forward = unet_forward.__get__(model) @contextlib.contextmanager def patch__forward(): with unittest.mock.patch.object(model, "_forward", new_forward): yield return patch__forward def create_patch_flux_forward_orig(model, *, residual_diff_threshold, validate_can_use_cache_function=None): def call_remaining_blocks(self, blocks_replace, control, img, txt, vec, pe, attn_mask, ca_idx, timesteps, transformer_options): original_img = img extra_kwargs = {"attn_mask": attn_mask} if attn_mask is not None else {} for i, block in enumerate(self.double_blocks): if i < 1: continue if ("double_block", i) in blocks_replace: out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, **extra_kwargs}, {"original_block": lambda args: {"img": block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], **extra_kwargs)[0], "txt": block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], **extra_kwargs)[1]}, "transformer_options": transformer_options}) img, txt = out["img"], out["txt"] else: img, txt = block(img=img, txt=txt, vec=vec, pe=pe, **extra_kwargs) if control and i < len(control.get("input", [])) and control["input"][i] is not None: img += control["input"][i] if getattr(self, "pulid_data", {}) and i % self.pulid_double_interval == 0: for _, node_data in self.pulid_data.items(): if torch.any((node_data['sigma_start'] >= timesteps) & (timesteps >= node_data['sigma_end'])): img = img + node_data['weight'] * self.pulid_ca[ca_idx](node_data['embedding'], img) ca_idx += 1 img = torch.cat((txt, img), 1) for i, block in enumerate(self.single_blocks): if ("single_block", i) in blocks_replace: out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, **extra_kwargs}, {"original_block": lambda args: {"img": block(args["img"], vec=args["vec"], pe=args["pe"], **extra_kwargs)}, "transformer_options": transformer_options}) img = out["img"] else: img = block(img, vec=vec, pe=pe, **extra_kwargs) if control and i < len(control.get("output", [])) and control["output"][i] is not None: img[:, txt.shape[1]:, ...] += control["output"][i] if getattr(self, "pulid_data", {}) and i % self.pulid_single_interval == 0: real_img, txt_part = img[:, txt.shape[1]:, ...], img[:, :txt.shape[1], ...] for _, node_data in self.pulid_data.items(): if torch.any((node_data['sigma_start'] >= timesteps) & (timesteps >= node_data['sigma_end'])): real_img = real_img + node_data['weight'] * self.pulid_ca[ca_idx](node_data['embedding'], real_img) ca_idx += 1 img = torch.cat((txt_part, real_img), 1) img = img[:, txt.shape[1]:, ...].contiguous() return img, img - original_img def forward_orig(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, control=None, transformer_options={}, attn_mask=None): patches_replace = transformer_options.get("patches_replace", {}) if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input tensors must have 3 dimensions.") img = self.img_in(img) vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype)) if self.params.guidance_embed: if guidance is None: raise ValueError("Missing guidance for guidance distilled model.") vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) txt = self.txt_in(txt) pe = self.pe_embedder(torch.cat((txt_ids, img_ids), dim=1)) ca_idx = 0 extra_kwargs = {"attn_mask": attn_mask} if attn_mask is not None else {} blocks_replace = patches_replace.get("dit", {}) for i, block in enumerate(self.double_blocks): if i >= 1: break if ("double_block", i) in blocks_replace: out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, **extra_kwargs}, {"original_block": lambda args: {"img": block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], **extra_kwargs)[0], "txt": block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], **extra_kwargs)[1]}, "transformer_options": transformer_options}) img, txt = out["img"], out["txt"] else: img, txt = block(img=img, txt=txt, vec=vec, pe=pe, **extra_kwargs) if control and i < len(control.get("input", [])) and control["input"][i] is not None: img += control["input"][i] if getattr(self, "pulid_data", {}) and i % self.pulid_double_interval == 0: for _, node_data in self.pulid_data.items(): if torch.any((node_data['sigma_start'] >= timesteps) & (timesteps >= node_data['sigma_end'])): img = img + node_data['weight'] * self.pulid_ca[ca_idx](node_data['embedding'], img) ca_idx += 1 if i == 0: first_residual = img can_use_cache = get_can_use_cache(first_residual, threshold=residual_diff_threshold) if validate_can_use_cache_function: can_use_cache = validate_can_use_cache_function(can_use_cache) if not can_use_cache: set_buffer("first_hidden_states_residual", first_residual) torch._dynamo.graph_break() if can_use_cache: img = apply_prev_hidden_states_residual(img) else: img, residual = call_remaining_blocks(self, blocks_replace, control, img, txt, vec, pe, attn_mask, ca_idx, timesteps, transformer_options) set_buffer("hidden_states_residual", residual) torch._dynamo.graph_break() return self.final_layer(img, vec) new_forward = forward_orig.__get__(model) @contextlib.contextmanager def patch_forward(): with unittest.mock.patch.object(model, "forward_orig", new_forward): yield return patch_forward