File size: 620 Bytes
865db26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# model.py
import flax.linen as nn
import jax.numpy as jnp

class AQIPredictor(nn.Module):
    features: int

    @nn.compact
    def __call__(self, x, deterministic: bool):
        x = nn.Conv(features=64, kernel_size=(3,))(x)
        x = nn.relu(x)
        x = nn.LayerNorm()(x)
        x = nn.Conv(features=64, kernel_size=(3,))(x)
        x = nn.relu(x)
        x = nn.LayerNorm()(x)
        x = jnp.mean(x, axis=1)
        x = nn.Dense(128)(x)
        x = nn.Dropout(0.1)(nn.silu(x), deterministic=deterministic)
        x = nn.Dense(64)(x)
        x = nn.silu(x)
        return nn.Dense(1)(x)