Update
Browse filesFix for numerical stability
- spear_modules.py +4 -2
spear_modules.py
CHANGED
|
@@ -1547,7 +1547,8 @@ class SwooshROnnx(torch.nn.Module):
|
|
| 1547 |
# ActivationDropoutAndLinearFunction.
|
| 1548 |
def SwooshLForward(x: Tensor):
|
| 1549 |
x_offset = x - 4.0
|
| 1550 |
-
log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
|
|
|
|
| 1551 |
log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
|
| 1552 |
return log_sum - 0.08 * x - 0.035
|
| 1553 |
|
|
@@ -1556,7 +1557,8 @@ def SwooshLForward(x: Tensor):
|
|
| 1556 |
# ActivationDropoutAndLinearFunction.
|
| 1557 |
def SwooshRForward(x: Tensor):
|
| 1558 |
x_offset = x - 1.0
|
| 1559 |
-
log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
|
|
|
|
| 1560 |
log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
|
| 1561 |
return log_sum - 0.08 * x - 0.313261687
|
| 1562 |
|
|
|
|
| 1547 |
# ActivationDropoutAndLinearFunction.
|
| 1548 |
def SwooshLForward(x: Tensor):
|
| 1549 |
x_offset = x - 4.0
|
| 1550 |
+
# log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
|
| 1551 |
+
log_sum = torch.nn.functional.softplus(x_offset)
|
| 1552 |
log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
|
| 1553 |
return log_sum - 0.08 * x - 0.035
|
| 1554 |
|
|
|
|
| 1557 |
# ActivationDropoutAndLinearFunction.
|
| 1558 |
def SwooshRForward(x: Tensor):
|
| 1559 |
x_offset = x - 1.0
|
| 1560 |
+
# log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
|
| 1561 |
+
log_sum = torch.nn.functional.softplus(x_offset)
|
| 1562 |
log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
|
| 1563 |
return log_sum - 0.08 * x - 0.313261687
|
| 1564 |
|