| import torch |
| import torch.nn.functional as F |
| from transformers import AutoTokenizer, AutoModel |
| import gradio as gr |
| import matplotlib.pyplot as plt |
|
|
| |
| MODEL_NAME = "facebook/esm2_t36_3B_UR50D" |
| model = AutoModel.from_pretrained(MODEL_NAME, output_hidden_states=True) |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| model.eval() |
|
|
| |
| torch.set_grad_enabled(False) |
|
|
| |
| def compute_dot_product_plot(seq1, pos1, seq2, pos2, pdb1, metal1, pdb2, metal2): |
| inputs1 = tokenizer(seq1, return_tensors="pt") |
| inputs2 = tokenizer(seq2, return_tensors="pt") |
|
|
| outputs1 = model(**inputs1) |
| outputs2 = model(**inputs2) |
|
|
| hidden_states1 = outputs1.hidden_states |
| hidden_states2 = outputs2.hidden_states |
|
|
| |
| token_index1 = pos1 + 1 |
| token_index2 = pos2 + 1 |
|
|
| similarities = [] |
| for layer1, layer2 in zip(hidden_states1, hidden_states2): |
| vec1 = layer1[0, token_index1, :] |
| vec2 = layer2[0, token_index2, :] |
| sim = F.cosine_similarity(vec1, vec2, dim=0).item() |
| similarities.append(sim) |
|
|
| |
| label = f"{pdb1}({metal1})-{pdb2}({metal2})" |
|
|
| |
| fig, ax = plt.subplots(figsize=(6, 4)) |
| ax.plot(range(len(similarities)), similarities, marker='o', label=label) |
| ax.set_xlabel("Transformer Layer") |
| ax.set_ylabel("Cosine Similarity") |
| |
| ax.set_ylim(-1, 1) |
| ax.grid(True) |
| ax.legend() |
|
|
| return fig |
|
|
| |
| demo = gr.Interface( |
| fn=compute_dot_product_plot, |
| inputs=[ |
| gr.Textbox(label="Protein Sequence 1"), |
| gr.Number(label="Residue Index in Sequence 1 (0-based)"), |
| gr.Textbox(label="Protein Sequence 2"), |
| gr.Number(label="Residue Index in Sequence 2 (0-based)"), |
| gr.Textbox(label="PDB ID of Protein 1"), |
| gr.Textbox(label="Metal Bound by Protein 1"), |
| gr.Textbox(label="PDB ID of Protein 2"), |
| gr.Textbox(label="Metal Bound by Protein 2") |
| ], |
| outputs=gr.Plot(label="Cosine Similarity Across Transformer Layers"), |
| title="ESM Layer-wise Residue Similarity", |
| description="Enter two protein sequences and choose one residue from each. Provide the PDB IDs and bound metals. The plot shows the cosine similarity between their embedding vectors at each layer of the ESM transformer." |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|