| from typing import Sequence |
|
|
| import flax.linen as nn |
|
|
|
|
| class MLP(nn.Module): |
| features: Sequence[int] |
|
|
| @nn.compact |
| def __call__(self, x): |
| for feat in self.features[:-1]: |
| x = nn.relu(nn.Dense(feat)(x)) |
| x = nn.Dense(self.features[-1])(x) |
| return x |
|
|
|
|
| def assertEqual(actual, expected, msg, first="Got", second="Expected"): |
| if actual != expected: |
| raise ValueError(msg + f' {first}: "{actual}" {second}: "{expected}"') |
|
|
|
|
| def assertIn(actual, expected, msg, first="Got", second="Expected one of"): |
| if actual not in expected: |
| raise ValueError(msg + f' {first}: "{actual}" {second}: {expected}') |
|
|