File size: 3,312 Bytes
128b46f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""Interactive visualization of protein latent embeddings using PCA."""
import streamlit as st
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd

def render_embedding_viewer(embeddings: np.ndarray, seq: str, displacements: np.ndarray, mode_label: str = "Mode 0"):
    """
    Render a 2D PCA scatter plot of per-residue embeddings.
    
    Args:
        embeddings: (N, D) numpy array of latent node representations.
        seq: Amino acid sequence.
        displacements: (N,) array of displacement magnitudes to color by.
        mode_label: Label for tooltip/titles.
    """
    if embeddings is None or len(embeddings) == 0:
        st.info("No embeddings available for this protein.")
        return

    n_res = len(seq)
    
    # Trim or pad if length mismatch (shouldn't happen but just in case)
    if len(embeddings) != n_res:
        min_len = min(len(embeddings), n_res)
        embeddings = embeddings[:min_len]
        seq = seq[:min_len]
        displacements = displacements[:min_len]
        n_res = min_len

    try:
        from sklearn.decomposition import PCA
    except ImportError:
        st.error("scikit-learn is required to compute PCA for embeddings.")
        return

    with st.spinner("Computing PCA..."):
        # Reduce dimensionality to 2D
        pca = PCA(n_components=2)
        emb_2d = pca.fit_transform(embeddings)

    # Prepare DataFrame for Plotly
    residues = [f"{seq[i]}{i+1}" for i in range(n_res)]
    df_emb = pd.DataFrame({
        "PC1": emb_2d[:, 0],
        "PC2": emb_2d[:, 1],
        "Residue": residues,
        "Amino Acid": list(seq),
        "Mobility (Å)": displacements,
        "Position": np.arange(1, n_res + 1)
    })

    # Explained variance
    ev = pca.explained_variance_ratio_

    col1, col2 = st.columns([3, 1])

    with col1:
        # Create interactive scatter plot
        fig = px.scatter(
            df_emb, x="PC1", y="PC2",
            color="Mobility (Å)",
            hover_name="Residue",
            hover_data={"PC1": False, "PC2": False, "Amino Acid": True, "Position": True},
            color_continuous_scale="Turbo",
            title=f"Node Embeddings PCA Space (colored by {mode_label} mobility)",
            template="plotly_dark"
        )
        
        fig.update_traces(marker=dict(size=8, opacity=0.8, line=dict(width=1, color='DarkSlateGrey')))
        fig.update_layout(
            paper_bgcolor="rgba(0,0,0,0)",
            plot_bgcolor="rgba(30,27,75,0.4)",
            xaxis_title=f"PC 1 ({ev[0]*100:.1f}%)",
            yaxis_title=f"PC 2 ({ev[1]*100:.1f}%)",
            margin=dict(l=40, r=40, t=40, b=40),
        )
        st.plotly_chart(fig, use_container_width=True, key="pca_scatter")

    with col2:
        st.markdown("### 🧠 Latent Space")
        st.markdown(
            "This 2D projection shows the latent representations "
            "of each residue inside the **SE(3)-Equivariant GNN**."
        )
        st.markdown(
            "Points that are close together in this space "
            "are considered structurally and kinematically similar by the model."
        )
        
        st.metric("Dimensions", f"{embeddings.shape[1]} → 2")
        st.metric("Total Explained Variance", f"{(ev[0]+ev[1])*100:.1f}%")