microbiome-model / README.md
basilboy's picture
Update README.md
174ad1f verified
# Microbiome Transformer (Set-Based OTU Stability Model)
This repository provides Transformer checkpoints for microbiome set modeling using SSU rRNA OTU embeddings (ProkBERT-derived vectors) and optional text metadata embeddings.
Please see https://github.com/the-puzzler/microbiome-model for more information and relevant code.
## Model summary
- **Architecture:** `MicrobiomeTransformer` (see `model.py`)
- **Input type 1 (DNA/OTU):** 384-d embeddings
- **Input type 2 (text metadata):** 1536-d embeddings
- **Core behavior:** permutation-invariant set encoding via Transformer encoder (no positional encodings)
- **Output:** per-token scalar logits (used as stability scores)
## Available checkpoints
| Filename | Size variant | Metadata variant | `d_model` | `num_layers` | `dim_feedforward` | `nhead` |
|---|---|---|---:|---:|---:|---:|
| `small-notext.pt` | small | DNA-only | 20 | 3 | 80 | 5 |
| `small-text.pt` | small | DNA + text | 20 | 3 | 80 | 5 |
| `large-notext.pt` | large | DNA-only | 100 | 5 | 400 | 5 |
| `large-text.pt` | large | DNA + text | 100 | 5 | 400 | 5 |
Shared dimensions:
- `OTU_EMB = 384`
- `TXT_EMB = 1536`
- `DROPOUT = 0.1`
## Input expectations
1. Build a set of OTU embeddings (ProkBERT vectors) per sample.
2. Optionally build a set of text embeddings (metadata) per sample for text-enabled variants.
3. Feed both sets as:
- `embeddings_type1`: shape `(B, N_otu, 384)`
- `embeddings_type2`: shape `(B, N_txt, 1536)`
- `mask`: shape `(B, N_otu + N_txt)` with valid positions as `True`
- `type_indicators`: shape `(B, N_otu + N_txt)` (0 for OTU tokens, 1 for text tokens)
## Minimal loading example
```python
import torch
from model import MicrobiomeTransformer
ckpt_path = "large-notext.pt" # or small-notext.pt / small-text.pt / large-text.pt
checkpoint = torch.load(ckpt_path, map_location="cpu")
state_dict = checkpoint.get("model_state_dict", checkpoint)
is_small = "small" in ckpt_path
model = MicrobiomeTransformer(
input_dim_type1=384,
input_dim_type2=1536,
d_model=20 if is_small else 100,
nhead=5,
num_layers=3 if is_small else 5,
dim_feedforward=80 if is_small else 400,
dropout=0.1,
)
model.load_state_dict(state_dict, strict=False)
model.eval()
```
## Intended use
- Microbiome representation learning from OTU sets
- Stability-style scoring of community members
- Downstream analyses such as dropout/colonization prediction and rollout trajectory experiments
## Limitations
- This is a research model and not a clinical diagnostic tool.
- Outputs depend strongly on upstream OTU mapping, embedding resolution, and cohort preprocessing.
- Text-enabled checkpoints expect compatible metadata embedding pipelines.