manbeast3b commited on
Commit
7f2e11c
·
verified ·
1 Parent(s): f1e914e

Update src/caching.py

Browse files
Files changed (1) hide show
  1. src/caching.py +63 -50
src/caching.py CHANGED
@@ -63,7 +63,8 @@ def are_two_tensors_similar(t1, t2, *, threshold=0.85):
63
  mean_t1 = t1.abs().mean()
64
  diff = mean_diff / mean_t1
65
  return diff.item() < threshold
66
-
 
67
  class CachedTransformerBlocks(torch.nn.Module):
68
  def __init__(
69
  self,
@@ -72,7 +73,7 @@ class CachedTransformerBlocks(torch.nn.Module):
72
  *,
73
  transformer=None,
74
  residual_diff_threshold=0.05,
75
- return_hidden_states_first=False, # Changed default to False for Flux
76
  ):
77
  super().__init__()
78
  self.transformer = transformer
@@ -82,82 +83,94 @@ class CachedTransformerBlocks(torch.nn.Module):
82
  self.return_hidden_states_first = return_hidden_states_first
83
 
84
  def forward(self, hidden_states, encoder_hidden_states, *args, **kwargs):
85
- # For Flux architecture, we need to handle the order differently
86
- if not self.return_hidden_states_first:
87
- hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
88
-
89
- if self.residual_diff_threshold <= 0.0:
90
- for block in self.transformer_blocks:
91
- if self.return_hidden_states_first:
92
- hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs)
93
- else:
94
- encoder_hidden_states, hidden_states = block(encoder_hidden_states, hidden_states, *args, **kwargs)
95
 
96
- return (hidden_states, encoder_hidden_states) if self.return_hidden_states_first else (encoder_hidden_states, hidden_states)
97
-
98
- original_encoder_states = encoder_hidden_states
99
  first_block = self.transformer_blocks[0]
 
 
 
100
 
101
- if self.return_hidden_states_first:
102
- hidden_states, encoder_hidden_states = first_block(hidden_states, encoder_hidden_states, *args, **kwargs)
103
- first_residual = hidden_states - original_encoder_states
104
- else:
105
- encoder_hidden_states, hidden_states = first_block(encoder_hidden_states, hidden_states, *args, **kwargs)
106
- first_residual = encoder_hidden_states - original_encoder_states
107
 
108
  cache_context = get_current_cache_context()
109
- prev_residual = cache_context.get_buffer("first_residual")
 
110
  can_use_cache = prev_residual is not None and are_two_tensors_similar(
111
- prev_residual, first_residual, threshold=self.residual_diff_threshold
 
 
112
  )
113
 
114
  if can_use_cache:
115
- residual = cache_context.get_buffer("residual")
116
- if self.return_hidden_states_first:
117
- hidden_states = residual + hidden_states
118
- else:
119
- encoder_hidden_states = residual + encoder_hidden_states
 
120
  else:
121
- cache_context.set_buffer("first_residual", first_residual)
122
- original_states = original_encoder_states
123
 
124
  for block in self.transformer_blocks[1:]:
125
- if self.return_hidden_states_first:
126
- hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs)
127
- else:
128
- encoder_hidden_states, hidden_states = block(encoder_hidden_states, hidden_states, *args, **kwargs)
129
-
130
- if self.return_hidden_states_first:
131
- cache_context.set_buffer("residual", hidden_states - original_states)
132
- else:
133
- cache_context.set_buffer("residual", encoder_hidden_states - original_states)
134
-
135
- return (hidden_states, encoder_hidden_states) if self.return_hidden_states_first else (encoder_hidden_states, hidden_states)
136
-
137
- def apply_cache_on_transformer(transformer: FluxTransformer2DModel, *, residual_diff_threshold=0.05):
138
- cached_blocks = torch.nn.ModuleList([
 
 
 
 
 
 
 
 
 
 
 
 
139
  CachedTransformerBlocks(
140
  transformer.transformer_blocks,
141
  transformer.single_transformer_blocks if hasattr(transformer, 'single_transformer_blocks') else None,
142
  transformer=transformer,
143
  residual_diff_threshold=residual_diff_threshold,
144
- return_hidden_states_first=False # Specifically for Flux
145
  )
146
  ])
 
147
 
148
  original_forward = transformer.forward
149
 
150
- @functools.wraps(transformer.__class__.forward)
151
  def new_forward(self, *args, **kwargs):
152
- with unittest.mock.patch.object(self, "transformer_blocks", cached_blocks):
153
- if hasattr(self, 'single_transformer_blocks'):
154
- with unittest.mock.patch.object(self, "single_transformer_blocks", torch.nn.ModuleList()):
155
- return original_forward(*args, **kwargs)
 
 
 
 
 
