Spaces:
Running on Zero
Running on Zero
File size: 6,388 Bytes
b8c861f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 | # Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
import torch
from .hooks import BaseState, HookRegistry, ModelHook, StateManager
_TEXT_KV_CACHE_TRANSFORMER_HOOK = "text_kv_cache_transformer"
_TEXT_KV_CACHE_BLOCK_HOOK = "text_kv_cache_block"
@dataclass
class TextKVCacheConfig:
"""Enable exact (lossless) text K/V caching for transformer models.
Pre-computes per-block text key and value projections once before the denoising loop and reuses them across all
steps. Positive and negative prompts are distinguished via a stable cache key captured by a transformer-level hook
before any intermediate tensor allocations.
"""
pass
class TextKVCacheState(BaseState):
"""Shared state between the transformer-level and block-level hooks.
The transformer hook writes the stable ``encoder_hidden_states`` ``data_ptr()`` (captured *before* ``txt_norm``) so
that block hooks can use it as a reliable cache key across denoising steps.
"""
def __init__(self):
self.key: int | None = None
def reset(self):
self.key = None
class TextKVCacheBlockState(BaseState):
"""Per-block state holding cached text key/value projections."""
def __init__(self):
self.kv_cache: dict[int, tuple[torch.Tensor, torch.Tensor]] = {}
def reset(self):
self.kv_cache.clear()
class TextKVCacheTransformerHook(ModelHook):
"""Captures ``encoder_hidden_states.data_ptr()`` before ``txt_norm``
and writes it to shared state for the block hooks to read."""
_is_stateful = True
def __init__(self, state_manager: StateManager):
super().__init__()
self.state_manager = state_manager
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
if self.state_manager._current_context is None:
self.state_manager.set_context("inference")
encoder_hidden_states = kwargs.get("encoder_hidden_states")
if encoder_hidden_states is not None:
state: TextKVCacheState = self.state_manager.get_state()
state.key = encoder_hidden_states.data_ptr()
return self.fn_ref.original_forward(*args, **kwargs)
def reset_state(self, module: torch.nn.Module):
self.state_manager.reset()
return module
class TextKVCacheBlockHook(ModelHook):
"""Caches ``(txt_key, txt_value)`` per block per unique prompt using
the stable cache key from the shared state."""
_is_stateful = True
def __init__(self, state_manager: StateManager, block_state_manager: StateManager):
super().__init__()
self.state_manager = state_manager
self.block_state_manager = block_state_manager
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
from ..models.transformers.transformer_nucleusmoe_image import _apply_rotary_emb_nucleus
if self.state_manager._current_context is None:
self.state_manager.set_context("inference")
if self.block_state_manager._current_context is None:
self.block_state_manager.set_context("inference")
if "encoder_hidden_states" in kwargs:
encoder_hidden_states = kwargs["encoder_hidden_states"]
else:
encoder_hidden_states = args[1]
if "image_rotary_emb" in kwargs:
image_rotary_emb = kwargs["image_rotary_emb"]
elif len(args) > 3:
image_rotary_emb = args[3]
else:
image_rotary_emb = None
state: TextKVCacheState = self.state_manager.get_state()
cache_key = state.key
block_state: TextKVCacheBlockState = self.block_state_manager.get_state()
if cache_key not in block_state.kv_cache:
context = module.encoder_proj(encoder_hidden_states)
attn = module.attn
head_dim = attn.inner_dim // attn.heads
num_kv_heads = attn.inner_kv_dim // head_dim
txt_key = attn.add_k_proj(context).unflatten(-1, (num_kv_heads, -1))
txt_value = attn.add_v_proj(context).unflatten(-1, (num_kv_heads, -1))
if attn.norm_added_k is not None:
txt_key = attn.norm_added_k(txt_key)
if image_rotary_emb is not None:
_, txt_freqs = image_rotary_emb
txt_key = _apply_rotary_emb_nucleus(txt_key, txt_freqs, use_real=False)
block_state.kv_cache[cache_key] = (txt_key, txt_value)
txt_key, txt_value = block_state.kv_cache[cache_key]
attn_kwargs = kwargs.get("attention_kwargs") or {}
attn_kwargs["cached_txt_key"] = txt_key
attn_kwargs["cached_txt_value"] = txt_value
kwargs["attention_kwargs"] = attn_kwargs
return self.fn_ref.original_forward(*args, **kwargs)
def reset_state(self, module: torch.nn.Module):
self.block_state_manager.reset()
return module
def apply_text_kv_cache(module: torch.nn.Module, config: TextKVCacheConfig) -> None:
from ..models.transformers.transformer_nucleusmoe_image import NucleusMoEImageTransformerBlock
HookRegistry.check_if_exists_or_initialize(module)
state_manager = StateManager(TextKVCacheState)
transformer_hook = TextKVCacheTransformerHook(state_manager)
registry = HookRegistry.check_if_exists_or_initialize(module)
registry.register_hook(transformer_hook, _TEXT_KV_CACHE_TRANSFORMER_HOOK)
for _, submodule in module.named_modules():
if isinstance(submodule, NucleusMoEImageTransformerBlock):
block_state_manager = StateManager(TextKVCacheBlockState)
hook = TextKVCacheBlockHook(state_manager, block_state_manager)
block_registry = HookRegistry.check_if_exists_or_initialize(submodule)
block_registry.register_hook(hook, _TEXT_KV_CACHE_BLOCK_HOOK)
|