import os import tempfile import h5py import numpy as np import gradio as gr import torch from transformers import EsmForMaskedLM, EsmTokenizer # ── constants ────────────────────────────────────────────────────────────────── MODEL_REGISTRY: dict[str, str | None] = { "PlantPLM-8M (7.5M params)": "dipayan26/PlantPLM-8M", "PlantPLM-35M (35M params) — coming soon": None, "PlantPLM-150M (150M params) — coming soon": None, "PlantPLM-650M (650M params) — coming soon": None, } MAX_LEN = 256 # AA per sequence on CPU MAX_SEQS = 50 # sequences per job MAX_FILES = 10 # uploaded files per job EXAMPLE_SEQ = ( "MSPQTETKASVGFKAGVKDYKLTYYTPEYETKDTDILAAFRVTPQPGVPPEEAGAAVAAESSTGT" "WTTPWTPTFGDDKIMASVGFKAGVKDYKLTYYTPEYETKDTDILAAFRVTPQPGVPPEEAGAAVA" ) # RuBisCO large subunit fragment — Spinacia oleracea EMB_PER_PROTEIN = "Per-protein · mean pool → shape [hidden_dim]" EMB_PER_AA = "Per-amino-acid · residue-level → shape [L × hidden_dim]" FMT_PT = ".pt — PyTorch tensor" FMT_H5 = ".h5 — HDF5" device = "cuda" if torch.cuda.is_available() else "cpu" ON_CPU = device == "cpu" # ── load model ───────────────────────────────────────────────────────────────── print("Loading PlantPLM-8M …") tokenizer: EsmTokenizer = EsmTokenizer.from_pretrained("dipayan26/PlantPLM-8M") model: EsmForMaskedLM = EsmForMaskedLM.from_pretrained("dipayan26/PlantPLM-8M") model.eval() model.to(device) # type: ignore[arg-type] print(f"Model ready on {device}.") # ── helpers ──────────────────────────────────────────────────────────────────── def parse_fasta(text: str) -> list[tuple[str, str]]: seqs: list[tuple[str, str]] = [] cur_id: str | None = None cur_seq: list[str] = [] for line in text.strip().splitlines(): line = line.strip() if not line: continue if line.startswith(">"): if cur_id is not None and cur_seq: seqs.append((cur_id, "".join(cur_seq))) header = line[1:].strip() cur_id = header.split()[0] if header else f"seq_{len(seqs) + 1}" cur_seq = [] else: cur_seq.append(line.upper()) if cur_id is not None and cur_seq: seqs.append((cur_id, "".join(cur_seq))) return seqs def read_file(f) -> str: path = f if isinstance(f, str) else getattr(f, "name", str(f)) with open(path, encoding="utf-8", errors="ignore") as fh: return fh.read() def embed_sequence(seq: str) -> tuple[np.ndarray, np.ndarray]: """ Returns: protein_emb : [hidden_dim] — mean pool over residue tokens aa_emb : [L, hidden_dim] — per-residue, CLS and EOS excluded """ inputs = tokenizer( seq, return_tensors="pt", truncation=True, max_length=MAX_LEN + 2 ).to(device) with torch.no_grad(): # last_hidden_state: [1, seq_len, hidden_dim] # seq_len = 1 (CLS) + L (AAs) + 1 (EOS) hidden = model.esm(**inputs).last_hidden_state hidden = hidden[0] # [seq_len, hidden_dim] n_total = int(inputs["attention_mask"][0].sum().item()) # CLS + L + EOS L = n_total - 2 aa_emb = hidden[1 : L + 1].cpu().numpy() # [L, hidden_dim] protein_emb = aa_emb.mean(axis=0) # [hidden_dim] return protein_emb, aa_emb # ── main generation function ─────────────────────────────────────────────────── def generate_embeddings( model_key: str, input_method: str, single_seq: str, fasta_text: str, fasta_files, emb_type: str, fmt: str, ): if MODEL_REGISTRY.get(model_key) is None: return "❌ This model is not yet available. Please select PlantPLM-8M.", None # ── collect sequences ─────────────────────────────────────────────────────── sequences: list[tuple[str, str]] = [] if input_method == "Single sequence": seq = single_seq.strip().upper().replace(" ", "") if not seq: return "❌ Please enter a protein sequence.", None sequences = [("seq_1", seq)] elif input_method == "Paste FASTA": txt = (fasta_text or "").strip() if not txt: return "❌ Please paste at least one sequence.", None if not txt.startswith(">"): sequences = [("seq_1", txt.upper().replace(" ", "").replace("\n", ""))] else: sequences = parse_fasta(txt) if not sequences: return "❌ No valid sequences found in the pasted text.", None elif input_method == "Upload files": if not fasta_files: return "❌ Please upload at least one FASTA file.", None files = fasta_files if isinstance(fasta_files, list) else [fasta_files] files = [f for f in files if f is not None] if len(files) > MAX_FILES: return f"❌ Too many files ({len(files)}). Maximum: {MAX_FILES}.", None for f in files: content = read_file(f) parsed = parse_fasta(content) if not parsed: raw = content.strip().upper().replace(" ", "") if raw: name = os.path.basename( f if isinstance(f, str) else f.name ).rsplit(".", 1)[0] parsed = [(name, raw)] sequences.extend(parsed) if not sequences: return "❌ No valid sequences found in uploaded files.", None # ── cap and truncate ──────────────────────────────────────────────────────── capped_warn = "" if len(sequences) > MAX_SEQS: sequences = sequences[:MAX_SEQS] capped_warn = f"⚠️ Only first {MAX_SEQS} sequences processed (per-job limit)." truncated: list[str] = [] clean: list[tuple[str, str]] = [] for sid, seq in sequences: if len(seq) > MAX_LEN: truncated.append(sid) seq = seq[:MAX_LEN] clean.append((sid, seq)) # ── embed ─────────────────────────────────────────────────────────────────── protein_embs: dict[str, np.ndarray] = {} aa_embs: dict[str, np.ndarray] = {} for sid, seq in clean: p, a = embed_sequence(seq) protein_embs[sid] = p aa_embs[sid] = a # ── select data to export ─────────────────────────────────────────────────── use_per_aa = EMB_PER_AA in emb_type data = aa_embs if use_per_aa else protein_embs use_h5 = FMT_H5 in fmt ext = ".h5" if use_h5 else ".pt" label = "aa" if use_per_aa else "protein" tmp = tempfile.NamedTemporaryFile( suffix=f"_{label}_emb{ext}", prefix="plantplm_", delete=False ) tmp.close() if use_h5: with h5py.File(tmp.name, "w") as hf: for sid, emb in data.items(): hf.create_dataset(sid, data=emb) else: torch.save( {sid: torch.from_numpy(emb) for sid, emb in data.items()}, tmp.name, ) # ── status log ────────────────────────────────────────────────────────────── D = next(iter(protein_embs.values())).shape[0] type_str = f"per-amino-acid [L × {D}]" if use_per_aa else f"per-protein [{D}]" lines = [ f"✅ {len(clean)} sequence(s) embedded on {device.upper()}.", f" Embedding type : {type_str}", f" File format : {ext}", f" Model : {model_key}", ] if capped_warn: lines.append(capped_warn) if truncated: shown = ", ".join(truncated[:5]) + ("…" if len(truncated) > 5 else "") lines.append(f"⚠️ {len(truncated)} sequence(s) truncated to {MAX_LEN} AA: {shown}") return "\n".join(lines), tmp.name # ── Gradio UI ────────────────────────────────────────────────────────────────── with gr.Blocks(title="PlantPLM — Embedding Generator", theme=gr.themes.Soft()) as demo: gr.HTML("""