marcoyang commited on
Commit
372db9a
·
verified ·
1 Parent(s): 2bbfdf9

Fix for numerical stability

Files changed (1) hide show
  1. spear_modules.py +5 -3
spear_modules.py CHANGED
@@ -1547,7 +1547,8 @@ class SwooshROnnx(torch.nn.Module):
1547
  # ActivationDropoutAndLinearFunction.
1548
  def SwooshLForward(x: Tensor):
1549
  x_offset = x - 4.0
1550
- log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
 
1551
  log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
1552
  return log_sum - 0.08 * x - 0.035
1553
 
@@ -1556,7 +1557,8 @@ def SwooshLForward(x: Tensor):
1556
  # ActivationDropoutAndLinearFunction.
1557
  def SwooshRForward(x: Tensor):
1558
  x_offset = x - 1.0
1559
- log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
 
1560
  log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
1561
  return log_sum - 0.08 * x - 0.313261687
1562
 
@@ -1962,4 +1964,4 @@ if __name__ == "__main__":
1962
  _test_double_swish_deriv()
1963
  _test_swooshr_deriv()
1964
  _test_swooshl_deriv()
1965
- _test_activation_dropout_and_linear()
 
1547
  # ActivationDropoutAndLinearFunction.
1548
  def SwooshLForward(x: Tensor):
1549
  x_offset = x - 4.0
1550
+ # log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
1551
+ log_sum = torch.nn.functional.softplus(x_offset)
1552
  log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
1553
  return log_sum - 0.08 * x - 0.035
1554
 
 
1557
  # ActivationDropoutAndLinearFunction.
1558
  def SwooshRForward(x: Tensor):
1559
  x_offset = x - 1.0
1560
+ # log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
1561
+ log_sum = torch.nn.functional.softplus(x_offset)
1562
  log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
1563
  return log_sum - 0.08 * x - 0.313261687
1564
 
 
1964
  _test_double_swish_deriv()
1965
  _test_swooshr_deriv()
1966
  _test_swooshl_deriv()
1967
+ _test_activation_dropout_and_linear()