yushize commited on
Commit
9e1f83a
·
verified ·
1 Parent(s): 2770056

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +263 -384
app.py CHANGED
@@ -5,184 +5,102 @@ import re
5
  import zipfile
6
  import tempfile
7
  from dataclasses import dataclass
8
- from typing import Dict, List, Tuple, Optional
9
 
10
  import gradio as gr
11
- import numpy as np
12
- import pandas as pd
13
  import torch
 
14
 
15
- from transformers import (
16
- AutoModel,
17
- AutoTokenizer,
18
- T5EncoderModel,
19
- T5Tokenizer,
20
- )
21
-
22
- # =========================
23
- # Global config
24
- # =========================
25
- APP_TITLE = "Protein Embedding Hub"
26
- APP_DESC = """
27
- Input FASTA protein sequences, choose a model, and export residue-level embeddings with shape L*d.
28
- This app automatically normalizes model outputs such as L+1, L+2, or tokenized variants back to strict residue-level L*d.
29
- """
30
 
31
  ALLOWED_AA = set(list("ACDEFGHIKLMNPQRSTVWYXBZJUO"))
32
  REPLACE_WITH_X = set(list("UZOB"))
33
 
34
- # =========================
35
- # Model registry
36
- # =========================
37
  @dataclass
38
  class ModelSpec:
39
  name: str
40
- family: str # "hf_encoder", "t5_encoder", "esmc"
41
  model_id: str
42
  tokenizer_id: Optional[str] = None
43
- note: str = ""
44
 
45
 
46
  MODEL_SPECS: Dict[str, ModelSpec] = {
47
- # ESM2
48
  "ESM2-8M": ModelSpec(
49
  name="ESM2-8M",
50
  family="hf_encoder",
51
  model_id="facebook/esm2_t6_8M_UR50D",
52
  tokenizer_id="facebook/esm2_t6_8M_UR50D",
53
- note="Very light."
54
  ),
55
  "ESM2-35M": ModelSpec(
56
  name="ESM2-35M",
57
  family="hf_encoder",
58
  model_id="facebook/esm2_t12_35M_UR50D",
59
  tokenizer_id="facebook/esm2_t12_35M_UR50D",
60
- note="Good small baseline."
61
  ),
62
  "ESM2-150M": ModelSpec(
63
  name="ESM2-150M",
64
  family="hf_encoder",
65
  model_id="facebook/esm2_t30_150M_UR50D",
66
  tokenizer_id="facebook/esm2_t30_150M_UR50D",
67
- note="Balanced."
68
  ),
69
  "ESM2-650M": ModelSpec(
70
  name="ESM2-650M",
71
  family="hf_encoder",
72
  model_id="facebook/esm2_t33_650M_UR50D",
73
  tokenizer_id="facebook/esm2_t33_650M_UR50D",
74
- note="Strong sequence-only baseline."
75
  ),
76
-
77
- # ESMC
78
  "ESMC-300M": ModelSpec(
79
  name="ESMC-300M",
80
  family="esmc",
81
  model_id="esmc_300m",
82
- note="Representation model; usually better efficiency/performance than similar-size ESM2."
83
  ),
84
  "ESMC-600M": ModelSpec(
85
  name="ESMC-600M",
86
  family="esmc",
87
  model_id="esmc_600m",
88
- note="Larger ESMC."
89
  ),
90
-
91
- # Ankh
92
  "Ankh-Base": ModelSpec(
93
  name="Ankh-Base",
94
  family="hf_encoder",
95
  model_id="ElnaggarLab/ankh-base",
96
  tokenizer_id="ElnaggarLab/ankh-base",
97
- note="Efficient strong general-purpose protein LM."
98
  ),
99
  "Ankh-Large": ModelSpec(
100
  name="Ankh-Large",
101
  family="hf_encoder",
102
  model_id="ElnaggarLab/ankh-large",
103
  tokenizer_id="ElnaggarLab/ankh-large",
104
- note="Larger Ankh variant."
105
  ),
106
-
107
- # ProtT5 encoder
108
  "ProtT5-XL-Encoder": ModelSpec(
109
  name="ProtT5-XL-Encoder",
110
  family="t5_encoder",
111
  model_id="Rostlab/prot_t5_xl_half_uniref50-enc",
112
  tokenizer_id="Rostlab/prot_t5_xl_half_uniref50-enc",
113
- note="Classic protein embedding model; heavy."
 
 
 
 
 
114
  ),
115
  }
116
 
117
 
