telegram191 commited on
Commit
fd3ad60
·
verified ·
1 Parent(s): 353793a

Make xformers optional to reduce Space build failures

Browse files
Files changed (1) hide show
  1. audiocraft/modules/transformer.py +17 -5
audiocraft/modules/transformer.py CHANGED
@@ -20,7 +20,10 @@ import torch
20
  import torch.nn as nn
21
  from torch.nn import functional as F
22
  from torch.utils.checkpoint import checkpoint as torch_checkpoint
23
- from xformers import ops
 
 
 
24
 
25
  from .rope import RotaryEmbedding
26
  from .streaming import StreamingModule
@@ -31,7 +34,9 @@ _efficient_attention_backend: str = 'torch'
31
  def set_efficient_attention_backend(backend: str = 'torch'):
32
  # Using torch by default, it seems a bit faster on older P100 GPUs (~20% faster).
33
  global _efficient_attention_backend
34
- assert _efficient_attention_backend in ['xformers', 'torch']
 
 
35
  _efficient_attention_backend = backend
36
 
37
 
@@ -236,7 +241,7 @@ class StreamingMultiheadAttention(StreamingModule):
236
  # We actually return a bias for the attention score, as this has the same
237
  # convention both in the builtin MHA in Pytorch, and Xformers functions.
238
  time_dim = _get_attention_time_dimension(self.memory_efficient)
239
- if self.memory_efficient:
240
  from xformers.ops import LowerTriangularMask
241
  if current_steps == 1:
242
  # If we only have one step, then we do not need a mask.
@@ -373,7 +378,10 @@ class StreamingMultiheadAttention(StreamingModule):
373
  else:
374
  bound_layout = "b t p h d"
375
  packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
376
- q, k, v = ops.unbind(packed, dim=2)
 
 
 
377
  else:
378
  embed_dim = self.embed_dim
379
  per_head_dim = (embed_dim // self.num_heads)
@@ -425,7 +433,11 @@ class StreamingMultiheadAttention(StreamingModule):
425
  x = torch.nn.functional.scaled_dot_product_attention(
426
  q, k, v, is_causal=self.causal, attn_mask=attn_mask, dropout_p=p)
427
  else:
428
- x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p)
 
 
 
 
429
  else:
430
  # We include the dot product as float32, for consistency
431
  # with the other implementations that include that step
 
20
  import torch.nn as nn
21
  from torch.nn import functional as F
22
  from torch.utils.checkpoint import checkpoint as torch_checkpoint
23
+ try:
24
+ from xformers import ops
25
+ except Exception:
26
+ ops = None
27
 
28
  from .rope import RotaryEmbedding
29
  from .streaming import StreamingModule
 
34
  def set_efficient_attention_backend(backend: str = 'torch'):
35
  # Using torch by default, it seems a bit faster on older P100 GPUs (~20% faster).
36
  global _efficient_attention_backend
37
+ assert backend in ['xformers', 'torch']
38
+ if backend == 'xformers' and ops is None:
39
+ backend = 'torch'
40
  _efficient_attention_backend = backend
41
 
42
 
 
241
  # We actually return a bias for the attention score, as this has the same
242
  # convention both in the builtin MHA in Pytorch, and Xformers functions.
243
  time_dim = _get_attention_time_dimension(self.memory_efficient)
244
+ if self.memory_efficient and _efficient_attention_backend == 'xformers' and ops is not None:
245
  from xformers.ops import LowerTriangularMask
246
  if current_steps == 1:
247
  # If we only have one step, then we do not need a mask.
 
378
  else:
379
  bound_layout = "b t p h d"
380
  packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
381
+ if ops is None:
382
+ q, k, v = torch.unbind(packed, dim=2)
383
+ else:
384
+ q, k, v = ops.unbind(packed, dim=2)
385
  else:
386
  embed_dim = self.embed_dim
387
  per_head_dim = (embed_dim // self.num_heads)
 
433
  x = torch.nn.functional.scaled_dot_product_attention(
434
  q, k, v, is_causal=self.causal, attn_mask=attn_mask, dropout_p=p)
435
  else:
436
+ if ops is None:
437
+ x = torch.nn.functional.scaled_dot_product_attention(
438
+ q, k, v, is_causal=self.causal, attn_mask=attn_mask, dropout_p=p)
439
+ else:
440
+ x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p)
441
  else:
442
  # We include the dot product as float32, for consistency
443
  # with the other implementations that include that step