yushize's picture
Update app.py
e6c3762 verified
import gc
import io
import os
import re
import sys
import zipfile
import tempfile
import subprocess
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import gradio as gr
import torch
from transformers import (
AutoModel,
AutoModelForMaskedLM,
AutoTokenizer,
T5EncoderModel,
T5Tokenizer,
)
APP_TITLE = "Protein Embedding"
ALLOWED_AA = set(list("ACDEFGHIKLMNPQRSTVWYXBZJUO"))
REPLACE_WITH_X = set(list("UZOB"))
PROSST_REPO_DIR = "/tmp/ProSST"
@dataclass
class ModelSpec:
name: str
family: str
model_id: str
tokenizer_id: Optional[str] = None
MODEL_SPECS: Dict[str, ModelSpec] = {
"ESM2-8M": ModelSpec(
name="ESM2-8M",
family="hf_encoder",
model_id="facebook/esm2_t6_8M_UR50D",
tokenizer_id="facebook/esm2_t6_8M_UR50D",
),
"ESM2-35M": ModelSpec(
name="ESM2-35M",
family="hf_encoder",
model_id="facebook/esm2_t12_35M_UR50D",
tokenizer_id="facebook/esm2_t12_35M_UR50D",
),
"ESM2-150M": ModelSpec(
name="ESM2-150M",
family="hf_encoder",
model_id="facebook/esm2_t30_150M_UR50D",
tokenizer_id="facebook/esm2_t30_150M_UR50D",
),
"ESM2-650M": ModelSpec(
name="ESM2-650M",
family="hf_encoder",
model_id="facebook/esm2_t33_650M_UR50D",
tokenizer_id="facebook/esm2_t33_650M_UR50D",
),
"ESMC-300M": ModelSpec(
name="ESMC-300M",
family="esmc",
model_id="esmc_300m",
),
"ESMC-600M": ModelSpec(
name="ESMC-600M",
family="esmc",
model_id="esmc_600m",
),
"Ankh-Base": ModelSpec(
name="Ankh-Base",
family="hf_encoder",
model_id="ElnaggarLab/ankh-base",
tokenizer_id="ElnaggarLab/ankh-base",
),
"Ankh-Large": ModelSpec(
name="Ankh-Large",
family="hf_encoder",
model_id="ElnaggarLab/ankh-large",
tokenizer_id="ElnaggarLab/ankh-large",
),
"ProtT5-XL-Encoder": ModelSpec(
name="ProtT5-XL-Encoder",
family="t5_encoder",
model_id="Rostlab/prot_t5_xl_half_uniref50-enc",
tokenizer_id="Rostlab/prot_t5_xl_half_uniref50-enc",
),
"ProSST-2048": ModelSpec(
name="ProSST-2048",
family="prosst",
model_id="AI4Protein/ProSST-2048",
tokenizer_id="AI4Protein/ProSST-2048",
),
}
def resolve_device(device: str) -> str:
if device == "auto":
return "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda" and not torch.cuda.is_available():
return "cpu"
return device
def safe_filename(x: str) -> str:
x = re.sub(r"[^A-Za-z0-9._-]+", "_", x)
x = x.strip("._")
return x or "item"
def parse_fasta(text: str) -> List[Dict[str, str]]:
text = text.strip()
if not text:
raise ValueError("Empty FASTA input.")
records = []
current_id = None
current_seq = []
for raw_line in text.splitlines():
line = raw_line.strip()
if not line:
continue
if line.startswith(">"):
if current_id is not None:
seq = "".join(current_seq).strip()
if not seq:
raise ValueError(f"Sequence for record '{current_id}' is empty.")
records.append({"id": current_id, "sequence": seq})
current_id = line[1:].strip() or f"seq_{len(records)+1}"
current_seq = []
else:
if current_id is None:
current_id = f"seq_{len(records)+1}"
current_seq.append(line)
if current_id is not None:
seq = "".join(current_seq).strip()
if not seq:
raise ValueError(f"Sequence for record '{current_id}' is empty.")
records.append({"id": current_id, "sequence": seq})
if not records:
raise ValueError("No FASTA records found.")
return records
def clean_sequence(seq: str) -> str:
seq = re.sub(r"\s+", "", seq).upper()
if not seq:
raise ValueError("Empty sequence after cleaning.")
bad = sorted({c for c in seq if c not in ALLOWED_AA})
if bad:
raise ValueError(f"Invalid amino acid letters found: {bad}")
for c in REPLACE_WITH_X:
seq = seq.replace(c, "X")
return seq
def protein_to_spaced(seq: str) -> str:
return " ".join(list(seq))
def normalize_to_Ld(
hidden: torch.Tensor,
expected_len: int,
special_tokens_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if hidden.ndim != 2:
raise ValueError(f"Expected [T, d], got {tuple(hidden.shape)}")
T = hidden.shape[0]
if special_tokens_mask is not None:
keep = ~special_tokens_mask.bool().view(-1)
if attention_mask is not None:
keep = keep & attention_mask.bool().view(-1)
filtered = hidden[keep]
if filtered.shape[0] == expected_len:
return filtered
if filtered.shape[0] > expected_len:
return filtered[:expected_len]
if T == expected_len:
return hidden
if T == expected_len + 2:
return hidden[1:-1]
if T == expected_len + 1:
return hidden[:expected_len]
if T > expected_len:
return hidden[:expected_len]
raise ValueError(f"Cannot normalize token length {T} to residue length {expected_len}.")
def ensure_prosst_repo():
if os.path.isdir(PROSST_REPO_DIR) and os.path.isdir(os.path.join(PROSST_REPO_DIR, "prosst")):
if PROSST_REPO_DIR not in sys.path:
sys.path.append(PROSST_REPO_DIR)
return
subprocess.run(
["git", "clone", "--depth", "1", "https://github.com/openmedlab/ProSST.git", PROSST_REPO_DIR],
check=True,
)
if PROSST_REPO_DIR not in sys.path:
sys.path.append(PROSST_REPO_DIR)
class SingleModelRunner:
def __init__(self):
self.model_key = None
self.family = None
self.device = None
self.model = None
self.tokenizer = None
self.sst_predictor = None
def unload(self):
self.model_key = None
self.family = None
self.device = None
self.model = None
self.tokenizer = None
self.sst_predictor = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def load(self, model_key: str, device: str):
target_device = resolve_device(device)
if self.model_key == model_key and self.device == target_device and self.model is not None:
return
self.unload()
spec = MODEL_SPECS[model_key]
if spec.family == "hf_encoder":
self.tokenizer = AutoTokenizer.from_pretrained(spec.tokenizer_id)
self.model = AutoModel.from_pretrained(spec.model_id)
self.model.to(target_device)
self.model.eval()
elif spec.family == "t5_encoder":
self.tokenizer = T5Tokenizer.from_pretrained(spec.tokenizer_id, do_lower_case=False)
self.model = T5EncoderModel.from_pretrained(spec.model_id)
self.model.to(target_device)
self.model.eval()
elif spec.family == "esmc":
from esm.models.esmc import ESMC
self.model = ESMC.from_pretrained(spec.model_id).to(target_device)
self.model.eval()
self.tokenizer = None
elif spec.family == "prosst":
ensure_prosst_repo()
self.tokenizer = AutoTokenizer.from_pretrained(
spec.tokenizer_id,
trust_remote_code=True,
)
self.model = AutoModelForMaskedLM.from_pretrained(
spec.model_id,
trust_remote_code=True,
output_hidden_states=True,
)
self.model.to(target_device)
self.model.eval()
from prosst.structure.get_sst_seq import SSTPredictor
self.sst_predictor = SSTPredictor()
else:
raise ValueError(f"Unsupported family: {spec.family}")
self.model_key = model_key
self.family = spec.family
self.device = target_device
RUNNER = SingleModelRunner()
@torch.no_grad()
def embed_hf_encoder(seq: str) -> torch.Tensor:
enc = RUNNER.tokenizer(
seq,
return_tensors="pt",
add_special_tokens=True,
return_special_tokens_mask=True,
truncation=False,
)
enc = {k: v.to(RUNNER.device) for k, v in enc.items()}
out = RUNNER.model(**{k: v for k, v in enc.items() if k != "special_tokens_mask"})
hidden = out.last_hidden_state[0]
emb = normalize_to_Ld(
hidden=hidden,
expected_len=len(seq),
special_tokens_mask=enc.get("special_tokens_mask", None)[0] if enc.get("special_tokens_mask", None) is not None else None,
attention_mask=enc.get("attention_mask", None)[0] if enc.get("attention_mask", None) is not None else None,
)
return emb.detach().cpu().float()
@torch.no_grad()
def embed_t5_encoder(seq: str) -> torch.Tensor:
spaced = protein_to_spaced(seq)
enc = RUNNER.tokenizer(
spaced,
return_tensors="pt",
add_special_tokens=True,
return_special_tokens_mask=True,
truncation=False,
)
enc = {k: v.to(RUNNER.device) for k, v in enc.items()}
out = RUNNER.model(**{k: v for k, v in enc.items() if k != "special_tokens_mask"})
hidden = out.last_hidden_state[0]
emb = normalize_to_Ld(
hidden=hidden,
expected_len=len(seq),
special_tokens_mask=enc.get("special_tokens_mask", None)[0] if enc.get("special_tokens_mask", None) is not None else None,
attention_mask=enc.get("attention_mask", None)[0] if enc.get("attention_mask", None) is not None else None,
)
return emb.detach().cpu().float()
@torch.no_grad()
def embed_esmc(seq: str) -> torch.Tensor:
from esm.sdk.api import ESMProtein, LogitsConfig
protein = ESMProtein(sequence=seq)
protein_tensor = RUNNER.model.encode(protein)
out = RUNNER.model.logits(
protein_tensor,
LogitsConfig(sequence=True, return_embeddings=True)
)
emb = out.embeddings
if not isinstance(emb, torch.Tensor):
emb = torch.tensor(emb)
if emb.ndim == 3:
emb = emb[0]
if emb.shape[0] == len(seq):
return emb.detach().cpu().float()
if emb.shape[0] == len(seq) + 2:
return emb[1:-1].detach().cpu().float()
if emb.shape[0] == len(seq) + 1:
return emb[:len(seq)].detach().cpu().float()
if emb.shape[0] > len(seq):
return emb[:len(seq)].detach().cpu().float()
raise ValueError(f"ESMC returned shape {tuple(emb.shape)} for sequence length {len(seq)}.")
def get_sst_tokens(seq: str) -> List[int]:
sst = RUNNER.sst_predictor.predict(seq)
print("SST raw type:", type(sst))
print("SST raw repr:", repr(sst)[:500])
if isinstance(sst, str):
tokens = [int(x) for x in sst.strip().split()]
elif isinstance(sst, torch.Tensor):
tokens = sst.detach().cpu().view(-1).tolist()
elif hasattr(sst, "tolist"):
tokens = sst.tolist()
if isinstance(tokens, list) and len(tokens) > 0 and isinstance(tokens[0], list):
tokens = tokens[0]
elif isinstance(sst, (list, tuple)):
tokens = list(sst)
else:
raise ValueError(f"Unsupported SSTPredictor output type: {type(sst)}")
tokens = [int(x) for x in tokens]
if len(tokens) == len(seq) + 2:
tokens = tokens[1:-1]
elif len(tokens) == len(seq) + 1:
tokens = tokens[:len(seq)]
elif len(tokens) > len(seq):
tokens = tokens[:len(seq)]
if len(tokens) != len(seq):
raise ValueError(f"SST token length mismatch: got {len(tokens)}, expected {len(seq)}")
print("SST final length:", len(tokens))
print("SST first 30:", tokens[:30])
return tokens
@torch.no_grad()
def embed_prosst(seq: str) -> Tuple[torch.Tensor, List[int]]:
sst_tokens = get_sst_tokens(seq)
aa_spaced = protein_to_spaced(seq)
seq_enc = RUNNER.tokenizer(
aa_spaced,
return_tensors="pt",
add_special_tokens=True,
return_special_tokens_mask=True,
truncation=False,
)
seq_enc = {k: v.to(RUNNER.device) for k, v in seq_enc.items()}
sst_ids = torch.tensor([sst_tokens], dtype=torch.long, device=RUNNER.device)
tried = []
for kw in ("ss_input_ids", "structure_ids", "sst_input_ids", "struc_input_ids"):
try:
out = RUNNER.model(
input_ids=seq_enc["input_ids"],
attention_mask=seq_enc.get("attention_mask", None),
output_hidden_states=True,
return_dict=True,
**{kw: sst_ids},
)
if getattr(out, "hidden_states", None) is None:
raise RuntimeError("ProSST output has no hidden_states")
hidden = out.hidden_states[-1][0]
emb = normalize_to_Ld(
hidden=hidden,
expected_len=len(seq),
special_tokens_mask=seq_enc.get("special_tokens_mask", None)[0] if seq_enc.get("special_tokens_mask", None) is not None else None,
attention_mask=seq_enc.get("attention_mask", None)[0] if seq_enc.get("attention_mask", None) is not None else None,
)
return emb.detach().cpu().float(), sst_tokens
except Exception as e:
tried.append(f"{kw}: {repr(e)}")
raise RuntimeError(
"Failed to run ProSST with known structure-token arg names: " + " | ".join(tried)
)
def embed_one_sequence(seq: str):
if RUNNER.family == "hf_encoder":
return embed_hf_encoder(seq), None
if RUNNER.family == "t5_encoder":
return embed_t5_encoder(seq), None
if RUNNER.family == "esmc":
return embed_esmc(seq), None
if RUNNER.family == "prosst":
return embed_prosst(seq)
raise ValueError(f"Unsupported family: {RUNNER.family}")
def run_embedding(fasta_text: str, model_keys: List[str], device: str, progress=gr.Progress()):
if not model_keys:
raise ValueError("Please select at least one model.")
records = parse_fasta(fasta_text)
records = [{"id": r["id"], "sequence": clean_sequence(r["sequence"])} for r in records]
tmpdir = tempfile.mkdtemp(prefix="protein_embeddings_")
zip_path = os.path.join(tmpdir, "embeddings.zip")
total_steps = len(model_keys) * len(records)
step = 0
with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
for model_key in model_keys:
RUNNER.load(model_key, device)
for rec in records:
step += 1
progress(step / total_steps, desc=f"{model_key} | {rec['id']}")
emb, sst_tokens = embed_one_sequence(rec["sequence"])
if emb.ndim != 2 or emb.shape[0] != len(rec["sequence"]):
raise ValueError(
f"{model_key} failed on {rec['id']}: got shape {tuple(emb.shape)}, expected ({len(rec['sequence'])}, d)"
)
pt_name = f"{safe_filename(model_key)}/{safe_filename(rec['id'])}.pt"
pt_buf = io.BytesIO()
torch.save(emb, pt_buf)
zf.writestr(pt_name, pt_buf.getvalue())
if sst_tokens is not None:
tok_name = f"{safe_filename(model_key)}_structure_tokens/{safe_filename(rec['id'])}.txt"
zf.writestr(tok_name, " ".join(map(str, sst_tokens)))
return zip_path, f"Done: {len(records)} sequence(s), {len(model_keys)} model(s)."
def clear_cache():
RUNNER.unload()
return "Cache cleared."
EXAMPLE_FASTA = """>seq1
MKWVTFISLLLLFSSAYSRGVFRRDTHKSEIAHRFKDLGE
>seq2
GAVLILKKKGHHEAELKPLAQSHATKHKIPIKYLEFISEAIIHVLHSR
"""
with gr.Blocks(title=APP_TITLE) as demo:
gr.Markdown(f"# {APP_TITLE}")
fasta_input = gr.Textbox(
label="FASTA",
lines=16,
value=EXAMPLE_FASTA,
placeholder="Paste FASTA here",
)
model_select = gr.CheckboxGroup(
choices=list(MODEL_SPECS.keys()),
value=["ESM2-150M"],
label="Models",
)
device_select = gr.Dropdown(
choices=["auto", "cuda", "cpu"],
value="auto",
label="Device",
)
with gr.Row():
run_btn = gr.Button("Run", variant="primary")
clear_btn = gr.Button("Clear cache")
output_file = gr.File(label="Download")
log_box = gr.Textbox(label="Log", lines=4)
run_btn.click(
fn=run_embedding,
inputs=[fasta_input, model_select, device_select],
outputs=[output_file, log_box],
)
clear_btn.click(
fn=clear_cache,
inputs=[],
outputs=[log_box],
)
demo.queue(max_size=8)
demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)