--- library_name: flax tags: - jax - flax - orbax - mamba - malaria - simulation - time-series - regression pipeline_tag: time-series-forecasting --- # StateMINT StateMINT is a neural emulator for `malariasimulation` outputs. This model repository contains two exported inference artifacts: - `prevalence/`: predicts malaria prevalence over time. - `cases/`: predicts malaria cases over time. Both artifacts use the same `Mamba2Regressor` architecture but have separate weights and preprocessing metadata. Users should load the folder that matches the target they want to predict. ## Repository Layout ```text . ├── prevalence/ │ ├── checkpoint/ │ ├── model_config.json │ └── preprocessing_config.json ├── cases/ │ ├── checkpoint/ │ ├── model_config.json │ └── preprocessing_config.json └── README.md ``` Each target folder is self-contained: - `checkpoint/` contains model-only Orbax checkpoint data. - `model_config.json` contains the model architecture settings needed to instantiate `Mamba2Regressor`. - `preprocessing_config.json` contains feature ordering, intervention timing, target transform settings, and the fitted static covariate scaler. ## Intended Use These models are intended for emulating trajectories generated by `malariasimulation`-style simulation inputs. They are designed for research and analysis workflows where fast approximate prediction of simulated prevalence or cases is useful. They are not intended for direct clinical decision-making or for use on real-world surveillance data without additional validation. ## Installation Install from PyPI and the Hugging Face Hub client: ```bash pip install mintstate # For GPU support, install with the `[gpu]` extra: pip install mintstate[gpu] ``` If installing from source: ```bash git clone https://github.com/mrc-ide/stateMINT.git cd stateMINT pip install -e . # For GPU support, install with the `[gpu]` extra: pip install -e .[gpu] ``` ## Loading A Model Recommended high-level API: ```python from stateMINT.model import Mamba2Regressor artifact = Mamba2Regressor.from_pretrained( "dide-ic/stateMINT", predictor="prevalence", revision="v1.2.0", ) model = artifact.model ``` To load the cases model: ```python from stateMINT.model import Mamba2Regressor artifact = Mamba2Regressor.from_pretrained( "dide-ic/stateMINT", predictor="cases", revision="v1.2.0", ) model = artifact.model ``` `from_pretrained` returns a `ModelArtifact` containing: - `artifact.model`: the restored `Mamba2Regressor`. - `artifact.preprocessing_config`: the exported preprocessing metadata. - `artifact.scaler`: the fitted static covariate scaler. - `artifact.prepare_inputs(...)`: converts raw static covariate dictionaries into model inputs. - `artifact.predict(...)`: predicts directly from raw static covariate dictionaries. ## Example Use Case Use the prevalence model to emulate malaria prevalence trajectories for two intervention scenarios. Provide one raw static covariate dictionary per trajectory: ```python from stateMINT.model import Mamba2Regressor artifact = Mamba2Regressor.from_pretrained( "dide-ic/stateMINT", predictor="prevalence", revision="v1.2.0", ) static_covars = [ { "eir": 50.0, "dn0_use": 0.3, "dn0_future": 0.4, "Q0": 0.8, "phi_bednets": 0.7, "seasonal": 1.0, "routine": 0.5, "itn_use": 0.2, "irs_use": 0.1, "itn_future": 0.3, "irs_future": 0.2, "lsm": 0.0, }, { "eir": 120.0, "dn0_use": 0.2, "dn0_future": 0.2, "Q0": 0.9, "phi_bednets": 0.6, "seasonal": 1.0, "routine": 0.4, "itn_use": 0.3, "irs_use": 0.0, "itn_future": 0.5, "irs_future": 0.1, "lsm": 0.2, }, ] predicted_prevalence = artifact.predict(static_covars) print(predicted_prevalence.shape) # (2, n_steps) print(predicted_prevalence[0]) # prevalence trajectory for the first scenario ``` For case-count predictions, load `predictor="cases"` and call the same `.predict(...)` method: ```python artifact = Mamba2Regressor.from_pretrained( "dide-ic/stateMINT", predictor="cases", revision="v1.2.0", ) predicted_cases = artifact.predict(static_covars) ``` By default, `.predict(...)` returns predictions on the original target scale: - prevalence predictions are probabilities in `[0, 1]`; - cases predictions are case counts. To get predictions in the model's transformed training space, pass `transformed=True`: ```python raw_predictions = artifact.predict(static_covars, transformed=True) ``` If you use the low-level model directly, it expects already-prepared arrays with shape: ```text (batch, time, input_size) ``` ```python import jax.numpy as jnp X = artifact.prepare_inputs(static_covars) raw_predictions = artifact.model(jnp.asarray(X)) ``` ## Preprocessing Contract The model was trained on transformed inputs, not raw covariates. To reproduce training-time behavior, users must apply the same preprocessing described in `preprocessing_config.json`. The static covariates are expected in this order: ```text eir dn0_use dn0_future Q0 phi_bednets seasonal routine itn_use irs_use itn_future irs_future lsm ``` The following covariates are zeroed before the intervention day: ```text dn0_future itn_future irs_future lsm routine ``` Each timestep uses: ```text time_normalized / cyclized , scaled_static_covariates, post_intervention_flag, years_since_intervention ``` The static covariates must be standardized using the fitted scaler stored in `preprocessing_config.json`: ```python scaled_static = (raw_static - scaler_mean) / scaler_scale ``` Do not refit the scaler for inference. The exported scaler is part of the trained model. ## Prediction Scale The model predicts in the transformed target space used during training. For prevalence: ```python prevalence = sigmoid(raw_prediction) ``` For cases: ```python cases = expm1(raw_prediction) ``` StateMINT utilities may perform this inverse transform for you depending on the prediction helper being used. ## General Notes - The `prevalence` and `cases` folders have separate checkpoints and separate fitted scalers. Always load the folder corresponding to the target being predicted.