# Microbiome Transformer (Set-Based OTU Stability Model) This repository provides Transformer checkpoints for microbiome set modeling using SSU rRNA OTU embeddings (ProkBERT-derived vectors) and optional text metadata embeddings. Please see https://github.com/the-puzzler/microbiome-model for more information and relevant code. ## Model summary - **Architecture:** `MicrobiomeTransformer` (see `model.py`) - **Input type 1 (DNA/OTU):** 384-d embeddings - **Input type 2 (text metadata):** 1536-d embeddings - **Core behavior:** permutation-invariant set encoding via Transformer encoder (no positional encodings) - **Output:** per-token scalar logits (used as stability scores) ## Available checkpoints | Filename | Size variant | Metadata variant | `d_model` | `num_layers` | `dim_feedforward` | `nhead` | |---|---|---|---:|---:|---:|---:| | `small-notext.pt` | small | DNA-only | 20 | 3 | 80 | 5 | | `small-text.pt` | small | DNA + text | 20 | 3 | 80 | 5 | | `large-notext.pt` | large | DNA-only | 100 | 5 | 400 | 5 | | `large-text.pt` | large | DNA + text | 100 | 5 | 400 | 5 | Shared dimensions: - `OTU_EMB = 384` - `TXT_EMB = 1536` - `DROPOUT = 0.1` ## Input expectations 1. Build a set of OTU embeddings (ProkBERT vectors) per sample. 2. Optionally build a set of text embeddings (metadata) per sample for text-enabled variants. 3. Feed both sets as: - `embeddings_type1`: shape `(B, N_otu, 384)` - `embeddings_type2`: shape `(B, N_txt, 1536)` - `mask`: shape `(B, N_otu + N_txt)` with valid positions as `True` - `type_indicators`: shape `(B, N_otu + N_txt)` (0 for OTU tokens, 1 for text tokens) ## Minimal loading example ```python import torch from model import MicrobiomeTransformer ckpt_path = "large-notext.pt" # or small-notext.pt / small-text.pt / large-text.pt checkpoint = torch.load(ckpt_path, map_location="cpu") state_dict = checkpoint.get("model_state_dict", checkpoint) is_small = "small" in ckpt_path model = MicrobiomeTransformer( input_dim_type1=384, input_dim_type2=1536, d_model=20 if is_small else 100, nhead=5, num_layers=3 if is_small else 5, dim_feedforward=80 if is_small else 400, dropout=0.1, ) model.load_state_dict(state_dict, strict=False) model.eval() ``` ## Intended use - Microbiome representation learning from OTU sets - Stability-style scoring of community members - Downstream analyses such as dropout/colonization prediction and rollout trajectory experiments ## Limitations - This is a research model and not a clinical diagnostic tool. - Outputs depend strongly on upstream OTU mapping, embedding resolution, and cohort preprocessing. - Text-enabled checkpoints expect compatible metadata embedding pipelines.