stateMINT / README.md
absternator's picture
update readme with pypi
9815b16 verified
|
Raw
History Blame Contribute Delete
6.41 kB
---
library_name: flax
tags:
- jax
- flax
- orbax
- mamba
- malaria
- simulation
- time-series
- regression
pipeline_tag: time-series-forecasting
---
# StateMINT
StateMINT is a neural emulator for `malariasimulation` outputs. This model repository contains two exported inference artifacts:
- `prevalence/`: predicts malaria prevalence over time.
- `cases/`: predicts malaria cases over time.
Both artifacts use the same `Mamba2Regressor` architecture but have separate weights and preprocessing metadata. Users should load the folder that matches the target they want to predict.
## Repository Layout
```text
.
├── prevalence/
│ ├── checkpoint/
│ ├── model_config.json
│ └── preprocessing_config.json
├── cases/
│ ├── checkpoint/
│ ├── model_config.json
│ └── preprocessing_config.json
└── README.md
```
Each target folder is self-contained:
- `checkpoint/` contains model-only Orbax checkpoint data.
- `model_config.json` contains the model architecture settings needed to instantiate `Mamba2Regressor`.
- `preprocessing_config.json` contains feature ordering, intervention timing, target transform settings, and the fitted static covariate scaler.
## Intended Use
These models are intended for emulating trajectories generated by `malariasimulation`-style simulation inputs. They are designed for research and analysis workflows where fast approximate prediction of simulated prevalence or cases is useful.
They are not intended for direct clinical decision-making or for use on real-world surveillance data without additional validation.
## Installation
Install from PyPI and the Hugging Face Hub client:
```bash
pip install mintstate
# For GPU support, install with the `[gpu]` extra:
pip install mintstate[gpu]
```
If installing from source:
```bash
git clone https://github.com/mrc-ide/stateMINT.git
cd stateMINT
pip install -e .
# For GPU support, install with the `[gpu]` extra:
pip install -e .[gpu]
```
## Loading A Model
Recommended high-level API:
```python
from stateMINT.model import Mamba2Regressor
artifact = Mamba2Regressor.from_pretrained(
"dide-ic/stateMINT",
predictor="prevalence",
revision="v1.2.0",
)
model = artifact.model
```
To load the cases model:
```python
from stateMINT.model import Mamba2Regressor
artifact = Mamba2Regressor.from_pretrained(
"dide-ic/stateMINT",
predictor="cases",
revision="v1.2.0",
)
model = artifact.model
```
`from_pretrained` returns a `ModelArtifact` containing:
- `artifact.model`: the restored `Mamba2Regressor`.
- `artifact.preprocessing_config`: the exported preprocessing metadata.
- `artifact.scaler`: the fitted static covariate scaler.
- `artifact.prepare_inputs(...)`: converts raw static covariate dictionaries into model inputs.
- `artifact.predict(...)`: predicts directly from raw static covariate dictionaries.
## Example Use Case
Use the prevalence model to emulate malaria prevalence trajectories for two intervention scenarios. Provide one raw static covariate dictionary per trajectory:
```python
from stateMINT.model import Mamba2Regressor
artifact = Mamba2Regressor.from_pretrained(
"dide-ic/stateMINT",
predictor="prevalence",
revision="v1.2.0",
)
static_covars = [
{
"eir": 50.0,
"dn0_use": 0.3,
"dn0_future": 0.4,
"Q0": 0.8,
"phi_bednets": 0.7,
"seasonal": 1.0,
"routine": 0.5,
"itn_use": 0.2,
"irs_use": 0.1,
"itn_future": 0.3,
"irs_future": 0.2,
"lsm": 0.0,
},
{
"eir": 120.0,
"dn0_use": 0.2,
"dn0_future": 0.2,
"Q0": 0.9,
"phi_bednets": 0.6,
"seasonal": 1.0,
"routine": 0.4,
"itn_use": 0.3,
"irs_use": 0.0,
"itn_future": 0.5,
"irs_future": 0.1,
"lsm": 0.2,
},
]
predicted_prevalence = artifact.predict(static_covars)
print(predicted_prevalence.shape) # (2, n_steps)
print(predicted_prevalence[0]) # prevalence trajectory for the first scenario
```
For case-count predictions, load `predictor="cases"` and call the same `.predict(...)` method:
```python
artifact = Mamba2Regressor.from_pretrained(
"dide-ic/stateMINT",
predictor="cases",
revision="v1.2.0",
)
predicted_cases = artifact.predict(static_covars)
```
By default, `.predict(...)` returns predictions on the original target scale:
- prevalence predictions are probabilities in `[0, 1]`;
- cases predictions are case counts.
To get predictions in the model's transformed training space, pass `transformed=True`:
```python
raw_predictions = artifact.predict(static_covars, transformed=True)
```
If you use the low-level model directly, it expects already-prepared arrays with shape:
```text
(batch, time, input_size)
```
```python
import jax.numpy as jnp
X = artifact.prepare_inputs(static_covars)
raw_predictions = artifact.model(jnp.asarray(X))
```
## Preprocessing Contract
The model was trained on transformed inputs, not raw covariates. To reproduce training-time behavior, users must apply the same preprocessing described in `preprocessing_config.json`.
The static covariates are expected in this order:
```text
eir
dn0_use
dn0_future
Q0
phi_bednets
seasonal
routine
itn_use
irs_use
itn_future
irs_future
lsm
```
The following covariates are zeroed before the intervention day:
```text
dn0_future
itn_future
irs_future
lsm
routine
```
Each timestep uses:
```text
time_normalized / cyclized , scaled_static_covariates, post_intervention_flag, years_since_intervention
```
The static covariates must be standardized using the fitted scaler stored in `preprocessing_config.json`:
```python
scaled_static = (raw_static - scaler_mean) / scaler_scale
```
Do not refit the scaler for inference. The exported scaler is part of the trained model.
## Prediction Scale
The model predicts in the transformed target space used during training.
For prevalence:
```python
prevalence = sigmoid(raw_prediction)
```
For cases:
```python
cases = expm1(raw_prediction)
```
StateMINT utilities may perform this inverse transform for you depending on the prediction helper being used.
## General Notes
- The `prevalence` and `cases` folders have separate checkpoints and separate fitted scalers. Always load the folder corresponding to the target being predicted.