| # Microbiome Transformer (Set-Based OTU Stability Model) |
|
|
| This repository provides Transformer checkpoints for microbiome set modeling using SSU rRNA OTU embeddings (ProkBERT-derived vectors) and optional text metadata embeddings. |
|
|
| Please see https://github.com/the-puzzler/microbiome-model for more information and relevant code. |
|
|
| ## Model summary |
|
|
| - **Architecture:** `MicrobiomeTransformer` (see `model.py`) |
| - **Input type 1 (DNA/OTU):** 384-d embeddings |
| - **Input type 2 (text metadata):** 1536-d embeddings |
| - **Core behavior:** permutation-invariant set encoding via Transformer encoder (no positional encodings) |
| - **Output:** per-token scalar logits (used as stability scores) |
|
|
| ## Available checkpoints |
|
|
| | Filename | Size variant | Metadata variant | `d_model` | `num_layers` | `dim_feedforward` | `nhead` | |
| |---|---|---|---:|---:|---:|---:| |
| | `small-notext.pt` | small | DNA-only | 20 | 3 | 80 | 5 | |
| | `small-text.pt` | small | DNA + text | 20 | 3 | 80 | 5 | |
| | `large-notext.pt` | large | DNA-only | 100 | 5 | 400 | 5 | |
| | `large-text.pt` | large | DNA + text | 100 | 5 | 400 | 5 | |
|
|
| Shared dimensions: |
| - `OTU_EMB = 384` |
| - `TXT_EMB = 1536` |
| - `DROPOUT = 0.1` |
|
|
| ## Input expectations |
|
|
| 1. Build a set of OTU embeddings (ProkBERT vectors) per sample. |
| 2. Optionally build a set of text embeddings (metadata) per sample for text-enabled variants. |
| 3. Feed both sets as: |
| - `embeddings_type1`: shape `(B, N_otu, 384)` |
| - `embeddings_type2`: shape `(B, N_txt, 1536)` |
| - `mask`: shape `(B, N_otu + N_txt)` with valid positions as `True` |
| - `type_indicators`: shape `(B, N_otu + N_txt)` (0 for OTU tokens, 1 for text tokens) |
|
|
| ## Minimal loading example |
|
|
| ```python |
| import torch |
| from model import MicrobiomeTransformer |
| |
| ckpt_path = "large-notext.pt" # or small-notext.pt / small-text.pt / large-text.pt |
| checkpoint = torch.load(ckpt_path, map_location="cpu") |
| state_dict = checkpoint.get("model_state_dict", checkpoint) |
| |
| is_small = "small" in ckpt_path |
| model = MicrobiomeTransformer( |
| input_dim_type1=384, |
| input_dim_type2=1536, |
| d_model=20 if is_small else 100, |
| nhead=5, |
| num_layers=3 if is_small else 5, |
| dim_feedforward=80 if is_small else 400, |
| dropout=0.1, |
| ) |
| model.load_state_dict(state_dict, strict=False) |
| model.eval() |
| ``` |
|
|
| ## Intended use |
|
|
| - Microbiome representation learning from OTU sets |
| - Stability-style scoring of community members |
| - Downstream analyses such as dropout/colonization prediction and rollout trajectory experiments |
|
|
| ## Limitations |
|
|
| - This is a research model and not a clinical diagnostic tool. |
| - Outputs depend strongly on upstream OTU mapping, embedding resolution, and cohort preprocessing. |
| - Text-enabled checkpoints expect compatible metadata embedding pipelines. |
|
|
|
|