import gc import io import os import re import sys import zipfile import tempfile import subprocess from dataclasses import dataclass from typing import Dict, List, Optional, Tuple import gradio as gr import torch from transformers import ( AutoModel, AutoModelForMaskedLM, AutoTokenizer, T5EncoderModel, T5Tokenizer, ) APP_TITLE = "Protein Embedding" ALLOWED_AA = set(list("ACDEFGHIKLMNPQRSTVWYXBZJUO")) REPLACE_WITH_X = set(list("UZOB")) PROSST_REPO_DIR = "/tmp/ProSST" @dataclass class ModelSpec: name: str family: str model_id: str tokenizer_id: Optional[str] = None MODEL_SPECS: Dict[str, ModelSpec] = { "ESM2-8M": ModelSpec( name="ESM2-8M", family="hf_encoder", model_id="facebook/esm2_t6_8M_UR50D", tokenizer_id="facebook/esm2_t6_8M_UR50D", ), "ESM2-35M": ModelSpec( name="ESM2-35M", family="hf_encoder", model_id="facebook/esm2_t12_35M_UR50D", tokenizer_id="facebook/esm2_t12_35M_UR50D", ), "ESM2-150M": ModelSpec( name="ESM2-150M", family="hf_encoder", model_id="facebook/esm2_t30_150M_UR50D", tokenizer_id="facebook/esm2_t30_150M_UR50D", ), "ESM2-650M": ModelSpec( name="ESM2-650M", family="hf_encoder", model_id="facebook/esm2_t33_650M_UR50D", tokenizer_id="facebook/esm2_t33_650M_UR50D", ), "ESMC-300M": ModelSpec( name="ESMC-300M", family="esmc", model_id="esmc_300m", ), "ESMC-600M": ModelSpec( name="ESMC-600M", family="esmc", model_id="esmc_600m", ), "Ankh-Base": ModelSpec( name="Ankh-Base", family="hf_encoder", model_id="ElnaggarLab/ankh-base", tokenizer_id="ElnaggarLab/ankh-base", ), "Ankh-Large": ModelSpec( name="Ankh-Large", family="hf_encoder", model_id="ElnaggarLab/ankh-large", tokenizer_id="ElnaggarLab/ankh-large", ), "ProtT5-XL-Encoder": ModelSpec( name="ProtT5-XL-Encoder", family="t5_encoder", model_id="Rostlab/prot_t5_xl_half_uniref50-enc", tokenizer_id="Rostlab/prot_t5_xl_half_uniref50-enc", ), "ProSST-2048": ModelSpec( name="ProSST-2048", family="prosst", model_id="AI4Protein/ProSST-2048", tokenizer_id="AI4Protein/ProSST-2048", ), } def resolve_device(device: str) -> str: if device == "auto": return "cuda" if torch.cuda.is_available() else "cpu" if device == "cuda" and not torch.cuda.is_available(): return "cpu" return device def safe_filename(x: str) -> str: x = re.sub(r"[^A-Za-z0-9._-]+", "_", x) x = x.strip("._") return x or "item" def parse_fasta(text: str) -> List[Dict[str, str]]: text = text.strip() if not text: raise ValueError("Empty FASTA input.") records = [] current_id = None current_seq = [] for raw_line in text.splitlines(): line = raw_line.strip() if not line: continue if line.startswith(">"): if current_id is not None: seq = "".join(current_seq).strip() if not seq: raise ValueError(f"Sequence for record '{current_id}' is empty.") records.append({"id": current_id, "sequence": seq}) current_id = line[1:].strip() or f"seq_{len(records)+1}" current_seq = [] else: if current_id is None: current_id = f"seq_{len(records)+1}" current_seq.append(line) if current_id is not None: seq = "".join(current_seq).strip() if not seq: raise ValueError(f"Sequence for record '{current_id}' is empty.") records.append({"id": current_id, "sequence": seq}) if not records: raise ValueError("No FASTA records found.") return records def clean_sequence(seq: str) -> str: seq = re.sub(r"\s+", "", seq).upper() if not seq: raise ValueError("Empty sequence after cleaning.") bad = sorted({c for c in seq if c not in ALLOWED_AA}) if bad: raise ValueError(f"Invalid amino acid letters found: {bad}") for c in REPLACE_WITH_X: seq = seq.replace(c, "X") return seq def protein_to_spaced(seq: str) -> str: return " ".join(list(seq)) def normalize_to_Ld( hidden: torch.Tensor, expected_len: int, special_tokens_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: if hidden.ndim != 2: raise ValueError(f"Expected [T, d], got {tuple(hidden.shape)}") T = hidden.shape[0] if special_tokens_mask is not None: keep = ~special_tokens_mask.bool().view(-1) if attention_mask is not None: keep = keep & attention_mask.bool().view(-1) filtered = hidden[keep] if filtered.shape[0] == expected_len: return filtered if filtered.shape[0] > expected_len: return filtered[:expected_len] if T == expected_len: return hidden if T == expected_len + 2: return hidden[1:-1] if T == expected_len + 1: return hidden[:expected_len] if T > expected_len: return hidden[:expected_len] raise ValueError(f"Cannot normalize token length {T} to residue length {expected_len}.") def ensure_prosst_repo(): if os.path.isdir(PROSST_REPO_DIR) and os.path.isdir(os.path.join(PROSST_REPO_DIR, "prosst")): if PROSST_REPO_DIR not in sys.path: sys.path.append(PROSST_REPO_DIR) return subprocess.run( ["git", "clone", "--depth", "1", "https://github.com/openmedlab/ProSST.git", PROSST_REPO_DIR], check=True, ) if PROSST_REPO_DIR not in sys.path: sys.path.append(PROSST_REPO_DIR) class SingleModelRunner: def __init__(self): self.model_key = None self.family = None self.device = None self.model = None self.tokenizer = None self.sst_predictor = None def unload(self): self.model_key = None self.family = None self.device = None self.model = None self.tokenizer = None self.sst_predictor = None gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def load(self, model_key: str, device: str): target_device = resolve_device(device) if self.model_key == model_key and self.device == target_device and self.model is not None: return self.unload() spec = MODEL_SPECS[model_key] if spec.family == "hf_encoder": self.tokenizer = AutoTokenizer.from_pretrained(spec.tokenizer_id) self.model = AutoModel.from_pretrained(spec.model_id) self.model.to(target_device) self.model.eval() elif spec.family == "t5_encoder": self.tokenizer = T5Tokenizer.from_pretrained(spec.tokenizer_id, do_lower_case=False) self.model = T5EncoderModel.from_pretrained(spec.model_id) self.model.to(target_device) self.model.eval() elif spec.family == "esmc": from esm.models.esmc import ESMC self.model = ESMC.from_pretrained(spec.model_id).to(target_device) self.model.eval() self.tokenizer = None elif spec.family == "prosst": ensure_prosst_repo() self.tokenizer = AutoTokenizer.from_pretrained( spec.tokenizer_id, trust_remote_code=True, ) self.model = AutoModelForMaskedLM.from_pretrained( spec.model_id, trust_remote_code=True, output_hidden_states=True, ) self.model.to(target_device) self.model.eval() from prosst.structure.get_sst_seq import SSTPredictor self.sst_predictor = SSTPredictor() else: raise ValueError(f"Unsupported family: {spec.family}") self.model_key = model_key self.family = spec.family self.device = target_device RUNNER = SingleModelRunner() @torch.no_grad() def embed_hf_encoder(seq: str) -> torch.Tensor: enc = RUNNER.tokenizer( seq, return_tensors="pt", add_special_tokens=True, return_special_tokens_mask=True, truncation=False, ) enc = {k: v.to(RUNNER.device) for k, v in enc.items()} out = RUNNER.model(**{k: v for k, v in enc.items() if k != "special_tokens_mask"}) hidden = out.last_hidden_state[0] emb = normalize_to_Ld( hidden=hidden, expected_len=len(seq), special_tokens_mask=enc.get("special_tokens_mask", None)[0] if enc.get("special_tokens_mask", None) is not None else None, attention_mask=enc.get("attention_mask", None)[0] if enc.get("attention_mask", None) is not None else None, ) return emb.detach().cpu().float() @torch.no_grad() def embed_t5_encoder(seq: str) -> torch.Tensor: spaced = protein_to_spaced(seq) enc = RUNNER.tokenizer( spaced, return_tensors="pt", add_special_tokens=True, return_special_tokens_mask=True, truncation=False, ) enc = {k: v.to(RUNNER.device) for k, v in enc.items()} out = RUNNER.model(**{k: v for k, v in enc.items() if k != "special_tokens_mask"}) hidden = out.last_hidden_state[0] emb = normalize_to_Ld( hidden=hidden, expected_len=len(seq), special_tokens_mask=enc.get("special_tokens_mask", None)[0] if enc.get("special_tokens_mask", None) is not None else None, attention_mask=enc.get("attention_mask", None)[0] if enc.get("attention_mask", None) is not None else None, ) return emb.detach().cpu().float() @torch.no_grad() def embed_esmc(seq: str) -> torch.Tensor: from esm.sdk.api import ESMProtein, LogitsConfig protein = ESMProtein(sequence=seq) protein_tensor = RUNNER.model.encode(protein) out = RUNNER.model.logits( protein_tensor, LogitsConfig(sequence=True, return_embeddings=True) ) emb = out.embeddings if not isinstance(emb, torch.Tensor): emb = torch.tensor(emb) if emb.ndim == 3: emb = emb[0] if emb.shape[0] == len(seq): return emb.detach().cpu().float() if emb.shape[0] == len(seq) + 2: return emb[1:-1].detach().cpu().float() if emb.shape[0] == len(seq) + 1: return emb[:len(seq)].detach().cpu().float() if emb.shape[0] > len(seq): return emb[:len(seq)].detach().cpu().float() raise ValueError(f"ESMC returned shape {tuple(emb.shape)} for sequence length {len(seq)}.") def get_sst_tokens(seq: str) -> List[int]: sst = RUNNER.sst_predictor.predict(seq) print("SST raw type:", type(sst)) print("SST raw repr:", repr(sst)[:500]) if isinstance(sst, str): tokens = [int(x) for x in sst.strip().split()] elif isinstance(sst, torch.Tensor): tokens = sst.detach().cpu().view(-1).tolist() elif hasattr(sst, "tolist"): tokens = sst.tolist() if isinstance(tokens, list) and len(tokens) > 0 and isinstance(tokens[0], list): tokens = tokens[0] elif isinstance(sst, (list, tuple)): tokens = list(sst) else: raise ValueError(f"Unsupported SSTPredictor output type: {type(sst)}") tokens = [int(x) for x in tokens] if len(tokens) == len(seq) + 2: tokens = tokens[1:-1] elif len(tokens) == len(seq) + 1: tokens = tokens[:len(seq)] elif len(tokens) > len(seq): tokens = tokens[:len(seq)] if len(tokens) != len(seq): raise ValueError(f"SST token length mismatch: got {len(tokens)}, expected {len(seq)}") print("SST final length:", len(tokens)) print("SST first 30:", tokens[:30]) return tokens @torch.no_grad() def embed_prosst(seq: str) -> Tuple[torch.Tensor, List[int]]: sst_tokens = get_sst_tokens(seq) aa_spaced = protein_to_spaced(seq) seq_enc = RUNNER.tokenizer( aa_spaced, return_tensors="pt", add_special_tokens=True, return_special_tokens_mask=True, truncation=False, ) seq_enc = {k: v.to(RUNNER.device) for k, v in seq_enc.items()} sst_ids = torch.tensor([sst_tokens], dtype=torch.long, device=RUNNER.device) tried = [] for kw in ("ss_input_ids", "structure_ids", "sst_input_ids", "struc_input_ids"): try: out = RUNNER.model( input_ids=seq_enc["input_ids"], attention_mask=seq_enc.get("attention_mask", None), output_hidden_states=True, return_dict=True, **{kw: sst_ids}, ) if getattr(out, "hidden_states", None) is None: raise RuntimeError("ProSST output has no hidden_states") hidden = out.hidden_states[-1][0] emb = normalize_to_Ld( hidden=hidden, expected_len=len(seq), special_tokens_mask=seq_enc.get("special_tokens_mask", None)[0] if seq_enc.get("special_tokens_mask", None) is not None else None, attention_mask=seq_enc.get("attention_mask", None)[0] if seq_enc.get("attention_mask", None) is not None else None, ) return emb.detach().cpu().float(), sst_tokens except Exception as e: tried.append(f"{kw}: {repr(e)}") raise RuntimeError( "Failed to run ProSST with known structure-token arg names: " + " | ".join(tried) ) def embed_one_sequence(seq: str): if RUNNER.family == "hf_encoder": return embed_hf_encoder(seq), None if RUNNER.family == "t5_encoder": return embed_t5_encoder(seq), None if RUNNER.family == "esmc": return embed_esmc(seq), None if RUNNER.family == "prosst": return embed_prosst(seq) raise ValueError(f"Unsupported family: {RUNNER.family}") def run_embedding(fasta_text: str, model_keys: List[str], device: str, progress=gr.Progress()): if not model_keys: raise ValueError("Please select at least one model.") records = parse_fasta(fasta_text) records = [{"id": r["id"], "sequence": clean_sequence(r["sequence"])} for r in records] tmpdir = tempfile.mkdtemp(prefix="protein_embeddings_") zip_path = os.path.join(tmpdir, "embeddings.zip") total_steps = len(model_keys) * len(records) step = 0 with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf: for model_key in model_keys: RUNNER.load(model_key, device) for rec in records: step += 1 progress(step / total_steps, desc=f"{model_key} | {rec['id']}") emb, sst_tokens = embed_one_sequence(rec["sequence"]) if emb.ndim != 2 or emb.shape[0] != len(rec["sequence"]): raise ValueError( f"{model_key} failed on {rec['id']}: got shape {tuple(emb.shape)}, expected ({len(rec['sequence'])}, d)" ) pt_name = f"{safe_filename(model_key)}/{safe_filename(rec['id'])}.pt" pt_buf = io.BytesIO() torch.save(emb, pt_buf) zf.writestr(pt_name, pt_buf.getvalue()) if sst_tokens is not None: tok_name = f"{safe_filename(model_key)}_structure_tokens/{safe_filename(rec['id'])}.txt" zf.writestr(tok_name, " ".join(map(str, sst_tokens))) return zip_path, f"Done: {len(records)} sequence(s), {len(model_keys)} model(s)." def clear_cache(): RUNNER.unload() return "Cache cleared." EXAMPLE_FASTA = """>seq1 MKWVTFISLLLLFSSAYSRGVFRRDTHKSEIAHRFKDLGE >seq2 GAVLILKKKGHHEAELKPLAQSHATKHKIPIKYLEFISEAIIHVLHSR """ with gr.Blocks(title=APP_TITLE) as demo: gr.Markdown(f"# {APP_TITLE}") fasta_input = gr.Textbox( label="FASTA", lines=16, value=EXAMPLE_FASTA, placeholder="Paste FASTA here", ) model_select = gr.CheckboxGroup( choices=list(MODEL_SPECS.keys()), value=["ESM2-150M"], label="Models", ) device_select = gr.Dropdown( choices=["auto", "cuda", "cpu"], value="auto", label="Device", ) with gr.Row(): run_btn = gr.Button("Run", variant="primary") clear_btn = gr.Button("Clear cache") output_file = gr.File(label="Download") log_box = gr.Textbox(label="Log", lines=4) run_btn.click( fn=run_embedding, inputs=[fasta_input, model_select, device_select], outputs=[output_file, log_box], ) clear_btn.click( fn=clear_cache, inputs=[], outputs=[log_box], ) demo.queue(max_size=8) demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)