yujia commited on
Commit
41be0d1
·
1 Parent(s): d4c7a24

fix w/o flash attn

Browse files
Files changed (1) hide show
  1. utonia/model.py +9 -12
utonia/model.py CHANGED
@@ -292,15 +292,16 @@ class SerializedAttention(PointModule):
292
  qkv = self.qkv(point.feat)[order]
293
 
294
  rope_coord = point.coord[order].clone()
 
 
 
 
 
295
  if not self.enable_flash:
296
- # encode and reshape qkv: (N', K, 3, H, C') => (3, N', H, K, C')
297
- qkv = qkv.reshape(-1, 3, H, C // H)
298
- q, k, v = qkv.unbind(dim=1)
299
- q, k = self.rope(q, k, rope_coord)
300
- qkv_roped = torch.stack([q, k, v], dim=1)
301
- q, k, v = (
302
- qkv.reshape(-1, K, 3, H, C // H).permute(2, 0, 3, 1, 4).unbind(dim=0)
303
- )
304
  # attn
305
  if self.upcast_attention:
306
  q = q.float()
@@ -314,9 +315,6 @@ class SerializedAttention(PointModule):
314
  attn = self.attn_drop(attn).to(qkv.dtype)
315
  feat = (attn @ v).transpose(1, 2).reshape(-1, C)
316
  else:
317
- qkv = qkv.reshape(-1, 3, H, C // H)
318
- q, k, v = qkv.unbind(dim=1)
319
- q, k = self.rope(q, k, rope_coord)
320
  qkv_roped = torch.stack([q, k, v], dim=1)
321
  feat = flash_attn.flash_attn_varlen_qkvpacked_func(
322
  qkv_roped.to(torch.bfloat16),
@@ -334,7 +332,6 @@ class SerializedAttention(PointModule):
334
  point.feat = feat
335
  return point
336
 
337
-
338
  class MLP(nn.Module):
339
  def __init__(
340
  self,
 
292
  qkv = self.qkv(point.feat)[order]
293
 
294
  rope_coord = point.coord[order].clone()
295
+
296
+ qkv = qkv.reshape(-1, 3, H, C // H)
297
+ q, k, v = qkv.unbind(dim=1)
298
+ q, k = self.rope(q, k, rope_coord)
299
+
300
  if not self.enable_flash:
301
+ q = q.reshape(-1, K, H, C // H).permute(0, 2, 1, 3)
302
+ k = k.reshape(-1, K, H, C // H).permute(0, 2, 1, 3)
303
+ v = v.reshape(-1, K, H, C // H).permute(0, 2, 1, 3)
304
+
 
 
 
 
305
  # attn
306
  if self.upcast_attention:
307
  q = q.float()
 
315
  attn = self.attn_drop(attn).to(qkv.dtype)
316
  feat = (attn @ v).transpose(1, 2).reshape(-1, C)
317
  else:
 
 
 
318
  qkv_roped = torch.stack([q, k, v], dim=1)
319
  feat = flash_attn.flash_attn_varlen_qkvpacked_func(
320
  qkv_roped.to(torch.bfloat16),
 
332
  point.feat = feat
333
  return point
334
 
 
335
  class MLP(nn.Module):
336
  def __init__(
337
  self,