""" ProtGPT3-MSA — homolog-conditioned protein generator (Hugging Face Space, Gradio). Upload a FASTA of homologous sequences. If more than 15 are given, 15 are randomly sampled and used to prompt AI4PD/ProtGPT3-MSA, which then generates N new family-consistent sequences in parallel. Generation stops at EOS or the separator, so each output is exactly one new sequence. """ import random import re import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer # --------------------------------------------------------------------------- # # Config / model load # --------------------------------------------------------------------------- # MODEL_ID = "AI4PD/ProtGPT3-MSA" MAX_CONTEXT = 15 # the model is prompted with at most 15 homologs DIRECTION = "1" # "1" = N-to-C (fixed; "2" would need reversed seqs) DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 tokenizer = AutoTokenizer.from_pretrained( MODEL_ID, trust_remote_code=True, add_bos_token=False, # BOS is added manually in build_prompt add_eos_token=False, padding_side="left", ) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=DTYPE, device_map="auto", trust_remote_code=True, ) model.eval() S_TOKEN_ID = tokenizer.convert_tokens_to_ids("") EOS_IDS = [i for i in [tokenizer.eos_token_id, S_TOKEN_ID] if i is not None] PAD_ID = tokenizer.pad_token_id or tokenizer.eos_token_id or S_TOKEN_ID # --------------------------------------------------------------------------- # # ProtGPT3-MSA prompt construction (from the model card) # --------------------------------------------------------------------------- # def process_style(seq: str, gap: bool): """Uppercase, drop X; keep gaps in aligned mode, strip '-' in unaligned.""" if gap: return re.sub(r"[X]", "", seq.upper()) return re.sub(r"[X]", "", seq.replace("-", "").upper()) def build_prompt(sequences: list, gap: bool = False, direction: str = "1") -> str: assert len(sequences) <= 15, "Cannot prompt with more than 15 sequences." random.shuffle(sequences) if gap: gap_token = "" assert all(len(s) == len(sequences[0]) for s in sequences), ( "Aligned mode needs sequences of equal length — align them or " "untick 'use aligned sequences'." ) else: gap_token = "" tokens = ["<|bos|>", direction, gap_token] for seq in sequences: tokens.append("") tokens.extend(list(process_style(seq, gap=gap))) tokens.append("") return " ".join(tokens) # --------------------------------------------------------------------------- # # FASTA parsing # --------------------------------------------------------------------------- # def parse_fasta(text: str): records, header, seq = [], None, [] for line in text.splitlines(): line = line.strip() if not line: continue if line.startswith(">"): if header is not None: records.append((header, "".join(seq))) header, seq = line[1:].strip(), [] else: seq.append(line) if header is not None: records.append((header, "".join(seq))) return records # --------------------------------------------------------------------------- # # Generation # --------------------------------------------------------------------------- # @torch.no_grad() def generate(fasta_file, n_sequences, max_new_tokens, temperature, top_p, use_aligned, seed): if fasta_file is None: return "Please upload a FASTA file.", None with open(fasta_file, "r") as f: records = parse_fasta(f.read()) seqs = [s for _, s in records if s] if not seqs: return "No sequences found in the uploaded file.", None if seed is not None and int(seed) >= 0: random.seed(int(seed)) torch.manual_seed(int(seed)) context = random.sample(seqs, MAX_CONTEXT) if len(seqs) > MAX_CONTEXT else seqs try: prompt = build_prompt(list(context), gap=bool(use_aligned), direction=DIRECTION) except AssertionError as e: return f"⚠️ {e}", None inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(model.device) prompt_len = inputs["input_ids"].shape[1] output_ids = model.generate( inputs["input_ids"], max_new_tokens=int(max_new_tokens), do_sample=True, temperature=float(temperature), top_p=float(top_p), eos_token_id=EOS_IDS, pad_token_id=PAD_ID, num_return_sequences=int(n_sequences), ) generated = output_ids[:, prompt_len:] records_out = [] for i, row in enumerate(generated, 1): text = tokenizer.decode(row, skip_special_tokens=True) clean = text.replace("", "").replace(" ", "").strip() if clean: mode = "aligned" if use_aligned else "unaligned" records_out.append( f">generated_{i} | ProtGPT3-MSA | {len(context)} homologs | {mode}\n{clean}" ) if not records_out: return "Model produced no sequence — try a higher 'max new tokens'.", None fasta_text = "\n".join(records_out) out_path = "generated.fasta" with open(out_path, "w") as f: f.write(fasta_text + "\n") return fasta_text, out_path # --------------------------------------------------------------------------- # # Header banner (inline SVG — ships inside the app, works in light/dark mode) # --------------------------------------------------------------------------- # def _banner_svg() -> str: rng = random.Random(11) bars = [] for i in range(6): w = rng.randint(150, 225) y = 100 + i * 20 op = 0.5 + i * 0.08 bars.append( f'' ) bars_svg = "\n".join(bars) return f'''
ProtGPT3-MSA homolog-conditioned protein sequence generation {bars_svg} up to 15 homologs ProtGPT3-MSA N new sequences
''' INTRO_MD = """ **ProtGPT3-MSA** is an autoregressive protein language model that can be prompted \ with up to 15 homologous protein sequences to generate new, family-consistent \ sequences. ([read the paper](https://www.biorxiv.org/content/10.64898/2026.06.04.730041v1)) It operates in two modalities. In **unaligned** mode, homologs are supplied as raw \ sequences and the model returns a plain sequence. In **aligned** mode, homologs are \ supplied as a gapped multiple-sequence alignment and the model returns a new aligned \ (gapped) sequence consistent with that alignment. """ SETTINGS_MD = """ **Settings** - **Upload a FASTA file** of homologous sequences. If it contains more than 15 \ sequences, a random subset of 15 is sampled to build the prompt. - **Temperature** sets how stochastic generation is: raise it (e.g. >= 1.0) for more \ diverse, exploratory sequences, or lower it (e.g. < 1.0) to keep generation \ conservative and close to the input family. - **Use aligned sequences** — tick this box to run in aligned mode, in which the model \ generates aligned (gapped) sequences. Note that your uploaded homologs must then \ already be aligned (equal length, with gap characters). **Recommendation.** ProtGPT3-MSA performs best when several homologs are provided. \ With very few sequences (fewer than 5), generation quality may be limited. For full functionality and programmatic use, ProtGPT3-MSA is available on the \ [Hugging Face Hub](https://huggingface.co/AI4PD/ProtGPT3-MSA). **Cite work** Garibbo, M., Boxo, G., Stocco, F., Illanes-Vicioso, R., Middendorf, L., & Ferruz, N. (2026). ProtGPT3: an Open-source family of Promptable and Aligned Protein Language Models. bioRxiv, 2026-06. """ # --------------------------------------------------------------------------- # # UI # --------------------------------------------------------------------------- # CHECKBOX_CSS = """ #aligned_box input[type="checkbox"] { width: 18px; height: 18px; accent-color: #1c3a5e !important; box-shadow: 0 0 0 2px #1c3a5e !important; border-radius: 3px; cursor: pointer; } """ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="teal"), css=CHECKBOX_CSS, title="ProtGPT3-MSA") as demo: gr.HTML(_banner_svg()) gr.Markdown(INTRO_MD) with gr.Row(): with gr.Column(scale=1): fasta = gr.File(label="Homologs (FASTA)", file_types=[".fasta", ".fa", ".txt"], type="filepath") n_seq = gr.Slider(1, 100, value=20, step=1, label="N sequences to generate") max_tok = gr.Slider(16, 1024, value=1024, step=16, label="Max new tokens (per sequence)") temp = gr.Slider(0.5, 1.5, value=0.8, step=0.05, label="Temperature") top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") aligned = gr.Checkbox(value=False, label="use aligned sequences", elem_id="aligned_box") seed = gr.Number(value=-1, label="Random seed (-1 = random)") btn = gr.Button("Generate", variant="primary") with gr.Column(scale=1): out_text = gr.Textbox(label="Generated sequences (FASTA)", lines=20) out_file = gr.File(label="Download FASTA") gr.Markdown(SETTINGS_MD) btn.click( generate, inputs=[fasta, n_seq, max_tok, temp, top_p, aligned, seed], outputs=[out_text, out_file], api_name="predict", ) if __name__ == "__main__": demo.launch()