dsk129's picture
Update app.py
ff46b1e verified
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
import gradio as gr
import matplotlib.pyplot as plt
# Load ESM2 model and tokenizer
MODEL_NAME = "facebook/esm2_t36_3B_UR50D" # smaller model for Hugging Face Space
model = AutoModel.from_pretrained(MODEL_NAME, output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model.eval()
# Disable grad for inference
torch.set_grad_enabled(False)
# Main function
def compute_dot_product_plot(seq1, pos1, seq2, pos2, pdb1, metal1, pdb2, metal2):
inputs1 = tokenizer(seq1, return_tensors="pt")
inputs2 = tokenizer(seq2, return_tensors="pt")
outputs1 = model(**inputs1)
outputs2 = model(**inputs2)
hidden_states1 = outputs1.hidden_states # tuple: (num_layers + 1) x [1, L, D]
hidden_states2 = outputs2.hidden_states
# Get the token index (adjusting for BOS)
token_index1 = pos1 + 1
token_index2 = pos2 + 1
similarities = []
for layer1, layer2 in zip(hidden_states1, hidden_states2):
vec1 = layer1[0, token_index1, :]
vec2 = layer2[0, token_index2, :]
sim = F.cosine_similarity(vec1, vec2, dim=0).item()
similarities.append(sim)
# Construct legend label
label = f"{pdb1}({metal1})-{pdb2}({metal2})"
# Plot the similarities
fig, ax = plt.subplots(figsize=(6, 4))
ax.plot(range(len(similarities)), similarities, marker='o', label=label)
ax.set_xlabel("Transformer Layer")
ax.set_ylabel("Cosine Similarity")
#ax.set_title("Cosine Similarity Across Layers")
ax.set_ylim(-1, 1)
ax.grid(True)
ax.legend()
return fig
# Gradio interface
demo = gr.Interface(
fn=compute_dot_product_plot,
inputs=[
gr.Textbox(label="Protein Sequence 1"),
gr.Number(label="Residue Index in Sequence 1 (0-based)"),
gr.Textbox(label="Protein Sequence 2"),
gr.Number(label="Residue Index in Sequence 2 (0-based)"),
gr.Textbox(label="PDB ID of Protein 1"),
gr.Textbox(label="Metal Bound by Protein 1"),
gr.Textbox(label="PDB ID of Protein 2"),
gr.Textbox(label="Metal Bound by Protein 2")
],
outputs=gr.Plot(label="Cosine Similarity Across Transformer Layers"),
title="ESM Layer-wise Residue Similarity",
description="Enter two protein sequences and choose one residue from each. Provide the PDB IDs and bound metals. The plot shows the cosine similarity between their embedding vectors at each layer of the ESM transformer."
)
if __name__ == "__main__":
demo.launch()