revert: remove sparse attention _linear inference-tensor clone hack
Browse files
trellis/modules/sparse/attention/modules.py
CHANGED
|
@@ -74,21 +74,9 @@ class SparseMultiHeadAttention(nn.Module):
|
|
| 74 |
if use_rope:
|
| 75 |
self.rope = RotaryPositionEmbedder(channels)
|
| 76 |
|
| 77 |
-
@staticmethod
|
| 78 |
-
def _sparse_tensor_autograd_safe(x: SparseTensor) -> SparseTensor:
|
| 79 |
-
"""If autograd is recording this forward, inference-mode feats break nn.Linear (cannot save for backward)."""
|
| 80 |
-
if not torch.is_grad_enabled():
|
| 81 |
-
return x
|
| 82 |
-
f = x.feats
|
| 83 |
-
is_inference = getattr(f, "is_inference", None)
|
| 84 |
-
if callable(is_inference) and is_inference():
|
| 85 |
-
return x.replace(f.clone())
|
| 86 |
-
return x
|
| 87 |
-
|
| 88 |
@staticmethod
|
| 89 |
def _linear(module: nn.Linear, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
|
| 90 |
if isinstance(x, SparseTensor):
|
| 91 |
-
x = SparseMultiHeadAttention._sparse_tensor_autograd_safe(x)
|
| 92 |
return x.replace(module(x.feats))
|
| 93 |
else:
|
| 94 |
return module(x)
|
|
|
|
| 74 |
if use_rope:
|
| 75 |
self.rope = RotaryPositionEmbedder(channels)
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
@staticmethod
|
| 78 |
def _linear(module: nn.Linear, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
|
| 79 |
if isinstance(x, SparseTensor):
|
|
|
|
| 80 |
return x.replace(module(x.feats))
|
| 81 |
else:
|
| 82 |
return module(x)
|