Spaces:
Sleeping
Sleeping
Update text description
Browse files
hexviz/Attention_Visualization.py
CHANGED
|
@@ -9,10 +9,6 @@ from hexviz.models import Model, ModelType
|
|
| 9 |
|
| 10 |
st.title("Attention Visualization on proteins")
|
| 11 |
|
| 12 |
-
"""
|
| 13 |
-
Visualize attention weights on protein structures for the protein language models TAPE-BERT and ZymCTRL.
|
| 14 |
-
Pick a PDB ID, layer and head to visualize attention.
|
| 15 |
-
"""
|
| 16 |
|
| 17 |
models = [
|
| 18 |
Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
|
|
@@ -69,6 +65,7 @@ with right:
|
|
| 69 |
head_one = st.number_input("Head", value=1, min_value=1, max_value=selected_model.heads)
|
| 70 |
head = head_one - 1
|
| 71 |
|
|
|
|
| 72 |
if selected_model.name == ModelType.ZymCTRL:
|
| 73 |
try:
|
| 74 |
ec_class = structure.header["compound"]["1"]["ec"]
|
|
@@ -110,8 +107,14 @@ def get_3dview(pdb):
|
|
| 110 |
|
| 111 |
xyzview = get_3dview(pdb_id)
|
| 112 |
showmol(xyzview, height=500, width=800)
|
| 113 |
-
st.markdown(f'PDB: [{pdb_id}](https://www.rcsb.org/structure/{pdb_id})', unsafe_allow_html=True)
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
chain_dict = {f"{chain.id}": chain for chain in list(structure.get_chains())}
|
| 117 |
data = []
|
|
@@ -122,9 +125,12 @@ for att_weight, _ , _ , chain, first, second in top_n:
|
|
| 122 |
data.append(el)
|
| 123 |
|
| 124 |
df = pd.DataFrame(data, columns=['Avg attention', 'Residue pair'])
|
| 125 |
-
f"
|
| 126 |
st.table(df)
|
| 127 |
|
|
|
|
|
|
|
|
|
|
| 128 |
"""
|
| 129 |
-
|
| 130 |
"""
|
|
|
|
| 9 |
|
| 10 |
st.title("Attention Visualization on proteins")
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
models = [
|
| 14 |
Model(name=ModelType.TAPE_BERT, layers=12, heads=12),
|
|
|
|
| 65 |
head_one = st.number_input("Head", value=1, min_value=1, max_value=selected_model.heads)
|
| 66 |
head = head_one - 1
|
| 67 |
|
| 68 |
+
|
| 69 |
if selected_model.name == ModelType.ZymCTRL:
|
| 70 |
try:
|
| 71 |
ec_class = structure.header["compound"]["1"]["ec"]
|
|
|
|
| 107 |
|
| 108 |
xyzview = get_3dview(pdb_id)
|
| 109 |
showmol(xyzview, height=500, width=800)
|
|
|
|
| 110 |
|
| 111 |
+
st.markdown(f"""
|
| 112 |
+
Visualize attention weights from protein language models on protein structures.
|
| 113 |
+
Currently attention weights for PDB: [{pdb_id}](https://www.rcsb.org/structure/{pdb_id}) from layer: {layer_one}, head: {head_one} above {min_attn} from {selected_model.name.value}
|
| 114 |
+
are visualized as red bars. The highest {n_pairs} attention pairs are labeled.
|
| 115 |
+
Visualize attention weights on protein structures for the protein language models TAPE-BERT and ZymCTRL.
|
| 116 |
+
Pick a PDB ID, layer and head to visualize attention.
|
| 117 |
+
""", unsafe_allow_html=True)
|
| 118 |
|
| 119 |
chain_dict = {f"{chain.id}": chain for chain in list(structure.get_chains())}
|
| 120 |
data = []
|
|
|
|
| 125 |
data.append(el)
|
| 126 |
|
| 127 |
df = pd.DataFrame(data, columns=['Avg attention', 'Residue pair'])
|
| 128 |
+
st.markdown(f"The {n_pairs} residue pairs with the highest average attention weights are labeled in the visualization and listed below:")
|
| 129 |
st.table(df)
|
| 130 |
|
| 131 |
+
st.markdown("""Clik in to the [Identify Interesting heads](#Identify-Interesting-heads) page to get an overview of attention
|
| 132 |
+
patterns across all layers and heads
|
| 133 |
+
to help you find heads with interesting attention patterns to study here.""")
|
| 134 |
"""
|
| 135 |
+
The attention visualization is inspired by [provis](https://github.com/salesforce/provis#provis-attention-visualizer).
|
| 136 |
"""
|