yushize commited on
Commit
b0d1814
·
verified ·
1 Parent(s): 5c80eee

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +623 -0
app.py ADDED
@@ -0,0 +1,623 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import io
3
+ import os
4
+ import re
5
+ import zipfile
6
+ import tempfile
7
+ from dataclasses import dataclass
8
+ from typing import Dict, List, Tuple, Optional
9
+
10
+ import gradio as gr
11
+ import numpy as np
12
+ import pandas as pd
13
+ import torch
14
+
15
+ from transformers import (
16
+ AutoModel,
17
+ AutoTokenizer,
18
+ T5EncoderModel,
19
+ T5Tokenizer,
20
+ )
21
+
22
+ # =========================
23
+ # Global config
24
+ # =========================
25
+ APP_TITLE = "Protein Embedding Hub"
26
+ APP_DESC = """
27
+ Input FASTA protein sequences, choose a model, and export residue-level embeddings with shape L*d.
28
+ This app automatically normalizes model outputs such as L+1, L+2, or tokenized variants back to strict residue-level L*d.
29
+ """
30
+
31
+ ALLOWED_AA = set(list("ACDEFGHIKLMNPQRSTVWYXBZJUO"))
32
+ REPLACE_WITH_X = set(list("UZOB"))
33
+
34
+ # =========================
35
+ # Model registry
36
+ # =========================
37
+ @dataclass
38
+ class ModelSpec:
39
+ name: str
40
+ family: str # "hf_encoder", "t5_encoder", "esmc"
41
+ model_id: str
42
+ tokenizer_id: Optional[str] = None
43
+ note: str = ""
44
+
45
+
46
+ MODEL_SPECS: Dict[str, ModelSpec] = {
47
+ # ESM2
48
+ "ESM2-8M": ModelSpec(
49
+ name="ESM2-8M",
50
+ family="hf_encoder",
51
+ model_id="facebook/esm2_t6_8M_UR50D",
52
+ tokenizer_id="facebook/esm2_t6_8M_UR50D",
53
+ note="Very light."
54
+ ),
55
+ "ESM2-35M": ModelSpec(
56
+ name="ESM2-35M",
57
+ family="hf_encoder",
58
+ model_id="facebook/esm2_t12_35M_UR50D",
59
+ tokenizer_id="facebook/esm2_t12_35M_UR50D",
60
+ note="Good small baseline."
61
+ ),
62
+ "ESM2-150M": ModelSpec(
63
+ name="ESM2-150M",
64
+ family="hf_encoder",
65
+ model_id="facebook/esm2_t30_150M_UR50D",
66
+ tokenizer_id="facebook/esm2_t30_150M_UR50D",
67
+ note="Balanced."
68
+ ),
69
+ "ESM2-650M": ModelSpec(
70
+ name="ESM2-650M",
71
+ family="hf_encoder",
72
+ model_id="facebook/esm2_t33_650M_UR50D",
73
+ tokenizer_id="facebook/esm2_t33_650M_UR50D",
74
+ note="Strong sequence-only baseline."
75
+ ),
76
+
77
+ # ESMC
78
+ "ESMC-300M": ModelSpec(
79
+ name="ESMC-300M",
80
+ family="esmc",
81
+ model_id="esmc_300m",
82
+ note="Representation model; usually better efficiency/performance than similar-size ESM2."
83
+ ),
84
+ "ESMC-600M": ModelSpec(
85
+ name="ESMC-600M",
86
+ family="esmc",
87
+ model_id="esmc_600m",
88
+ note="Larger ESMC."
89
+ ),
90
+
91
+ # Ankh
92
+ "Ankh-Base": ModelSpec(
93
+ name="Ankh-Base",
94
+ family="hf_encoder",
95
+ model_id="ElnaggarLab/ankh-base",
96
+ tokenizer_id="ElnaggarLab/ankh-base",
97
+ note="Efficient strong general-purpose protein LM."
98
+ ),
99
+ "Ankh-Large": ModelSpec(
100
+ name="Ankh-Large",
101
+ family="hf_encoder",
102
+ model_id="ElnaggarLab/ankh-large",
103
+ tokenizer_id="ElnaggarLab/ankh-large",
104
+ note="Larger Ankh variant."
105
+ ),
106
+
107
+ # ProtT5 encoder
108
+ "ProtT5-XL-Encoder": ModelSpec(
109
+ name="ProtT5-XL-Encoder",
110
+ family="t5_encoder",
111
+ model_id="Rostlab/prot_t5_xl_half_uniref50-enc",
112
+ tokenizer_id="Rostlab/prot_t5_xl_half_uniref50-enc",
113
+ note="Classic protein embedding model; heavy."
114
+ ),
115
+ }
116
+
117
+
118
+ # =========================
119
+ # Model manager
120
+ # =========================
121
+ class ModelManager:
122
+ def __init__(self):
123
+ self.current_key = None
124
+ self.current_family = None
125
+ self.model = None
126
+ self.tokenizer = None
127
+ self.device = None
128
+
129
+ def unload(self):
130
+ self.model = None
131
+ self.tokenizer = None
132
+ self.current_key = None
133
+ self.current_family = None
134
+ self.device = None
135
+ gc.collect()
136
+ if torch.cuda.is_available():
137
+ torch.cuda.empty_cache()
138
+
139
+ def load(self, model_key: str, device: str):
140
+ if self.current_key == model_key and self.device == device and self.model is not None:
141
+ return
142
+
143
+ self.unload()
144
+
145
+ spec = MODEL_SPECS[model_key]
146
+ resolved_device = _resolve_device(device)
147
+
148
+ if spec.family == "hf_encoder":
149
+ self.tokenizer = AutoTokenizer.from_pretrained(spec.tokenizer_id)
150
+ self.model = AutoModel.from_pretrained(spec.model_id)
151
+ self.model.to(resolved_device)
152
+ self.model.eval()
153
+
154
+ elif spec.family == "t5_encoder":
155
+ self.tokenizer = T5Tokenizer.from_pretrained(spec.tokenizer_id, do_lower_case=False)
156
+ self.model = T5EncoderModel.from_pretrained(spec.model_id)
157
+ self.model.to(resolved_device)
158
+ self.model.eval()
159
+
160
+ elif spec.family == "esmc":
161
+ try:
162
+ from esm.models.esmc import ESMC
163
+ except Exception as e:
164
+ raise RuntimeError(
165
+ "Failed to import ESMC. Please install the official `esm` package. "
166
+ f"Original error: {e}"
167
+ )
168
+ self.model = ESMC.from_pretrained(spec.model_id).to(resolved_device)
169
+ self.model.eval()
170
+ self.tokenizer = None
171
+
172
+ else:
173
+ raise ValueError(f"Unsupported family: {spec.family}")
174
+
175
+ self.current_key = model_key
176
+ self.current_family = spec.family
177
+ self.device = resolved_device
178
+
179
+
180
+ MODEL_MANAGER = ModelManager()
181
+
182
+
183
+ # =========================
184
+ # FASTA and sequence utils
185
+ # =========================
186
+ def parse_fasta(text: str) -> List[Dict[str, str]]:
187
+ text = text.strip()
188
+ if not text:
189
+ raise ValueError("Empty FASTA input.")
190
+
191
+ records = []
192
+ current_id = None
193
+ current_seq = []
194
+
195
+ for raw_line in text.splitlines():
196
+ line = raw_line.strip()
197
+ if not line:
198
+ continue
199
+ if line.startswith(">"):
200
+ if current_id is not None:
201
+ seq = "".join(current_seq).strip()
202
+ if not seq:
203
+ raise ValueError(f"Sequence for record '{current_id}' is empty.")
204
+ records.append({"id": current_id, "sequence": seq})
205
+ current_id = line[1:].strip() or f"seq_{len(records)+1}"
206
+ current_seq = []
207
+ else:
208
+ if current_id is None:
209
+ current_id = f"seq_{len(records)+1}"
210
+ current_seq.append(line)
211
+
212
+ if current_id is not None:
213
+ seq = "".join(current_seq).strip()
214
+ if not seq:
215
+ raise ValueError(f"Sequence for record '{current_id}' is empty.")
216
+ records.append({"id": current_id, "sequence": seq})
217
+
218
+ if not records:
219
+ raise ValueError("No FASTA records found.")
220
+ return records
221
+
222
+
223
+ def clean_sequence(seq: str) -> Tuple[str, List[str]]:
224
+ seq = re.sub(r"\s+", "", seq).upper()
225
+ warnings = []
226
+
227
+ if not seq:
228
+ raise ValueError("Empty sequence after cleaning.")
229
+
230
+ bad = sorted({c for c in seq if c not in ALLOWED_AA})
231
+ if bad:
232
+ raise ValueError(f"Invalid amino acid letters found: {bad}")
233
+
234
+ replaced = sorted({c for c in seq if c in REPLACE_WITH_X})
235
+ if replaced:
236
+ for c in replaced:
237
+ seq = seq.replace(c, "X")
238
+ warnings.append(f"Replaced uncommon residues {replaced} with X.")
239
+
240
+ return seq, warnings
241
+
242
+
243
+ def protein_to_spaced(seq: str) -> str:
244
+ return " ".join(list(seq))
245
+
246
+
247
+ def safe_filename(x: str) -> str:
248
+ x = re.sub(r"[^A-Za-z0-9._-]+", "_", x)
249
+ x = x.strip("._")
250
+ return x or "sequence"
251
+
252
+
253
+ def _resolve_device(device: str) -> str:
254
+ if device == "auto":
255
+ return "cuda" if torch.cuda.is_available() else "cpu"
256
+ if device == "cuda" and not torch.cuda.is_available():
257
+ return "cpu"
258
+ return device
259
+
260
+
261
+ # =========================
262
+ # Embedding normalization
263
+ # =========================
264
+ def normalize_to_residue_level(
265
+ hidden: torch.Tensor,
266
+ expected_len: int,
267
+ special_tokens_mask: Optional[torch.Tensor] = None,
268
+ attention_mask: Optional[torch.Tensor] = None,
269
+ ) -> torch.Tensor:
270
+ """
271
+ Convert model output to strict residue-level shape [L, d].
272
+
273
+ Priority:
274
+ 1) If special_tokens_mask exists, remove special tokens.
275
+ 2) If already exactly L, keep.
276
+ 3) If L+2, assume BOS/EOS and slice [1:-1].
277
+ 4) If L+1, trim one token from the end.
278
+ 5) Else crop to first L after best effort.
279
+ """
280
+ if hidden.ndim != 2:
281
+ raise ValueError(f"Expected hidden shape [T, d], got {tuple(hidden.shape)}")
282
+
283
+ T, d = hidden.shape
284
+
285
+ if special_tokens_mask is not None:
286
+ mask = special_tokens_mask.bool().view(-1)
287
+ if attention_mask is not None:
288
+ attn = attention_mask.bool().view(-1)
289
+ keep = (~mask) & attn
290
+ else:
291
+ keep = ~mask
292
+ if keep.numel() == T:
293
+ filtered = hidden[keep]
294
+ if filtered.shape[0] == expected_len:
295
+ return filtered
296
+ if filtered.shape[0] > expected_len:
297
+ return filtered[:expected_len]
298
+
299
+ if T == expected_len:
300
+ return hidden
301
+ if T == expected_len + 2:
302
+ return hidden[1:-1]
303
+ if T == expected_len + 1:
304
+ return hidden[:expected_len]
305
+ if T > expected_len:
306
+ return hidden[:expected_len]
307
+
308
+ raise ValueError(
309
+ f"Could not normalize token length {T} to residue length {expected_len}."
310
+ )
311
+
312
+
313
+ # =========================
314
+ # Model-specific embedding
315
+ # =========================
316
+ @torch.no_grad()
317
+ def embed_one_hf_encoder(seq: str, model, tokenizer, device: str) -> np.ndarray:
318
+ enc = tokenizer(
319
+ seq,
320
+ return_tensors="pt",
321
+ add_special_tokens=True,
322
+ return_special_tokens_mask=True,
323
+ truncation=False,
324
+ )
325
+ enc = {k: v.to(device) for k, v in enc.items()}
326
+ out = model(**{k: v for k, v in enc.items() if k != "special_tokens_mask"})
327
+ hidden = out.last_hidden_state[0] # [T, d]
328
+
329
+ special_tokens_mask = enc.get("special_tokens_mask", None)
330
+ attention_mask = enc.get("attention_mask", None)
331
+
332
+ residue_hidden = normalize_to_residue_level(
333
+ hidden=hidden,
334
+ expected_len=len(seq),
335
+ special_tokens_mask=special_tokens_mask[0] if special_tokens_mask is not None else None,
336
+ attention_mask=attention_mask[0] if attention_mask is not None else None,
337
+ )
338
+ return residue_hidden.detach().cpu().float().numpy()
339
+
340
+
341
+ @torch.no_grad()
342
+ def embed_one_t5_encoder(seq: str, model, tokenizer, device: str) -> np.ndarray:
343
+ # ProtT5 style preprocessing: uppercase residues separated by spaces.
344
+ spaced = protein_to_spaced(seq)
345
+ enc = tokenizer(
346
+ spaced,
347
+ return_tensors="pt",
348
+ add_special_tokens=True,
349
+ return_special_tokens_mask=True,
350
+ truncation=False,
351
+ )
352
+ enc = {k: v.to(device) for k, v in enc.items()}
353
+ out = model(**{k: v for k, v in enc.items() if k != "special_tokens_mask"})
354
+ hidden = out.last_hidden_state[0]
355
+
356
+ special_tokens_mask = enc.get("special_tokens_mask", None)
357
+ attention_mask = enc.get("attention_mask", None)
358
+
359
+ residue_hidden = normalize_to_residue_level(
360
+ hidden=hidden,
361
+ expected_len=len(seq),
362
+ special_tokens_mask=special_tokens_mask[0] if special_tokens_mask is not None else None,
363
+ attention_mask=attention_mask[0] if attention_mask is not None else None,
364
+ )
365
+ return residue_hidden.detach().cpu().float().numpy()
366
+
367
+
368
+ @torch.no_grad()
369
+ def embed_one_esmc(seq: str, model, device: str) -> np.ndarray:
370
+ from esm.sdk.api import ESMProtein, LogitsConfig
371
+
372
+ protein = ESMProtein(sequence=seq)
373
+ protein_tensor = model.encode(protein)
374
+ out = model.logits(
375
+ protein_tensor,
376
+ LogitsConfig(sequence=True, return_embeddings=True)
377
+ )
378
+
379
+ emb = out.embeddings
380
+ if isinstance(emb, np.ndarray):
381
+ arr = emb
382
+ else:
383
+ arr = emb.detach().cpu().float().numpy()
384
+
385
+ # Expected shape is typically [1, T, d] or [T, d]
386
+ if arr.ndim == 3:
387
+ arr = arr[0]
388
+
389
+ if arr.shape[0] == len(seq):
390
+ return arr
391
+ if arr.shape[0] == len(seq) + 2:
392
+ return arr[1:-1]
393
+ if arr.shape[0] == len(seq) + 1:
394
+ return arr[:len(seq)]
395
+ if arr.shape[0] > len(seq):
396
+ return arr[:len(seq)]
397
+
398
+ raise ValueError(
399
+ f"ESMC returned incompatible shape {arr.shape} for sequence length {len(seq)}."
400
+ )
401
+
402
+
403
+ def embed_sequences(
404
+ fasta_text: str,
405
+ model_key: str,
406
+ device: str,
407
+ progress=gr.Progress(track_tqdm=False),
408
+ ):
409
+ records = parse_fasta(fasta_text)
410
+
411
+ cleaned_records = []
412
+ global_warnings = []
413
+
414
+ for rec in records:
415
+ clean_seq, warnings = clean_sequence(rec["sequence"])
416
+ cleaned_records.append({"id": rec["id"], "sequence": clean_seq})
417
+ for w in warnings:
418
+ global_warnings.append(f"{rec['id']}: {w}")
419
+
420
+ MODEL_MANAGER.load(model_key, device)
421
+ spec = MODEL_SPECS[model_key]
422
+
423
+ embeddings_by_id: Dict[str, np.ndarray] = {}
424
+ summary_rows = []
425
+ first_preview = None
426
+ first_preview_name = None
427
+
428
+ for idx, rec in enumerate(cleaned_records, start=1):
429
+ seq_id = rec["id"]
430
+ seq = rec["sequence"]
431
+ progress((idx - 1) / max(len(cleaned_records), 1), desc=f"Embedding {seq_id}")
432
+
433
+ if spec.family == "hf_encoder":
434
+ emb = embed_one_hf_encoder(seq, MODEL_MANAGER.model, MODEL_MANAGER.tokenizer, MODEL_MANAGER.device)
435
+ elif spec.family == "t5_encoder":
436
+ emb = embed_one_t5_encoder(seq, MODEL_MANAGER.model, MODEL_MANAGER.tokenizer, MODEL_MANAGER.device)
437
+ elif spec.family == "esmc":
438
+ emb = embed_one_esmc(seq, MODEL_MANAGER.model, MODEL_MANAGER.device)
439
+ else:
440
+ raise ValueError(f"Unsupported family: {spec.family}")
441
+
442
+ if emb.shape[0] != len(seq):
443
+ raise ValueError(
444
+ f"Normalization failed for {seq_id}: got {emb.shape}, expected first dimension {len(seq)}."
445
+ )
446
+
447
+ embeddings_by_id[seq_id] = emb
448
+
449
+ summary_rows.append({
450
+ "id": seq_id,
451
+ "length_L": len(seq),
452
+ "embedding_dim_d": emb.shape[1],
453
+ "shape": f"{emb.shape[0]} x {emb.shape[1]}",
454
+ "model": model_key,
455
+ })
456
+
457
+ if first_preview is None:
458
+ preview_rows = min(20, emb.shape[0])
459
+ preview_cols = min(8, emb.shape[1])
460
+ df = pd.DataFrame(
461
+ emb[:preview_rows, :preview_cols],
462
+ index=[f"res_{i+1}" for i in range(preview_rows)],
463
+ columns=[f"dim_{j+1}" for j in range(preview_cols)],
464
+ )
465
+ first_preview = df
466
+ first_preview_name = seq_id
467
+
468
+ progress(1.0, desc="Packaging outputs")
469
+
470
+ out_zip = package_outputs(
471
+ embeddings_by_id=embeddings_by_id,
472
+ sequences={x["id"]: x["sequence"] for x in cleaned_records},
473
+ model_key=model_key,
474
+ notes=global_warnings,
475
+ )
476
+
477
+ summary_df = pd.DataFrame(summary_rows)
478
+ log_text = []
479
+ log_text.append(f"Loaded model: {model_key}")
480
+ log_text.append(f"Resolved device: {MODEL_MANAGER.device}")
481
+ log_text.append(f"Processed sequences: {len(cleaned_records)}")
482
+ if global_warnings:
483
+ log_text.append("")
484
+ log_text.append("Warnings:")
485
+ log_text.extend(global_warnings)
486
+
487
+ preview_markdown = f"Preview shown for: `{first_preview_name}`"
488
+
489
+ return summary_df, first_preview, preview_markdown, out_zip, "\n".join(log_text)
490
+
491
+
492
+ def package_outputs(
493
+ embeddings_by_id: Dict[str, np.ndarray],
494
+ sequences: Dict[str, str],
495
+ model_key: str,
496
+ notes: List[str],
497
+ ) -> str:
498
+ tmpdir = tempfile.mkdtemp(prefix="protein_embedding_hub_")
499
+ zip_path = os.path.join(tmpdir, f"{safe_filename(model_key)}_embeddings.zip")
500
+
501
+ summary_rows = []
502
+ for seq_id, emb in embeddings_by_id.items():
503
+ summary_rows.append({
504
+ "id": seq_id,
505
+ "length_L": sequences[seq_id].__len__(),
506
+ "embedding_dim_d": emb.shape[1],
507
+ "shape": f"{emb.shape[0]} x {emb.shape[1]}",
508
+ "npy_file": f"{safe_filename(seq_id)}.npy",
509
+ })
510
+
511
+ summary_df = pd.DataFrame(summary_rows)
512
+ sequences_df = pd.DataFrame(
513
+ [{"id": k, "sequence": v} for k, v in sequences.items()]
514
+ )
515
+
516
+ with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
517
+ # summary.csv
518
+ with io.StringIO() as s:
519
+ summary_df.to_csv(s, index=False)
520
+ zf.writestr("summary.csv", s.getvalue())
521
+
522
+ # sequences.csv
523
+ with io.StringIO() as s:
524
+ sequences_df.to_csv(s, index=False)
525
+ zf.writestr("sequences.csv", s.getvalue())
526
+
527
+ # notes.txt
528
+ note_text = "\n".join(notes) if notes else "No warnings."
529
+ zf.writestr("notes.txt", note_text)
530
+
531
+ # per-sequence npy
532
+ for seq_id, emb in embeddings_by_id.items():
533
+ npy_name = f"embeddings/{safe_filename(seq_id)}.npy"
534
+ buf = io.BytesIO()
535
+ np.save(buf, emb)
536
+ zf.writestr(npy_name, buf.getvalue())
537
+
538
+ return zip_path
539
+
540
+
541
+ def clear_loaded_model():
542
+ MODEL_MANAGER.unload()
543
+ return "Model cache cleared."
544
+
545
+
546
+ # =========================
547
+ # Gradio UI
548
+ # =========================
549
+ EXAMPLE_FASTA = """>seq1
550
+ MKWVTFISLLLLFSSAYSRGVFRRDTHKSEIAHRFKDLGE
551
+ >seq2
552
+ GAVLILKKKGHHEAELKPLAQSHATKHKIPIKYLEFISEAIIHVLHSR
553
+ """
554
+
555
+ with gr.Blocks(title=APP_TITLE) as demo:
556
+ gr.Markdown(f"# {APP_TITLE}")
557
+ gr.Markdown(APP_DESC)
558
+
559
+ with gr.Row():
560
+ with gr.Column(scale=2):
561
+ fasta_input = gr.Textbox(
562
+ label="Protein FASTA input",
563
+ lines=16,
564
+ value=EXAMPLE_FASTA,
565
+ placeholder="Paste FASTA here..."
566
+ )
567
+
568
+ model_dropdown = gr.Dropdown(
569
+ choices=list(MODEL_SPECS.keys()),
570
+ value="ESM2-150M",
571
+ label="Model"
572
+ )
573
+
574
+ device_dropdown = gr.Dropdown(
575
+ choices=["auto", "cuda", "cpu"],
576
+ value="auto",
577
+ label="Device"
578
+ )
579
+
580
+ with gr.Row():
581
+ run_btn = gr.Button("Run embedding", variant="primary")
582
+ clear_btn = gr.Button("Clear loaded model")
583
+
584
+ with gr.Column(scale=1):
585
+ gr.Markdown("## Notes")
586
+ gr.Markdown(
587
+ "- Output is always normalized to residue-level `L*d`\n"
588
+ "- ZIP contains one `.npy` per sequence\n"
589
+ "- `summary.csv` records final shapes\n"
590
+ "- Large models need GPU"
591
+ )
592
+ model_note = gr.Markdown(
593
+ value="\n".join(
594
+ [f"- **{k}**: {v.note}" for k, v in MODEL_SPECS.items()]
595
+ )
596
+ )
597
+
598
+ with gr.Row():
599
+ summary_output = gr.Dataframe(label="Summary", interactive=False)
600
+ with gr.Row():
601
+ preview_note = gr.Markdown()
602
+ with gr.Row():
603
+ preview_output = gr.Dataframe(label="Embedding preview (first sequence)", interactive=False)
604
+
605
+ with gr.Row():
606
+ download_output = gr.File(label="Download ZIP")
607
+ with gr.Row():
608
+ log_output = gr.Textbox(label="Log", lines=10)
609
+
610
+ run_btn.click(
611
+ fn=embed_sequences,
612
+ inputs=[fasta_input, model_dropdown, device_dropdown],
613
+ outputs=[summary_output, preview_output, preview_note, download_output, log_output],
614
+ )
615
+
616
+ clear_btn.click(
617
+ fn=clear_loaded_model,
618
+ inputs=[],
619
+ outputs=[log_output],
620
+ )
621
+
622
+ demo.queue(max_size=16)
623
+ demo.launch()