prot-bert-mlm / app.py
Araik Tamazian
minor fix
2370c0b
import re
import itertools
import pandas as pd
from transformers import BertForMaskedLM, BertTokenizer, pipeline
import gradio as gr
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
model = BertForMaskedLM.from_pretrained("Rostlab/prot_bert")
unmasker = pipeline('fill-mask', model=model, tokenizer=tokenizer)
def unmask_sequence(seq, top_n=5):
# Standardize and space
clean_seq = seq.replace(" ", "").replace("X", "[MASK]")
spaced_seq = " ".join(list(clean_seq)).replace("[ M A S K ]", "[MASK]")
# Run unmasker
results = unmasker(spaced_seq)
if isinstance(results[0], dict):
results = [results]
# Extract candidates for each mask: [(AA, score), (AA, score), ...]
mask_candidates = []
for mask_group in results:
candidates = [(res['token_str'], res['score']) for res in mask_group]
mask_candidates.append(candidates)
# Calculate Cartesian Product of all combinations
all_combinations = list(itertools.product(*mask_candidates))
scored_sequences = []
for combo in all_combinations:
# Calculate joint probability (Product of individual scores)
joint_score = 1.0
temp_seq = spaced_seq
for amino_acid, score in combo:
joint_score *= score
# Replace only the first occurrence of [MASK] in each iteration
temp_seq = temp_seq.replace("[MASK]", amino_acid, 1)
scored_sequences.append({
"Complete Sequence": temp_seq.replace(" ", ""),
"Joint Probability (%)": round(joint_score * 100, 4),
"Combination": " + ".join([c[0] for c in combo])
})
# Sort by probability and take top_n
df = pd.DataFrame(scored_sequences)
df = df.sort_values(by="Joint Probability (%)", ascending=False).head(top_n)
return df
app = gr.Interface(
fn=unmask_sequence,
inputs=gr.Textbox(lines=3, placeholder="Enter your sequence here. Use X to denote unknown amino acid."),
outputs=gr.DataFrame(label="Top 5 Predictions"),
title="ProtBert MLM",
description="This app predicts missing amino acid with ProtBert."
)
app.launch()