luh1124 commited on
Commit
3f38e8f
·
1 Parent(s): c7b33ad

revert: remove sparse attention _linear inference-tensor clone hack

Browse files
trellis/modules/sparse/attention/modules.py CHANGED
@@ -74,21 +74,9 @@ class SparseMultiHeadAttention(nn.Module):
74
  if use_rope:
75
  self.rope = RotaryPositionEmbedder(channels)
76
 
77
- @staticmethod
78
- def _sparse_tensor_autograd_safe(x: SparseTensor) -> SparseTensor:
79
- """If autograd is recording this forward, inference-mode feats break nn.Linear (cannot save for backward)."""
80
- if not torch.is_grad_enabled():
81
- return x
82
- f = x.feats
83
- is_inference = getattr(f, "is_inference", None)
84
- if callable(is_inference) and is_inference():
85
- return x.replace(f.clone())
86
- return x
87
-
88
  @staticmethod
89
  def _linear(module: nn.Linear, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
90
  if isinstance(x, SparseTensor):
91
- x = SparseMultiHeadAttention._sparse_tensor_autograd_safe(x)
92
  return x.replace(module(x.feats))
93
  else:
94
  return module(x)
 
74
  if use_rope:
75
  self.rope = RotaryPositionEmbedder(channels)
76
 
 
 
 
 
 
 
 
 
 
 
 
77
  @staticmethod
78
  def _linear(module: nn.Linear, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
79
  if isinstance(x, SparseTensor):
 
80
  return x.replace(module(x.feats))
81
  else:
82
  return module(x)