manbeast3b commited on
Commit
f97a2cb
·
verified ·
1 Parent(s): cdd39c3

Create caching.py

Browse files
Files changed (1) hide show
  1. src/caching.py +173 -0
src/caching.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import functools
3
+ import unittest
4
+ import contextlib
5
+ import dataclasses
6
+ from collections import defaultdict
7
+ from typing import DefaultDict, Dict
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from diffusers import DiffusionPipeline, FluxTransformer2DModel
11
+
12
+ @dataclasses.dataclass
13
+ class CacheContext:
14
+ buffers: Dict[str, torch.Tensor] = dataclasses.field(default_factory=dict)
15
+ incremental_name_counters: DefaultDict[str, int] = dataclasses.field(default_factory=lambda: defaultdict(int))
16
+
17
+ def get_incremental_name(self, name=None):
18
+ if name is None:
19
+ name = "default"
20
+ idx = self.incremental_name_counters[name]
21
+ self.incremental_name_counters[name] += 1
22
+ return f"{name}_{idx}"
23
+
24
+ def reset_incremental_names(self):
25
+ self.incremental_name_counters.clear()
26
+
27
+ @torch.compiler.disable
28
+ def get_buffer(self, name):
29
+ return self.buffers.get(name)
30
+
31
+ @torch.compiler.disable
32
+ def set_buffer(self, name, buffer):
33
+ self.buffers[name] = buffer
34
+
35
+ def clear_buffers(self):
36
+ self.buffers.clear()
37
+
38
+ _current_cache_context = None
39
+
40
+ def create_cache_context():
41
+ return CacheContext()
42
+
43
+ def get_current_cache_context():
44
+ return _current_cache_context
45
+
46
+ def set_current_cache_context(cache_context=None):
47
+ global _current_cache_context
48
+ _current_cache_context = cache_context
49
+
50
+ @contextlib.contextmanager
51
+ def cache_context(cache_context):
52
+ global _current_cache_context
53
+ old_cache_context = _current_cache_context
54
+ _current_cache_context = cache_context
55
+ try:
56
+ yield
57
+ finally:
58
+ _current_cache_context = old_cache_context
59
+
60
+ @torch.compiler.disable
61
+ def are_two_tensors_similar(t1, t2, *, threshold=0.85):
62
+ mean_diff = (t1 - t2).abs().mean()
63
+ mean_t1 = t1.abs().mean()
64
+ diff = mean_diff / mean_t1
65
+ return diff.item() < threshold
66
+
67
+ class CachedTransformerBlocks(torch.nn.Module):
68
+ def __init__(
69
+ self,
70
+ transformer_blocks,
71
+ single_transformer_blocks=None,
72
+ *,
73
+ transformer=None,
74
+ residual_diff_threshold=0.05,
75
+ return_hidden_states_first=True,
76
+ ):
77
+ super().__init__()
78
+ self.transformer = transformer
79
+ self.transformer_blocks = transformer_blocks
80
+ self.single_transformer_blocks = single_transformer_blocks
81
+ self.residual_diff_threshold = residual_diff_threshold
82
+ self.return_hidden_states_first = return_hidden_states_first
83
+
84
+ def forward(self, hidden_states, encoder_hidden_states, *args, **kwargs):
85
+ if self.residual_diff_threshold <= 0.0:
86
+ for block in self.transformer_blocks:
87
+ hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs)
88
+ if not self.return_hidden_states_first:
89
+ hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
90
+ if self.single_transformer_blocks is not None:
91
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
92
+ for block in self.single_transformer_blocks:
93
+ hidden_states = block(hidden_states, *args, **kwargs)
94
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1]:]
95
+ return (hidden_states, encoder_hidden_states) if self.return_hidden_states_first else (encoder_hidden_states, hidden_states)
96
+
97
+ original_hidden_states = hidden_states
98
+ first_block = self.transformer_blocks[0]
99
+ hidden_states, encoder_hidden_states = first_block(hidden_states, encoder_hidden_states, *args, **kwargs)
100
+ if not self.return_hidden_states_first:
101
+ hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
102
+
103
+ first_hidden_states_residual = hidden_states - original_hidden_states
104
+
105
+ cache_context = get_current_cache_context()
106
+ prev_residual = cache_context.get_buffer("first_hidden_states_residual")
107
+ can_use_cache = prev_residual is not None and are_two_tensors_similar(
108
+ prev_residual, first_hidden_states_residual, threshold=self.residual_diff_threshold
109
+ )
110
+
111
+ if can_use_cache:
112
+ hidden_states_residual = cache_context.get_buffer("hidden_states_residual")
113
+ encoder_hidden_states_residual = cache_context.get_buffer("encoder_hidden_states_residual")
114
+ hidden_states = hidden_states_residual + hidden_states
115
+ encoder_hidden_states = encoder_hidden_states_residual + encoder_hidden_states
116
+ else:
117
+ cache_context.set_buffer("first_hidden_states_residual", first_hidden_states_residual)
118
+ for block in self.transformer_blocks[1:]:
119
+ hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs)
120
+ if not self.return_hidden_states_first:
121
+ hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
122
+
123
+ if self.single_transformer_blocks is not None:
124
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
125
+ for block in self.single_transformer_blocks:
126
+ hidden_states = block(hidden_states, *args, **kwargs)
127
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1]:]
128
+
129
+ cache_context.set_buffer("hidden_states_residual", hidden_states - original_hidden_states)
130
+ cache_context.set_buffer("encoder_hidden_states_residual", encoder_hidden_states - original_encoder_hidden_states)
131
+
132
+ return (hidden_states, encoder_hidden_states) if self.return_hidden_states_first else (encoder_hidden_states, hidden_states)
133
+
134
+ def apply_cache_on_transformer(transformer: FluxTransformer2DModel, *, residual_diff_threshold=0.05):
135
+ cached_blocks = torch.nn.ModuleList([
136
+ CachedTransformerBlocks(
137
+ transformer.transformer_blocks,
138
+ transformer.single_transformer_blocks if hasattr(transformer, 'single_transformer_blocks') else None,
139
+ transformer=transformer,
140
+ residual_diff_threshold=residual_diff_threshold,
141
+ )
142
+ ])
143
+
144
+ original_forward = transformer.forward
145
+
146
+ @functools.wraps(transformer.__class__.forward)
147
+ def new_forward(self, *args, **kwargs):
148
+ with unittest.mock.patch.object(self, "transformer_blocks", cached_blocks):
149
+ if hasattr(self, 'single_transformer_blocks'):
150
+ with unittest.mock.patch.object(self, "single_transformer_blocks", torch.nn.ModuleList()):
151
+ return original_forward(*args, **kwargs)
152
+ return original_forward(*args, **kwargs)
153
+
154
+ transformer.forward = new_forward.__get__(transformer)
155
+ return transformer
156
+
157
+ def apply_cache_on_pipe(pipe: DiffusionPipeline, *, shallow_patch: bool = False, **kwargs):
158
+ original_call = pipe.__class__.__call__
159
+
160
+ if not getattr(original_call, "_is_cached", False):
161
+ @functools.wraps(original_call)
162
+ def new_call(self, *args, **kwargs):
163
+ with cache_context(create_cache_context()):
164
+ return original_call(self, *args, **kwargs)
165
+
166
+ pipe.__class__.__call__ = new_call
167
+ new_call._is_cached = True
168
+
169
+ if not shallow_patch:
170
+ apply_cache_on_transformer(pipe.transformer, **kwargs)
171
+
172
+ pipe._is_cached = True
173
+ return pipe