the-puzzler commited on
Commit
44b0b79
·
1 Parent(s): 174ad1f
Files changed (3) hide show
  1. .gitignore +2 -0
  2. app.py +260 -0
  3. requirements.txt +8 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__/
2
+ *.pyc
app.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+ from typing import List, Tuple
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import plotly.express as px
8
+ import torch
9
+ import umap
10
+ from Bio import SeqIO
11
+ from transformers import AutoModel, AutoTokenizer
12
+
13
+ from model import MicrobiomeTransformer
14
+
15
+
16
+ MAX_GENES = 800
17
+ MAX_SEQ_LEN = 1024
18
+ PROKBERT_MODEL_ID = os.getenv("PROKBERT_MODEL_ID", "neuralbioinfo/prokbert-mini-long")
19
+ CHECKPOINT_PATH = os.getenv("CHECKPOINT_PATH", "large-notext.pt")
20
+ BATCH_SIZE = int(os.getenv("EMBED_BATCH_SIZE", "32"))
21
+ TRUST_REMOTE_CODE = "true"
22
+
23
+
24
+ @dataclass
25
+ class LoadedModels:
26
+ tokenizer: AutoTokenizer
27
+ prokbert: AutoModel
28
+ microbiome: MicrobiomeTransformer
29
+ device: torch.device
30
+
31
+
32
+ _MODELS: LoadedModels | None = None
33
+
34
+
35
+ def _load_models() -> LoadedModels:
36
+ global _MODELS
37
+ if _MODELS is not None:
38
+ return _MODELS
39
+
40
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
+
42
+ tokenizer = AutoTokenizer.from_pretrained(PROKBERT_MODEL_ID, trust_remote_code=TRUST_REMOTE_CODE)
43
+ prokbert = AutoModel.from_pretrained(PROKBERT_MODEL_ID, trust_remote_code=TRUST_REMOTE_CODE)
44
+ prokbert.to(device)
45
+ prokbert.eval()
46
+
47
+ checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
48
+ state_dict = checkpoint.get("model_state_dict", checkpoint)
49
+ microbiome = MicrobiomeTransformer(
50
+ input_dim_type1=384,
51
+ input_dim_type2=1536,
52
+ d_model=100,
53
+ nhead=5,
54
+ num_layers=5,
55
+ dim_feedforward=400,
56
+ dropout=0.1,
57
+ use_output_activation=False,
58
+ )
59
+ microbiome.load_state_dict(state_dict, strict=False)
60
+ microbiome.to(device)
61
+ microbiome.eval()
62
+
63
+ _MODELS = LoadedModels(
64
+ tokenizer=tokenizer,
65
+ prokbert=prokbert,
66
+ microbiome=microbiome,
67
+ device=device,
68
+ )
69
+ return _MODELS
70
+
71
+
72
+ def _read_fasta(path: str) -> Tuple[List[str], List[str], int, int]:
73
+ ids: List[str] = []
74
+ seqs: List[str] = []
75
+ truncated = 0
76
+
77
+ for record in SeqIO.parse(path, "fasta"):
78
+ seq = str(record.seq).upper()
79
+ if len(seq) > MAX_SEQ_LEN:
80
+ seq = seq[:MAX_SEQ_LEN]
81
+ truncated += 1
82
+ ids.append(record.id)
83
+ seqs.append(seq)
84
+
85
+ original_n = len(ids)
86
+ if original_n == 0:
87
+ raise ValueError("No FASTA records found.")
88
+
89
+ if original_n > MAX_GENES:
90
+ ids = ids[:MAX_GENES]
91
+ seqs = seqs[:MAX_GENES]
92
+
93
+ return ids, seqs, original_n, truncated
94
+
95
+
96
+ def _mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
97
+ mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
98
+ summed = (last_hidden_state * mask).sum(dim=1)
99
+ counts = mask.sum(dim=1).clamp(min=1e-8)
100
+ return summed / counts
101
+
102
+
103
+ def _embed_sequences(seqs: List[str], models: LoadedModels) -> np.ndarray:
104
+ pooled_batches: List[np.ndarray] = []
105
+
106
+ for i in range(0, len(seqs), BATCH_SIZE):
107
+ batch = seqs[i : i + BATCH_SIZE]
108
+ inputs = models.tokenizer(
109
+ batch,
110
+ return_tensors="pt",
111
+ truncation=True,
112
+ max_length=MAX_SEQ_LEN,
113
+ padding=True,
114
+ )
115
+ inputs = {k: v.to(models.device) for k, v in inputs.items()}
116
+
117
+ with torch.no_grad():
118
+ outputs = models.prokbert(**inputs)
119
+ pooled = _mean_pool(outputs.last_hidden_state, inputs["attention_mask"])
120
+
121
+ pooled_batches.append(pooled.detach().cpu().numpy())
122
+
123
+ emb = np.vstack(pooled_batches)
124
+ if emb.shape[1] != 384:
125
+ raise ValueError(
126
+ f"Expected 384-d ProkBERT embeddings, got {emb.shape[1]} dimensions from {PROKBERT_MODEL_ID}."
127
+ )
128
+ return emb
129
+
130
+
131
+ def _infer_logits_and_final_embeddings(input_embeddings: np.ndarray, models: LoadedModels) -> Tuple[np.ndarray, np.ndarray]:
132
+ x = torch.tensor(input_embeddings, dtype=torch.float32, device=models.device).unsqueeze(0)
133
+ n = x.shape[1]
134
+
135
+ empty_text = torch.zeros((1, 0, 1536), dtype=torch.float32, device=models.device)
136
+ mask = torch.ones((1, n), dtype=torch.bool, device=models.device)
137
+ type_indicators = torch.zeros((1, n), dtype=torch.long, device=models.device)
138
+
139
+ batch = {
140
+ "embeddings_type1": x,
141
+ "embeddings_type2": empty_text,
142
+ "mask": mask,
143
+ "type_indicators": type_indicators,
144
+ }
145
+
146
+ with torch.no_grad():
147
+ x_proj = models.microbiome.input_projection_type1(batch["embeddings_type1"])
148
+ final_hidden = models.microbiome.transformer(x_proj, src_key_padding_mask=~mask)
149
+ logits = models.microbiome.output_projection(final_hidden).squeeze(-1)
150
+
151
+ return (
152
+ logits.squeeze(0).detach().cpu().numpy(),
153
+ final_hidden.squeeze(0).detach().cpu().numpy(),
154
+ )
155
+
156
+
157
+ def _umap_df(vectors: np.ndarray, labels: List[str], value_name: str):
158
+ n = vectors.shape[0]
159
+ if n < 2:
160
+ raise ValueError("Need at least 2 genes to compute UMAP.")
161
+
162
+ reducer = umap.UMAP(
163
+ n_components=2,
164
+ n_neighbors=min(15, n - 1),
165
+ min_dist=0.1,
166
+ metric="cosine",
167
+ random_state=42,
168
+ )
169
+ coords = reducer.fit_transform(vectors)
170
+ return {
171
+ "x": coords[:, 0],
172
+ "y": coords[:, 1],
173
+ "gene": labels,
174
+ value_name: np.linalg.norm(vectors, axis=1),
175
+ }
176
+
177
+
178
+ def _plot_umap(vectors: np.ndarray, labels: List[str], title: str):
179
+ df = _umap_df(vectors, labels, "norm")
180
+ fig = px.scatter(
181
+ df,
182
+ x="x",
183
+ y="y",
184
+ hover_name="gene",
185
+ color="norm",
186
+ title=title,
187
+ color_continuous_scale="Viridis",
188
+ )
189
+ fig.update_traces(marker={"size": 9, "line": {"width": 0.5, "color": "black"}})
190
+ return fig
191
+
192
+
193
+ def _plot_logits(logits: np.ndarray, labels: List[str]):
194
+ fig = px.histogram(
195
+ x=logits,
196
+ nbins=min(50, max(10, len(logits) // 4)),
197
+ title="Logit Distribution Over Input DNA Embeddings",
198
+ )
199
+ fig.update_layout(xaxis_title="Logit", yaxis_title="Count")
200
+ return fig
201
+
202
+
203
+ def run_pipeline(fasta_file: str):
204
+ if fasta_file is None:
205
+ raise gr.Error("Upload a FASTA file first.")
206
+
207
+ models = _load_models()
208
+ labels, seqs, original_n, truncated = _read_fasta(fasta_file)
209
+
210
+ input_embeddings = _embed_sequences(seqs, models)
211
+ logits, final_embeddings = _infer_logits_and_final_embeddings(input_embeddings, models)
212
+
213
+ input_umap = _plot_umap(input_embeddings, labels, "UMAP of Input DNA Embeddings (ProkBERT Mean-Pooled)")
214
+ final_umap = _plot_umap(final_embeddings, labels, "UMAP of Final Embeddings (After large-notext Transformer)")
215
+ logits_hist = _plot_logits(logits, labels)
216
+
217
+ capped_n = len(labels)
218
+ info = (
219
+ f"Loaded {original_n} genes. "
220
+ f"Used {capped_n} (cap={MAX_GENES}). "
221
+ f"Truncated {truncated} sequence(s) to {MAX_SEQ_LEN} nt."
222
+ )
223
+
224
+ top_idx = np.argsort(logits)[::-1]
225
+ top_rows = [[labels[i], float(logits[i])] for i in top_idx[: min(50, len(labels))]]
226
+
227
+ return info, input_umap, final_umap, logits_hist, top_rows
228
+
229
+
230
+ with gr.Blocks(title="Microbiome Space: ProkBERT -> large-notext") as demo:
231
+ gr.Markdown(
232
+ """
233
+ # Microbiome Gene Scoring Explorer
234
+ Upload a FASTA of genes, embed with `prokbert-mini-long` (mean pooling), score with `large-notext`, and inspect embedding geometry + logit distribution.
235
+
236
+ Constraints:
237
+ - Max genes per run: 800
238
+ - Max gene length: 1024 nt (longer sequences are truncated)
239
+ """
240
+ )
241
+
242
+ with gr.Row():
243
+ fasta_in = gr.File(label="FASTA file", file_types=[".fa", ".fasta", ".fna", ".txt"], type="filepath")
244
+ run_btn = gr.Button("Run", variant="primary")
245
+
246
+ status = gr.Textbox(label="Run Summary")
247
+ input_umap_plot = gr.Plot(label="Input Embedding UMAP")
248
+ final_umap_plot = gr.Plot(label="Final Embedding UMAP")
249
+ logits_plot = gr.Plot(label="Logit Distribution")
250
+ top_table = gr.Dataframe(headers=["gene_id", "logit"], label="Top genes by logit")
251
+
252
+ run_btn.click(
253
+ fn=run_pipeline,
254
+ inputs=[fasta_in],
255
+ outputs=[status, input_umap_plot, final_umap_plot, logits_plot, top_table],
256
+ )
257
+
258
+
259
+ if __name__ == "__main__":
260
+ demo.queue().launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio>=5.0.0
2
+ torch>=2.1.0
3
+ transformers>=4.44.0
4
+ sentencepiece>=0.2.0
5
+ biopython>=1.84
6
+ umap-learn>=0.5.6
7
+ plotly>=5.24.0
8
+ numpy>=1.26.0