| --- |
| library_name: flax |
| tags: |
| - jax |
| - flax |
| - orbax |
| - mamba |
| - malaria |
| - simulation |
| - time-series |
| - regression |
| pipeline_tag: time-series-forecasting |
| --- |
| |
| # StateMINT |
|
|
| StateMINT is a neural emulator for `malariasimulation` outputs. This model repository contains two exported inference artifacts: |
|
|
| - `prevalence/`: predicts malaria prevalence over time. |
| - `cases/`: predicts malaria cases over time. |
|
|
| Both artifacts use the same `Mamba2Regressor` architecture but have separate weights and preprocessing metadata. Users should load the folder that matches the target they want to predict. |
|
|
| ## Repository Layout |
|
|
| ```text |
| . |
| ├── prevalence/ |
| │ ├── checkpoint/ |
| │ ├── model_config.json |
| │ └── preprocessing_config.json |
| ├── cases/ |
| │ ├── checkpoint/ |
| │ ├── model_config.json |
| │ └── preprocessing_config.json |
| └── README.md |
| ``` |
|
|
| Each target folder is self-contained: |
|
|
| - `checkpoint/` contains model-only Orbax checkpoint data. |
| - `model_config.json` contains the model architecture settings needed to instantiate `Mamba2Regressor`. |
| - `preprocessing_config.json` contains feature ordering, intervention timing, target transform settings, and the fitted static covariate scaler. |
|
|
| ## Intended Use |
|
|
| These models are intended for emulating trajectories generated by `malariasimulation`-style simulation inputs. They are designed for research and analysis workflows where fast approximate prediction of simulated prevalence or cases is useful. |
|
|
| They are not intended for direct clinical decision-making or for use on real-world surveillance data without additional validation. |
|
|
| ## Installation |
|
|
| Install from PyPI and the Hugging Face Hub client: |
|
|
| ```bash |
| pip install mintstate |
| |
| # For GPU support, install with the `[gpu]` extra: |
| pip install mintstate[gpu] |
| ``` |
|
|
| If installing from source: |
|
|
| ```bash |
| git clone https://github.com/mrc-ide/stateMINT.git |
| cd stateMINT |
| |
| pip install -e . |
| |
| # For GPU support, install with the `[gpu]` extra: |
| pip install -e .[gpu] |
| ``` |
|
|
| ## Loading A Model |
|
|
| Recommended high-level API: |
|
|
| ```python |
| from stateMINT.model import Mamba2Regressor |
| |
| artifact = Mamba2Regressor.from_pretrained( |
| "dide-ic/stateMINT", |
| predictor="prevalence", |
| revision="v1.2.0", |
| ) |
| |
| model = artifact.model |
| ``` |
|
|
| To load the cases model: |
|
|
| ```python |
| from stateMINT.model import Mamba2Regressor |
| |
| artifact = Mamba2Regressor.from_pretrained( |
| "dide-ic/stateMINT", |
| predictor="cases", |
| revision="v1.2.0", |
| ) |
| |
| model = artifact.model |
| ``` |
|
|
| `from_pretrained` returns a `ModelArtifact` containing: |
|
|
| - `artifact.model`: the restored `Mamba2Regressor`. |
| - `artifact.preprocessing_config`: the exported preprocessing metadata. |
| - `artifact.scaler`: the fitted static covariate scaler. |
| - `artifact.prepare_inputs(...)`: converts raw static covariate dictionaries into model inputs. |
| - `artifact.predict(...)`: predicts directly from raw static covariate dictionaries. |
|
|
| ## Example Use Case |
|
|
| Use the prevalence model to emulate malaria prevalence trajectories for two intervention scenarios. Provide one raw static covariate dictionary per trajectory: |
|
|
| ```python |
| from stateMINT.model import Mamba2Regressor |
| |
| artifact = Mamba2Regressor.from_pretrained( |
| "dide-ic/stateMINT", |
| predictor="prevalence", |
| revision="v1.2.0", |
| ) |
| |
| static_covars = [ |
| { |
| "eir": 50.0, |
| "dn0_use": 0.3, |
| "dn0_future": 0.4, |
| "Q0": 0.8, |
| "phi_bednets": 0.7, |
| "seasonal": 1.0, |
| "routine": 0.5, |
| "itn_use": 0.2, |
| "irs_use": 0.1, |
| "itn_future": 0.3, |
| "irs_future": 0.2, |
| "lsm": 0.0, |
| }, |
| { |
| "eir": 120.0, |
| "dn0_use": 0.2, |
| "dn0_future": 0.2, |
| "Q0": 0.9, |
| "phi_bednets": 0.6, |
| "seasonal": 1.0, |
| "routine": 0.4, |
| "itn_use": 0.3, |
| "irs_use": 0.0, |
| "itn_future": 0.5, |
| "irs_future": 0.1, |
| "lsm": 0.2, |
| }, |
| ] |
| |
| predicted_prevalence = artifact.predict(static_covars) |
| |
| print(predicted_prevalence.shape) # (2, n_steps) |
| print(predicted_prevalence[0]) # prevalence trajectory for the first scenario |
| ``` |
|
|
| For case-count predictions, load `predictor="cases"` and call the same `.predict(...)` method: |
|
|
| ```python |
| artifact = Mamba2Regressor.from_pretrained( |
| "dide-ic/stateMINT", |
| predictor="cases", |
| revision="v1.2.0", |
| ) |
| |
| predicted_cases = artifact.predict(static_covars) |
| ``` |
|
|
| By default, `.predict(...)` returns predictions on the original target scale: |
|
|
| - prevalence predictions are probabilities in `[0, 1]`; |
| - cases predictions are case counts. |
|
|
| To get predictions in the model's transformed training space, pass `transformed=True`: |
|
|
| ```python |
| raw_predictions = artifact.predict(static_covars, transformed=True) |
| ``` |
|
|
| If you use the low-level model directly, it expects already-prepared arrays with shape: |
|
|
| ```text |
| (batch, time, input_size) |
| ``` |
|
|
| ```python |
| import jax.numpy as jnp |
| |
| X = artifact.prepare_inputs(static_covars) |
| raw_predictions = artifact.model(jnp.asarray(X)) |
| ``` |
|
|
| ## Preprocessing Contract |
|
|
| The model was trained on transformed inputs, not raw covariates. To reproduce training-time behavior, users must apply the same preprocessing described in `preprocessing_config.json`. |
|
|
| The static covariates are expected in this order: |
|
|
| ```text |
| eir |
| dn0_use |
| dn0_future |
| Q0 |
| phi_bednets |
| seasonal |
| routine |
| itn_use |
| irs_use |
| itn_future |
| irs_future |
| lsm |
| ``` |
|
|
| The following covariates are zeroed before the intervention day: |
|
|
| ```text |
| dn0_future |
| itn_future |
| irs_future |
| lsm |
| routine |
| ``` |
|
|
| Each timestep uses: |
|
|
| ```text |
| time_normalized / cyclized , scaled_static_covariates, post_intervention_flag, years_since_intervention |
| ``` |
|
|
| The static covariates must be standardized using the fitted scaler stored in `preprocessing_config.json`: |
|
|
| ```python |
| scaled_static = (raw_static - scaler_mean) / scaler_scale |
| ``` |
|
|
| Do not refit the scaler for inference. The exported scaler is part of the trained model. |
|
|
| ## Prediction Scale |
|
|
| The model predicts in the transformed target space used during training. |
|
|
| For prevalence: |
|
|
| ```python |
| prevalence = sigmoid(raw_prediction) |
| ``` |
|
|
| For cases: |
|
|
| ```python |
| cases = expm1(raw_prediction) |
| ``` |
|
|
| StateMINT utilities may perform this inverse transform for you depending on the prediction helper being used. |
|
|
|
|
|
|
| ## General Notes |
|
|
| - The `prevalence` and `cases` folders have separate checkpoints and separate fitted scalers. Always load the folder corresponding to the target being predicted. |
|
|
|
|