Spaces:
Sleeping
Sleeping
File size: 620 Bytes
865db26 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 | # model.py
import flax.linen as nn
import jax.numpy as jnp
class AQIPredictor(nn.Module):
features: int
@nn.compact
def __call__(self, x, deterministic: bool):
x = nn.Conv(features=64, kernel_size=(3,))(x)
x = nn.relu(x)
x = nn.LayerNorm()(x)
x = nn.Conv(features=64, kernel_size=(3,))(x)
x = nn.relu(x)
x = nn.LayerNorm()(x)
x = jnp.mean(x, axis=1)
x = nn.Dense(128)(x)
x = nn.Dropout(0.1)(nn.silu(x), deterministic=deterministic)
x = nn.Dense(64)(x)
x = nn.silu(x)
return nn.Dense(1)(x)
|