AndreiB137 commited on
Commit
9874994
·
verified ·
1 Parent(s): f02bdcc

Delete mlp_pinn.py

Browse files
Files changed (1) hide show
  1. mlp_pinn.py +0 -44
mlp_pinn.py DELETED
@@ -1,44 +0,0 @@
1
- # Flax
2
- import jax.numpy as jnp
3
- from flax import linen as nn
4
- from typing import Callable, Union, Dict
5
- from .utils import Dense, FourierEmbs
6
-
7
- # Modified MLP version based on the state-of-the-art practicies in PINN training:
8
- # Fourier embeddings and random weight factorization
9
- # You can read more about it in the paper: https://arxiv.org/pdf/2210.01274
10
-
11
-
12
- class MLP_PINN(nn.Module):
13
- hidden_dim: int
14
- output_dim: int
15
- num_layers: int
16
- act: Callable = nn.silu
17
- dtype: jnp.dtype = jnp.float32
18
- reparam : Union[None, Dict] = None
19
- fourier_emb : Union[None, Dict] = None
20
-
21
- @nn.compact
22
- def __call__(self, x):
23
- if self.fourier_emb is not None:
24
- x = FourierEmbs(**self.fourier_emb)(x)
25
- else:
26
- x = Dense(
27
- features=self.hidden_dim,
28
- reparam=self.reparam,
29
- dtype=self.dtype
30
- )(x)
31
- x = self.act(x)
32
- for _ in range(self.num_layers):
33
- x = Dense(
34
- features=self.hidden_dim,
35
- reparam=self.reparam,
36
- dtype=self.dtype
37
- )(x)
38
- x = self.act(x)
39
- x = Dense(
40
- features=self.output_dim,
41
- reparam=self.reparam,
42
- dtype=self.dtype
43
- )(x)
44
- return x