118
- # =========================
119
- # Model manager
120
- # =========================
121
- class ModelManager:
122
- def __init__(self):
123
- self.current_key = None
124
- self.current_family = None
125
- self.model = None
126
- self.tokenizer = None
127
- self.device = None
128
-
129
- def unload(self):
130
- self.model = None
131
- self.tokenizer = None
132
- self.current_key = None
133
- self.current_family = None
134
- self.device = None
135
- gc.collect()
136
- if torch.cuda.is_available():
137
- torch.cuda.empty_cache()
138
-
139
- def load(self, model_key: str, device: str):
140
- if self.current_key == model_key and self.device == device and self.model is not None:
141
- return
142
-
143
- self.unload()
144
-
145
- spec = MODEL_SPECS[model_key]
146
- resolved_device = _resolve_device(device)
147
-
148
- if spec.family == "hf_encoder":
149
- self.tokenizer = AutoTokenizer.from_pretrained(spec.tokenizer_id)
150
- self.model = AutoModel.from_pretrained(spec.model_id)
151
- self.model.to(resolved_device)
152
- self.model.eval()
153
-
154
- elif spec.family == "t5_encoder":
155
- self.tokenizer = T5Tokenizer.from_pretrained(spec.tokenizer_id, do_lower_case=False)
156
- self.model = T5EncoderModel.from_pretrained(spec.model_id)
157
- self.model.to(resolved_device)
158
- self.model.eval()
159
-
160
- elif spec.family == "esmc":
161
- try:
162
- from esm.models.esmc import ESMC
163
- except Exception as e:
164
- raise RuntimeError(
165
- "Failed to import ESMC. Please install the official `esm` package. "
166
- f"Original error: {e}"
167
- )
168
- self.model = ESMC.from_pretrained(spec.model_id).to(resolved_device)
169
- self.model.eval()
170
- self.tokenizer = None
171
-
172
- else:
173
- raise ValueError(f"Unsupported family: {spec.family}")
174
-
175
- self.current_key = model_key
176
- self.current_family = spec.family
177
- self.device = resolved_device
178
 
179
 
180
- MODEL_MANAGER = ModelManager()
 
 
 
181
 
182
 
183
- # =========================
184
- # FASTA and sequence utils
185
- # =========================
186
  def parse_fasta(text: str) -> List[Dict[str, str]]:
187
  text = text.strip()
188
  if not text:
@@ -220,10 +138,8 @@ def parse_fasta(text: str) -> List[Dict[str, str]]:
220
  return records
221
 
222
 
223
- def clean_sequence(seq: str) -> Tuple[str, List[str]]:
224
  seq = re.sub(r"\s+", "", seq).upper()
225
- warnings = []
226
-
227
  if not seq:
228
  raise ValueError("Empty sequence after cleaning.")
229
 
@@ -231,70 +147,35 @@ def clean_sequence(seq: str) -> Tuple[str, List[str]]:
231
  if bad:
232
  raise ValueError(f"Invalid amino acid letters found: {bad}")
233
 
234
- replaced = sorted({c for c in seq if c in REPLACE_WITH_X})
235
- if replaced:
236
- for c in replaced:
237
- seq = seq.replace(c, "X")
238
- warnings.append(f"Replaced uncommon residues {replaced} with X.")
239
-
240
- return seq, warnings
241
 
242
 
243
  def protein_to_spaced(seq: str) -> str:
244
  return " ".join(list(seq))
245
 
246
 
247
- def safe_filename(x: str) -> str:
248
- x = re.sub(r"[^A-Za-z0-9._-]+", "_", x)
249
- x = x.strip("._")
250
- return x or "sequence"
251
-
252
-
253
- def _resolve_device(device: str) -> str:
254
- if device == "auto":
255
- return "cuda" if torch.cuda.is_available() else "cpu"
256
- if device == "cuda" and not torch.cuda.is_available():
257
- return "cpu"
258
- return device
259
-
260
-
261
- # =========================
262
- # Embedding normalization
263
- # =========================
264
- def normalize_to_residue_level(
265
  hidden: torch.Tensor,
266
  expected_len: int,
267
  special_tokens_mask: Optional[torch.Tensor] = None,
268
  attention_mask: Optional[torch.Tensor] = None,
269
  ) -> torch.Tensor:
270
- """
271
- Convert model output to strict residue-level shape [L, d].
272
-
273
- Priority:
274
- 1) If special_tokens_mask exists, remove special tokens.
275
- 2) If already exactly L, keep.
276
- 3) If L+2, assume BOS/EOS and slice [1:-1].
277
- 4) If L+1, trim one token from the end.
278
- 5) Else crop to first L after best effort.
279
- """
280
  if hidden.ndim != 2:
281
- raise ValueError(f"Expected hidden shape [T, d], got {tuple(hidden.shape)}")
282
 
283
- T, d = hidden.shape
284
 
285
  if special_tokens_mask is not None:
286
- mask = special_tokens_mask.bool().view(-1)
287
  if attention_mask is not None:
288
- attn = attention_mask.bool().view(-1)
289
- keep = (~mask) & attn
290
- else:
291
- keep = ~mask
292
- if keep.numel() == T:
293
- filtered = hidden[keep]
294
- if filtered.shape[0] == expected_len:
295
- return filtered
296
- if filtered.shape[0] > expected_len:
297
- return filtered[:expected_len]
298
 
299
  if T == expected_len:
300
  return hidden
