ProtGPT3-MSA / app.py
Michele Garibbo
changes
8ed9f7b
Raw
History Blame Contribute Delete
12.7 kB
"""
ProtGPT3-MSA — homolog-conditioned protein generator (Hugging Face Space, Gradio).
Upload a FASTA of homologous sequences. If more than 15 are given, 15 are
randomly sampled and used to prompt AI4PD/ProtGPT3-MSA, which then generates
N new family-consistent sequences in parallel. Generation stops at EOS or the
<s> separator, so each output is exactly one new sequence.
"""
import random
import re
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# --------------------------------------------------------------------------- #
# Config / model load
# --------------------------------------------------------------------------- #
MODEL_ID = "AI4PD/ProtGPT3-MSA"
MAX_CONTEXT = 15 # the model is prompted with at most 15 homologs
DIRECTION = "1" # "1" = N-to-C (fixed; "2" would need reversed seqs)
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID,
trust_remote_code=True,
add_bos_token=False, # BOS is added manually in build_prompt
add_eos_token=False,
padding_side="left",
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=DTYPE,
device_map="auto",
trust_remote_code=True,
)
model.eval()
S_TOKEN_ID = tokenizer.convert_tokens_to_ids("<s>")
EOS_IDS = [i for i in [tokenizer.eos_token_id, S_TOKEN_ID] if i is not None]
PAD_ID = tokenizer.pad_token_id or tokenizer.eos_token_id or S_TOKEN_ID
# --------------------------------------------------------------------------- #
# ProtGPT3-MSA prompt construction (from the model card)
# --------------------------------------------------------------------------- #
def process_style(seq: str, gap: bool):
"""Uppercase, drop X; keep gaps in aligned mode, strip '-' in unaligned."""
if gap:
return re.sub(r"[X]", "", seq.upper())
return re.sub(r"[X]", "", seq.replace("-", "").upper())
def build_prompt(sequences: list, gap: bool = False, direction: str = "1") -> str:
assert len(sequences) <= 15, "Cannot prompt with more than 15 sequences."
random.shuffle(sequences)
if gap:
gap_token = "<gap>"
assert all(len(s) == len(sequences[0]) for s in sequences), (
"Aligned mode needs sequences of equal length — align them or "
"untick 'use aligned sequences'."
)
else:
gap_token = "<no_gap>"
tokens = ["<|bos|>", direction, gap_token]
for seq in sequences:
tokens.append("<s>")
tokens.extend(list(process_style(seq, gap=gap)))
tokens.append("<s>")
return " ".join(tokens)
# --------------------------------------------------------------------------- #
# FASTA parsing
# --------------------------------------------------------------------------- #
def parse_fasta(text: str):
records, header, seq = [], None, []
for line in text.splitlines():
line = line.strip()
if not line:
continue
if line.startswith(">"):
if header is not None:
records.append((header, "".join(seq)))
header, seq = line[1:].strip(), []
else:
seq.append(line)
if header is not None:
records.append((header, "".join(seq)))
return records
# --------------------------------------------------------------------------- #
# Generation
# --------------------------------------------------------------------------- #
@torch.no_grad()
def generate(fasta_file, n_sequences, max_new_tokens, temperature, top_p,
use_aligned, seed):
if fasta_file is None:
return "Please upload a FASTA file.", None
with open(fasta_file, "r") as f:
records = parse_fasta(f.read())
seqs = [s for _, s in records if s]
if not seqs:
return "No sequences found in the uploaded file.", None
if seed is not None and int(seed) >= 0:
random.seed(int(seed))
torch.manual_seed(int(seed))
context = random.sample(seqs, MAX_CONTEXT) if len(seqs) > MAX_CONTEXT else seqs
try:
prompt = build_prompt(list(context), gap=bool(use_aligned), direction=DIRECTION)
except AssertionError as e:
return f"⚠️ {e}", None
inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(model.device)
prompt_len = inputs["input_ids"].shape[1]
output_ids = model.generate(
inputs["input_ids"],
max_new_tokens=int(max_new_tokens),
do_sample=True,
temperature=float(temperature),
top_p=float(top_p),
eos_token_id=EOS_IDS,
pad_token_id=PAD_ID,
num_return_sequences=int(n_sequences),
)
generated = output_ids[:, prompt_len:]
records_out = []
for i, row in enumerate(generated, 1):
text = tokenizer.decode(row, skip_special_tokens=True)
clean = text.replace("<s>", "").replace(" ", "").strip()
if clean:
mode = "aligned" if use_aligned else "unaligned"
records_out.append(
f">generated_{i} | ProtGPT3-MSA | {len(context)} homologs | {mode}\n{clean}"
)
if not records_out:
return "Model produced no sequence — try a higher 'max new tokens'.", None
fasta_text = "\n".join(records_out)
out_path = "generated.fasta"
with open(out_path, "w") as f:
f.write(fasta_text + "\n")
return fasta_text, out_path
# --------------------------------------------------------------------------- #
# Header banner (inline SVG — ships inside the app, works in light/dark mode)
# --------------------------------------------------------------------------- #
def _banner_svg() -> str:
rng = random.Random(11)
bars = []
for i in range(6):
w = rng.randint(150, 225)
y = 100 + i * 20
op = 0.5 + i * 0.08
bars.append(
f'<rect x="60" y="{y}" width="{w}" height="11" rx="5.5" '
f'fill="url(#homo)" opacity="{op:.2f}"/>'
)
bars_svg = "\n".join(bars)
return f'''
<div style="width:100%;max-width:980px;margin:0 auto 4px;">
<svg viewBox="0 0 980 250" xmlns="http://www.w3.org/2000/svg"
style="width:100%;height:auto;display:block;border-radius:20px;">
<defs>
<linearGradient id="panel" x1="0" y1="0" x2="1" y2="1">
<stop offset="0" stop-color="#0f2742"/>
<stop offset="1" stop-color="#1c3a5e"/>
</linearGradient>
<linearGradient id="homo" x1="0" y1="0" x2="1" y2="0">
<stop offset="0" stop-color="#5b8fc7"/>
<stop offset="1" stop-color="#83b4dc"/>
</linearGradient>
<linearGradient id="gen" x1="0" y1="0" x2="1" y2="0">
<stop offset="0" stop-color="#34d29c"/>
<stop offset="1" stop-color="#7ff0cb"/>
</linearGradient>
<filter id="glow" x="-40%" y="-80%" width="180%" height="260%">
<feGaussianBlur stdDeviation="4" result="b"/>
<feMerge><feMergeNode in="b"/><feMergeNode in="SourceGraphic"/></feMerge>
</filter>
<marker id="arr" markerWidth="9" markerHeight="9" refX="5.5" refY="3" orient="auto">
<path d="M0,0 L6,3 L0,6 Z" fill="#86a5cb"/>
</marker>
</defs>
<rect x="0" y="0" width="980" height="250" rx="20" fill="url(#panel)"/>
<text x="60" y="46" font-family="system-ui,Segoe UI,Helvetica,Arial"
font-size="31" font-weight="700" fill="#eef4fb">ProtGPT3-MSA</text>
<text x="60" y="72" font-family="system-ui,Segoe UI,Helvetica,Arial"
font-size="14" fill="#9fbbd8">homolog-conditioned protein sequence generation</text>
{bars_svg}
<text x="60" y="236" font-family="system-ui,Segoe UI,Helvetica,Arial"
font-size="12" fill="#9fbbd8">up to 15 homologs</text>
<line x1="300" y1="151" x2="360" y2="151" stroke="#86a5cb"
stroke-width="2.5" marker-end="url(#arr)"/>
<rect x="384" y="119" width="64" height="64" rx="15"
fill="rgba(255,255,255,0.05)" stroke="#466a99" stroke-width="1.5"/>
<g stroke="#83b4dc" stroke-width="1.4" fill="#83b4dc">
<line x1="402" y1="135" x2="430" y2="151"/><line x1="402" y1="151" x2="430" y2="151"/>
<line x1="402" y1="167" x2="430" y2="151"/><line x1="402" y1="135" x2="430" y2="167"/>
<line x1="402" y1="167" x2="430" y2="135"/>
<circle cx="402" cy="135" r="3.4"/><circle cx="402" cy="151" r="3.4"/>
<circle cx="402" cy="167" r="3.4"/><circle cx="430" cy="151" r="3.8"/>
</g>
<text x="416" y="201" text-anchor="middle"
font-family="system-ui,Segoe UI,Helvetica,Arial"
font-size="10.5" font-weight="600" fill="#cfe0f2">ProtGPT3-MSA</text>
<line x1="460" y1="151" x2="524" y2="151" stroke="#86a5cb"
stroke-width="2.5" marker-end="url(#arr)"/>
<rect x="552" y="130" width="222" height="13" rx="6.5" fill="url(#gen)" opacity="0.25"/>
<rect x="548" y="138" width="228" height="14" rx="7" fill="url(#gen)" opacity="0.45"/>
<rect x="544" y="147" width="234" height="17" rx="8.5" fill="url(#gen)" filter="url(#glow)"/>
<text x="544" y="186" font-family="system-ui,Segoe UI,Helvetica,Arial"
font-size="12" fill="#9fbbd8">N new sequences</text>
</svg>
</div>'''
INTRO_MD = """
**ProtGPT3-MSA** is an autoregressive protein language model that can be prompted \
with up to 15 homologous protein sequences to generate new, family-consistent \
sequences. ([read the paper](https://www.biorxiv.org/content/10.64898/2026.06.04.730041v1))
It operates in two modalities. In **unaligned** mode, homologs are supplied as raw \
sequences and the model returns a plain sequence. In **aligned** mode, homologs are \
supplied as a gapped multiple-sequence alignment and the model returns a new aligned \
(gapped) sequence consistent with that alignment.
"""
SETTINGS_MD = """
**Settings**
- **Upload a FASTA file** of homologous sequences. If it contains more than 15 \
sequences, a random subset of 15 is sampled to build the prompt.
- **Temperature** sets how stochastic generation is: raise it (e.g. >= 1.0) for more \
diverse, exploratory sequences, or lower it (e.g. < 1.0) to keep generation \
conservative and close to the input family.
- **Use aligned sequences** — tick this box to run in aligned mode, in which the model \
generates aligned (gapped) sequences. Note that your uploaded homologs must then \
already be aligned (equal length, with gap characters).
**Recommendation.** ProtGPT3-MSA performs best when several homologs are provided. \
With very few sequences (fewer than 5), generation quality may be limited.
For full functionality and programmatic use, ProtGPT3-MSA is available on the \
[Hugging Face Hub](https://huggingface.co/AI4PD/ProtGPT3-MSA).
**Cite work**
Garibbo, M., Boxo, G., Stocco, F., Illanes-Vicioso, R., Middendorf, L., & Ferruz, N. (2026). ProtGPT3: an Open-source family of Promptable and Aligned Protein Language Models. bioRxiv, 2026-06.
"""
# --------------------------------------------------------------------------- #
# UI
# --------------------------------------------------------------------------- #
CHECKBOX_CSS = """
#aligned_box input[type="checkbox"] {
width: 18px; height: 18px;
accent-color: #1c3a5e !important;
box-shadow: 0 0 0 2px #1c3a5e !important;
border-radius: 3px;
cursor: pointer;
}
"""
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="teal"),
css=CHECKBOX_CSS, title="ProtGPT3-MSA") as demo:
gr.HTML(_banner_svg())
gr.Markdown(INTRO_MD)
with gr.Row():
with gr.Column(scale=1):
fasta = gr.File(label="Homologs (FASTA)",
file_types=[".fasta", ".fa", ".txt"], type="filepath")
n_seq = gr.Slider(1, 100, value=20, step=1, label="N sequences to generate")
max_tok = gr.Slider(16, 1024, value=1024, step=16,
label="Max new tokens (per sequence)")
temp = gr.Slider(0.5, 1.5, value=0.8, step=0.05, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
aligned = gr.Checkbox(value=False, label="use aligned sequences",
elem_id="aligned_box")
seed = gr.Number(value=-1, label="Random seed (-1 = random)")
btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=1):
out_text = gr.Textbox(label="Generated sequences (FASTA)", lines=20)
out_file = gr.File(label="Download FASTA")
gr.Markdown(SETTINGS_MD)
btn.click(
generate,
inputs=[fasta, n_seq, max_tok, temp, top_p, aligned, seed],
outputs=[out_text, out_file],
api_name="predict",
)
if __name__ == "__main__":
demo.launch()