| import gradio as gr |
| import pandas as pd |
| import torch |
| from transformers import AutoTokenizer, AutoModelForMaskedLM |
| import logging |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| from io import BytesIO |
| from PIL import Image |
|
|
| logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
|
|
| |
| model_name = "ChatterjeeLab/FusOn-pLM" |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
| model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True) |
| model.to(device) |
| model.eval() |
|
|
| def process_sequence(sequence, domain_bounds, n): |
| start_index = int(domain_bounds['start'][0]) - 1 |
| end_index = int(domain_bounds['end'][0]) |
|
|
| top_n_mutations = {} |
| all_logits = [] |
|
|
| for i in range(len(sequence)): |
| masked_seq = sequence[:i] + '<mask>' + sequence[i+1:] |
| inputs = tokenizer(masked_seq, return_tensors="pt", padding=True, truncation=True, max_length=2000) |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| with torch.no_grad(): |
| logits = model(**inputs).logits |
| mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] |
| mask_token_logits = logits[0, mask_token_index, :] |
| |
| top_n_tokens = torch.topk(mask_token_logits, n, dim=1).indices[0].tolist() |
| mutation = [tokenizer.decode([token]) for token in top_n_tokens] |
| top_n_mutations[(sequence[i], i)] = mutation |
|
|
| logits_array = mask_token_logits.cpu().numpy() |
| |
| filtered_indices = list(range(4, 23 + 1)) |
| filtered_logits = logits_array[:, filtered_indices] |
| all_logits.append(filtered_logits) |
|
|
| token_indices = torch.arange(logits.size(-1)) |
| tokens = [tokenizer.decode([idx]) for idx in token_indices] |
| filtered_tokens = [tokens[i] for i in filtered_indices] |
| |
| all_logits_array = np.vstack(all_logits) |
| normalized_logits_array = (all_logits_array - all_logits_array.min()) / (all_logits_array.max() - all_logits_array.min()) |
| transposed_logits_array = normalized_logits_array.T |
|
|
| |
| step = 50 |
| y_tick_positions = np.arange(0, len(sequence), step) |
| y_tick_labels = [str(pos) for pos in y_tick_positions] |
| |
| plt.figure(figsize=(15, 8)) |
| sns.heatmap(transposed_logits_array, cmap='plasma', xticklabels=y_tick_labels, yticklabels=filtered_tokens) |
| plt.title('Logits for masked per residue tokens') |
| plt.ylabel('Token') |
| plt.xlabel('Residue Index') |
| plt.yticks(rotation=0) |
| plt.xticks(y_tick_positions, y_tick_labels, rotation = 0) |
| |
| |
| buf = BytesIO() |
| plt.savefig(buf, format='png') |
| buf.seek(0) |
| plt.close() |
| |
| |
| img = Image.open(buf) |
|
|
| original_residues = [] |
| mutations = [] |
| positions = [] |
|
|
| for key, value in top_n_mutations.items(): |
| original_residue, position = key |
| original_residues.append(original_residue) |
| mutations.append(value) |
| positions.append(position + 1) |
|
|
| df = pd.DataFrame({ |
| 'Original Residue': original_residues, |
| 'Predicted Residues (in order of decreasing likelihood)': mutations, |
| 'Position': positions |
| }) |
|
|
| df = df[start_index:end_index] |
| |
| return df, img |
|
|
| demo = gr.Interface( |
| fn=process_sequence, |
| inputs=[ |
| "text", |
| gr.Dataframe( |
| headers=["start", "end"], |
| datatype=["number", "number"], |
| row_count=(1, "fixed"), |
| col_count=(2, "fixed"), |
| ), |
| gr.Dropdown([i for i in range(1, 21)]), |
| ], |
| outputs=["dataframe", "image"], |
| description="Choose a number between 1-20 to predict n tokens for each position. Choose the start and end index of the domain of interest (indexing starts at 1).", |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|