manbeast3b commited on
Commit
ec800bc
·
verified ·
1 Parent(s): d000ac9

Create caching.py

Browse files
Files changed (1) hide show
  1. src/caching.py +294 -0
src/caching.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # caching.py
2
+
3
+ import functools
4
+ import unittest
5
+ import contextlib
6
+ import dataclasses
7
+ from collections import defaultdict
8
+ from typing import DefaultDict, Dict
9
+ import torch
10
+ from diffusers import DiffusionPipeline, FluxTransformer2DModel
11
+
12
+
13
+ @dataclasses.dataclass
14
+ class CacheContext:
15
+ buffers: Dict[str, torch.Tensor] = dataclasses.field(default_factory=dict)
16
+ incremental_name_counters: DefaultDict[str, int] = dataclasses.field(default_factory=lambda: defaultdict(int))
17
+
18
+ def get_incremental_name(self, name=None):
19
+ if name is None:
20
+ name = "default"
21
+ idx = self.incremental_name_counters[name]
22
+ self.incremental_name_counters[name] += 1
23
+ return f"{name}_{idx}"
24
+
25
+ def reset_incremental_names(self):
26
+ self.incremental_name_counters.clear()
27
+
28
+ @torch.compiler.disable
29
+ def get_buffer(self, name):
30
+ return self.buffers.get(name)
31
+
32
+ @torch.compiler.disable
33
+ def set_buffer(self, name, buffer):
34
+ self.buffers[name] = buffer
35
+
36
+ def clear_buffers(self):
37
+ self.buffers.clear()
38
+
39
+
40
+ @torch.compiler.disable
41
+ def get_buffer(name):
42
+ cache_context = get_current_cache_context()
43
+ assert cache_context is not None, "cache_context must be set before"
44
+ return cache_context.get_buffer(name)
45
+
46
+
47
+ @torch.compiler.disable
48
+ def set_buffer(name, buffer):
49
+ cache_context = get_current_cache_context()
50
+ assert cache_context is not None, "cache_context must be set before"
51
+ cache_context.set_buffer(name, buffer)
52
+
53
+
54
+ _current_cache_context = None
55
+
56
+
57
+ def create_cache_context():
58
+ return CacheContext()
59
+
60
+
61
+ def get_current_cache_context():
62
+ return _current_cache_context
63
+
64
+
65
+ def set_current_cache_context(cache_context=None):
66
+ global _current_cache_context
67
+ _current_cache_context = cache_context
68
+
69
+
70
+ @contextlib.contextmanager
71
+ def cache_context(cache_context):
72
+ global _current_cache_context
73
+ old_cache_context = _current_cache_context
74
+ _current_cache_context = cache_context
75
+ try:
76
+ yield
77
+ finally:
78
+ _current_cache_context = old_cache_context
79
+
80
+
81
+
82
+ @torch.compiler.disable
83
+ def apply_prev_hidden_states_residual(hidden_states, encoder_hidden_states):
84
+ hidden_states_residual = get_buffer("hidden_states_residual")
85
+ assert hidden_states_residual is not None, "hidden_states_residual must be set before"
86
+ hidden_states = hidden_states_residual + hidden_states
87
+
88
+ encoder_hidden_states_residual = get_buffer("encoder_hidden_states_residual")
89
+ assert encoder_hidden_states_residual is not None, "encoder_hidden_states_residual must be set before"
90
+ encoder_hidden_states = encoder_hidden_states_residual + encoder_hidden_states
91
+
92
+ hidden_states = hidden_states.contiguous()
93
+ encoder_hidden_states = encoder_hidden_states.contiguous()
94
+
95
+ return hidden_states, encoder_hidden_states
96
+
97
+
98
+ def are_two_tensors_similar(t1, t2, *, threshold=0.85):
99
+ mean_diff = (t1 - t2).abs().mean()
100
+ mean_t1 = t1.abs().mean()
101
+ diff = mean_diff / mean_t1
102
+ return diff.item() < threshold
103
+
104
+ @torch.compiler.disable
105
+ def get_can_use_cache(first_hidden_states_residual, threshold, parallelized=False):
106
+ prev_first_hidden_states_residual = get_buffer("first_hidden_states_residual")
107
+ can_use_cache = prev_first_hidden_states_residual is not None and are_two_tensors_similar(
108
+ prev_first_hidden_states_residual,
109
+ first_hidden_states_residual,
110
+ )
111
+ return can_use_cache
112
+
113
+
114
+ class CachedTransformerBlocks(torch.nn.Module):
115
+ def __init__(
116
+ self,
117
+ transformer_blocks,
118
+ single_transformer_blocks=None,
119
+ *,
120
+ transformer=None,
121
+ residual_diff_threshold,
122
+ return_hidden_states_first=True,
123
+ ):
124
+ super().__init__()
125
+ self.transformer = transformer
126
+ self.transformer_blocks = transformer_blocks
127
+ self.single_transformer_blocks = single_transformer_blocks
128
+ self.residual_diff_threshold = residual_diff_threshold
129
+ self.return_hidden_states_first = return_hidden_states_first
130
+
131
+ def forward(self, hidden_states, encoder_hidden_states, *args, **kwargs):
132
+ if self.residual_diff_threshold <= 0.0:
133
+ for block in self.transformer_blocks:
134
+ hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs)
135
+ if not self.return_hidden_states_first:
136
+ hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
137
+ if self.single_transformer_blocks is not None:
138
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
139
+ for block in self.single_transformer_blocks:
140
+ hidden_states = block(hidden_states, *args, **kwargs)
141
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :]
142
+ return (
143
+ (hidden_states, encoder_hidden_states)
144
+ if self.return_hidden_states_first
145
+ else (encoder_hidden_states, hidden_states)
146
+ )
147
+
148
+ original_hidden_states = hidden_states
149
+ first_transformer_block = self.transformer_blocks[0]
150
+ hidden_states, encoder_hidden_states = first_transformer_block(
151
+ hidden_states, encoder_hidden_states, *args, **kwargs
152
+ )
153
+ if not self.return_hidden_states_first:
154
+ hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
155
+ first_hidden_states_residual = hidden_states - original_hidden_states
156
+ del original_hidden_states
157
+
158
+ can_use_cache = get_can_use_cache(
159
+ first_hidden_states_residual,
160
+ threshold=self.residual_diff_threshold,
161
+ parallelized=self.transformer is not None and getattr(self.transformer, "_is_parallelized", False),
162
+ )
163
+
164
+ torch._dynamo.graph_break()
165
+ if can_use_cache:
166
+ del first_hidden_states_residual
167
+ hidden_states, encoder_hidden_states = apply_prev_hidden_states_residual(
168
+ hidden_states, encoder_hidden_states
169
+ )
170
+ else:
171
+ set_buffer("first_hidden_states_residual", first_hidden_states_residual)
172
+ del first_hidden_states_residual
173
+ (
174
+ hidden_states,
175
+ encoder_hidden_states,
176
+ hidden_states_residual,
177
+ encoder_hidden_states_residual,
178
+ ) = self.call_remaining_transformer_blocks(hidden_states, encoder_hidden_states, *args, **kwargs)
179
+ set_buffer("hidden_states_residual", hidden_states_residual)
180
+ set_buffer("encoder_hidden_states_residual", encoder_hidden_states_residual)
181
+ torch._dynamo.graph_break()
182
+
183
+ return (
184
+ (hidden_states, encoder_hidden_states)
185
+ if self.return_hidden_states_first
186
+ else (encoder_hidden_states, hidden_states)
187
+ )
188
+
189
+ def call_remaining_transformer_blocks(self, hidden_states, encoder_hidden_states, *args, **kwargs):
190
+ original_hidden_states = hidden_states
191
+ original_encoder_hidden_states = encoder_hidden_states
192
+ for block in self.transformer_blocks[1:]:
193
+ hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs)
194
+ if not self.return_hidden_states_first:
195
+ hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
196
+ if self.single_transformer_blocks is not None:
197
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
198
+ for block in self.single_transformer_blocks:
199
+ hidden_states = block(hidden_states, *args, **kwargs)
200
+ encoder_hidden_states, hidden_states = hidden_states.split(
201
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
202
+ )
203
+
204
+ # hidden_states_shape = hidden_states.shape
205
+ # encoder_hidden_states_shape = encoder_hidden_states.shape
206
+ hidden_states = hidden_states.reshape(-1).contiguous().reshape(original_hidden_states.shape)
207
+ encoder_hidden_states = (
208
+ encoder_hidden_states.reshape(-1).contiguous().reshape(original_encoder_hidden_states.shape)
209
+ )
210
+
211
+ # hidden_states = hidden_states.contiguous()
212
+ # encoder_hidden_states = encoder_hidden_states.contiguous()
213
+
214
+ hidden_states_residual = hidden_states - original_hidden_states
215
+ encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states
216
+
217
+ hidden_states_residual = hidden_states_residual.reshape(-1).contiguous().reshape(original_hidden_states.shape)
218
+ encoder_hidden_states_residual = (
219
+ encoder_hidden_states_residual.reshape(-1).contiguous().reshape(original_encoder_hidden_states.shape)
220
+ )
221
+
222
+ return hidden_states, encoder_hidden_states, hidden_states_residual, encoder_hidden_states_residual
223
+
224
+
225
+ def apply_cache_on_transformer(
226
+ transformer: FluxTransformer2DModel,
227
+ *,
228
+ residual_diff_threshold=0.15,
229
+ ):
230
+ cached_transformer_blocks = torch.nn.ModuleList(
231
+ [
232
+ CachedTransformerBlocks(
233
+ transformer.transformer_blocks,
234
+ transformer.single_transformer_blocks,
235
+ transformer=transformer,
236
+ residual_diff_threshold=residual_diff_threshold,
237
+ return_hidden_states_first=False,
238
+ )
239
+ ]
240
+ )
241
+ dummy_single_transformer_blocks = torch.nn.ModuleList()
242
+
243
+ original_forward = transformer.forward
244
+
245
+ @functools.wraps(original_forward)
246
+ def new_forward(
247
+ self,
248
+ *args,
249
+ **kwargs,
250
+ ):
251
+ with unittest.mock.patch.object(
252
+ self,
253
+ "transformer_blocks",
254
+ cached_transformer_blocks,
255
+ ), unittest.mock.patch.object(
256
+ self,
257
+ "single_transformer_blocks",
258
+ dummy_single_transformer_blocks,
259
+ ):
260
+ return original_forward(
261
+ *args,
262
+ **kwargs,
263
+ )
264
+
265
+ transformer.forward = new_forward.__get__(transformer)
266
+
267
+ return transformer
268
+
269
+
270
+ def apply_cache_on_pipe(
271
+ pipe: DiffusionPipeline,
272
+ *,
273
+ shallow_patch: bool = False,
274
+ **kwargs,
275
+ ):
276
+ original_call = pipe.__class__.__call__
277
+
278
+ if not getattr(original_call, "_is_cached", False):
279
+
280
+ @functools.wraps(original_call)
281
+ def new_call(self, *args, **kwargs):
282
+ with cache_context(create_cache_context()):
283
+ return original_call(self, *args, **kwargs)
284
+
285
+ pipe.__class__.__call__ = new_call
286
+
287
+ new_call._is_cached = True
288
+
289
+ if not shallow_patch:
290
+ apply_cache_on_transformer(pipe.transformer, **kwargs)
291
+
292
+ pipe._is_cached = True
293
+
294
+ return pipe