Spaces:
Running
Running
| import gc | |
| import io | |
| import os | |
| import re | |
| import sys | |
| import zipfile | |
| import tempfile | |
| import subprocess | |
| from dataclasses import dataclass | |
| from typing import Dict, List, Optional, Tuple | |
| import gradio as gr | |
| import torch | |
| from transformers import ( | |
| AutoModel, | |
| AutoModelForMaskedLM, | |
| AutoTokenizer, | |
| T5EncoderModel, | |
| T5Tokenizer, | |
| ) | |
| APP_TITLE = "Protein Embedding" | |
| ALLOWED_AA = set(list("ACDEFGHIKLMNPQRSTVWYXBZJUO")) | |
| REPLACE_WITH_X = set(list("UZOB")) | |
| PROSST_REPO_DIR = "/tmp/ProSST" | |
| class ModelSpec: | |
| name: str | |
| family: str | |
| model_id: str | |
| tokenizer_id: Optional[str] = None | |
| MODEL_SPECS: Dict[str, ModelSpec] = { | |
| "ESM2-8M": ModelSpec( | |
| name="ESM2-8M", | |
| family="hf_encoder", | |
| model_id="facebook/esm2_t6_8M_UR50D", | |
| tokenizer_id="facebook/esm2_t6_8M_UR50D", | |
| ), | |
| "ESM2-35M": ModelSpec( | |
| name="ESM2-35M", | |
| family="hf_encoder", | |
| model_id="facebook/esm2_t12_35M_UR50D", | |
| tokenizer_id="facebook/esm2_t12_35M_UR50D", | |
| ), | |
| "ESM2-150M": ModelSpec( | |
| name="ESM2-150M", | |
| family="hf_encoder", | |
| model_id="facebook/esm2_t30_150M_UR50D", | |
| tokenizer_id="facebook/esm2_t30_150M_UR50D", | |
| ), | |
| "ESM2-650M": ModelSpec( | |
| name="ESM2-650M", | |
| family="hf_encoder", | |
| model_id="facebook/esm2_t33_650M_UR50D", | |
| tokenizer_id="facebook/esm2_t33_650M_UR50D", | |
| ), | |
| "ESMC-300M": ModelSpec( | |
| name="ESMC-300M", | |
| family="esmc", | |
| model_id="esmc_300m", | |
| ), | |
| "ESMC-600M": ModelSpec( | |
| name="ESMC-600M", | |
| family="esmc", | |
| model_id="esmc_600m", | |
| ), | |
| "Ankh-Base": ModelSpec( | |
| name="Ankh-Base", | |
| family="hf_encoder", | |
| model_id="ElnaggarLab/ankh-base", | |
| tokenizer_id="ElnaggarLab/ankh-base", | |
| ), | |
| "Ankh-Large": ModelSpec( | |
| name="Ankh-Large", | |
| family="hf_encoder", | |
| model_id="ElnaggarLab/ankh-large", | |
| tokenizer_id="ElnaggarLab/ankh-large", | |
| ), | |
| "ProtT5-XL-Encoder": ModelSpec( | |
| name="ProtT5-XL-Encoder", | |
| family="t5_encoder", | |
| model_id="Rostlab/prot_t5_xl_half_uniref50-enc", | |
| tokenizer_id="Rostlab/prot_t5_xl_half_uniref50-enc", | |
| ), | |
| "ProSST-2048": ModelSpec( | |
| name="ProSST-2048", | |
| family="prosst", | |
| model_id="AI4Protein/ProSST-2048", | |
| tokenizer_id="AI4Protein/ProSST-2048", | |
| ), | |
| } | |
| def resolve_device(device: str) -> str: | |
| if device == "auto": | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| if device == "cuda" and not torch.cuda.is_available(): | |
| return "cpu" | |
| return device | |
| def safe_filename(x: str) -> str: | |
| x = re.sub(r"[^A-Za-z0-9._-]+", "_", x) | |
| x = x.strip("._") | |
| return x or "item" | |
| def parse_fasta(text: str) -> List[Dict[str, str]]: | |
| text = text.strip() | |
| if not text: | |
| raise ValueError("Empty FASTA input.") | |
| records = [] | |
| current_id = None | |
| current_seq = [] | |
| for raw_line in text.splitlines(): | |
| line = raw_line.strip() | |
| if not line: | |
| continue | |
| if line.startswith(">"): | |
| if current_id is not None: | |
| seq = "".join(current_seq).strip() | |
| if not seq: | |
| raise ValueError(f"Sequence for record '{current_id}' is empty.") | |
| records.append({"id": current_id, "sequence": seq}) | |
| current_id = line[1:].strip() or f"seq_{len(records)+1}" | |
| current_seq = [] | |
| else: | |
| if current_id is None: | |
| current_id = f"seq_{len(records)+1}" | |
| current_seq.append(line) | |
| if current_id is not None: | |
| seq = "".join(current_seq).strip() | |
| if not seq: | |
| raise ValueError(f"Sequence for record '{current_id}' is empty.") | |
| records.append({"id": current_id, "sequence": seq}) | |
| if not records: | |
| raise ValueError("No FASTA records found.") | |
| return records | |
| def clean_sequence(seq: str) -> str: | |
| seq = re.sub(r"\s+", "", seq).upper() | |
| if not seq: | |
| raise ValueError("Empty sequence after cleaning.") | |
| bad = sorted({c for c in seq if c not in ALLOWED_AA}) | |
| if bad: | |
| raise ValueError(f"Invalid amino acid letters found: {bad}") | |
| for c in REPLACE_WITH_X: | |
| seq = seq.replace(c, "X") | |
| return seq | |
| def protein_to_spaced(seq: str) -> str: | |
| return " ".join(list(seq)) | |
| def normalize_to_Ld( | |
| hidden: torch.Tensor, | |
| expected_len: int, | |
| special_tokens_mask: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| if hidden.ndim != 2: | |
| raise ValueError(f"Expected [T, d], got {tuple(hidden.shape)}") | |
| T = hidden.shape[0] | |
| if special_tokens_mask is not None: | |
| keep = ~special_tokens_mask.bool().view(-1) | |
| if attention_mask is not None: | |
| keep = keep & attention_mask.bool().view(-1) | |
| filtered = hidden[keep] | |
| if filtered.shape[0] == expected_len: | |
| return filtered | |
| if filtered.shape[0] > expected_len: | |
| return filtered[:expected_len] | |
| if T == expected_len: | |
| return hidden | |
| if T == expected_len + 2: | |
| return hidden[1:-1] | |
| if T == expected_len + 1: | |
| return hidden[:expected_len] | |
| if T > expected_len: | |
| return hidden[:expected_len] | |
| raise ValueError(f"Cannot normalize token length {T} to residue length {expected_len}.") | |
| def ensure_prosst_repo(): | |
| if os.path.isdir(PROSST_REPO_DIR) and os.path.isdir(os.path.join(PROSST_REPO_DIR, "prosst")): | |
| if PROSST_REPO_DIR not in sys.path: | |
| sys.path.append(PROSST_REPO_DIR) | |
| return | |
| subprocess.run( | |
| ["git", "clone", "--depth", "1", "https://github.com/openmedlab/ProSST.git", PROSST_REPO_DIR], | |
| check=True, | |
| ) | |
| if PROSST_REPO_DIR not in sys.path: | |
| sys.path.append(PROSST_REPO_DIR) | |
| class SingleModelRunner: | |
| def __init__(self): | |
| self.model_key = None | |
| self.family = None | |
| self.device = None | |
| self.model = None | |
| self.tokenizer = None | |
| self.sst_predictor = None | |
| def unload(self): | |
| self.model_key = None | |
| self.family = None | |
| self.device = None | |
| self.model = None | |
| self.tokenizer = None | |
| self.sst_predictor = None | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def load(self, model_key: str, device: str): | |
| target_device = resolve_device(device) | |
| if self.model_key == model_key and self.device == target_device and self.model is not None: | |
| return | |
| self.unload() | |
| spec = MODEL_SPECS[model_key] | |
| if spec.family == "hf_encoder": | |
| self.tokenizer = AutoTokenizer.from_pretrained(spec.tokenizer_id) | |
| self.model = AutoModel.from_pretrained(spec.model_id) | |
| self.model.to(target_device) | |
| self.model.eval() | |
| elif spec.family == "t5_encoder": | |
| self.tokenizer = T5Tokenizer.from_pretrained(spec.tokenizer_id, do_lower_case=False) | |
| self.model = T5EncoderModel.from_pretrained(spec.model_id) | |
| self.model.to(target_device) | |
| self.model.eval() | |
| elif spec.family == "esmc": | |
| from esm.models.esmc import ESMC | |
| self.model = ESMC.from_pretrained(spec.model_id).to(target_device) | |
| self.model.eval() | |
| self.tokenizer = None | |
| elif spec.family == "prosst": | |
| ensure_prosst_repo() | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| spec.tokenizer_id, | |
| trust_remote_code=True, | |
| ) | |
| self.model = AutoModelForMaskedLM.from_pretrained( | |
| spec.model_id, | |
| trust_remote_code=True, | |
| output_hidden_states=True, | |
| ) | |
| self.model.to(target_device) | |
| self.model.eval() | |
| from prosst.structure.get_sst_seq import SSTPredictor | |
| self.sst_predictor = SSTPredictor() | |
| else: | |
| raise ValueError(f"Unsupported family: {spec.family}") | |
| self.model_key = model_key | |
| self.family = spec.family | |
| self.device = target_device | |
| RUNNER = SingleModelRunner() | |
| def embed_hf_encoder(seq: str) -> torch.Tensor: | |
| enc = RUNNER.tokenizer( | |
| seq, | |
| return_tensors="pt", | |
| add_special_tokens=True, | |
| return_special_tokens_mask=True, | |
| truncation=False, | |
| ) | |
| enc = {k: v.to(RUNNER.device) for k, v in enc.items()} | |
| out = RUNNER.model(**{k: v for k, v in enc.items() if k != "special_tokens_mask"}) | |
| hidden = out.last_hidden_state[0] | |
| emb = normalize_to_Ld( | |
| hidden=hidden, | |
| expected_len=len(seq), | |
| special_tokens_mask=enc.get("special_tokens_mask", None)[0] if enc.get("special_tokens_mask", None) is not None else None, | |
| attention_mask=enc.get("attention_mask", None)[0] if enc.get("attention_mask", None) is not None else None, | |
| ) | |
| return emb.detach().cpu().float() | |
| def embed_t5_encoder(seq: str) -> torch.Tensor: | |
| spaced = protein_to_spaced(seq) | |
| enc = RUNNER.tokenizer( | |
| spaced, | |
| return_tensors="pt", | |
| add_special_tokens=True, | |
| return_special_tokens_mask=True, | |
| truncation=False, | |
| ) | |
| enc = {k: v.to(RUNNER.device) for k, v in enc.items()} | |
| out = RUNNER.model(**{k: v for k, v in enc.items() if k != "special_tokens_mask"}) | |
| hidden = out.last_hidden_state[0] | |
| emb = normalize_to_Ld( | |
| hidden=hidden, | |
| expected_len=len(seq), | |
| special_tokens_mask=enc.get("special_tokens_mask", None)[0] if enc.get("special_tokens_mask", None) is not None else None, | |
| attention_mask=enc.get("attention_mask", None)[0] if enc.get("attention_mask", None) is not None else None, | |
| ) | |
| return emb.detach().cpu().float() | |
| def embed_esmc(seq: str) -> torch.Tensor: | |
| from esm.sdk.api import ESMProtein, LogitsConfig | |
| protein = ESMProtein(sequence=seq) | |
| protein_tensor = RUNNER.model.encode(protein) | |
| out = RUNNER.model.logits( | |
| protein_tensor, | |
| LogitsConfig(sequence=True, return_embeddings=True) | |
| ) | |
| emb = out.embeddings | |
| if not isinstance(emb, torch.Tensor): | |
| emb = torch.tensor(emb) | |
| if emb.ndim == 3: | |
| emb = emb[0] | |
| if emb.shape[0] == len(seq): | |
| return emb.detach().cpu().float() | |
| if emb.shape[0] == len(seq) + 2: | |
| return emb[1:-1].detach().cpu().float() | |
| if emb.shape[0] == len(seq) + 1: | |
| return emb[:len(seq)].detach().cpu().float() | |
| if emb.shape[0] > len(seq): | |
| return emb[:len(seq)].detach().cpu().float() | |
| raise ValueError(f"ESMC returned shape {tuple(emb.shape)} for sequence length {len(seq)}.") | |
| def get_sst_tokens(seq: str) -> List[int]: | |
| sst = RUNNER.sst_predictor.predict(seq) | |
| print("SST raw type:", type(sst)) | |
| print("SST raw repr:", repr(sst)[:500]) | |
| if isinstance(sst, str): | |
| tokens = [int(x) for x in sst.strip().split()] | |
| elif isinstance(sst, torch.Tensor): | |
| tokens = sst.detach().cpu().view(-1).tolist() | |
| elif hasattr(sst, "tolist"): | |
| tokens = sst.tolist() | |
| if isinstance(tokens, list) and len(tokens) > 0 and isinstance(tokens[0], list): | |
| tokens = tokens[0] | |
| elif isinstance(sst, (list, tuple)): | |
| tokens = list(sst) | |
| else: | |
| raise ValueError(f"Unsupported SSTPredictor output type: {type(sst)}") | |
| tokens = [int(x) for x in tokens] | |
| if len(tokens) == len(seq) + 2: | |
| tokens = tokens[1:-1] | |
| elif len(tokens) == len(seq) + 1: | |
| tokens = tokens[:len(seq)] | |
| elif len(tokens) > len(seq): | |
| tokens = tokens[:len(seq)] | |
| if len(tokens) != len(seq): | |
| raise ValueError(f"SST token length mismatch: got {len(tokens)}, expected {len(seq)}") | |
| print("SST final length:", len(tokens)) | |
| print("SST first 30:", tokens[:30]) | |
| return tokens | |
| def embed_prosst(seq: str) -> Tuple[torch.Tensor, List[int]]: | |
| sst_tokens = get_sst_tokens(seq) | |
| aa_spaced = protein_to_spaced(seq) | |
| seq_enc = RUNNER.tokenizer( | |
| aa_spaced, | |
| return_tensors="pt", | |
| add_special_tokens=True, | |
| return_special_tokens_mask=True, | |
| truncation=False, | |
| ) | |
| seq_enc = {k: v.to(RUNNER.device) for k, v in seq_enc.items()} | |
| sst_ids = torch.tensor([sst_tokens], dtype=torch.long, device=RUNNER.device) | |
| tried = [] | |
| for kw in ("ss_input_ids", "structure_ids", "sst_input_ids", "struc_input_ids"): | |
| try: | |
| out = RUNNER.model( | |
| input_ids=seq_enc["input_ids"], | |
| attention_mask=seq_enc.get("attention_mask", None), | |
| output_hidden_states=True, | |
| return_dict=True, | |
| **{kw: sst_ids}, | |
| ) | |
| if getattr(out, "hidden_states", None) is None: | |
| raise RuntimeError("ProSST output has no hidden_states") | |
| hidden = out.hidden_states[-1][0] | |
| emb = normalize_to_Ld( | |
| hidden=hidden, | |
| expected_len=len(seq), | |
| special_tokens_mask=seq_enc.get("special_tokens_mask", None)[0] if seq_enc.get("special_tokens_mask", None) is not None else None, | |
| attention_mask=seq_enc.get("attention_mask", None)[0] if seq_enc.get("attention_mask", None) is not None else None, | |
| ) | |
| return emb.detach().cpu().float(), sst_tokens | |
| except Exception as e: | |
| tried.append(f"{kw}: {repr(e)}") | |
| raise RuntimeError( | |
| "Failed to run ProSST with known structure-token arg names: " + " | ".join(tried) | |
| ) | |
| def embed_one_sequence(seq: str): | |
| if RUNNER.family == "hf_encoder": | |
| return embed_hf_encoder(seq), None | |
| if RUNNER.family == "t5_encoder": | |
| return embed_t5_encoder(seq), None | |
| if RUNNER.family == "esmc": | |
| return embed_esmc(seq), None | |
| if RUNNER.family == "prosst": | |
| return embed_prosst(seq) | |
| raise ValueError(f"Unsupported family: {RUNNER.family}") | |
| def run_embedding(fasta_text: str, model_keys: List[str], device: str, progress=gr.Progress()): | |
| if not model_keys: | |
| raise ValueError("Please select at least one model.") | |
| records = parse_fasta(fasta_text) | |
| records = [{"id": r["id"], "sequence": clean_sequence(r["sequence"])} for r in records] | |
| tmpdir = tempfile.mkdtemp(prefix="protein_embeddings_") | |
| zip_path = os.path.join(tmpdir, "embeddings.zip") | |
| total_steps = len(model_keys) * len(records) | |
| step = 0 | |
| with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf: | |
| for model_key in model_keys: | |
| RUNNER.load(model_key, device) | |
| for rec in records: | |
| step += 1 | |
| progress(step / total_steps, desc=f"{model_key} | {rec['id']}") | |
| emb, sst_tokens = embed_one_sequence(rec["sequence"]) | |
| if emb.ndim != 2 or emb.shape[0] != len(rec["sequence"]): | |
| raise ValueError( | |
| f"{model_key} failed on {rec['id']}: got shape {tuple(emb.shape)}, expected ({len(rec['sequence'])}, d)" | |
| ) | |
| pt_name = f"{safe_filename(model_key)}/{safe_filename(rec['id'])}.pt" | |
| pt_buf = io.BytesIO() | |
| torch.save(emb, pt_buf) | |
| zf.writestr(pt_name, pt_buf.getvalue()) | |
| if sst_tokens is not None: | |
| tok_name = f"{safe_filename(model_key)}_structure_tokens/{safe_filename(rec['id'])}.txt" | |
| zf.writestr(tok_name, " ".join(map(str, sst_tokens))) | |
| return zip_path, f"Done: {len(records)} sequence(s), {len(model_keys)} model(s)." | |
| def clear_cache(): | |
| RUNNER.unload() | |
| return "Cache cleared." | |
| EXAMPLE_FASTA = """>seq1 | |
| MKWVTFISLLLLFSSAYSRGVFRRDTHKSEIAHRFKDLGE | |
| >seq2 | |
| GAVLILKKKGHHEAELKPLAQSHATKHKIPIKYLEFISEAIIHVLHSR | |
| """ | |
| with gr.Blocks(title=APP_TITLE) as demo: | |
| gr.Markdown(f"# {APP_TITLE}") | |
| fasta_input = gr.Textbox( | |
| label="FASTA", | |
| lines=16, | |
| value=EXAMPLE_FASTA, | |
| placeholder="Paste FASTA here", | |
| ) | |
| model_select = gr.CheckboxGroup( | |
| choices=list(MODEL_SPECS.keys()), | |
| value=["ESM2-150M"], | |
| label="Models", | |
| ) | |
| device_select = gr.Dropdown( | |
| choices=["auto", "cuda", "cpu"], | |
| value="auto", | |
| label="Device", | |
| ) | |
| with gr.Row(): | |
| run_btn = gr.Button("Run", variant="primary") | |
| clear_btn = gr.Button("Clear cache") | |
| output_file = gr.File(label="Download") | |
| log_box = gr.Textbox(label="Log", lines=4) | |
| run_btn.click( | |
| fn=run_embedding, | |
| inputs=[fasta_input, model_select, device_select], | |
| outputs=[output_file, log_box], | |
| ) | |
| clear_btn.click( | |
| fn=clear_cache, | |
| inputs=[], | |
| outputs=[log_box], | |
| ) | |
| demo.queue(max_size=8) | |
| demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False) | |