Spaces:
Running on Zero
Running on Zero
Bobby commited on
Commit ·
9ddbbcc
1
Parent(s): 8a888b5
attn bf16
Browse files
trellis/modules/attention/full_attn.py
CHANGED
|
@@ -122,12 +122,16 @@ def scaled_dot_product_attention(*args, **kwargs):
|
|
| 122 |
k, v = kv.unbind(dim=2)
|
| 123 |
out = xops.memory_efficient_attention(q, k, v)
|
| 124 |
elif BACKEND == 'flash_attn':
|
|
|
|
| 125 |
if num_all_args == 1:
|
| 126 |
-
|
|
|
|
| 127 |
elif num_all_args == 2:
|
| 128 |
-
|
|
|
|
| 129 |
elif num_all_args == 3:
|
| 130 |
-
|
|
|
|
| 131 |
elif BACKEND == 'sdpa':
|
| 132 |
if num_all_args == 1:
|
| 133 |
q, k, v = qkv.unbind(dim=2)
|
|
|
|
| 122 |
k, v = kv.unbind(dim=2)
|
| 123 |
out = xops.memory_efficient_attention(q, k, v)
|
| 124 |
elif BACKEND == 'flash_attn':
|
| 125 |
+
_fa_dtype = torch.bfloat16
|
| 126 |
if num_all_args == 1:
|
| 127 |
+
_orig_dtype = qkv.dtype
|
| 128 |
+
out = flash_attn.flash_attn_qkvpacked_func(qkv.to(_fa_dtype)).to(_orig_dtype)
|
| 129 |
elif num_all_args == 2:
|
| 130 |
+
_orig_dtype = q.dtype
|
| 131 |
+
out = flash_attn.flash_attn_kvpacked_func(q.to(_fa_dtype), kv.to(_fa_dtype)).to(_orig_dtype)
|
| 132 |
elif num_all_args == 3:
|
| 133 |
+
_orig_dtype = q.dtype
|
| 134 |
+
out = flash_attn.flash_attn_func(q.to(_fa_dtype), k.to(_fa_dtype), v.to(_fa_dtype)).to(_orig_dtype)
|
| 135 |
elif BACKEND == 'sdpa':
|
| 136 |
if num_all_args == 1:
|
| 137 |
q, k, v = qkv.unbind(dim=2)
|
trellis/modules/sparse/attention/full_attn.py
CHANGED
|
@@ -219,12 +219,16 @@ def sparse_scaled_dot_product_attention(*args, **kwargs):
|
|
| 219 |
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
|
| 220 |
if num_all_args in [2, 3]:
|
| 221 |
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
|
|
|
|
| 222 |
if num_all_args == 1:
|
| 223 |
-
|
|
|
|
| 224 |
elif num_all_args == 2:
|
| 225 |
-
|
|
|
|
| 226 |
elif num_all_args == 3:
|
| 227 |
-
|
|
|
|
| 228 |
elif ATTN in {'sdpa', 'naive'}:
|
| 229 |
outs = []
|
| 230 |
q_start, kv_start = 0, 0
|
|
|
|
| 219 |
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
|
| 220 |
if num_all_args in [2, 3]:
|
| 221 |
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
|
| 222 |
+
_fa_dtype = torch.bfloat16
|
| 223 |
if num_all_args == 1:
|
| 224 |
+
_orig_dtype = qkv.dtype
|
| 225 |
+
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv.to(_fa_dtype), cu_seqlens_q, max(q_seqlen)).to(_orig_dtype)
|
| 226 |
elif num_all_args == 2:
|
| 227 |
+
_orig_dtype = q.dtype
|
| 228 |
+
out = flash_attn.flash_attn_varlen_kvpacked_func(q.to(_fa_dtype), kv.to(_fa_dtype), cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)).to(_orig_dtype)
|
| 229 |
elif num_all_args == 3:
|
| 230 |
+
_orig_dtype = q.dtype
|
| 231 |
+
out = flash_attn.flash_attn_varlen_func(q.to(_fa_dtype), k.to(_fa_dtype), v.to(_fa_dtype), cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)).to(_orig_dtype)
|
| 232 |
elif ATTN in {'sdpa', 'naive'}:
|
| 233 |
outs = []
|
| 234 |
q_start, kv_start = 0, 0
|
trellis/modules/sparse/attention/serialized_attn.py
CHANGED
|
@@ -193,7 +193,8 @@ def sparse_serialized_scaled_dot_product_self_attention(
|
|
| 193 |
q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
|
| 194 |
out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
|
| 195 |
elif ATTN == 'flash_attn':
|
| 196 |
-
|
|
|
|
| 197 |
elif ATTN in {'sdpa', 'naive'}:
|
| 198 |
out = _sdpa_varlen_qkv(qkv_feats.reshape(B * N, 3, H, C), [N] * B)
|
| 199 |
else:
|
|
@@ -210,7 +211,8 @@ def sparse_serialized_scaled_dot_product_self_attention(
|
|
| 210 |
elif ATTN == 'flash_attn':
|
| 211 |
cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
|
| 212 |
.to(qkv.device).int()
|
| 213 |
-
|
|
|
|
| 214 |
elif ATTN in {'sdpa', 'naive'}:
|
| 215 |
out = _sdpa_varlen_qkv(qkv_feats, seq_lens)
|
| 216 |
else:
|
|
|
|
| 193 |
q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
|
| 194 |
out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
|
| 195 |
elif ATTN == 'flash_attn':
|
| 196 |
+
_orig_dtype = qkv_feats.dtype
|
| 197 |
+
out = flash_attn.flash_attn_qkvpacked_func(qkv_feats.to(torch.bfloat16)).to(_orig_dtype) # [B, N, H, C]
|
| 198 |
elif ATTN in {'sdpa', 'naive'}:
|
| 199 |
out = _sdpa_varlen_qkv(qkv_feats.reshape(B * N, 3, H, C), [N] * B)
|
| 200 |
else:
|
|
|
|
| 211 |
elif ATTN == 'flash_attn':
|
| 212 |
cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
|
| 213 |
.to(qkv.device).int()
|
| 214 |
+
_orig_dtype = qkv_feats.dtype
|
| 215 |
+
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats.to(torch.bfloat16), cu_seqlens, max(seq_lens)).to(_orig_dtype) # [M, H, C]
|
| 216 |
elif ATTN in {'sdpa', 'naive'}:
|
| 217 |
out = _sdpa_varlen_qkv(qkv_feats, seq_lens)
|
| 218 |
else:
|
trellis/modules/sparse/attention/windowed_attn.py
CHANGED
|
@@ -135,7 +135,8 @@ def sparse_windowed_scaled_dot_product_self_attention(
|
|
| 135 |
q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
|
| 136 |
out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
|
| 137 |
elif ATTN == 'flash_attn':
|
| 138 |
-
|
|
|
|
| 139 |
elif ATTN in {'sdpa', 'naive'}:
|
| 140 |
out = _sdpa_varlen_qkv(qkv_feats.reshape(B * N, 3, H, C), [N] * B)
|
| 141 |
else:
|
|
@@ -152,7 +153,8 @@ def sparse_windowed_scaled_dot_product_self_attention(
|
|
| 152 |
elif ATTN == 'flash_attn':
|
| 153 |
cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
|
| 154 |
.to(qkv.device).int()
|
| 155 |
-
|
|
|
|
| 156 |
elif ATTN in {'sdpa', 'naive'}:
|
| 157 |
out = _sdpa_varlen_qkv(qkv_feats, seq_lens)
|
| 158 |
else:
|
|
|
|
| 135 |
q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
|
| 136 |
out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
|
| 137 |
elif ATTN == 'flash_attn':
|
| 138 |
+
_orig_dtype = qkv_feats.dtype
|
| 139 |
+
out = flash_attn.flash_attn_qkvpacked_func(qkv_feats.to(torch.bfloat16)).to(_orig_dtype) # [B, N, H, C]
|
| 140 |
elif ATTN in {'sdpa', 'naive'}:
|
| 141 |
out = _sdpa_varlen_qkv(qkv_feats.reshape(B * N, 3, H, C), [N] * B)
|
| 142 |
else:
|
|
|
|
| 153 |
elif ATTN == 'flash_attn':
|
| 154 |
cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
|
| 155 |
.to(qkv.device).int()
|
| 156 |
+
_orig_dtype = qkv_feats.dtype
|
| 157 |
+
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats.to(torch.bfloat16), cu_seqlens, max(seq_lens)).to(_orig_dtype) # [M, H, C]
|
| 158 |
elif ATTN in {'sdpa', 'naive'}:
|
| 159 |
out = _sdpa_varlen_qkv(qkv_feats, seq_lens)
|
| 160 |
else:
|