| """Interactive visualization of protein latent embeddings using PCA.""" |
| import streamlit as st |
| import numpy as np |
| import plotly.express as px |
| import plotly.graph_objects as go |
| import pandas as pd |
|
|
| def render_embedding_viewer(embeddings: np.ndarray, seq: str, displacements: np.ndarray, mode_label: str = "Mode 0"): |
| """ |
| Render a 2D PCA scatter plot of per-residue embeddings. |
| |
| Args: |
| embeddings: (N, D) numpy array of latent node representations. |
| seq: Amino acid sequence. |
| displacements: (N,) array of displacement magnitudes to color by. |
| mode_label: Label for tooltip/titles. |
| """ |
| if embeddings is None or len(embeddings) == 0: |
| st.info("No embeddings available for this protein.") |
| return |
|
|
| n_res = len(seq) |
| |
| |
| if len(embeddings) != n_res: |
| min_len = min(len(embeddings), n_res) |
| embeddings = embeddings[:min_len] |
| seq = seq[:min_len] |
| displacements = displacements[:min_len] |
| n_res = min_len |
|
|
| try: |
| from sklearn.decomposition import PCA |
| except ImportError: |
| st.error("scikit-learn is required to compute PCA for embeddings.") |
| return |
|
|
| with st.spinner("Computing PCA..."): |
| |
| pca = PCA(n_components=2) |
| emb_2d = pca.fit_transform(embeddings) |
|
|
| |
| residues = [f"{seq[i]}{i+1}" for i in range(n_res)] |
| df_emb = pd.DataFrame({ |
| "PC1": emb_2d[:, 0], |
| "PC2": emb_2d[:, 1], |
| "Residue": residues, |
| "Amino Acid": list(seq), |
| "Mobility (Å)": displacements, |
| "Position": np.arange(1, n_res + 1) |
| }) |
|
|
| |
| ev = pca.explained_variance_ratio_ |
|
|
| col1, col2 = st.columns([3, 1]) |
|
|
| with col1: |
| |
| fig = px.scatter( |
| df_emb, x="PC1", y="PC2", |
| color="Mobility (Å)", |
| hover_name="Residue", |
| hover_data={"PC1": False, "PC2": False, "Amino Acid": True, "Position": True}, |
| color_continuous_scale="Turbo", |
| title=f"Node Embeddings PCA Space (colored by {mode_label} mobility)", |
| template="plotly_dark" |
| ) |
| |
| fig.update_traces(marker=dict(size=8, opacity=0.8, line=dict(width=1, color='DarkSlateGrey'))) |
| fig.update_layout( |
| paper_bgcolor="rgba(0,0,0,0)", |
| plot_bgcolor="rgba(30,27,75,0.4)", |
| xaxis_title=f"PC 1 ({ev[0]*100:.1f}%)", |
| yaxis_title=f"PC 2 ({ev[1]*100:.1f}%)", |
| margin=dict(l=40, r=40, t=40, b=40), |
| ) |
| st.plotly_chart(fig, use_container_width=True, key="pca_scatter") |
|
|
| with col2: |
| st.markdown("### 🧠 Latent Space") |
| st.markdown( |
| "This 2D projection shows the latent representations " |
| "of each residue inside the **SE(3)-Equivariant GNN**." |
| ) |
| st.markdown( |
| "Points that are close together in this space " |
| "are considered structurally and kinematically similar by the model." |
| ) |
| |
| st.metric("Dimensions", f"{embeddings.shape[1]} → 2") |
| st.metric("Total Explained Variance", f"{(ev[0]+ev[1])*100:.1f}%") |
|
|