Update src/caching.py
Browse files- src/caching.py +63 -50
src/caching.py
CHANGED
|
@@ -63,7 +63,8 @@ def are_two_tensors_similar(t1, t2, *, threshold=0.85):
|
|
| 63 |
mean_t1 = t1.abs().mean()
|
| 64 |
diff = mean_diff / mean_t1
|
| 65 |
return diff.item() < threshold
|
| 66 |
-
|
|
|
|
| 67 |
class CachedTransformerBlocks(torch.nn.Module):
|
| 68 |
def __init__(
|
| 69 |
self,
|
|
@@ -72,7 +73,7 @@ class CachedTransformerBlocks(torch.nn.Module):
|
|
| 72 |
*,
|
| 73 |
transformer=None,
|
| 74 |
residual_diff_threshold=0.05,
|
| 75 |
-
return_hidden_states_first=False,
|
| 76 |
):
|
| 77 |
super().__init__()
|
| 78 |
self.transformer = transformer
|
|
@@ -82,82 +83,94 @@ class CachedTransformerBlocks(torch.nn.Module):
|
|
| 82 |
self.return_hidden_states_first = return_hidden_states_first
|
| 83 |
|
| 84 |
def forward(self, hidden_states, encoder_hidden_states, *args, **kwargs):
|
| 85 |
-
#
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
if self.residual_diff_threshold <= 0.0:
|
| 90 |
-
for block in self.transformer_blocks:
|
| 91 |
-
if self.return_hidden_states_first:
|
| 92 |
-
hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs)
|
| 93 |
-
else:
|
| 94 |
-
encoder_hidden_states, hidden_states = block(encoder_hidden_states, hidden_states, *args, **kwargs)
|
| 95 |
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
original_encoder_states = encoder_hidden_states
|
| 99 |
first_block = self.transformer_blocks[0]
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
first_residual = hidden_states - original_encoder_states
|
| 104 |
-
else:
|
| 105 |
-
encoder_hidden_states, hidden_states = first_block(encoder_hidden_states, hidden_states, *args, **kwargs)
|
| 106 |
-
first_residual = encoder_hidden_states - original_encoder_states
|
| 107 |
|
| 108 |
cache_context = get_current_cache_context()
|
| 109 |
-
prev_residual = cache_context.get_buffer("
|
|
|
|
| 110 |
can_use_cache = prev_residual is not None and are_two_tensors_similar(
|
| 111 |
-
prev_residual,
|
|
|
|
|
|
|
| 112 |
)
|
| 113 |
|
| 114 |
if can_use_cache:
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
|
|
|
| 120 |
else:
|
| 121 |
-
|
| 122 |
-
|
| 123 |
|
| 124 |
for block in self.transformer_blocks[1:]:
|
| 125 |
-
|
| 126 |
-
hidden_states, encoder_hidden_states
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
CachedTransformerBlocks(
|
| 140 |
transformer.transformer_blocks,
|
| 141 |
transformer.single_transformer_blocks if hasattr(transformer, 'single_transformer_blocks') else None,
|
| 142 |
transformer=transformer,
|
| 143 |
residual_diff_threshold=residual_diff_threshold,
|
| 144 |
-
return_hidden_states_first=False # Specifically for Flux
|
| 145 |
)
|
| 146 |
])
|
|
|
|
| 147 |
|
| 148 |
original_forward = transformer.forward
|
| 149 |
|
| 150 |
-
@functools.wraps(
|
| 151 |
def new_forward(self, *args, **kwargs):
|
| 152 |
-
with unittest.mock.patch.object(
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
return original_forward(*args, **kwargs)
|
| 157 |
|
| 158 |
transformer.forward = new_forward.__get__(transformer)
|
| 159 |
return transformer
|
| 160 |
|
|
|
|
| 161 |
def apply_cache_on_pipe(pipe: DiffusionPipeline, *, shallow_patch: bool = False, **kwargs):
|
| 162 |
original_call = pipe.__class__.__call__
|
| 163 |
|
|
|
|
| 63 |
mean_t1 = t1.abs().mean()
|
| 64 |
diff = mean_diff / mean_t1
|
| 65 |
return diff.item() < threshold
|
| 66 |
+
|
| 67 |
+
|
| 68 |
class CachedTransformerBlocks(torch.nn.Module):
|
| 69 |
def __init__(
|
| 70 |
self,
|
|
|
|
| 73 |
*,
|
| 74 |
transformer=None,
|
| 75 |
residual_diff_threshold=0.05,
|
| 76 |
+
return_hidden_states_first=False,
|
| 77 |
):
|
| 78 |
super().__init__()
|
| 79 |
self.transformer = transformer
|
|
|
|
| 83 |
self.return_hidden_states_first = return_hidden_states_first
|
| 84 |
|
| 85 |
def forward(self, hidden_states, encoder_hidden_states, *args, **kwargs):
|
| 86 |
+
# Store original states before any transformations
|
| 87 |
+
original_hidden_states = hidden_states
|
| 88 |
+
original_encoder_hidden_states = encoder_hidden_states
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
+
# Process first block
|
|
|
|
|
|
|
| 91 |
first_block = self.transformer_blocks[0]
|
| 92 |
+
hidden_states, encoder_hidden_states = first_block(
|
| 93 |
+
hidden_states, encoder_hidden_states, *args, **kwargs
|
| 94 |
+
)
|
| 95 |
|
| 96 |
+
# Calculate residual from first block
|
| 97 |
+
first_hidden_states_residual = hidden_states - original_hidden_states
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
cache_context = get_current_cache_context()
|
| 100 |
+
prev_residual = cache_context.get_buffer("first_hidden_states_residual")
|
| 101 |
+
|
| 102 |
can_use_cache = prev_residual is not None and are_two_tensors_similar(
|
| 103 |
+
prev_residual,
|
| 104 |
+
first_hidden_states_residual,
|
| 105 |
+
threshold=self.residual_diff_threshold
|
| 106 |
)
|
| 107 |
|
| 108 |
if can_use_cache:
|
| 109 |
+
# Use cached residuals
|
| 110 |
+
hidden_states_residual = cache_context.get_buffer("hidden_states_residual")
|
| 111 |
+
encoder_hidden_states_residual = cache_context.get_buffer("encoder_hidden_states_residual")
|
| 112 |
+
|
| 113 |
+
hidden_states = hidden_states + hidden_states_residual
|
| 114 |
+
encoder_hidden_states = encoder_hidden_states + encoder_hidden_states_residual
|
| 115 |
else:
|
| 116 |
+
# Process remaining blocks and cache results
|
| 117 |
+
cache_context.set_buffer("first_hidden_states_residual", first_hidden_states_residual)
|
| 118 |
|
| 119 |
for block in self.transformer_blocks[1:]:
|
| 120 |
+
hidden_states, encoder_hidden_states = block(
|
| 121 |
+
hidden_states, encoder_hidden_states, *args, **kwargs
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
if self.single_transformer_blocks is not None:
|
| 125 |
+
for block in self.single_transformer_blocks:
|
| 126 |
+
hidden_states = block(hidden_states, *args, **kwargs)
|
| 127 |
+
|
| 128 |
+
# Store residuals for future use
|
| 129 |
+
cache_context.set_buffer(
|
| 130 |
+
"hidden_states_residual",
|
| 131 |
+
hidden_states - original_hidden_states
|
| 132 |
+
)
|
| 133 |
+
cache_context.set_buffer(
|
| 134 |
+
"encoder_hidden_states_residual",
|
| 135 |
+
encoder_hidden_states - original_encoder_hidden_states
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
return hidden_states, encoder_hidden_states
|
| 139 |
+
|
| 140 |
+
def apply_cache_on_transformer(
|
| 141 |
+
transformer: FluxTransformer2DModel,
|
| 142 |
+
*,
|
| 143 |
+
residual_diff_threshold=0.05,
|
| 144 |
+
):
|
| 145 |
+
cached_transformer_blocks = torch.nn.ModuleList([
|
| 146 |
CachedTransformerBlocks(
|
| 147 |
transformer.transformer_blocks,
|
| 148 |
transformer.single_transformer_blocks if hasattr(transformer, 'single_transformer_blocks') else None,
|
| 149 |
transformer=transformer,
|
| 150 |
residual_diff_threshold=residual_diff_threshold,
|
|
|
|
| 151 |
)
|
| 152 |
])
|
| 153 |
+
dummy_single_transformer_blocks = torch.nn.ModuleList()
|
| 154 |
|
| 155 |
original_forward = transformer.forward
|
| 156 |
|
| 157 |
+
@functools.wraps(original_forward)
|
| 158 |
def new_forward(self, *args, **kwargs):
|
| 159 |
+
with unittest.mock.patch.object(
|
| 160 |
+
self,
|
| 161 |
+
"transformer_blocks",
|
| 162 |
+
cached_transformer_blocks,
|
| 163 |
+
), unittest.mock.patch.object(
|
| 164 |
+
self,
|
| 165 |
+
"single_transformer_blocks",
|
| 166 |
+
dummy_single_transformer_blocks,
|
| 167 |
+
):
|
| 168 |
return original_forward(*args, **kwargs)
|
| 169 |
|
| 170 |
transformer.forward = new_forward.__get__(transformer)
|
| 171 |
return transformer
|
| 172 |
|
| 173 |
+
|
| 174 |
def apply_cache_on_pipe(pipe: DiffusionPipeline, *, shallow_patch: bool = False, **kwargs):
|
| 175 |
original_call = pipe.__class__.__call__
|
| 176 |
|