@@ -305,247 +186,270 @@ def normalize_to_residue_level(
305
  if T > expected_len:
306
  return hidden[:expected_len]
307
 
308
- raise ValueError(
309
- f"Could not normalize token length {T} to residue length {expected_len}."
310
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
 
312
 
313
- # =========================
314
- # Model-specific embedding
315
- # =========================
316
  @torch.no_grad()
317
- def embed_one_hf_encoder(seq: str, model, tokenizer, device: str) -> np.ndarray:
318
- enc = tokenizer(
319
  seq,
320
  return_tensors="pt",
321
  add_special_tokens=True,
322
  return_special_tokens_mask=True,
323
  truncation=False,
324
  )
325
- enc = {k: v.to(device) for k, v in enc.items()}
326
- out = model(**{k: v for k, v in enc.items() if k != "special_tokens_mask"})
327
- hidden = out.last_hidden_state[0] # [T, d]
328
-
329
- special_tokens_mask = enc.get("special_tokens_mask", None)
330
- attention_mask = enc.get("attention_mask", None)
331
 
332
- residue_hidden = normalize_to_residue_level(
333
  hidden=hidden,
334
  expected_len=len(seq),
335
- special_tokens_mask=special_tokens_mask[0] if special_tokens_mask is not None else None,
336
- attention_mask=attention_mask[0] if attention_mask is not None else None,
337
  )
338
- return residue_hidden.detach().cpu().float().numpy()
339
 
340
 
341
  @torch.no_grad()
342
- def embed_one_t5_encoder(seq: str, model, tokenizer, device: str) -> np.ndarray:
343
- # ProtT5 style preprocessing: uppercase residues separated by spaces.
344
  spaced = protein_to_spaced(seq)
345
- enc = tokenizer(
346
  spaced,
347
  return_tensors="pt",
348
  add_special_tokens=True,
349
  return_special_tokens_mask=True,
350
  truncation=False,
351
  )
352
- enc = {k: v.to(device) for k, v in enc.items()}
353
- out = model(**{k: v for k, v in enc.items() if k != "special_tokens_mask"})
354
  hidden = out.last_hidden_state[0]
355
 
356
- special_tokens_mask = enc.get("special_tokens_mask", None)
357
- attention_mask = enc.get("attention_mask", None)
358
-
359
- residue_hidden = normalize_to_residue_level(
360
  hidden=hidden,
361
  expected_len=len(seq),
362
- special_tokens_mask=special_tokens_mask[0] if special_tokens_mask is not None else None,
363
- attention_mask=attention_mask[0] if attention_mask is not None else None,
364
  )
365
- return residue_hidden.detach().cpu().float().numpy()
366
 
367
 
368
  @torch.no_grad()
369
- def embed_one_esmc(seq: str, model, device: str) -> np.ndarray:
370
  from esm.sdk.api import ESMProtein, LogitsConfig
371
 
372
  protein = ESMProtein(sequence=seq)
373
- protein_tensor = model.encode(protein)
374
- out = model.logits(
375
  protein_tensor,
376
  LogitsConfig(sequence=True, return_embeddings=True)
377
  )
378
 
379
  emb = out.embeddings
380
- if isinstance(emb, np.ndarray):
381
- arr = emb
382
- else:
383
- arr = emb.detach().cpu().float().numpy()
384
-
385
- # Expected shape is typically [1, T, d] or [T, d]
386
- if arr.ndim == 3:
387
- arr = arr[0]
388
-
389
- if arr.shape[0] == len(seq):
390
- return arr
391
- if arr.shape[0] == len(seq) + 2:
392
- return arr[1:-1]
393
- if arr.shape[0] == len(seq) + 1:
394
- return arr[:len(seq)]
395
- if arr.shape[0] > len(seq):
396
- return arr[:len(seq)]
397
-
398
- raise ValueError(
399
- f"ESMC returned incompatible shape {arr.shape} for sequence length {len(seq)}."
400
- )
401
 
 
 
402
 
403
- def embed_sequences(
404
- fasta_text: str,
405
- model_key: str,
406
- device: str,
407
- progress=gr.Progress(track_tqdm=False),
408
- ):
409
- records = parse_fasta(fasta_text)
410
-
411
- cleaned_records = []
412
- global_warnings = []
413
 
414
- for rec in records:
415
- clean_seq, warnings = clean_sequence(rec["sequence"])
416
- cleaned_records.append({"id": rec["id"], "sequence": clean_seq})
417
- for w in warnings:
418
- global_warnings.append(f"{rec['id']}: {w}")
419
 
420
- MODEL_MANAGER.load(model_key, device)
421
- spec = MODEL_SPECS[model_key]
422
 
423
- embeddings_by_id: Dict[str, np.ndarray] = {}
424
- summary_rows = []
425
- first_preview = None
426
- first_preview_name = None
 
 
 
 
 
 
 
 
427
 
428
- for idx, rec in enumerate(cleaned_records, start=1):
429
- seq_id = rec["id"]
430
- seq = rec["sequence"]
431
- progress((idx - 1) / max(len(cleaned_records), 1), desc=f"Embedding {seq_id}")
432
 
433
- if spec.family == "hf_encoder":
434
- emb = embed_one_hf_encoder(seq, MODEL_MANAGER.model, MODEL_MANAGER.tokenizer, MODEL_MANAGER.device)
435
- elif spec.family == "t5_encoder":
436
- emb = embed_one_t5_encoder(seq, MODEL_MANAGER.model, MODEL_MANAGER.tokenizer, MODEL_MANAGER.device)
437
- elif spec.family == "esmc":
438
- emb = embed_one_esmc(seq, MODEL_MANAGER.model, MODEL_MANAGER.device)
439
- else:
440
- raise ValueError(f"Unsupported family: {spec.family}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
 
442
- if emb.shape[0] != len(seq):
443
- raise ValueError(
444
- f"Normalization failed for {seq_id}: got {emb.shape}, expected first dimension {len(seq)}."
 
 
445
  )
446
 
447
- embeddings_by_id[seq_id] = emb
448
-
449
- summary_rows.append({
450
- "id": seq_id,
451
- "length_L": len(seq),
452
- "embedding_dim_d": emb.shape[1],
453
- "shape": f"{emb.shape[0]} x {emb.shape[1]}",
454
- "model": model_key,
455
- })
456
-
457
- if first_preview is None:
458
- preview_rows = min(20, emb.shape[0])
459
- preview_cols = min(8, emb.shape[1])
460
- df = pd.DataFrame(
461
- emb[:preview_rows, :preview_cols],
462
- index=[f"res_{i+1}" for i in range(preview_rows)],
463
- columns=[f"dim_{j+1}" for j in range(preview_cols)],
464
  )
465
- first_preview = df
466
- first_preview_name = seq_id
467
 
468
- progress(1.0, desc="Packaging outputs")
 
469
 
470
- out_zip = package_outputs(
471
- embeddings_by_id=embeddings_by_id,
472
- sequences={x["id"]: x["sequence"] for x in cleaned_records},
473
- model_key=model_key,
474
- notes=global_warnings,
475
  )
476
 
477
- summary_df = pd.DataFrame(summary_rows)
478
- log_text = []
479
- log_text.append(f"Loaded model: {model_key}")
480
- log_text.append(f"Resolved device: {MODEL_MANAGER.device}")
481
- log_text.append(f"Processed sequences: {len(cleaned_records)}")
482
- if global_warnings:
483
- log_text.append("")
484
- log_text.append("Warnings:")
485
- log_text.extend(global_warnings)
486
-
487
- preview_markdown = f"Preview shown for: `{first_preview_name}`"
488
-
489
- return summary_df, first_preview, preview_markdown, out_zip, "\n".join(log_text)
490
-
491
-
492
- def package_outputs(
493
- embeddings_by_id: Dict[str, np.ndarray],
494
- sequences: Dict[str, str],
495
- model_key: str,
496
- notes: List[str],
497
- ) -> str:
498
- tmpdir = tempfile.mkdtemp(prefix="protein_embedding_hub_")
499
- zip_path = os.path.join(tmpdir, f"{safe_filename(model_key)}_embeddings.zip")
500
-
501
- summary_rows = []
502
- for seq_id, emb in embeddings_by_id.items():
503
- summary_rows.append({
504
- "id": seq_id,
505
- "length_L": sequences[seq_id].__len__(),
506
- "embedding_dim_d": emb.shape[1],
507
- "shape": f"{emb.shape[0]} x {emb.shape[1]}",
508
- "npy_file": f"{safe_filename(seq_id)}.npy",
509
- })
510
-
511
- summary_df = pd.DataFrame(summary_rows)
512
- sequences_df = pd.DataFrame(
513
- [{"id": k, "sequence": v} for k, v in sequences.items()]
514
- )
515
 
516
  with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
517
- # summary.csv
518
- with io.StringIO() as s:
519
- summary_df.to_csv(s, index=False)
520
- zf.writestr("summary.csv", s.getvalue())
521
 
522
- # sequences.csv
523
- with io.StringIO() as s:
524
- sequences_df.to_csv(s, index=False)
525
- zf.writestr("sequences.csv", s.getvalue())
526
 
527
- # notes.txt
528
- note_text = "\n".join(notes) if notes else "No warnings."
529
- zf.writestr("notes.txt", note_text)
 
530
 
531
- # per-sequence npy
532
- for seq_id, emb in embeddings_by_id.items():
533
- npy_name = f"embeddings/{safe_filename(seq_id)}.npy"
534
- buf = io.BytesIO()
535
- np.save(buf, emb)
536
- zf.writestr(npy_name, buf.getvalue())
537
 
538
- return zip_path
539
 
540
 
541
- def clear_loaded_model():
542
- MODEL_MANAGER.unload()
543
- return "Model cache cleared."
544
 
545
 
546
- # =========================
547
- # Gradio UI
548
- # =========================
549
  EXAMPLE_FASTA = """>seq1
550
  MKWVTFISLLLLFSSAYSRGVFRRDTHKSEIAHRFKDLGE
551
  >seq2
@@ -554,70 +458,45 @@ GAVLILKKKGHHEAELKPLAQSHATKHKIPIKYLEFISEAIIHVLHSR
554
 
555
  with gr.Blocks(title=APP_TITLE) as demo:
556
  gr.Markdown(f"# {APP_TITLE}")
557
- gr.Markdown(APP_DESC)
558
-
559
- with gr.Row():
560
- with gr.Column(scale=2):
561
- fasta_input = gr.Textbox(
562
- label="Protein FASTA input",
563
- lines=16,
564
- value=EXAMPLE_FASTA,
565
- placeholder="Paste FASTA here..."
566
- )
567
 
568
- model_dropdown = gr.Dropdown(
569
- choices=list(MODEL_SPECS.keys()),
570
- value="ESM2-150M",
571
- label="Model"
572
- )
 
573
 
574
- device_dropdown = gr.Dropdown(
575
- choices=["auto", "cuda", "cpu"],
576
- value="auto",
577
- label="Device"
578
- )
579
 
580
- with gr.Row():
581
- run_btn = gr.Button("Run embedding", variant="primary")
582
- clear_btn = gr.Button("Clear loaded model")
583
-
584
- with gr.Column(scale=1):
585
- gr.Markdown("## Notes")
586
- gr.Markdown(
587
- "- Output is always normalized to residue-level `L*d`\n"
588
- "- ZIP contains one `.npy` per sequence\n"
589
- "- `summary.csv` records final shapes\n"
590
- "- Large models need GPU"
591
- )
592
- model_note = gr.Markdown(
593
- value="\n".join(
594
- [f"- **{k}**: {v.note}" for k, v in MODEL_SPECS.items()]
595
- )
596
- )
597
 
598
  with gr.Row():
599
- summary_output = gr.Dataframe(label="Summary", interactive=False)
600
- with gr.Row():
601
- preview_note = gr.Markdown()
602
- with gr.Row():
603
- preview_output = gr.Dataframe(label="Embedding preview (first sequence)", interactive=False)
604
 
605
- with gr.Row():
606
- download_output = gr.File(label="Download ZIP")
607
- with gr.Row():
608
- log_output = gr.Textbox(label="Log", lines=10)
609
 
610
  run_btn.click(
611
- fn=embed_sequences,
612
- inputs=[fasta_input, model_dropdown, device_dropdown],
613
- outputs=[summary_output, preview_output, preview_note, download_output, log_output],
614
  )
615
 
616
  clear_btn.click(
617
- fn=clear_loaded_model,
618
  inputs=[],
619
- outputs=[log_output],
620
  )
621
 
622
- demo.queue(max_size=16)
623
- demo.launch()
 
 
5
  import zipfile
6
  import tempfile
7
  from dataclasses import dataclass
8
+ from typing import Dict, List, Optional
9
 
10
  import gradio as gr
 
 
11
  import torch
12
+ from transformers import AutoModel, AutoTokenizer, T5EncoderModel, T5Tokenizer
13
 
14
+ APP_TITLE = "Protein Embedding"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  ALLOWED_AA = set(list("ACDEFGHIKLMNPQRSTVWYXBZJUO"))
17
  REPLACE_WITH_X = set(list("UZOB"))
18
 
19
+
 
 
20
  @dataclass
21
  class ModelSpec:
22
  name: str
23
+ family: str
24
  model_id: str
25
  tokenizer_id: Optional[str] = None
 
26
 
27
 
28
  MODEL_SPECS: Dict[str, ModelSpec] = {
 
29
  "ESM2-8M": ModelSpec(
30
  name="ESM2-8M",
31
  family="hf_encoder",
32
  model_id="facebook/esm2_t6_8M_UR50D",
33
  tokenizer_id="facebook/esm2_t6_8M_UR50D",
 
34
  ),
35
  "ESM2-35M": ModelSpec(
36
  name="ESM2-35M",
37
  family="hf_encoder",
38
  model_id="facebook/esm2_t12_35M_UR50D",
39
  tokenizer_id="facebook/esm2_t12_35M_UR50D",
 
40
  ),
41
  "ESM2-150M": ModelSpec(
42
  name="ESM2-150M",
43
  family="hf_encoder",
44
  model_id="facebook/esm2_t30_150M_UR50D",
45
  tokenizer_id="facebook/esm2_t30_150M_UR50D",
 
46
  ),
47
  "ESM2-650M": ModelSpec(
48
  name="ESM2-650M",
49
  family="hf_encoder",
50
  model_id="facebook/esm2_t33_650M_UR50D",
51
  tokenizer_id="facebook/esm2_t33_650M_UR50D",
 
52
  ),
 
 
53
  "ESMC-300M": ModelSpec(
54
  name="ESMC-300M",
55
  family="esmc",
56
  model_id="esmc_300m",
 
57
  ),
58
  "ESMC-600M": ModelSpec(
59
  name="ESMC-600M",
60
  family="esmc",
61
  model_id="esmc_600m",
 
62
  ),
 
 
63
  "Ankh-Base": ModelSpec(
64
  name="Ankh-Base",
65
  family="hf_encoder",
66
  model_id="ElnaggarLab/ankh-base",
67
  tokenizer_id="ElnaggarLab/ankh-base",
 
68
  ),
69
  "Ankh-Large": ModelSpec(
70
  name="Ankh-Large",
71
  family="hf_encoder",
72
  model_id="ElnaggarLab/ankh-large",
73
  tokenizer_id="ElnaggarLab/ankh-large",
 
74
  ),
 
 
75
  "ProtT5-XL-Encoder": ModelSpec(
76
  name="ProtT5-XL-Encoder",
77
  family="t5_encoder",
78
  model_id="Rostlab/prot_t5_xl_half_uniref50-enc",
79
  tokenizer_id="Rostlab/prot_t5_xl_half_uniref50-enc",
80
+ ),
81
+ "ProSST-2048": ModelSpec(
82
+ name="ProSST-2048",
83
+ family="prosst",
84
+ model_id="AI4Protein/ProSST-2048",
85
+ tokenizer_id="AI4Protein/ProSST-2048",
86
  ),
87
  }
88
 
89
 
90
+ def resolve_device(device: str) -> str:
91
+ if device == "auto":
92
+ return "cuda" if torch.cuda.is_available() else "cpu"
93
+ if device == "cuda" and not torch.cuda.is_available():
94
+ return "cpu"
95
+ return device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
 
98
+ def safe_filename(x: str) -> str:
99
+ x = re.sub(r"[^A-Za-z0-9._-]+", "_", x)
100
+ x = x.strip("._")
101
+ return x or "item"
102
 
103
 
 
 
 
104
  def parse_fasta(text: str) -> List[Dict[str, str]]:
105
  text = text.strip()
106
  if not text:
 
138
  return records
139
 
140
 
141
+ def clean_sequence(seq: str) -> str:
142
  seq = re.sub(r"\s+", "", seq).upper()
 
 
143
  if not seq:
144
  raise ValueError("Empty sequence after cleaning.")
145
 
 
147
  if bad:
148
  raise ValueError(f"Invalid amino acid letters found: {bad}")
149
 
150
+ for c in REPLACE_WITH_X:
151
+ seq = seq.replace(c, "X")
152
+ return seq
 
 
 
 
153
 
154
 
155
  def protein_to_spaced(seq: str) -> str:
156
  return " ".join(list(seq))
157
 
158
 
159
+ def normalize_to_Ld(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  hidden: torch.Tensor,
161
  expected_len: int,
162
  special_tokens_mask: Optional[torch.Tensor] = None,
163
  attention_mask: Optional[torch.Tensor] = None,
164
  ) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
165
  if hidden.ndim != 2:
166
+ raise ValueError(f"Expected [T, d], got {tuple(hidden.shape)}")
167
 
168
+ T = hidden.shape[0]
169
 
170
  if special_tokens_mask is not None:
171
+ keep = ~special_tokens_mask.bool().view(-1)
172
  if attention_mask is not None:
173
+ keep = keep & attention_mask.bool().view(-1)
174
+ filtered = hidden[keep]
175
+ if filtered.shape[0] == expected_len:
176
+ return filtered
177
+ if filtered.shape[0] > expected_len:
178
+ return filtered[:expected_len]
 
 
 
 
179
 
180
  if T == expected_len:
181
  return hidden
 
186
  if T > expected_len:
187
  return hidden[:expected_len]
188
 
189
+ raise ValueError(f"Cannot normalize token length {T} to residue length {expected_len}.")
190
+
191
+
192
+ class SingleModelRunner:
193
+ def __init__(self):
194
+ self.model_key = None
195
+ self.family = None
196
+ self.device = None
197
+ self.model = None
198
+ self.tokenizer = None
199
+ self.sst_predictor = None
200
+
201
+ def unload(self):
202
+ self.model_key = None
203
+ self.family = None
204
+ self.device = None
205
+ self.model = None
206
+ self.tokenizer = None
207
+ self.sst_predictor = None
208
+ gc.collect()
209
+ if torch.cuda.is_available():
210
+ torch.cuda.empty_cache()
211
+
212
+ def load(self, model_key: str, device: str):
213
+ target_device = resolve_device(device)
214
+ if self.model_key == model_key and self.device == target_device and self.model is not None:
215
+ return
216
+
217
+ self.unload()
218
+ spec = MODEL_SPECS[model_key]
219
+
220
+ if spec.family == "hf_encoder":
221
+ self.tokenizer = AutoTokenizer.from_pretrained(spec.tokenizer_id)
222
+ self.model = AutoModel.from_pretrained(spec.model_id)
223
+ self.model.to(target_device)
224
+ self.model.eval()
225
+
226
+ elif spec.family == "t5_encoder":
227
+ self.tokenizer = T5Tokenizer.from_pretrained(spec.tokenizer_id, do_lower_case=False)
228
+ self.model = T5EncoderModel.from_pretrained(spec.model_id)
229
+ self.model.to(target_device)
230
+ self.model.eval()
231
+
232
+ elif spec.family == "esmc":
233
+ from esm.models.esmc import ESMC
234
+ self.model = ESMC.from_pretrained(spec.model_id).to(target_device)
235
+ self.model.eval()
236
+
237
+ elif spec.family == "prosst":
238
+ self.tokenizer = AutoTokenizer.from_pretrained(spec.tokenizer_id, trust_remote_code=True)
239
+ self.model = AutoModel.from_pretrained(
240
+ spec.model_id,
241
+ trust_remote_code=True,
242
+ output_hidden_states=True,
243
+ )
244
+ self.model.to(target_device)
245
+ self.model.eval()
246
+
247
+ # Official ProSST sequence-only route:
248
+ # predict structure tokens from sequence, then feed them into ProSST.
249
+ from prosst.structure.get_sst_seq import SSTPredictor
250
+ self.sst_predictor = SSTPredictor()
251
+
252
+ else:
253
+ raise ValueError(f"Unsupported family: {spec.family}")
254
+
255
+ self.model_key = model_key
256
+ self.family = spec.family
257
+ self.device = target_device
258
+
259
+
260
+ RUNNER = SingleModelRunner()
261
 
262
 
 
 
 
263
  @torch.no_grad()
264
+ def embed_hf_encoder(seq: str) -> torch.Tensor:
265
+ enc = RUNNER.tokenizer(
266
  seq,
267
  return_tensors="pt",
268
  add_special_tokens=True,
269
  return_special_tokens_mask=True,
270
  truncation=False,
271
  )
272
+ enc = {k: v.to(RUNNER.device) for k, v in enc.items()}
273
+ out = RUNNER.model(**{k: v for k, v in enc.items() if k != "special_tokens_mask"})
274
+ hidden = out.last_hidden_state[0]
 
 
 
275
 
276
+ emb = normalize_to_Ld(
277
  hidden=hidden,
278
  expected_len=len(seq),
279
+ special_tokens_mask=enc.get("special_tokens_mask", None)[0] if enc.get("special_tokens_mask", None) is not None else None,
280
+ attention_mask=enc.get("attention_mask", None)[0] if enc.get("attention_mask", None) is not None else None,
281
  )
282
+ return emb.detach().cpu().float()
283
 
284
 
285
  @torch.no_grad()
286
+ def embed_t5_encoder(seq: str) -> torch.Tensor:
 
287
  spaced = protein_to_spaced(seq)
288
+ enc = RUNNER.tokenizer(
289
  spaced,
290
  return_tensors="pt",
291
  add_special_tokens=True,
292
  return_special_tokens_mask=True,
293
  truncation=False,
294
  )
295
+ enc = {k: v.to(RUNNER.device) for k, v in enc.items()}
296
+ out = RUNNER.model(**{k: v for k, v in enc.items() if k != "special_tokens_mask"})
297
  hidden = out.last_hidden_state[0]
298
 
299
+ emb = normalize_to_Ld(
 
 
 
300
  hidden=hidden,
301
  expected_len=len(seq),
302
+ special_tokens_mask=enc.get("special_tokens_mask", None)[0] if enc.get("special_tokens_mask", None) is not None else None,
303
+ attention_mask=enc.get("attention_mask", None)[0] if enc.get("attention_mask", None) is not None else None,
304
  )
305
+ return emb.detach().cpu().float()
306
 
307
 
308
  @torch.no_grad()
309
+ def embed_esmc(seq: str) -> torch.Tensor:
310
  from esm.sdk.api import ESMProtein, LogitsConfig
311
 
312
  protein = ESMProtein(sequence=seq)
313
+ protein_tensor = RUNNER.model.encode(protein)
314
+ out = RUNNER.model.logits(
315
  protein_tensor,
316
  LogitsConfig(sequence=True, return_embeddings=True)
317
  )
318
 
319
  emb = out.embeddings
320
+ if not isinstance(emb, torch.Tensor):
321
+ emb = torch.tensor(emb)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
 
323
+ if emb.ndim == 3:
324
+ emb = emb[0]
325
 
326
+ if emb.shape[0] == len(seq):
327
+ return emb.detach().cpu().float()
328
+ if emb.shape[0] == len(seq) + 2:
329
+ return emb[1:-1].detach().cpu().float()
330
+ if emb.shape[0] == len(seq) + 1:
331
+ return emb[:len(seq)].detach().cpu().float()
332
+ if emb.shape[0] > len(seq):
333
+ return emb[:len(seq)].detach().cpu().float()
 
 
334
 
335
+ raise ValueError(f"ESMC returned shape {tuple(emb.shape)} for sequence length {len(seq)}.")
 
 
 
 
336
 
 
 
337
 
338
+ @torch.no_grad()
339
+ def embed_prosst(seq: str) -> torch.Tensor:
340
+ # Sequence-only mode:
341
+ # 1) predict structure token sequence from amino-acid sequence
342
+ # 2) feed sequence + structure tokens into ProSST
343
+ structure_tokens = RUNNER.sst_predictor.predict(seq)
344
+
345
+ # Structure tokens may come back as list[int], np.ndarray, or space-separated string
346
+ if isinstance(structure_tokens, str):
347
+ sst_seq = structure_tokens
348
+ else:
349
+ sst_seq = " ".join([str(x) for x in structure_tokens])
350
 
351
+ aa_spaced = protein_to_spaced(seq)
 
 
 
352
 
353
+ enc = RUNNER.tokenizer(
354
+ aa_spaced,
355
+ return_tensors="pt",
356
+ add_special_tokens=True,
357
+ return_special_tokens_mask=True,
358
+ truncation=False,
359
+ )
360
+ enc = {k: v.to(RUNNER.device) for k, v in enc.items()}
361
+
362
+ # Different ProSST remote-code implementations may expect different kwarg names.
363
+ # Try the common names first.
364
+ tried = []
365
+
366
+ for kw in ("ss_input_ids", "structure_ids", "sst_input_ids", "struc_input_ids"):
367
+ try:
368
+ sst_enc = RUNNER.tokenizer(
369
+ sst_seq,
370
+ return_tensors="pt",
371
+ add_special_tokens=True,
372
+ truncation=False,
373
+ )
374
+ sst_ids = sst_enc["input_ids"].to(RUNNER.device)
375
 
376
+ out = RUNNER.model(
377
+ input_ids=enc["input_ids"],
378
+ attention_mask=enc.get("attention_mask", None),
379
+ output_hidden_states=True,
380
+ **{kw: sst_ids},
381
  )
382
 
383
+ hidden = out.hidden_states[-1][0]
384
+ emb = normalize_to_Ld(
385
+ hidden=hidden,
386
+ expected_len=len(seq),
387
+ special_tokens_mask=enc.get("special_tokens_mask", None)[0] if enc.get("special_tokens_mask", None) is not None else None,
388
+ attention_mask=enc.get("attention_mask", None)[0] if enc.get("attention_mask", None) is not None else None,
 
 
 
 
 
 
 
 
 
 
 
389
  )
390
+ return emb.detach().cpu().float()
 
391
 
392
+ except Exception as e:
393
+ tried.append(f"{kw}: {repr(e)}")
394
 
395
+ raise RuntimeError(
396
+ "Failed to run ProSST. The installed ProSST remote-code signature may differ. "
397
+ + " | ".join(tried)
 
 
398
  )
399
 
400
+
401
+ def embed_one_sequence(seq: str) -> torch.Tensor:
402
+ if RUNNER.family == "hf_encoder":
403
+ return embed_hf_encoder(seq)
404
+ if RUNNER.family == "t5_encoder":
405
+ return embed_t5_encoder(seq)
406
+ if RUNNER.family == "esmc":
407
+ return embed_esmc(seq)
408
+ if RUNNER.family == "prosst":
409
+ return embed_prosst(seq)
410
+ raise ValueError(f"Unsupported family: {RUNNER.family}")
411
+
412
+
413
+ def run_embedding(fasta_text: str, model_keys: List[str], device: str, progress=gr.Progress()):
414
+ if not model_keys:
415
+ raise ValueError("Please select at least one model.")
416
+
417
+ records = parse_fasta(fasta_text)
418
+ records = [{"id": r["id"], "sequence": clean_sequence(r["sequence"])} for r in records]
419
+
420
+ tmpdir = tempfile.mkdtemp(prefix="protein_embeddings_")
421
+ zip_path = os.path.join(tmpdir, "embeddings.zip")
422
+
423
+ total_steps = len(model_keys) * len(records)
424
+ step = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
425
 
426
  with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
427
+ for model_key in model_keys:
428
+ RUNNER.load(model_key, device)
 
 
429
 
430
+ for rec in records:
431
+ step += 1
432
+ progress(step / total_steps, desc=f"{model_key} | {rec['id']}")
433
+ emb = embed_one_sequence(rec["sequence"])
434
 
435
+ if emb.ndim != 2 or emb.shape[0] != len(rec["sequence"]):
436
+ raise ValueError(
437
+ f"{model_key} failed on {rec['id']}: got shape {tuple(emb.shape)}, expected ({len(rec['sequence'])}, d)"
438
+ )
439
 
440
+ inner_name = f"{safe_filename(model_key)}/{safe_filename(rec['id'])}.pt"
441
+ buffer = io.BytesIO()
442
+ torch.save(emb, buffer)
443
+ zf.writestr(inner_name, buffer.getvalue())
 
 
444
 
445
+ return zip_path, f"Done: {len(records)} sequence(s), {len(model_keys)} model(s)."
446
 
447
 
448
+ def clear_cache():
449
+ RUNNER.unload()
450
+ return "Cache cleared."
451
 
452
 
 
 
 
453
  EXAMPLE_FASTA = """>seq1
454
  MKWVTFISLLLLFSSAYSRGVFRRDTHKSEIAHRFKDLGE
455
  >seq2
 
458
 
459
  with gr.Blocks(title=APP_TITLE) as demo:
460
  gr.Markdown(f"# {APP_TITLE}")
 
 
 
 
 
 
 
 
 
 
461
 
462
+ fasta_input = gr.Textbox(
463
+ label="FASTA",
464
+ lines=16,
465
+ value=EXAMPLE_FASTA,
466
+ placeholder="Paste FASTA here",
467
+ )
468
 
469
+ model_select = gr.CheckboxGroup(
470
+ choices=list(MODEL_SPECS.keys()),
471
+ value=["ESM2-150M"],
472
+ label="Models",
473
+ )
474
 
475
+ device_select = gr.Dropdown(
476
+ choices=["auto", "cuda", "cpu"],
477
+ value="auto",
478
+ label="Device",
479
+ )
 
 
 
 
 
 
 
 
 
 
 
 
480
 
481
  with gr.Row():
482
+ run_btn = gr.Button("Run", variant="primary")
483
+ clear_btn = gr.Button("Clear cache")
 
 
 
484
 
485
+ output_file = gr.File(label="Download")
486
+ log_box = gr.Textbox(label="Log", lines=4)
 
 
487
 
488
  run_btn.click(
489
+ fn=run_embedding,
490
+ inputs=[fasta_input, model_select, device_select],
491
+ outputs=[output_file, log_box],
492
  )
493
 
494
  clear_btn.click(
495
+ fn=clear_cache,
496
  inputs=[],
497
+ outputs=[log_box],
498
  )
499
 
500
+ demo.queue(max_size=8)
501
+ demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)
502
+