Petimot / app /components /embedding_viewer.py
Valmbd's picture
fix: correct GT eigvects parsing + add pt_bb_to_pdb for offline 3D viz
128b46f
"""Interactive visualization of protein latent embeddings using PCA."""
import streamlit as st
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
def render_embedding_viewer(embeddings: np.ndarray, seq: str, displacements: np.ndarray, mode_label: str = "Mode 0"):
"""
Render a 2D PCA scatter plot of per-residue embeddings.
Args:
embeddings: (N, D) numpy array of latent node representations.
seq: Amino acid sequence.
displacements: (N,) array of displacement magnitudes to color by.
mode_label: Label for tooltip/titles.
"""
if embeddings is None or len(embeddings) == 0:
st.info("No embeddings available for this protein.")
return
n_res = len(seq)
# Trim or pad if length mismatch (shouldn't happen but just in case)
if len(embeddings) != n_res:
min_len = min(len(embeddings), n_res)
embeddings = embeddings[:min_len]
seq = seq[:min_len]
displacements = displacements[:min_len]
n_res = min_len
try:
from sklearn.decomposition import PCA
except ImportError:
st.error("scikit-learn is required to compute PCA for embeddings.")
return
with st.spinner("Computing PCA..."):
# Reduce dimensionality to 2D
pca = PCA(n_components=2)
emb_2d = pca.fit_transform(embeddings)
# Prepare DataFrame for Plotly
residues = [f"{seq[i]}{i+1}" for i in range(n_res)]
df_emb = pd.DataFrame({
"PC1": emb_2d[:, 0],
"PC2": emb_2d[:, 1],
"Residue": residues,
"Amino Acid": list(seq),
"Mobility (Å)": displacements,
"Position": np.arange(1, n_res + 1)
})
# Explained variance
ev = pca.explained_variance_ratio_
col1, col2 = st.columns([3, 1])
with col1:
# Create interactive scatter plot
fig = px.scatter(
df_emb, x="PC1", y="PC2",
color="Mobility (Å)",
hover_name="Residue",
hover_data={"PC1": False, "PC2": False, "Amino Acid": True, "Position": True},
color_continuous_scale="Turbo",
title=f"Node Embeddings PCA Space (colored by {mode_label} mobility)",
template="plotly_dark"
)
fig.update_traces(marker=dict(size=8, opacity=0.8, line=dict(width=1, color='DarkSlateGrey')))
fig.update_layout(
paper_bgcolor="rgba(0,0,0,0)",
plot_bgcolor="rgba(30,27,75,0.4)",
xaxis_title=f"PC 1 ({ev[0]*100:.1f}%)",
yaxis_title=f"PC 2 ({ev[1]*100:.1f}%)",
margin=dict(l=40, r=40, t=40, b=40),
)
st.plotly_chart(fig, use_container_width=True, key="pca_scatter")
with col2:
st.markdown("### 🧠 Latent Space")
st.markdown(
"This 2D projection shows the latent representations "
"of each residue inside the **SE(3)-Equivariant GNN**."
)
st.markdown(
"Points that are close together in this space "
"are considered structurally and kinematically similar by the model."
)
st.metric("Dimensions", f"{embeddings.shape[1]} → 2")
st.metric("Total Explained Variance", f"{(ev[0]+ev[1])*100:.1f}%")