# model.py import flax.linen as nn import jax.numpy as jnp class AQIPredictor(nn.Module): features: int @nn.compact def __call__(self, x, deterministic: bool): x = nn.Conv(features=64, kernel_size=(3,))(x) x = nn.relu(x) x = nn.LayerNorm()(x) x = nn.Conv(features=64, kernel_size=(3,))(x) x = nn.relu(x) x = nn.LayerNorm()(x) x = jnp.mean(x, axis=1) x = nn.Dense(128)(x) x = nn.Dropout(0.1)(nn.silu(x), deterministic=deterministic) x = nn.Dense(64)(x) x = nn.silu(x) return nn.Dense(1)(x)