File size: 2,717 Bytes
db8dda6
 
 
 
174ad1f
db8dda6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# Microbiome Transformer (Set-Based OTU Stability Model)

This repository provides Transformer checkpoints for microbiome set modeling using SSU rRNA OTU embeddings (ProkBERT-derived vectors) and optional text metadata embeddings.

Please see https://github.com/the-puzzler/microbiome-model for more information and relevant code.

## Model summary

- **Architecture:** `MicrobiomeTransformer` (see `model.py`)
- **Input type 1 (DNA/OTU):** 384-d embeddings
- **Input type 2 (text metadata):** 1536-d embeddings
- **Core behavior:** permutation-invariant set encoding via Transformer encoder (no positional encodings)
- **Output:** per-token scalar logits (used as stability scores)

## Available checkpoints

| Filename | Size variant | Metadata variant | `d_model` | `num_layers` | `dim_feedforward` | `nhead` |
|---|---|---|---:|---:|---:|---:|
| `small-notext.pt` | small | DNA-only | 20 | 3 | 80 | 5 |
| `small-text.pt` | small | DNA + text | 20 | 3 | 80 | 5 |
| `large-notext.pt` | large | DNA-only | 100 | 5 | 400 | 5 |
| `large-text.pt` | large | DNA + text | 100 | 5 | 400 | 5 |

Shared dimensions:
- `OTU_EMB = 384`
- `TXT_EMB = 1536`
- `DROPOUT = 0.1`

## Input expectations

1. Build a set of OTU embeddings (ProkBERT vectors) per sample.
2. Optionally build a set of text embeddings (metadata) per sample for text-enabled variants.
3. Feed both sets as:
   - `embeddings_type1`: shape `(B, N_otu, 384)`
   - `embeddings_type2`: shape `(B, N_txt, 1536)`
   - `mask`: shape `(B, N_otu + N_txt)` with valid positions as `True`
   - `type_indicators`: shape `(B, N_otu + N_txt)` (0 for OTU tokens, 1 for text tokens)

## Minimal loading example

```python
import torch
from model import MicrobiomeTransformer

ckpt_path = "large-notext.pt"  # or small-notext.pt / small-text.pt / large-text.pt
checkpoint = torch.load(ckpt_path, map_location="cpu")
state_dict = checkpoint.get("model_state_dict", checkpoint)

is_small = "small" in ckpt_path
model = MicrobiomeTransformer(
    input_dim_type1=384,
    input_dim_type2=1536,
    d_model=20 if is_small else 100,
    nhead=5,
    num_layers=3 if is_small else 5,
    dim_feedforward=80 if is_small else 400,
    dropout=0.1,
)
model.load_state_dict(state_dict, strict=False)
model.eval()
```

## Intended use

- Microbiome representation learning from OTU sets
- Stability-style scoring of community members
- Downstream analyses such as dropout/colonization prediction and rollout trajectory experiments

## Limitations

- This is a research model and not a clinical diagnostic tool.
- Outputs depend strongly on upstream OTU mapping, embedding resolution, and cohort preprocessing.
- Text-enabled checkpoints expect compatible metadata embedding pipelines.