156
  return original_forward(*args, **kwargs)
157
 
158
  transformer.forward = new_forward.__get__(transformer)
159
  return transformer
160
 
 
161
  def apply_cache_on_pipe(pipe: DiffusionPipeline, *, shallow_patch: bool = False, **kwargs):
162
  original_call = pipe.__class__.__call__
163
 
 
63
  mean_t1 = t1.abs().mean()
64
  diff = mean_diff / mean_t1
65
  return diff.item() < threshold
66
+
67
+
68
  class CachedTransformerBlocks(torch.nn.Module):
69
  def __init__(
70
  self,
 
73
  *,
74
  transformer=None,
75
  residual_diff_threshold=0.05,
76
+ return_hidden_states_first=False,
77
  ):
78
  super().__init__()
79
  self.transformer = transformer
 
83
  self.return_hidden_states_first = return_hidden_states_first
84
 
85
  def forward(self, hidden_states, encoder_hidden_states, *args, **kwargs):
86
+ # Store original states before any transformations
87
+ original_hidden_states = hidden_states
88
+ original_encoder_hidden_states = encoder_hidden_states
 
 
 
 
 
 
 
89
 
90
+ # Process first block
 
 
91
  first_block = self.transformer_blocks[0]
92
+ hidden_states, encoder_hidden_states = first_block(
93
+ hidden_states, encoder_hidden_states, *args, **kwargs
94
+ )
95
 
96
+ # Calculate residual from first block
97
+ first_hidden_states_residual = hidden_states - original_hidden_states
 
 
 
 
98
 
99
  cache_context = get_current_cache_context()
100
+ prev_residual = cache_context.get_buffer("first_hidden_states_residual")
101
+
102
  can_use_cache = prev_residual is not None and are_two_tensors_similar(
103
+ prev_residual,
104
+ first_hidden_states_residual,
105
+ threshold=self.residual_diff_threshold
106
  )
107
 
108
  if can_use_cache:
109
+ # Use cached residuals
110
+ hidden_states_residual = cache_context.get_buffer("hidden_states_residual")
111
+ encoder_hidden_states_residual = cache_context.get_buffer("encoder_hidden_states_residual")
112
+
113
+ hidden_states = hidden_states + hidden_states_residual
114
+ encoder_hidden_states = encoder_hidden_states + encoder_hidden_states_residual
115
  else:
116
+ # Process remaining blocks and cache results
117
+ cache_context.set_buffer("first_hidden_states_residual", first_hidden_states_residual)
118
 
119
  for block in self.transformer_blocks[1:]:
120
+ hidden_states, encoder_hidden_states = block(
121
+ hidden_states, encoder_hidden_states, *args, **kwargs
122
+ )
123
+
124
+ if self.single_transformer_blocks is not None:
125
+ for block in self.single_transformer_blocks:
126
+ hidden_states = block(hidden_states, *args, **kwargs)
127
+
128
+ # Store residuals for future use
129
+ cache_context.set_buffer(
130
+ "hidden_states_residual",
131
+ hidden_states - original_hidden_states
132
+ )
133
+ cache_context.set_buffer(
134
+ "encoder_hidden_states_residual",
135
+ encoder_hidden_states - original_encoder_hidden_states
136
+ )
137
+
138
+ return hidden_states, encoder_hidden_states
139
+
140
+ def apply_cache_on_transformer(
141
+ transformer: FluxTransformer2DModel,
142
+ *,
143
+ residual_diff_threshold=0.05,
144
+ ):
145
+ cached_transformer_blocks = torch.nn.ModuleList([
146
  CachedTransformerBlocks(
147
  transformer.transformer_blocks,
148
  transformer.single_transformer_blocks if hasattr(transformer, 'single_transformer_blocks') else None,
149
  transformer=transformer,
150
  residual_diff_threshold=residual_diff_threshold,
 
151
  )
152
  ])
153
+ dummy_single_transformer_blocks = torch.nn.ModuleList()
154
 
155
  original_forward = transformer.forward
156
 
157
+ @functools.wraps(original_forward)
158
  def new_forward(self, *args, **kwargs):
159
+ with unittest.mock.patch.object(
160
+ self,
161
+ "transformer_blocks",
162
+ cached_transformer_blocks,
163
+ ), unittest.mock.patch.object(
164
+ self,
165
+ "single_transformer_blocks",
166
+ dummy_single_transformer_blocks,
167
+ ):
168
  return original_forward(*args, **kwargs)
169
 
170
  transformer.forward = new_forward.__get__(transformer)
171
  return transformer
172
 
173
+
174
  def apply_cache_on_pipe(pipe: DiffusionPipeline, *, shallow_patch: bool = False, **kwargs):
175
  original_call = pipe.__class__.__call__
176