yushize commited on
Commit
f7bcbc4
·
verified ·
1 Parent(s): 21429bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -48
app.py CHANGED
@@ -2,10 +2,12 @@ import gc
2
  import io
3
  import os
4
  import re
 
5
  import zipfile
6
  import tempfile
 
7
  from dataclasses import dataclass
8
- from typing import Dict, List, Optional
9
 
10
  import gradio as gr
11
  import torch
@@ -16,6 +18,8 @@ APP_TITLE = "Protein Embedding"
16
  ALLOWED_AA = set(list("ACDEFGHIKLMNPQRSTVWYXBZJUO"))
17
  REPLACE_WITH_X = set(list("UZOB"))
18
 
 
 
19
 
20
  @dataclass
21
  class ModelSpec:
@@ -189,6 +193,21 @@ def normalize_to_Ld(
189
  raise ValueError(f"Cannot normalize token length {T} to residue length {expected_len}.")
190
 
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  class SingleModelRunner:
193
  def __init__(self):
194
  self.model_key = None
@@ -235,6 +254,8 @@ class SingleModelRunner:
235
  self.model.eval()
236
 
237
  elif spec.family == "prosst":
 
 
238
  self.tokenizer = AutoTokenizer.from_pretrained(spec.tokenizer_id, trust_remote_code=True)
239
  self.model = AutoModel.from_pretrained(
240
  spec.model_id,
@@ -244,8 +265,6 @@ class SingleModelRunner:
244
  self.model.to(target_device)
245
  self.model.eval()
246
 
247
- # Official ProSST sequence-only route:
248
- # predict structure tokens from sequence, then feed them into ProSST.
249
  from prosst.structure.get_sst_seq import SSTPredictor
250
  self.sst_predictor = SSTPredictor()
251
 
@@ -335,76 +354,86 @@ def embed_esmc(seq: str) -> torch.Tensor:
335
  raise ValueError(f"ESMC returned shape {tuple(emb.shape)} for sequence length {len(seq)}.")
336
 
337
 
338
- @torch.no_grad()
339
- def embed_prosst(seq: str) -> torch.Tensor:
340
- # Sequence-only mode:
341
- # 1) predict structure token sequence from amino-acid sequence
342
- # 2) feed sequence + structure tokens into ProSST
343
- structure_tokens = RUNNER.sst_predictor.predict(seq)
344
-
345
- # Structure tokens may come back as list[int], np.ndarray, or space-separated string
346
- if isinstance(structure_tokens, str):
347
- sst_seq = structure_tokens
 
 
 
348
  else:
349
- sst_seq = " ".join([str(x) for x in structure_tokens])
350
 
351
- aa_spaced = protein_to_spaced(seq)
352
 
353
- enc = RUNNER.tokenizer(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
  aa_spaced,
355
  return_tensors="pt",
356
  add_special_tokens=True,
357
  return_special_tokens_mask=True,
358
  truncation=False,
359
  )
360
- enc = {k: v.to(RUNNER.device) for k, v in enc.items()}
361
 
362
- # Different ProSST remote-code implementations may expect different kwarg names.
363
- # Try the common names first.
364
- tried = []
365
 
 
366
  for kw in ("ss_input_ids", "structure_ids", "sst_input_ids", "struc_input_ids"):
367
  try:
368
- sst_enc = RUNNER.tokenizer(
369
- sst_seq,
370
- return_tensors="pt",
371
- add_special_tokens=True,
372
- truncation=False,
373
- )
374
- sst_ids = sst_enc["input_ids"].to(RUNNER.device)
375
-
376
  out = RUNNER.model(
377
- input_ids=enc["input_ids"],
378
- attention_mask=enc.get("attention_mask", None),
379
  output_hidden_states=True,
380
  **{kw: sst_ids},
381
  )
382
-
383
  hidden = out.hidden_states[-1][0]
384
  emb = normalize_to_Ld(
385
  hidden=hidden,
386
  expected_len=len(seq),
387
- special_tokens_mask=enc.get("special_tokens_mask", None)[0] if enc.get("special_tokens_mask", None) is not None else None,
388
- attention_mask=enc.get("attention_mask", None)[0] if enc.get("attention_mask", None) is not None else None,
389
  )
390
- return emb.detach().cpu().float()
391
-
392
  except Exception as e:
393
  tried.append(f"{kw}: {repr(e)}")
394
 
395
- raise RuntimeError(
396
- "Failed to run ProSST. The installed ProSST remote-code signature may differ. "
397
- + " | ".join(tried)
398
- )
399
 
400
 
401
- def embed_one_sequence(seq: str) -> torch.Tensor:
402
  if RUNNER.family == "hf_encoder":
403
- return embed_hf_encoder(seq)
404
  if RUNNER.family == "t5_encoder":
405
- return embed_t5_encoder(seq)
406
  if RUNNER.family == "esmc":
407
- return embed_esmc(seq)
408
  if RUNNER.family == "prosst":
409
  return embed_prosst(seq)
410
  raise ValueError(f"Unsupported family: {RUNNER.family}")
@@ -430,17 +459,21 @@ def run_embedding(fasta_text: str, model_keys: List[str], device: str, progress=
430
  for rec in records:
431
  step += 1
432
  progress(step / total_steps, desc=f"{model_key} | {rec['id']}")
433
- emb = embed_one_sequence(rec["sequence"])
434
 
435
  if emb.ndim != 2 or emb.shape[0] != len(rec["sequence"]):
436
  raise ValueError(
437
  f"{model_key} failed on {rec['id']}: got shape {tuple(emb.shape)}, expected ({len(rec['sequence'])}, d)"
438
  )
439
 
440
- inner_name = f"{safe_filename(model_key)}/{safe_filename(rec['id'])}.pt"
441
- buffer = io.BytesIO()
442
- torch.save(emb, buffer)
443
- zf.writestr(inner_name, buffer.getvalue())
 
 
 
 
444
 
445
  return zip_path, f"Done: {len(records)} sequence(s), {len(model_keys)} model(s)."
446
 
 
2
  import io
3
  import os
4
  import re
5
+ import sys
6
  import zipfile
7
  import tempfile
8
+ import subprocess
9
  from dataclasses import dataclass
10
+ from typing import Dict, List, Optional, Tuple
11
 
12
  import gradio as gr
13
  import torch
 
18
  ALLOWED_AA = set(list("ACDEFGHIKLMNPQRSTVWYXBZJUO"))
19
  REPLACE_WITH_X = set(list("UZOB"))
20
 
21
+ PROSST_REPO_DIR = "/tmp/ProSST"
22
+
23
 
24
  @dataclass
25
  class ModelSpec:
 
193
  raise ValueError(f"Cannot normalize token length {T} to residue length {expected_len}.")
194
 
195
 
196
+ def ensure_prosst_repo():
197
+ if os.path.isdir(PROSST_REPO_DIR) and os.path.isdir(os.path.join(PROSST_REPO_DIR, "prosst")):
198
+ if PROSST_REPO_DIR not in sys.path:
199
+ sys.path.append(PROSST_REPO_DIR)
200
+ return
201
+
202
+ subprocess.run(
203
+ ["git", "clone", "--depth", "1", "https://github.com/openmedlab/ProSST.git", PROSST_REPO_DIR],
204
+ check=True,
205
+ )
206
+
207
+ if PROSST_REPO_DIR not in sys.path:
208
+ sys.path.append(PROSST_REPO_DIR)
209
+
210
+
211
  class SingleModelRunner:
212
  def __init__(self):
213
  self.model_key = None
 
254
  self.model.eval()
255
 
256
  elif spec.family == "prosst":
257
+ ensure_prosst_repo()
258
+
259
  self.tokenizer = AutoTokenizer.from_pretrained(spec.tokenizer_id, trust_remote_code=True)
260
  self.model = AutoModel.from_pretrained(
261
  spec.model_id,
 
265
  self.model.to(target_device)
266
  self.model.eval()
267
 
 
 
268
  from prosst.structure.get_sst_seq import SSTPredictor
269
  self.sst_predictor = SSTPredictor()
270
 
 
354
  raise ValueError(f"ESMC returned shape {tuple(emb.shape)} for sequence length {len(seq)}.")
355
 
356
 
357
+ def get_sst_tokens(seq: str):
358
+ sst = RUNNER.sst_predictor.predict(seq)
359
+
360
+ if isinstance(sst, str):
361
+ tokens = [int(x) for x in sst.strip().split()]
362
+ elif isinstance(sst, torch.Tensor):
363
+ tokens = sst.detach().cpu().view(-1).tolist()
364
+ elif hasattr(sst, "tolist"):
365
+ tokens = sst.tolist()
366
+ if isinstance(tokens, list) and len(tokens) > 0 and isinstance(tokens[0], list):
367
+ tokens = tokens[0]
368
+ elif isinstance(sst, (list, tuple)):
369
+ tokens = list(sst)
370
  else:
371
+ raise ValueError(f"Unsupported SSTPredictor output type: {type(sst)}")
372
 
373
+ tokens = [int(x) for x in tokens]
374
 
375
+ # 尽量规整到 L
376
+ if len(tokens) == len(seq) + 2:
377
+ tokens = tokens[1:-1]
378
+ elif len(tokens) == len(seq) + 1:
379
+ tokens = tokens[:len(seq)]
380
+ elif len(tokens) > len(seq):
381
+ tokens = tokens[:len(seq)]
382
+
383
+ if len(tokens) != len(seq):
384
+ raise ValueError(f"SST token length mismatch: got {len(tokens)}, expected {len(seq)}")
385
+
386
+ return tokens
387
+
388
+
389
+ @torch.no_grad()
390
+ def embed_prosst(seq: str) -> Tuple[torch.Tensor, List[int]]:
391
+ sst_tokens = get_sst_tokens(seq)
392
+
393
+ aa_spaced = protein_to_spaced(seq)
394
+ seq_enc = RUNNER.tokenizer(
395
  aa_spaced,
396
  return_tensors="pt",
397
  add_special_tokens=True,
398
  return_special_tokens_mask=True,
399
  truncation=False,
400
  )
401
+ seq_enc = {k: v.to(RUNNER.device) for k, v in seq_enc.items()}
402
 
403
+ # ProSST 常见做法是把结构 token 当作额外输入 ids
404
+ # 这里直接构建 [1, L] LongTensor
405
+ sst_ids = torch.tensor([sst_tokens], dtype=torch.long, device=RUNNER.device)
406
 
407
+ tried = []
408
  for kw in ("ss_input_ids", "structure_ids", "sst_input_ids", "struc_input_ids"):
409
  try:
 
 
 
 
 
 
 
 
410
  out = RUNNER.model(
411
+ input_ids=seq_enc["input_ids"],
412
+ attention_mask=seq_enc.get("attention_mask", None),
413
  output_hidden_states=True,
414
  **{kw: sst_ids},
415
  )
 
416
  hidden = out.hidden_states[-1][0]
417
  emb = normalize_to_Ld(
418
  hidden=hidden,
419
  expected_len=len(seq),
420
+ special_tokens_mask=seq_enc.get("special_tokens_mask", None)[0] if seq_enc.get("special_tokens_mask", None) is not None else None,
421
+ attention_mask=seq_enc.get("attention_mask", None)[0] if seq_enc.get("attention_mask", None) is not None else None,
422
  )
423
+ return emb.detach().cpu().float(), sst_tokens
 
424
  except Exception as e:
425
  tried.append(f"{kw}: {repr(e)}")
426
 
427
+ raise RuntimeError("Failed to run ProSST with known structure-token arg names: " + " | ".join(tried))
 
 
 
428
 
429
 
430
+ def embed_one_sequence(seq: str):
431
  if RUNNER.family == "hf_encoder":
432
+ return embed_hf_encoder(seq), None
433
  if RUNNER.family == "t5_encoder":
434
+ return embed_t5_encoder(seq), None
435
  if RUNNER.family == "esmc":
436
+ return embed_esmc(seq), None
437
  if RUNNER.family == "prosst":
438
  return embed_prosst(seq)
439
  raise ValueError(f"Unsupported family: {RUNNER.family}")
 
459
  for rec in records:
460
  step += 1
461
  progress(step / total_steps, desc=f"{model_key} | {rec['id']}")
462
+ emb, sst_tokens = embed_one_sequence(rec["sequence"])
463
 
464
  if emb.ndim != 2 or emb.shape[0] != len(rec["sequence"]):
465
  raise ValueError(
466
  f"{model_key} failed on {rec['id']}: got shape {tuple(emb.shape)}, expected ({len(rec['sequence'])}, d)"
467
  )
468
 
469
+ pt_name = f"{safe_filename(model_key)}/{safe_filename(rec['id'])}.pt"
470
+ pt_buf = io.BytesIO()
471
+ torch.save(emb, pt_buf)
472
+ zf.writestr(pt_name, pt_buf.getvalue())
473
+
474
+ if sst_tokens is not None:
475
+ tok_name = f"{safe_filename(model_key)}_structure_tokens/{safe_filename(rec['id'])}.txt"
476
+ zf.writestr(tok_name, " ".join(map(str, sst_tokens)))
477
 
478
  return zip_path, f"Done: {len(records)} sequence(s), {len(model_keys)} model(s)."
479