personalive / src /models /mutual_self_attention.py
seawolf2357's picture
Deploy from GitHub repository
7428365 verified
from typing import Any, Dict, Optional
import torch
from einops import rearrange
from src.models.attention import TemporalBasicTransformerBlock
from .attention import BasicTransformerBlock
def torch_dfs(model: torch.nn.Module):
result = [model]
for child in model.children():
result += torch_dfs(child)
return result
class ReferenceAttentionControl:
def __init__(
self,
unet,
mode="write",
do_classifier_free_guidance=False,
attention_auto_machine_weight=float("inf"),
gn_auto_machine_weight=1.0,
style_fidelity=1.0,
reference_attn=True,
reference_adain=False,
fusion_blocks="midup",
batch_size=1,
cache_kv=False,
) -> None:
# 10. Modify self attention and group norm
self.unet = unet
assert mode in ["read", "write"]
assert fusion_blocks in ["midup", "full"]
self.reference_attn = reference_attn
self.reference_adain = reference_adain
self.fusion_blocks = fusion_blocks
self.cache_kv = cache_kv
self.register_reference_hooks(
mode,
do_classifier_free_guidance,
attention_auto_machine_weight,
gn_auto_machine_weight,
style_fidelity,
reference_attn,
reference_adain,
fusion_blocks,
batch_size=batch_size,
cache_kv=self.cache_kv,
)
def register_reference_hooks(
self,
mode,
do_classifier_free_guidance,
attention_auto_machine_weight,
gn_auto_machine_weight,
style_fidelity,
reference_attn,
reference_adain,
dtype=torch.float16,
batch_size=1,
num_images_per_prompt=1,
device=torch.device("cpu"),
fusion_blocks="midup",
cache_kv=False,
):
MODE = mode
do_classifier_free_guidance = do_classifier_free_guidance
attention_auto_machine_weight = attention_auto_machine_weight
gn_auto_machine_weight = gn_auto_machine_weight
style_fidelity = style_fidelity
reference_attn = reference_attn
reference_adain = reference_adain
fusion_blocks = fusion_blocks
num_images_per_prompt = num_images_per_prompt
cache_kv = cache_kv
dtype = dtype
if do_classifier_free_guidance:
uc_mask = (
torch.Tensor(
[1] * batch_size * num_images_per_prompt * 16
+ [0] * batch_size * num_images_per_prompt * 16
)
.to(device)
.bool()
)
else:
uc_mask = (
torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
.to(device)
.bool()
)
def hacked_basic_transformer_inner_forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
video_length=None,
):
if self.use_ada_layer_norm: # False
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
(
norm_hidden_states,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
) = self.norm1(
hidden_states,
timestep,
class_labels,
hidden_dtype=hidden_states.dtype,
)
else:
norm_hidden_states = self.norm1(hidden_states)
# 1. Self-Attention
# self.only_cross_attention = False
cross_attention_kwargs = (
cross_attention_kwargs if cross_attention_kwargs is not None else {}
)
if self.only_cross_attention:
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states
if self.only_cross_attention
else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
else:
if MODE == "write":
self.bank.append(norm_hidden_states.clone())
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states
if self.only_cross_attention
else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if MODE == "read":
kv_cache = norm_hidden_states.clone()
kv_cache = rearrange(
kv_cache, "(b t) l c -> b t l c", t=video_length
)
self.kv_cache = kv_cache[:,:2,:,:]
bank_fea = [
rearrange(
d.unsqueeze(1).repeat(1, video_length, 1, 1),
"b t l c -> (b t) l c",
)
for d in self.bank
]
if self.kv_bank is not None and cache_kv:
ahead_fea = self.kv_bank.unsqueeze(1).repeat(1, video_length, 1, 1, 1)
ahead_fea = rearrange(ahead_fea, "b t n l c -> (b t) (n l) c")
bank_fea.append(ahead_fea)
modify_norm_hidden_states = torch.cat(
[norm_hidden_states] + bank_fea, dim=1
)
hidden_states_uc = (
self.attn1(
norm_hidden_states,
encoder_hidden_states=modify_norm_hidden_states,
attention_mask=attention_mask,
)
# self.attn1(
# modify_norm_hidden_states,
# encoder_hidden_states=modify_norm_hidden_states,
# attention_mask=attention_mask,
# )[:, : hidden_states.shape[-2], :]
+ hidden_states
)
if do_classifier_free_guidance:
hidden_states_c = hidden_states_uc.clone()
_uc_mask = uc_mask.clone()
if hidden_states.shape[0] != _uc_mask.shape[0]:
_uc_mask = (
torch.Tensor(
[1] * (hidden_states.shape[0] // 2)
+ [0] * (hidden_states.shape[0] // 2)
)
.to(device)
.bool()
)
if self.kv_bank is not None:
# if False:
modify_norm_hidden_states = torch.cat(
[norm_hidden_states, ahead_fea], dim=1
)
else:
modify_norm_hidden_states = norm_hidden_states
hidden_states_c[_uc_mask] = (
self.attn1(
norm_hidden_states[_uc_mask],
encoder_hidden_states=modify_norm_hidden_states[_uc_mask],
attention_mask=attention_mask,
)
+ hidden_states[_uc_mask]
)
hidden_states = hidden_states_c.clone()
else:
hidden_states = hidden_states_uc
# self.bank.clear()
if self.attn2 is not None:
# Cross-Attention
norm_hidden_states = (
self.norm2(hidden_states, timestep)
if self.use_ada_layer_norm
else self.norm2(hidden_states)
)
hidden_states = (
self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
)
+ hidden_states
)
# Feed-forward
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
# Temporal-Attention
if self.unet_use_temporal_attention:
d = hidden_states.shape[1]
hidden_states = rearrange(
hidden_states, "(b f) d c -> (b d) f c", f=video_length
)
norm_hidden_states = (
self.norm_temp(hidden_states, timestep)
if self.use_ada_layer_norm
else self.norm_temp(hidden_states)
)
hidden_states = (
self.attn_temp(norm_hidden_states) + hidden_states
)
hidden_states = rearrange(
hidden_states, "(b d) f c -> (b f) d c", d=d
)
return hidden_states
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = attn_output + hidden_states
if self.attn2 is not None:
norm_hidden_states = (
self.norm2(hidden_states, timestep)
if self.use_ada_layer_norm
else self.norm2(hidden_states)
)
# 2. Cross-Attention
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 3. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = (
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
)
ff_output = self.ff(norm_hidden_states)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = ff_output + hidden_states
return hidden_states
if self.reference_attn:
if self.fusion_blocks == "midup":
attn_modules = [
module
for module in (
torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
)
if isinstance(module, BasicTransformerBlock)
or isinstance(module, TemporalBasicTransformerBlock)
]
elif self.fusion_blocks == "full":
attn_modules = [
module
for module in torch_dfs(self.unet)
if isinstance(module, BasicTransformerBlock)
or isinstance(module, TemporalBasicTransformerBlock)
]
attn_modules = sorted(
attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
)
for i, module in enumerate(attn_modules):
module._original_inner_forward = module.forward
if isinstance(module, BasicTransformerBlock):
module.forward = hacked_basic_transformer_inner_forward.__get__(
module, BasicTransformerBlock
)
if isinstance(module, TemporalBasicTransformerBlock):
module.forward = hacked_basic_transformer_inner_forward.__get__(
module, TemporalBasicTransformerBlock
)
module.bank = []
if(self.cache_kv):
module.kv_bank = None
module.kv_cache = None
module.attn_weight = float(i) / float(len(attn_modules))
def update(self, writer, dtype=torch.float16, drop_ratio=0.):
if self.reference_attn:
if self.fusion_blocks == "midup":
reader_attn_modules = [
module
for module in (
torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
)
if isinstance(module, TemporalBasicTransformerBlock)
]
writer_attn_modules = [
module
for module in (
torch_dfs(writer.unet.mid_block)
+ torch_dfs(writer.unet.up_blocks)
)
if isinstance(module, BasicTransformerBlock)
]
elif self.fusion_blocks == "full":
reader_attn_modules = [
module
for module in torch_dfs(self.unet)
if isinstance(module, TemporalBasicTransformerBlock)
]
writer_attn_modules = [
module
for module in torch_dfs(writer.unet)
if isinstance(module, BasicTransformerBlock)
]
reader_attn_modules = sorted(
reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
)
writer_attn_modules = sorted(
writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
)
for r, w in zip(reader_attn_modules, writer_attn_modules):
if drop_ratio > 0:
r.bank = []
for v in w.bank:
N, L, D = v.shape # batch, length, dim
len_keep = int(L * (1 - drop_ratio))
noise = torch.rand(N, L) # noise in [0, 1]
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_keep = ids_shuffle[:, :len_keep].to(v.device)
visible_tokens = torch.gather(v.clone(), dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
r.bank.append((visible_tokens).to(dtype))
else:
r.bank = [v.clone().to(dtype) for v in w.bank]
# w.bank.clear()
def update_hkf(self, writer, dtype=torch.float16, drop_ratio=0.):
if self.reference_attn:
if self.fusion_blocks == "midup":
reader_attn_modules = [
module
for module in (
torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
)
if isinstance(module, TemporalBasicTransformerBlock)
]
writer_attn_modules = [
module
for module in (
torch_dfs(writer.unet.mid_block)
+ torch_dfs(writer.unet.up_blocks)
)
if isinstance(module, BasicTransformerBlock)
]
elif self.fusion_blocks == "full":
reader_attn_modules = [
module
for module in torch_dfs(self.unet)
if isinstance(module, TemporalBasicTransformerBlock)
]
writer_attn_modules = [
module
for module in torch_dfs(writer.unet)
if isinstance(module, BasicTransformerBlock)
]
reader_attn_modules = sorted(
reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
)
writer_attn_modules = sorted(
writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
)
for r, w in zip(reader_attn_modules, writer_attn_modules):
if r.kv_bank is None:
r.kv_bank = torch.cat([v.clone().unsqueeze(1).to(dtype) for v in w.bank], dim=1)
else:
r.kv_bank = torch.cat([r.kv_bank] + [v.clone().unsqueeze(1).to(dtype) for v in w.bank], dim=1).to(dtype)
def output(self, dtype=torch.float16):
res = {}
for i in range(3):
for j in range(2):
res[f"d{i}{j}"] = torch.cat([v.clone().to(dtype=dtype, device=self.unet.device) for v in self.unet.down_blocks[i].attentions[j].transformer_blocks[0].bank], dim=1)
res["m"] = torch.cat([v.clone().to(dtype=dtype, device=self.unet.device) for v in self.unet.mid_block.attentions[0].transformer_blocks[0].bank], dim=1)
for i in range(1, 4):
for j in range(3):
res[f"u{i}{j}"] = torch.cat([v.clone().to(dtype=dtype, device=self.unet.device) for v in self.unet.up_blocks[i].attentions[j].transformer_blocks[0].bank], dim=1)
return res
def clear(self):
if self.reference_attn:
if self.fusion_blocks == "midup":
reader_attn_modules = [
module
for module in (
torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
)
if isinstance(module, BasicTransformerBlock)
or isinstance(module, TemporalBasicTransformerBlock)
]
elif self.fusion_blocks == "full":
reader_attn_modules = [
module
for module in torch_dfs(self.unet)
if isinstance(module, BasicTransformerBlock)
or isinstance(module, TemporalBasicTransformerBlock)
]
reader_attn_modules = sorted(
reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
)
for r in reader_attn_modules:
r.bank.clear()
if self.cache_kv:
r.kv_bank=None
r.kv_cache=None