Spaces:
Sleeping
Sleeping
| import re | |
| import itertools | |
| import pandas as pd | |
| from transformers import BertForMaskedLM, BertTokenizer, pipeline | |
| import gradio as gr | |
| tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False) | |
| model = BertForMaskedLM.from_pretrained("Rostlab/prot_bert") | |
| unmasker = pipeline('fill-mask', model=model, tokenizer=tokenizer) | |
| def unmask_sequence(seq, top_n=5): | |
| # Standardize and space | |
| clean_seq = seq.replace(" ", "").replace("X", "[MASK]") | |
| spaced_seq = " ".join(list(clean_seq)).replace("[ M A S K ]", "[MASK]") | |
| # Run unmasker | |
| results = unmasker(spaced_seq) | |
| if isinstance(results[0], dict): | |
| results = [results] | |
| # Extract candidates for each mask: [(AA, score), (AA, score), ...] | |
| mask_candidates = [] | |
| for mask_group in results: | |
| candidates = [(res['token_str'], res['score']) for res in mask_group] | |
| mask_candidates.append(candidates) | |
| # Calculate Cartesian Product of all combinations | |
| all_combinations = list(itertools.product(*mask_candidates)) | |
| scored_sequences = [] | |
| for combo in all_combinations: | |
| # Calculate joint probability (Product of individual scores) | |
| joint_score = 1.0 | |
| temp_seq = spaced_seq | |
| for amino_acid, score in combo: | |
| joint_score *= score | |
| # Replace only the first occurrence of [MASK] in each iteration | |
| temp_seq = temp_seq.replace("[MASK]", amino_acid, 1) | |
| scored_sequences.append({ | |
| "Complete Sequence": temp_seq.replace(" ", ""), | |
| "Joint Probability (%)": round(joint_score * 100, 4), | |
| "Combination": " + ".join([c[0] for c in combo]) | |
| }) | |
| # Sort by probability and take top_n | |
| df = pd.DataFrame(scored_sequences) | |
| df = df.sort_values(by="Joint Probability (%)", ascending=False).head(top_n) | |
| return df | |
| app = gr.Interface( | |
| fn=unmask_sequence, | |
| inputs=gr.Textbox(lines=3, placeholder="Enter your sequence here. Use X to denote unknown amino acid."), | |
| outputs=gr.DataFrame(label="Top 5 Predictions"), | |
| title="ProtBert MLM", | |
| description="This app predicts missing amino acid with ProtBert." | |
| ) | |
| app.launch() | |