Spaces:
Runtime error
Runtime error
yujia commited on
Commit ·
41be0d1
1
Parent(s): d4c7a24
fix w/o flash attn
Browse files- utonia/model.py +9 -12
utonia/model.py
CHANGED
|
@@ -292,15 +292,16 @@ class SerializedAttention(PointModule):
|
|
| 292 |
qkv = self.qkv(point.feat)[order]
|
| 293 |
|
| 294 |
rope_coord = point.coord[order].clone()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
if not self.enable_flash:
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
qkv_roped = torch.stack([q, k, v], dim=1)
|
| 301 |
-
q, k, v = (
|
| 302 |
-
qkv.reshape(-1, K, 3, H, C // H).permute(2, 0, 3, 1, 4).unbind(dim=0)
|
| 303 |
-
)
|
| 304 |
# attn
|
| 305 |
if self.upcast_attention:
|
| 306 |
q = q.float()
|
|
@@ -314,9 +315,6 @@ class SerializedAttention(PointModule):
|
|
| 314 |
attn = self.attn_drop(attn).to(qkv.dtype)
|
| 315 |
feat = (attn @ v).transpose(1, 2).reshape(-1, C)
|
| 316 |
else:
|
| 317 |
-
qkv = qkv.reshape(-1, 3, H, C // H)
|
| 318 |
-
q, k, v = qkv.unbind(dim=1)
|
| 319 |
-
q, k = self.rope(q, k, rope_coord)
|
| 320 |
qkv_roped = torch.stack([q, k, v], dim=1)
|
| 321 |
feat = flash_attn.flash_attn_varlen_qkvpacked_func(
|
| 322 |
qkv_roped.to(torch.bfloat16),
|
|
@@ -334,7 +332,6 @@ class SerializedAttention(PointModule):
|
|
| 334 |
point.feat = feat
|
| 335 |
return point
|
| 336 |
|
| 337 |
-
|
| 338 |
class MLP(nn.Module):
|
| 339 |
def __init__(
|
| 340 |
self,
|
|
|
|
| 292 |
qkv = self.qkv(point.feat)[order]
|
| 293 |
|
| 294 |
rope_coord = point.coord[order].clone()
|
| 295 |
+
|
| 296 |
+
qkv = qkv.reshape(-1, 3, H, C // H)
|
| 297 |
+
q, k, v = qkv.unbind(dim=1)
|
| 298 |
+
q, k = self.rope(q, k, rope_coord)
|
| 299 |
+
|
| 300 |
if not self.enable_flash:
|
| 301 |
+
q = q.reshape(-1, K, H, C // H).permute(0, 2, 1, 3)
|
| 302 |
+
k = k.reshape(-1, K, H, C // H).permute(0, 2, 1, 3)
|
| 303 |
+
v = v.reshape(-1, K, H, C // H).permute(0, 2, 1, 3)
|
| 304 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
# attn
|
| 306 |
if self.upcast_attention:
|
| 307 |
q = q.float()
|
|
|
|
| 315 |
attn = self.attn_drop(attn).to(qkv.dtype)
|
| 316 |
feat = (attn @ v).transpose(1, 2).reshape(-1, C)
|
| 317 |
else:
|
|
|
|
|
|
|
|
|
|
| 318 |
qkv_roped = torch.stack([q, k, v], dim=1)
|
| 319 |
feat = flash_attn.flash_attn_varlen_qkvpacked_func(
|
| 320 |
qkv_roped.to(torch.bfloat16),
|
|
|
|
| 332 |
point.feat = feat
|
| 333 |
return point
|
| 334 |
|
|
|
|
| 335 |
class MLP(nn.Module):
|
| 336 |
def __init__(
|
| 337 |
self,
|