Spaces:
Runtime error
Runtime error
Update seva/modules/transformer.py
Browse files
seva/modules/transformer.py
CHANGED
|
@@ -67,8 +67,14 @@ class Attention(nn.Module):
|
|
| 67 |
lambda t: rearrange(t, "b l (h d) -> b h l d", h=self.heads),
|
| 68 |
(q, k, v),
|
| 69 |
)
|
| 70 |
-
with
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
out = F.scaled_dot_product_attention(q, k, v)
|
|
|
|
| 72 |
out = rearrange(out, "b h l d -> b l (h d)")
|
| 73 |
out = self.to_out(out)
|
| 74 |
return out
|
|
|
|
| 67 |
lambda t: rearrange(t, "b l (h d) -> b h l d", h=self.heads),
|
| 68 |
(q, k, v),
|
| 69 |
)
|
| 70 |
+
with torch.backends.cuda.sdp_kernel(
|
| 71 |
+
enable_flash=False,
|
| 72 |
+
enable_mem_efficient=True,
|
| 73 |
+
enable_math=True,
|
| 74 |
+
enable_cudnn=True,
|
| 75 |
+
):
|
| 76 |
out = F.scaled_dot_product_attention(q, k, v)
|
| 77 |
+
|
| 78 |
out = rearrange(out, "b h l d -> b l (h d)")
|
| 79 |
out = self.to_out(out)
|
| 80 |
return out
|