Spaces:
Runtime error
Runtime error
Improved memory requirements
Browse files- bounded_attention.py +36 -30
- injection_utils.py +15 -122
bounded_attention.py
CHANGED
|
@@ -44,6 +44,7 @@ class BoundedAttention(injection_utils.AttentionBase):
|
|
| 44 |
pca_rank=None,
|
| 45 |
num_clusters=None,
|
| 46 |
num_clusters_per_box=3,
|
|
|
|
| 47 |
map_dir=None,
|
| 48 |
debug=False,
|
| 49 |
delta_debug_attention_steps=20,
|
|
@@ -81,6 +82,7 @@ class BoundedAttention(injection_utils.AttentionBase):
|
|
| 81 |
self.clustering = KMeans(n_clusters=num_clusters, num_init=100)
|
| 82 |
self.centers = None
|
| 83 |
|
|
|
|
| 84 |
self.map_dir = map_dir
|
| 85 |
self.debug = debug
|
| 86 |
self.delta_debug_attention_steps = delta_debug_attention_steps
|
|
@@ -124,24 +126,34 @@ class BoundedAttention(injection_utils.AttentionBase):
|
|
| 124 |
self.clear_values(include_maps=True)
|
| 125 |
super().reset()
|
| 126 |
|
| 127 |
-
def forward(self, q, k, v,
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
|
|
|
| 132 |
if is_cross:
|
| 133 |
-
|
| 134 |
else:
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
-
out = torch.bmm(attn, v)
|
| 145 |
out = einops.rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
|
| 146 |
return out
|
| 147 |
|
|
@@ -235,16 +247,13 @@ class BoundedAttention(injection_utils.AttentionBase):
|
|
| 235 |
references = references.reshape(-1, *references_unconditional.shape[2:])
|
| 236 |
return batch, references
|
| 237 |
|
| 238 |
-
def _hide_other_subjects_from_tokens(self,
|
| 239 |
-
|
| 240 |
-
device = sim.device
|
| 241 |
-
batch_size = sim.size(0)
|
| 242 |
-
resolution = int(sim.size(2) ** 0.5)
|
| 243 |
subject_masks, background_masks = self._obtain_masks(resolution, batch_size=batch_size, device=device) # b s n
|
| 244 |
include_background = self.optimized or (not self.mask_cross_during_guidance and self.cur_step < self.max_guidance_iter_per_step)
|
| 245 |
subject_masks = torch.logical_or(subject_masks, background_masks.unsqueeze(1)) if include_background else subject_masks
|
| 246 |
-
min_value = torch.finfo(
|
| 247 |
-
sim_masks = torch.
|
| 248 |
for token_indices in (*self.subject_token_indices, self.filter_token_indices):
|
| 249 |
sim_masks[:, :, token_indices] = min_value
|
| 250 |
|
|
@@ -257,16 +266,13 @@ class BoundedAttention(injection_utils.AttentionBase):
|
|
| 257 |
for batch_index, background_mask in zip(range(batch_size), background_masks):
|
| 258 |
sim_masks[batch_index, background_mask, self.eos_token_index] = min_value
|
| 259 |
|
| 260 |
-
return
|
| 261 |
|
| 262 |
-
def _hide_other_subjects_from_subjects(self,
|
| 263 |
-
|
| 264 |
-
device = sim.device
|
| 265 |
-
batch_size = sim.size(0)
|
| 266 |
-
resolution = int(sim.size(2) ** 0.5)
|
| 267 |
subject_masks, background_masks = self._obtain_masks(resolution, batch_size=batch_size, device=device) # b s n
|
| 268 |
min_value = torch.finfo(dtype).min
|
| 269 |
-
sim_masks = torch.
|
| 270 |
for batch_index, background_mask in zip(range(batch_size), background_masks):
|
| 271 |
sim_masks[batch_index, ~background_mask, ~background_mask] = min_value
|
| 272 |
|
|
@@ -276,7 +282,7 @@ class BoundedAttention(injection_utils.AttentionBase):
|
|
| 276 |
condition = torch.logical_or(subject_sim_mask == 0, subject_mask.unsqueeze(0))
|
| 277 |
sim_masks[batch_index, subject_mask] = torch.where(condition, 0, min_value).to(dtype=dtype)
|
| 278 |
|
| 279 |
-
return
|
| 280 |
|
| 281 |
def _save(self, attn, is_cross, num_heads):
|
| 282 |
_, attn = attn.chunk(2)
|
|
|
|
| 44 |
pca_rank=None,
|
| 45 |
num_clusters=None,
|
| 46 |
num_clusters_per_box=3,
|
| 47 |
+
max_resolution=32,
|
| 48 |
map_dir=None,
|
| 49 |
debug=False,
|
| 50 |
delta_debug_attention_steps=20,
|
|
|
|
| 82 |
self.clustering = KMeans(n_clusters=num_clusters, num_init=100)
|
| 83 |
self.centers = None
|
| 84 |
|
| 85 |
+
self.max_resolution = max_resolution
|
| 86 |
self.map_dir = map_dir
|
| 87 |
self.debug = debug
|
| 88 |
self.delta_debug_attention_steps = delta_debug_attention_steps
|
|
|
|
| 126 |
self.clear_values(include_maps=True)
|
| 127 |
super().reset()
|
| 128 |
|
| 129 |
+
def forward(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
|
| 130 |
+
batch_size = q.size(0) // num_heads
|
| 131 |
+
n = q.size(1)
|
| 132 |
+
d = k.size(1)
|
| 133 |
+
dtype = q.dtype
|
| 134 |
+
device = q.device
|
| 135 |
if is_cross:
|
| 136 |
+
masks = self._hide_other_subjects_from_tokens(batch_size // 2, n, d, dtype, device)
|
| 137 |
else:
|
| 138 |
+
masks = self._hide_other_subjects_from_subjects(batch_size // 2, n, dtype, device)
|
| 139 |
+
|
| 140 |
+
if int(n ** 0.5) > self.max_resolution:
|
| 141 |
+
q = q.reshape(batch_size, num_heads, n, -1)
|
| 142 |
+
k = k.reshape(batch_size, num_heads, d, -1)
|
| 143 |
+
v = v.reshape(batch_size, num_heads, d, -1)
|
| 144 |
+
out = F.scaled_dot_product_attention(q, k, v, attn_mask=masks)
|
| 145 |
+
out = out.reshape(batch_size * num_heads, n, -1)
|
| 146 |
+
else:
|
| 147 |
+
sim = torch.einsum('b i d, b j d -> b i j', q, k) * kwargs['scale']
|
| 148 |
+
attn = sim.softmax(-1)
|
| 149 |
+
self._display_attention_maps(attn, is_cross, num_heads)
|
| 150 |
+
sim = sim.reshape(batch_size, num_heads, n, d) + masks
|
| 151 |
+
attn = sim.reshape(-1, n, d).softmax(-1)
|
| 152 |
+
self._save(attn, is_cross, num_heads)
|
| 153 |
+
self._display_attention_maps(attn, is_cross, num_heads, prefix='masked')
|
| 154 |
+
self._debug_hook(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
|
| 155 |
+
out = torch.bmm(attn, v)
|
| 156 |
|
|
|
|
| 157 |
out = einops.rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
|
| 158 |
return out
|
| 159 |
|
|
|
|
| 247 |
references = references.reshape(-1, *references_unconditional.shape[2:])
|
| 248 |
return batch, references
|
| 249 |
|
| 250 |
+
def _hide_other_subjects_from_tokens(self, batch_size, n, d, dtype, device): # b h i j
|
| 251 |
+
resolution = int(n ** 0.5)
|
|
|
|
|
|
|
|
|
|
| 252 |
subject_masks, background_masks = self._obtain_masks(resolution, batch_size=batch_size, device=device) # b s n
|
| 253 |
include_background = self.optimized or (not self.mask_cross_during_guidance and self.cur_step < self.max_guidance_iter_per_step)
|
| 254 |
subject_masks = torch.logical_or(subject_masks, background_masks.unsqueeze(1)) if include_background else subject_masks
|
| 255 |
+
min_value = torch.finfo(dtype).min
|
| 256 |
+
sim_masks = torch.zeros((batch_size, n, d), dtype=dtype, device=device) # b i j
|
| 257 |
for token_indices in (*self.subject_token_indices, self.filter_token_indices):
|
| 258 |
sim_masks[:, :, token_indices] = min_value
|
| 259 |
|
|
|
|
| 266 |
for batch_index, background_mask in zip(range(batch_size), background_masks):
|
| 267 |
sim_masks[batch_index, background_mask, self.eos_token_index] = min_value
|
| 268 |
|
| 269 |
+
return torch.cat((torch.zeros_like(sim_masks), sim_masks)).unsqueeze(1)
|
| 270 |
|
| 271 |
+
def _hide_other_subjects_from_subjects(self, batch_size, n, dtype, device): # b h i j
|
| 272 |
+
resolution = int(n ** 0.5)
|
|
|
|
|
|
|
|
|
|
| 273 |
subject_masks, background_masks = self._obtain_masks(resolution, batch_size=batch_size, device=device) # b s n
|
| 274 |
min_value = torch.finfo(dtype).min
|
| 275 |
+
sim_masks = torch.zeros((batch_size, n, n), dtype=dtype, device=device) # b i j
|
| 276 |
for batch_index, background_mask in zip(range(batch_size), background_masks):
|
| 277 |
sim_masks[batch_index, ~background_mask, ~background_mask] = min_value
|
| 278 |
|
|
|
|
| 282 |
condition = torch.logical_or(subject_sim_mask == 0, subject_mask.unsqueeze(0))
|
| 283 |
sim_masks[batch_index, subject_mask] = torch.where(condition, 0, min_value).to(dtype=dtype)
|
| 284 |
|
| 285 |
+
return torch.cat((sim_masks, sim_masks)).unsqueeze(1)
|
| 286 |
|
| 287 |
def _save(self, attn, is_cross, num_heads):
|
| 288 |
_, attn = attn.chunk(2)
|
injection_utils.py
CHANGED
|
@@ -22,21 +22,29 @@ class AttentionBase:
|
|
| 22 |
def after_step(self):
|
| 23 |
pass
|
| 24 |
|
| 25 |
-
def __call__(self, q, k, v,
|
| 26 |
if self.cur_att_layer == 0:
|
| 27 |
self.before_step()
|
| 28 |
|
| 29 |
-
out = self.forward(q, k, v,
|
| 30 |
self.cur_att_layer += 1
|
| 31 |
if self.cur_att_layer == self.num_att_layers:
|
| 32 |
self.cur_att_layer = 0
|
| 33 |
self.cur_step += 1
|
| 34 |
-
# after step
|
| 35 |
self.after_step()
|
|
|
|
| 36 |
return out
|
| 37 |
|
| 38 |
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
|
| 41 |
return out
|
| 42 |
|
|
@@ -45,42 +53,6 @@ class AttentionBase:
|
|
| 45 |
self.cur_att_layer = 0
|
| 46 |
|
| 47 |
|
| 48 |
-
class AttentionStore(AttentionBase):
|
| 49 |
-
def __init__(self, res=[32], min_step=0, max_step=1000):
|
| 50 |
-
super().__init__()
|
| 51 |
-
self.res = res
|
| 52 |
-
self.min_step = min_step
|
| 53 |
-
self.max_step = max_step
|
| 54 |
-
self.valid_steps = 0
|
| 55 |
-
|
| 56 |
-
self.self_attns = [] # store the all attns
|
| 57 |
-
self.cross_attns = []
|
| 58 |
-
|
| 59 |
-
self.self_attns_step = [] # store the attns in each step
|
| 60 |
-
self.cross_attns_step = []
|
| 61 |
-
|
| 62 |
-
def after_step(self):
|
| 63 |
-
if self.cur_step > self.min_step and self.cur_step < self.max_step:
|
| 64 |
-
self.valid_steps += 1
|
| 65 |
-
if len(self.self_attns) == 0:
|
| 66 |
-
self.self_attns = self.self_attns_step
|
| 67 |
-
self.cross_attns = self.cross_attns_step
|
| 68 |
-
else:
|
| 69 |
-
for i in range(len(self.self_attns)):
|
| 70 |
-
self.self_attns[i] += self.self_attns_step[i]
|
| 71 |
-
self.cross_attns[i] += self.cross_attns_step[i]
|
| 72 |
-
self.self_attns_step.clear()
|
| 73 |
-
self.cross_attns_step.clear()
|
| 74 |
-
|
| 75 |
-
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
|
| 76 |
-
if attn.shape[1] <= 64 ** 2: # avoid OOM
|
| 77 |
-
if is_cross:
|
| 78 |
-
self.cross_attns_step.append(attn)
|
| 79 |
-
else:
|
| 80 |
-
self.self_attns_step.append(attn)
|
| 81 |
-
return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
|
| 82 |
-
|
| 83 |
-
|
| 84 |
def regiter_attention_editor_diffusers(model, editor: AttentionBase):
|
| 85 |
"""
|
| 86 |
Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
|
|
@@ -109,21 +81,9 @@ def regiter_attention_editor_diffusers(model, editor: AttentionBase):
|
|
| 109 |
k = self.to_k(context)
|
| 110 |
v = self.to_v(context)
|
| 111 |
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
| 112 |
-
|
| 113 |
-
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
|
| 114 |
-
|
| 115 |
-
if mask is not None:
|
| 116 |
-
mask = rearrange(mask, 'b ... -> b (...)')
|
| 117 |
-
max_neg_value = -torch.finfo(sim.dtype).max
|
| 118 |
-
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
| 119 |
-
mask = mask[:, None, :].repeat(h, 1, 1)
|
| 120 |
-
sim.masked_fill_(~mask, max_neg_value)
|
| 121 |
-
|
| 122 |
-
attn = sim.softmax(dim=-1)
|
| 123 |
-
# the only difference
|
| 124 |
out = editor(
|
| 125 |
-
q, k, v,
|
| 126 |
-
self.heads, scale=self.scale)
|
| 127 |
|
| 128 |
return to_out(out)
|
| 129 |
|
|
@@ -146,74 +106,7 @@ def regiter_attention_editor_diffusers(model, editor: AttentionBase):
|
|
| 146 |
cross_att_count += register_editor(net, 0, "mid")
|
| 147 |
elif "up" in net_name:
|
| 148 |
cross_att_count += register_editor(net, 0, "up")
|
|
|
|
| 149 |
editor.num_att_layers = cross_att_count
|
| 150 |
editor.model = model
|
| 151 |
model.editor = editor
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
def regiter_attention_editor_ldm(model, editor: AttentionBase):
|
| 155 |
-
"""
|
| 156 |
-
Register a attention editor to Stable Diffusion model, refer from [Prompt-to-Prompt]
|
| 157 |
-
"""
|
| 158 |
-
def ca_forward(self, place_in_unet):
|
| 159 |
-
def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
|
| 160 |
-
"""
|
| 161 |
-
The attention is similar to the original implementation of LDM CrossAttention class
|
| 162 |
-
except adding some modifications on the attention
|
| 163 |
-
"""
|
| 164 |
-
if encoder_hidden_states is not None:
|
| 165 |
-
context = encoder_hidden_states
|
| 166 |
-
if attention_mask is not None:
|
| 167 |
-
mask = attention_mask
|
| 168 |
-
|
| 169 |
-
to_out = self.to_out
|
| 170 |
-
if isinstance(to_out, nn.modules.container.ModuleList):
|
| 171 |
-
to_out = self.to_out[0]
|
| 172 |
-
else:
|
| 173 |
-
to_out = self.to_out
|
| 174 |
-
|
| 175 |
-
h = self.heads
|
| 176 |
-
q = self.to_q(x)
|
| 177 |
-
is_cross = context is not None
|
| 178 |
-
context = context if is_cross else x
|
| 179 |
-
k = self.to_k(context)
|
| 180 |
-
v = self.to_v(context)
|
| 181 |
-
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
| 182 |
-
|
| 183 |
-
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
|
| 184 |
-
|
| 185 |
-
if mask is not None:
|
| 186 |
-
mask = rearrange(mask, 'b ... -> b (...)')
|
| 187 |
-
max_neg_value = -torch.finfo(sim.dtype).max
|
| 188 |
-
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
| 189 |
-
mask = mask[:, None, :].repeat(h, 1, 1)
|
| 190 |
-
sim.masked_fill_(~mask, max_neg_value)
|
| 191 |
-
|
| 192 |
-
attn = sim.softmax(dim=-1)
|
| 193 |
-
# the only difference
|
| 194 |
-
out = editor(
|
| 195 |
-
q, k, v, sim, attn, is_cross, place_in_unet,
|
| 196 |
-
self.heads, scale=self.scale)
|
| 197 |
-
|
| 198 |
-
return to_out(out)
|
| 199 |
-
|
| 200 |
-
return forward
|
| 201 |
-
|
| 202 |
-
def register_editor(net, count, place_in_unet):
|
| 203 |
-
for name, subnet in net.named_children():
|
| 204 |
-
if net.__class__.__name__ == 'CrossAttention': # spatial Transformer layer
|
| 205 |
-
net.forward = ca_forward(net, place_in_unet)
|
| 206 |
-
return count + 1
|
| 207 |
-
elif hasattr(net, 'children'):
|
| 208 |
-
count = register_editor(subnet, count, place_in_unet)
|
| 209 |
-
return count
|
| 210 |
-
|
| 211 |
-
cross_att_count = 0
|
| 212 |
-
for net_name, net in model.model.diffusion_model.named_children():
|
| 213 |
-
if "input" in net_name:
|
| 214 |
-
cross_att_count += register_editor(net, 0, "input")
|
| 215 |
-
elif "middle" in net_name:
|
| 216 |
-
cross_att_count += register_editor(net, 0, "middle")
|
| 217 |
-
elif "output" in net_name:
|
| 218 |
-
cross_att_count += register_editor(net, 0, "output")
|
| 219 |
-
editor.num_att_layers = cross_att_count
|
|
|
|
| 22 |
def after_step(self):
|
| 23 |
pass
|
| 24 |
|
| 25 |
+
def __call__(self, q, k, v, is_cross, place_in_unet, num_heads, **kwargs):
|
| 26 |
if self.cur_att_layer == 0:
|
| 27 |
self.before_step()
|
| 28 |
|
| 29 |
+
out = self.forward(q, k, v, is_cross, place_in_unet, num_heads, **kwargs)
|
| 30 |
self.cur_att_layer += 1
|
| 31 |
if self.cur_att_layer == self.num_att_layers:
|
| 32 |
self.cur_att_layer = 0
|
| 33 |
self.cur_step += 1
|
|
|
|
| 34 |
self.after_step()
|
| 35 |
+
|
| 36 |
return out
|
| 37 |
|
| 38 |
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
|
| 39 |
+
batch_size = q.size(0) // num_heads
|
| 40 |
+
n = q.size(1)
|
| 41 |
+
d = k.size(1)
|
| 42 |
+
|
| 43 |
+
q = q.reshape(batch_size, num_heads, n, -1)
|
| 44 |
+
k = k.reshape(batch_size, num_heads, d, -1)
|
| 45 |
+
v = v.reshape(batch_size, num_heads, d, -1)
|
| 46 |
+
out = F.scaled_dot_product_attention(q, k, v, attn_mask=kwargs['mask'])
|
| 47 |
+
out = out.reshape(batch_size * num_heads, n, -1)
|
| 48 |
out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
|
| 49 |
return out
|
| 50 |
|
|
|
|
| 53 |
self.cur_att_layer = 0
|
| 54 |
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
def regiter_attention_editor_diffusers(model, editor: AttentionBase):
|
| 57 |
"""
|
| 58 |
Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
|
|
|
|
| 81 |
k = self.to_k(context)
|
| 82 |
v = self.to_v(context)
|
| 83 |
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
out = editor(
|
| 85 |
+
q, k, v, is_cross, place_in_unet,
|
| 86 |
+
self.heads, scale=self.scale, mask=mask)
|
| 87 |
|
| 88 |
return to_out(out)
|
| 89 |
|
|
|
|
| 106 |
cross_att_count += register_editor(net, 0, "mid")
|
| 107 |
elif "up" in net_name:
|
| 108 |
cross_att_count += register_editor(net, 0, "up")
|
| 109 |
+
|
| 110 |
editor.num_att_layers = cross_att_count
|
| 111 |
editor.model = model
|
| 112 |
model.editor = editor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|