chenmy0814 commited on
Commit
8a32ada
·
verified ·
1 Parent(s): 84a01e7

Update seva/modules/transformer.py

Browse files
Files changed (1) hide show
  1. seva/modules/transformer.py +7 -1
seva/modules/transformer.py CHANGED
@@ -67,8 +67,14 @@ class Attention(nn.Module):
67
  lambda t: rearrange(t, "b l (h d) -> b h l d", h=self.heads),
68
  (q, k, v),
69
  )
70
- with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
 
 
 
 
 
71
  out = F.scaled_dot_product_attention(q, k, v)
 
72
  out = rearrange(out, "b h l d -> b l (h d)")
73
  out = self.to_out(out)
74
  return out
 
67
  lambda t: rearrange(t, "b l (h d) -> b h l d", h=self.heads),
68
  (q, k, v),
69
  )
70
+ with torch.backends.cuda.sdp_kernel(
71
+ enable_flash=False,
72
+ enable_mem_efficient=True,
73
+ enable_math=True,
74
+ enable_cudnn=True,
75
+ ):
76
  out = F.scaled_dot_product_attention(q, k, v)
77
+
78
  out = rearrange(out, "b h l d -> b l (h d)")
79
  out = self.to_out(out)
80
  return out