Spaces:
Running
on
Zero
Running
on
Zero
From https://huggingface.co/spaces/ginigen/framepack-i2v/edit/main/diffusers_helper/models/hunyuan_video_packed.py
Browse files
diffusers_helper/models/hunyuan_video_packed.py
CHANGED
|
@@ -122,21 +122,17 @@ def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seq
|
|
| 122 |
x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2)
|
| 123 |
return x
|
| 124 |
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
v = v.flatten(0, 1)
|
| 130 |
-
|
| 131 |
if sageattn_varlen is not None:
|
| 132 |
x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
|
| 133 |
elif flash_attn_varlen_func is not None:
|
| 134 |
x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
|
| 135 |
else:
|
| 136 |
raise NotImplementedError('No Attn Installed!')
|
| 137 |
-
|
| 138 |
-
x = x.unflatten(0, (B, L))
|
| 139 |
-
|
| 140 |
return x
|
| 141 |
|
| 142 |
|
|
@@ -362,7 +358,7 @@ class HunyuanVideoIndividualTokenRefiner(nn.Module):
|
|
| 362 |
batch_size = attention_mask.shape[0]
|
| 363 |
seq_len = attention_mask.shape[1]
|
| 364 |
attention_mask = attention_mask.to(hidden_states.device).bool()
|
| 365 |
-
self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).
|
| 366 |
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
|
| 367 |
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
|
| 368 |
self_attn_mask[:, :, :, 0] = True
|
|
@@ -930,22 +926,23 @@ class HunyuanVideoTransformer3DModelPacked(ModelMixin, ConfigMixin, PeftAdapterM
|
|
| 930 |
encoder_hidden_states = torch.cat([extra_encoder_hidden_states, encoder_hidden_states], dim=1)
|
| 931 |
encoder_attention_mask = torch.cat([extra_attention_mask, encoder_attention_mask], dim=1)
|
| 932 |
|
| 933 |
-
|
| 934 |
-
|
| 935 |
-
|
| 936 |
-
|
| 937 |
-
|
| 938 |
-
|
| 939 |
-
|
| 940 |
-
|
| 941 |
-
|
|
|
|
| 942 |
|
| 943 |
-
|
| 944 |
-
|
| 945 |
-
|
| 946 |
-
|
| 947 |
|
| 948 |
-
|
| 949 |
|
| 950 |
if self.enable_teacache:
|
| 951 |
modulated_inp = self.transformer_blocks[0].norm1(hidden_states, emb=temb)[0]
|
|
|
|
| 122 |
x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2)
|
| 123 |
return x
|
| 124 |
|
| 125 |
+
batch_size = q.shape[0]
|
| 126 |
+
q = q.view(q.shape[0] * q.shape[1], *q.shape[2:])
|
| 127 |
+
k = k.view(k.shape[0] * k.shape[1], *k.shape[2:])
|
| 128 |
+
v = v.view(v.shape[0] * v.shape[1], *v.shape[2:])
|
|
|
|
|
|
|
| 129 |
if sageattn_varlen is not None:
|
| 130 |
x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
|
| 131 |
elif flash_attn_varlen_func is not None:
|
| 132 |
x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
|
| 133 |
else:
|
| 134 |
raise NotImplementedError('No Attn Installed!')
|
| 135 |
+
x = x.view(batch_size, max_seqlen_q, *x.shape[2:])
|
|
|
|
|
|
|
| 136 |
return x
|
| 137 |
|
| 138 |
|
|
|
|
| 358 |
batch_size = attention_mask.shape[0]
|
| 359 |
seq_len = attention_mask.shape[1]
|
| 360 |
attention_mask = attention_mask.to(hidden_states.device).bool()
|
| 361 |
+
self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
|
| 362 |
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
|
| 363 |
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
|
| 364 |
self_attn_mask[:, :, :, 0] = True
|
|
|
|
| 926 |
encoder_hidden_states = torch.cat([extra_encoder_hidden_states, encoder_hidden_states], dim=1)
|
| 927 |
encoder_attention_mask = torch.cat([extra_attention_mask, encoder_attention_mask], dim=1)
|
| 928 |
|
| 929 |
+
with torch.no_grad():
|
| 930 |
+
if batch_size == 1:
|
| 931 |
+
# When batch size is 1, we do not need any masks or var-len funcs since cropping is mathematically same to what we want
|
| 932 |
+
# If they are not same, then their impls are wrong. Ours are always the correct one.
|
| 933 |
+
text_len = encoder_attention_mask.sum().item()
|
| 934 |
+
encoder_hidden_states = encoder_hidden_states[:, :text_len]
|
| 935 |
+
attention_mask = None, None, None, None
|
| 936 |
+
else:
|
| 937 |
+
img_seq_len = hidden_states.shape[1]
|
| 938 |
+
txt_seq_len = encoder_hidden_states.shape[1]
|
| 939 |
|
| 940 |
+
cu_seqlens_q = get_cu_seqlens(encoder_attention_mask, img_seq_len)
|
| 941 |
+
cu_seqlens_kv = cu_seqlens_q
|
| 942 |
+
max_seqlen_q = img_seq_len + txt_seq_len
|
| 943 |
+
max_seqlen_kv = max_seqlen_q
|
| 944 |
|
| 945 |
+
attention_mask = cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
|
| 946 |
|
| 947 |
if self.enable_teacache:
|
| 948 |
modulated_inp = self.transformer_blocks[0].norm1(hidden_states, emb=temb)[0]
|