"""Interactive visualization of protein latent embeddings using PCA.""" import streamlit as st import numpy as np import plotly.express as px import plotly.graph_objects as go import pandas as pd def render_embedding_viewer(embeddings: np.ndarray, seq: str, displacements: np.ndarray, mode_label: str = "Mode 0"): """ Render a 2D PCA scatter plot of per-residue embeddings. Args: embeddings: (N, D) numpy array of latent node representations. seq: Amino acid sequence. displacements: (N,) array of displacement magnitudes to color by. mode_label: Label for tooltip/titles. """ if embeddings is None or len(embeddings) == 0: st.info("No embeddings available for this protein.") return n_res = len(seq) # Trim or pad if length mismatch (shouldn't happen but just in case) if len(embeddings) != n_res: min_len = min(len(embeddings), n_res) embeddings = embeddings[:min_len] seq = seq[:min_len] displacements = displacements[:min_len] n_res = min_len try: from sklearn.decomposition import PCA except ImportError: st.error("scikit-learn is required to compute PCA for embeddings.") return with st.spinner("Computing PCA..."): # Reduce dimensionality to 2D pca = PCA(n_components=2) emb_2d = pca.fit_transform(embeddings) # Prepare DataFrame for Plotly residues = [f"{seq[i]}{i+1}" for i in range(n_res)] df_emb = pd.DataFrame({ "PC1": emb_2d[:, 0], "PC2": emb_2d[:, 1], "Residue": residues, "Amino Acid": list(seq), "Mobility (Å)": displacements, "Position": np.arange(1, n_res + 1) }) # Explained variance ev = pca.explained_variance_ratio_ col1, col2 = st.columns([3, 1]) with col1: # Create interactive scatter plot fig = px.scatter( df_emb, x="PC1", y="PC2", color="Mobility (Å)", hover_name="Residue", hover_data={"PC1": False, "PC2": False, "Amino Acid": True, "Position": True}, color_continuous_scale="Turbo", title=f"Node Embeddings PCA Space (colored by {mode_label} mobility)", template="plotly_dark" ) fig.update_traces(marker=dict(size=8, opacity=0.8, line=dict(width=1, color='DarkSlateGrey'))) fig.update_layout( paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(30,27,75,0.4)", xaxis_title=f"PC 1 ({ev[0]*100:.1f}%)", yaxis_title=f"PC 2 ({ev[1]*100:.1f}%)", margin=dict(l=40, r=40, t=40, b=40), ) st.plotly_chart(fig, use_container_width=True, key="pca_scatter") with col2: st.markdown("### 🧠 Latent Space") st.markdown( "This 2D projection shows the latent representations " "of each residue inside the **SE(3)-Equivariant GNN**." ) st.markdown( "Points that are close together in this space " "are considered structurally and kinematically similar by the model." ) st.metric("Dimensions", f"{embeddings.shape[1]} → 2") st.metric("Total Explained Variance", f"{(ev[0]+ev[1])*100:.1f}%")