Spaces:
Running
Running
| """ | |
| ProtGPT3-MSA — homolog-conditioned protein generator (Hugging Face Space, Gradio). | |
| Upload a FASTA of homologous sequences. If more than 15 are given, 15 are | |
| randomly sampled and used to prompt AI4PD/ProtGPT3-MSA, which then generates | |
| N new family-consistent sequences in parallel. Generation stops at EOS or the | |
| <s> separator, so each output is exactly one new sequence. | |
| """ | |
| import random | |
| import re | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # --------------------------------------------------------------------------- # | |
| # Config / model load | |
| # --------------------------------------------------------------------------- # | |
| MODEL_ID = "AI4PD/ProtGPT3-MSA" | |
| MAX_CONTEXT = 15 # the model is prompted with at most 15 homologs | |
| DIRECTION = "1" # "1" = N-to-C (fixed; "2" would need reversed seqs) | |
| DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True, | |
| add_bos_token=False, # BOS is added manually in build_prompt | |
| add_eos_token=False, | |
| padding_side="left", | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=DTYPE, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| model.eval() | |
| S_TOKEN_ID = tokenizer.convert_tokens_to_ids("<s>") | |
| EOS_IDS = [i for i in [tokenizer.eos_token_id, S_TOKEN_ID] if i is not None] | |
| PAD_ID = tokenizer.pad_token_id or tokenizer.eos_token_id or S_TOKEN_ID | |
| # --------------------------------------------------------------------------- # | |
| # ProtGPT3-MSA prompt construction (from the model card) | |
| # --------------------------------------------------------------------------- # | |
| def process_style(seq: str, gap: bool): | |
| """Uppercase, drop X; keep gaps in aligned mode, strip '-' in unaligned.""" | |
| if gap: | |
| return re.sub(r"[X]", "", seq.upper()) | |
| return re.sub(r"[X]", "", seq.replace("-", "").upper()) | |
| def build_prompt(sequences: list, gap: bool = False, direction: str = "1") -> str: | |
| assert len(sequences) <= 15, "Cannot prompt with more than 15 sequences." | |
| random.shuffle(sequences) | |
| if gap: | |
| gap_token = "<gap>" | |
| assert all(len(s) == len(sequences[0]) for s in sequences), ( | |
| "Aligned mode needs sequences of equal length — align them or " | |
| "untick 'use aligned sequences'." | |
| ) | |
| else: | |
| gap_token = "<no_gap>" | |
| tokens = ["<|bos|>", direction, gap_token] | |
| for seq in sequences: | |
| tokens.append("<s>") | |
| tokens.extend(list(process_style(seq, gap=gap))) | |
| tokens.append("<s>") | |
| return " ".join(tokens) | |
| # --------------------------------------------------------------------------- # | |
| # FASTA parsing | |
| # --------------------------------------------------------------------------- # | |
| def parse_fasta(text: str): | |
| records, header, seq = [], None, [] | |
| for line in text.splitlines(): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| if line.startswith(">"): | |
| if header is not None: | |
| records.append((header, "".join(seq))) | |
| header, seq = line[1:].strip(), [] | |
| else: | |
| seq.append(line) | |
| if header is not None: | |
| records.append((header, "".join(seq))) | |
| return records | |
| # --------------------------------------------------------------------------- # | |
| # Generation | |
| # --------------------------------------------------------------------------- # | |
| def generate(fasta_file, n_sequences, max_new_tokens, temperature, top_p, | |
| use_aligned, seed): | |
| if fasta_file is None: | |
| return "Please upload a FASTA file.", None | |
| with open(fasta_file, "r") as f: | |
| records = parse_fasta(f.read()) | |
| seqs = [s for _, s in records if s] | |
| if not seqs: | |
| return "No sequences found in the uploaded file.", None | |
| if seed is not None and int(seed) >= 0: | |
| random.seed(int(seed)) | |
| torch.manual_seed(int(seed)) | |
| context = random.sample(seqs, MAX_CONTEXT) if len(seqs) > MAX_CONTEXT else seqs | |
| try: | |
| prompt = build_prompt(list(context), gap=bool(use_aligned), direction=DIRECTION) | |
| except AssertionError as e: | |
| return f"⚠️ {e}", None | |
| inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(model.device) | |
| prompt_len = inputs["input_ids"].shape[1] | |
| output_ids = model.generate( | |
| inputs["input_ids"], | |
| max_new_tokens=int(max_new_tokens), | |
| do_sample=True, | |
| temperature=float(temperature), | |
| top_p=float(top_p), | |
| eos_token_id=EOS_IDS, | |
| pad_token_id=PAD_ID, | |
| num_return_sequences=int(n_sequences), | |
| ) | |
| generated = output_ids[:, prompt_len:] | |
| records_out = [] | |
| for i, row in enumerate(generated, 1): | |
| text = tokenizer.decode(row, skip_special_tokens=True) | |
| clean = text.replace("<s>", "").replace(" ", "").strip() | |
| if clean: | |
| mode = "aligned" if use_aligned else "unaligned" | |
| records_out.append( | |
| f">generated_{i} | ProtGPT3-MSA | {len(context)} homologs | {mode}\n{clean}" | |
| ) | |
| if not records_out: | |
| return "Model produced no sequence — try a higher 'max new tokens'.", None | |
| fasta_text = "\n".join(records_out) | |
| out_path = "generated.fasta" | |
| with open(out_path, "w") as f: | |
| f.write(fasta_text + "\n") | |
| return fasta_text, out_path | |
| # --------------------------------------------------------------------------- # | |
| # Header banner (inline SVG — ships inside the app, works in light/dark mode) | |
| # --------------------------------------------------------------------------- # | |
| def _banner_svg() -> str: | |
| rng = random.Random(11) | |
| bars = [] | |
| for i in range(6): | |
| w = rng.randint(150, 225) | |
| y = 100 + i * 20 | |
| op = 0.5 + i * 0.08 | |
| bars.append( | |
| f'<rect x="60" y="{y}" width="{w}" height="11" rx="5.5" ' | |
| f'fill="url(#homo)" opacity="{op:.2f}"/>' | |
| ) | |
| bars_svg = "\n".join(bars) | |
| return f''' | |
| <div style="width:100%;max-width:980px;margin:0 auto 4px;"> | |
| <svg viewBox="0 0 980 250" xmlns="http://www.w3.org/2000/svg" | |
| style="width:100%;height:auto;display:block;border-radius:20px;"> | |
| <defs> | |
| <linearGradient id="panel" x1="0" y1="0" x2="1" y2="1"> | |
| <stop offset="0" stop-color="#0f2742"/> | |
| <stop offset="1" stop-color="#1c3a5e"/> | |
| </linearGradient> | |
| <linearGradient id="homo" x1="0" y1="0" x2="1" y2="0"> | |
| <stop offset="0" stop-color="#5b8fc7"/> | |
| <stop offset="1" stop-color="#83b4dc"/> | |
| </linearGradient> | |
| <linearGradient id="gen" x1="0" y1="0" x2="1" y2="0"> | |
| <stop offset="0" stop-color="#34d29c"/> | |
| <stop offset="1" stop-color="#7ff0cb"/> | |
| </linearGradient> | |
| <filter id="glow" x="-40%" y="-80%" width="180%" height="260%"> | |
| <feGaussianBlur stdDeviation="4" result="b"/> | |
| <feMerge><feMergeNode in="b"/><feMergeNode in="SourceGraphic"/></feMerge> | |
| </filter> | |
| <marker id="arr" markerWidth="9" markerHeight="9" refX="5.5" refY="3" orient="auto"> | |
| <path d="M0,0 L6,3 L0,6 Z" fill="#86a5cb"/> | |
| </marker> | |
| </defs> | |
| <rect x="0" y="0" width="980" height="250" rx="20" fill="url(#panel)"/> | |
| <text x="60" y="46" font-family="system-ui,Segoe UI,Helvetica,Arial" | |
| font-size="31" font-weight="700" fill="#eef4fb">ProtGPT3-MSA</text> | |
| <text x="60" y="72" font-family="system-ui,Segoe UI,Helvetica,Arial" | |
| font-size="14" fill="#9fbbd8">homolog-conditioned protein sequence generation</text> | |
| {bars_svg} | |
| <text x="60" y="236" font-family="system-ui,Segoe UI,Helvetica,Arial" | |
| font-size="12" fill="#9fbbd8">up to 15 homologs</text> | |
| <line x1="300" y1="151" x2="360" y2="151" stroke="#86a5cb" | |
| stroke-width="2.5" marker-end="url(#arr)"/> | |
| <rect x="384" y="119" width="64" height="64" rx="15" | |
| fill="rgba(255,255,255,0.05)" stroke="#466a99" stroke-width="1.5"/> | |
| <g stroke="#83b4dc" stroke-width="1.4" fill="#83b4dc"> | |
| <line x1="402" y1="135" x2="430" y2="151"/><line x1="402" y1="151" x2="430" y2="151"/> | |
| <line x1="402" y1="167" x2="430" y2="151"/><line x1="402" y1="135" x2="430" y2="167"/> | |
| <line x1="402" y1="167" x2="430" y2="135"/> | |
| <circle cx="402" cy="135" r="3.4"/><circle cx="402" cy="151" r="3.4"/> | |
| <circle cx="402" cy="167" r="3.4"/><circle cx="430" cy="151" r="3.8"/> | |
| </g> | |
| <text x="416" y="201" text-anchor="middle" | |
| font-family="system-ui,Segoe UI,Helvetica,Arial" | |
| font-size="10.5" font-weight="600" fill="#cfe0f2">ProtGPT3-MSA</text> | |
| <line x1="460" y1="151" x2="524" y2="151" stroke="#86a5cb" | |
| stroke-width="2.5" marker-end="url(#arr)"/> | |
| <rect x="552" y="130" width="222" height="13" rx="6.5" fill="url(#gen)" opacity="0.25"/> | |
| <rect x="548" y="138" width="228" height="14" rx="7" fill="url(#gen)" opacity="0.45"/> | |
| <rect x="544" y="147" width="234" height="17" rx="8.5" fill="url(#gen)" filter="url(#glow)"/> | |
| <text x="544" y="186" font-family="system-ui,Segoe UI,Helvetica,Arial" | |
| font-size="12" fill="#9fbbd8">N new sequences</text> | |
| </svg> | |
| </div>''' | |
| INTRO_MD = """ | |
| **ProtGPT3-MSA** is an autoregressive protein language model that can be prompted \ | |
| with up to 15 homologous protein sequences to generate new, family-consistent \ | |
| sequences. ([read the paper](https://www.biorxiv.org/content/10.64898/2026.06.04.730041v1)) | |
| It operates in two modalities. In **unaligned** mode, homologs are supplied as raw \ | |
| sequences and the model returns a plain sequence. In **aligned** mode, homologs are \ | |
| supplied as a gapped multiple-sequence alignment and the model returns a new aligned \ | |
| (gapped) sequence consistent with that alignment. | |
| """ | |
| SETTINGS_MD = """ | |
| **Settings** | |
| - **Upload a FASTA file** of homologous sequences. If it contains more than 15 \ | |
| sequences, a random subset of 15 is sampled to build the prompt. | |
| - **Temperature** sets how stochastic generation is: raise it (e.g. >= 1.0) for more \ | |
| diverse, exploratory sequences, or lower it (e.g. < 1.0) to keep generation \ | |
| conservative and close to the input family. | |
| - **Use aligned sequences** — tick this box to run in aligned mode, in which the model \ | |
| generates aligned (gapped) sequences. Note that your uploaded homologs must then \ | |
| already be aligned (equal length, with gap characters). | |
| **Recommendation.** ProtGPT3-MSA performs best when several homologs are provided. \ | |
| With very few sequences (fewer than 5), generation quality may be limited. | |
| For full functionality and programmatic use, ProtGPT3-MSA is available on the \ | |
| [Hugging Face Hub](https://huggingface.co/AI4PD/ProtGPT3-MSA). | |
| **Cite work** | |
| Garibbo, M., Boxo, G., Stocco, F., Illanes-Vicioso, R., Middendorf, L., & Ferruz, N. (2026). ProtGPT3: an Open-source family of Promptable and Aligned Protein Language Models. bioRxiv, 2026-06. | |
| """ | |
| # --------------------------------------------------------------------------- # | |
| # UI | |
| # --------------------------------------------------------------------------- # | |
| CHECKBOX_CSS = """ | |
| #aligned_box input[type="checkbox"] { | |
| width: 18px; height: 18px; | |
| accent-color: #1c3a5e !important; | |
| box-shadow: 0 0 0 2px #1c3a5e !important; | |
| border-radius: 3px; | |
| cursor: pointer; | |
| } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="teal"), | |
| css=CHECKBOX_CSS, title="ProtGPT3-MSA") as demo: | |
| gr.HTML(_banner_svg()) | |
| gr.Markdown(INTRO_MD) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| fasta = gr.File(label="Homologs (FASTA)", | |
| file_types=[".fasta", ".fa", ".txt"], type="filepath") | |
| n_seq = gr.Slider(1, 100, value=20, step=1, label="N sequences to generate") | |
| max_tok = gr.Slider(16, 1024, value=1024, step=16, | |
| label="Max new tokens (per sequence)") | |
| temp = gr.Slider(0.5, 1.5, value=0.8, step=0.05, label="Temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") | |
| aligned = gr.Checkbox(value=False, label="use aligned sequences", | |
| elem_id="aligned_box") | |
| seed = gr.Number(value=-1, label="Random seed (-1 = random)") | |
| btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=1): | |
| out_text = gr.Textbox(label="Generated sequences (FASTA)", lines=20) | |
| out_file = gr.File(label="Download FASTA") | |
| gr.Markdown(SETTINGS_MD) | |
| btn.click( | |
| generate, | |
| inputs=[fasta, n_seq, max_tok, temp, top_p, aligned, seed], | |
| outputs=[out_text, out_file], | |
| api_name="predict", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |