Bobby commited on
Commit
9ddbbcc
·
1 Parent(s): 8a888b5

attn bf16

Browse files
trellis/modules/attention/full_attn.py CHANGED
@@ -122,12 +122,16 @@ def scaled_dot_product_attention(*args, **kwargs):
122
  k, v = kv.unbind(dim=2)
123
  out = xops.memory_efficient_attention(q, k, v)
124
  elif BACKEND == 'flash_attn':
 
125
  if num_all_args == 1:
126
- out = flash_attn.flash_attn_qkvpacked_func(qkv)
 
127
  elif num_all_args == 2:
128
- out = flash_attn.flash_attn_kvpacked_func(q, kv)
 
129
  elif num_all_args == 3:
130
- out = flash_attn.flash_attn_func(q, k, v)
 
131
  elif BACKEND == 'sdpa':
132
  if num_all_args == 1:
133
  q, k, v = qkv.unbind(dim=2)
 
122
  k, v = kv.unbind(dim=2)
123
  out = xops.memory_efficient_attention(q, k, v)
124
  elif BACKEND == 'flash_attn':
125
+ _fa_dtype = torch.bfloat16
126
  if num_all_args == 1:
127
+ _orig_dtype = qkv.dtype
128
+ out = flash_attn.flash_attn_qkvpacked_func(qkv.to(_fa_dtype)).to(_orig_dtype)
129
  elif num_all_args == 2:
130
+ _orig_dtype = q.dtype
131
+ out = flash_attn.flash_attn_kvpacked_func(q.to(_fa_dtype), kv.to(_fa_dtype)).to(_orig_dtype)
132
  elif num_all_args == 3:
133
+ _orig_dtype = q.dtype
134
+ out = flash_attn.flash_attn_func(q.to(_fa_dtype), k.to(_fa_dtype), v.to(_fa_dtype)).to(_orig_dtype)
135
  elif BACKEND == 'sdpa':
136
  if num_all_args == 1:
137
  q, k, v = qkv.unbind(dim=2)
trellis/modules/sparse/attention/full_attn.py CHANGED
@@ -219,12 +219,16 @@ def sparse_scaled_dot_product_attention(*args, **kwargs):
219
  cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
220
  if num_all_args in [2, 3]:
221
  cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
 
222
  if num_all_args == 1:
223
- out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen))
 
224
  elif num_all_args == 2:
225
- out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
 
226
  elif num_all_args == 3:
227
- out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
 
228
  elif ATTN in {'sdpa', 'naive'}:
229
  outs = []
230
  q_start, kv_start = 0, 0
 
219
  cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
220
  if num_all_args in [2, 3]:
221
  cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
222
+ _fa_dtype = torch.bfloat16
223
  if num_all_args == 1:
224
+ _orig_dtype = qkv.dtype
225
+ out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv.to(_fa_dtype), cu_seqlens_q, max(q_seqlen)).to(_orig_dtype)
226
  elif num_all_args == 2:
227
+ _orig_dtype = q.dtype
228
+ out = flash_attn.flash_attn_varlen_kvpacked_func(q.to(_fa_dtype), kv.to(_fa_dtype), cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)).to(_orig_dtype)
229
  elif num_all_args == 3:
230
+ _orig_dtype = q.dtype
231
+ out = flash_attn.flash_attn_varlen_func(q.to(_fa_dtype), k.to(_fa_dtype), v.to(_fa_dtype), cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)).to(_orig_dtype)
232
  elif ATTN in {'sdpa', 'naive'}:
233
  outs = []
234
  q_start, kv_start = 0, 0
trellis/modules/sparse/attention/serialized_attn.py CHANGED
@@ -193,7 +193,8 @@ def sparse_serialized_scaled_dot_product_self_attention(
193
  q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
194
  out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
195
  elif ATTN == 'flash_attn':
196
- out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
 
197
  elif ATTN in {'sdpa', 'naive'}:
198
  out = _sdpa_varlen_qkv(qkv_feats.reshape(B * N, 3, H, C), [N] * B)
199
  else:
@@ -210,7 +211,8 @@ def sparse_serialized_scaled_dot_product_self_attention(
210
  elif ATTN == 'flash_attn':
211
  cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
212
  .to(qkv.device).int()
213
- out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
 
214
  elif ATTN in {'sdpa', 'naive'}:
215
  out = _sdpa_varlen_qkv(qkv_feats, seq_lens)
216
  else:
 
193
  q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
194
  out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
195
  elif ATTN == 'flash_attn':
196
+ _orig_dtype = qkv_feats.dtype
197
+ out = flash_attn.flash_attn_qkvpacked_func(qkv_feats.to(torch.bfloat16)).to(_orig_dtype) # [B, N, H, C]
198
  elif ATTN in {'sdpa', 'naive'}:
199
  out = _sdpa_varlen_qkv(qkv_feats.reshape(B * N, 3, H, C), [N] * B)
200
  else:
 
211
  elif ATTN == 'flash_attn':
212
  cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
213
  .to(qkv.device).int()
214
+ _orig_dtype = qkv_feats.dtype
215
+ out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats.to(torch.bfloat16), cu_seqlens, max(seq_lens)).to(_orig_dtype) # [M, H, C]
216
  elif ATTN in {'sdpa', 'naive'}:
217
  out = _sdpa_varlen_qkv(qkv_feats, seq_lens)
218
  else:
trellis/modules/sparse/attention/windowed_attn.py CHANGED
@@ -135,7 +135,8 @@ def sparse_windowed_scaled_dot_product_self_attention(
135
  q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
136
  out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
137
  elif ATTN == 'flash_attn':
138
- out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
 
139
  elif ATTN in {'sdpa', 'naive'}:
140
  out = _sdpa_varlen_qkv(qkv_feats.reshape(B * N, 3, H, C), [N] * B)
141
  else:
@@ -152,7 +153,8 @@ def sparse_windowed_scaled_dot_product_self_attention(
152
  elif ATTN == 'flash_attn':
153
  cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
154
  .to(qkv.device).int()
155
- out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
 
156
  elif ATTN in {'sdpa', 'naive'}:
157
  out = _sdpa_varlen_qkv(qkv_feats, seq_lens)
158
  else:
 
135
  q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
136
  out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
137
  elif ATTN == 'flash_attn':
138
+ _orig_dtype = qkv_feats.dtype
139
+ out = flash_attn.flash_attn_qkvpacked_func(qkv_feats.to(torch.bfloat16)).to(_orig_dtype) # [B, N, H, C]
140
  elif ATTN in {'sdpa', 'naive'}:
141
  out = _sdpa_varlen_qkv(qkv_feats.reshape(B * N, 3, H, C), [N] * B)
142
  else:
 
153
  elif ATTN == 'flash_attn':
154
  cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
155
  .to(qkv.device).int()
156
+ _orig_dtype = qkv_feats.dtype
157
+ out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats.to(torch.bfloat16), cu_seqlens, max(seq_lens)).to(_orig_dtype) # [M, H, C]
158
  elif ATTN in {'sdpa', 'naive'}:
159
  out = _sdpa_varlen_qkv(qkv_feats, seq_lens)
160
  else: