| |
|
|
| import matplotlib.pyplot as plt |
| import torch |
| from scaling import PiecewiseLinear, ScheduledFloat, SwooshL, SwooshR |
|
|
|
|
| def test_piecewise_linear(): |
| |
| |
| |
| |
| |
| pl = PiecewiseLinear((0, 0), (1, 1), (2, 0)) |
| assert pl(0.25) == 0.25, pl(0.25) |
| assert pl(0.625) == 0.625, pl(0.625) |
| assert pl(1.25) == 0.75, pl(1.25) |
|
|
| assert pl(-10) == pl(0), pl(-10) |
| assert pl(10) == pl(2), pl(10) |
|
|
| |
| pl10 = pl * 10 |
| assert pl10(1) == 10 * pl(1) |
| assert pl10(0.5) == 10 * pl(0.5) |
|
|
|
|
| def test_scheduled_float(): |
| |
| dropout = ScheduledFloat((0, 0.2), (4000, 0.0), default=0.0) |
| dropout.batch_count = 0 |
| assert float(dropout) == 0.2, (float(dropout), dropout.batch_count) |
|
|
| dropout.batch_count = 1000 |
| assert abs(float(dropout) - 0.15) < 1e-5, (float(dropout), dropout.batch_count) |
|
|
| dropout.batch_count = 2000 |
| assert float(dropout) == 0.1, (float(dropout), dropout.batch_count) |
|
|
| dropout.batch_count = 3000 |
| assert abs(float(dropout) - 0.05) < 1e-5, (float(dropout), dropout.batch_count) |
|
|
| dropout.batch_count = 4000 |
| assert float(dropout) == 0.0, (float(dropout), dropout.batch_count) |
|
|
| dropout.batch_count = 5000 |
| assert float(dropout) == 0.0, (float(dropout), dropout.batch_count) |
|
|
|
|
| def test_swoosh(): |
| x1 = torch.linspace(start=-10, end=0, steps=100, dtype=torch.float32) |
| x2 = torch.linspace(start=0, end=10, steps=100, dtype=torch.float32) |
| x = torch.cat([x1, x2[1:]]) |
|
|
| left = SwooshL()(x) |
| r = SwooshR()(x) |
|
|
| relu = torch.nn.functional.relu(x) |
| print(left[x == 0], r[x == 0]) |
| plt.plot(x, left, "k") |
| plt.plot(x, r, "r") |
| plt.plot(x, relu, "b") |
| plt.axis([-10, 10, -1, 10]) |
| plt.legend( |
| [ |
| "SwooshL(x) = log(1 + exp(x-4)) - 0.08x - 0.035 ", |
| "SwooshR(x) = log(1 + exp(x-1)) - 0.08x - 0.313261687", |
| "ReLU(x) = max(0, x)", |
| ] |
| ) |
| plt.grid() |
| plt.savefig("swoosh.pdf") |
|
|
|
|
| def main(): |
| test_piecewise_linear() |
| test_scheduled_float() |
| test_swoosh() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|