| """Mode panel β per-mode stats, RMSIP, alignment, dot-product, overlay.""" |
| import streamlit as st |
| import numpy as np |
| import plotly.graph_objects as go |
| import plotly.express as px |
|
|
|
|
| |
| |
| |
|
|
| def _rmsip(pred_modes: np.ndarray, gt_modes: np.ndarray) -> float: |
| """Root Mean Square Inner Product between two subspaces. |
| |
| RMSIP β [0,1]: 1 = perfect alignment, 0 = orthogonal subspaces. |
| |
| Args: |
| pred_modes: (N, 3*m) or list of (N, 3) arrays β predicted modes |
| gt_modes: (N, 3*m) or list of (N, 3) arrays β ground truth modes |
| """ |
| |
| if isinstance(pred_modes, list): |
| P = np.column_stack([v.flatten() for v in pred_modes]) |
| else: |
| P = pred_modes |
| if isinstance(gt_modes, list): |
| G = np.column_stack([v.flatten() for v in gt_modes]) |
| else: |
| G = gt_modes |
|
|
| |
| P = P / (np.linalg.norm(P, axis=0, keepdims=True) + 1e-12) |
| G = G / (np.linalg.norm(G, axis=0, keepdims=True) + 1e-12) |
|
|
| dot = P.T @ G |
| return float(np.sqrt((dot ** 2).mean())) |
|
|
|
|
| def _per_mode_overlap(pred_modes: dict, gt_modes: dict) -> np.ndarray: |
| """Per-mode overlap (|cos ΞΈ|) matrix between pred and GT modes. |
| |
| Returns (n_pred, n_gt) matrix. |
| """ |
| keys_pred = sorted(pred_modes.keys()) |
| keys_gt = sorted(gt_modes.keys()) |
| mat = np.zeros((len(keys_pred), len(keys_gt))) |
| for i, kp in enumerate(keys_pred): |
| vp = pred_modes[kp].flatten() |
| np_ = np.linalg.norm(vp) |
| for j, kg in enumerate(keys_gt): |
| vg = gt_modes[kg].flatten() |
| ng = np.linalg.norm(vg) |
| if np_ > 1e-8 and ng > 1e-8: |
| mat[i, j] = abs(np.dot(vp, vg)) / (np_ * ng) |
| return mat |
|
|
|
|
| def _parse_gt_modes(gt: dict, n_res: int) -> dict: |
| """Extract per-mode vectors from GT .pt data dict. |
| |
| GT format: eigvects shape (3N, n_modes) or (N, 3, n_modes). |
| Returns {mode_idx: (N, 3) array}. |
| """ |
| if gt is None: |
| return {} |
| ev = gt.get("eigvects", None) |
| if ev is None: |
| return {} |
| if hasattr(ev, "numpy"): |
| ev = ev.numpy() |
| ev = np.array(ev) |
|
|
| gt_modes = {} |
| if ev.ndim == 2: |
| |
| n_modes = ev.shape[1] |
| if ev.shape[0] == 3 * n_res: |
| for k in range(n_modes): |
| gt_modes[k] = ev[:, k].reshape(n_res, 3) |
| elif ev.shape[0] == n_res: |
| |
| for k in range(n_modes): |
| col = ev[:, k] |
| vec = np.zeros((n_res, 3)) |
| vec[:, 0] = col |
| gt_modes[k] = vec |
| elif ev.ndim == 3: |
| |
| for k in range(ev.shape[-1]): |
| gt_modes[k] = ev[:, :, k] |
|
|
| return gt_modes |
|
|
|
|
| |
| |
| |
|
|
| def render_mode_panel( |
| modes: dict, |
| seq: str = "", |
| eigenvalues: np.ndarray = None, |
| gt: dict = None, |
| ) -> int: |
| """Render mode selector with per-mode stats. Returns selected mode index.""" |
| n_modes = len(modes) |
| if n_modes == 0: |
| st.warning("No modes available") |
| return 0 |
|
|
| n_res = len(list(modes.values())[0]) |
| gt_modes = _parse_gt_modes(gt, n_res) |
|
|
| tabs = st.tabs([f"Mode {k}" for k in range(n_modes)]) |
| selected = 0 |
|
|
| for k in range(n_modes): |
| with tabs[k]: |
| vecs = modes[k] |
| mags = np.linalg.norm(vecs, axis=1) |
|
|
| c1, c2, c3, c4, c5 = st.columns(5) |
| c1.metric("Mean Ξ", f"{mags.mean():.3f} Γ
") |
| c2.metric("Max Ξ", f"{mags.max():.3f} Γ
") |
| c3.metric("Std Ξ", f"{mags.std():.3f} Γ
") |
| if eigenvalues is not None and k < len(eigenvalues): |
| c4.metric("Ξ» (GT)", f"{eigenvalues[k]:.4f}") |
| else: |
| c4.metric("Residues", f"{n_res}") |
|
|
| |
| if k in gt_modes: |
| gm = gt_modes[k].flatten() |
| pm = vecs.flatten() |
| cos = abs(np.dot(pm, gm)) / (np.linalg.norm(pm) * np.linalg.norm(gm) + 1e-12) |
| c5.metric("Overlap GT", f"{cos:.3f}", help="|cos ΞΈ| = directional agreement with ground truth mode") |
|
|
| |
| top5 = np.argsort(mags)[-5:][::-1] |
| top_data = [{"Residue": f"{seq[i] if i < len(seq) else '?'}{i+1}", |
| "Ξ (Γ
)": f"{mags[i]:.3f}", |
| "Rank": f"#{np.where(np.argsort(mags)[::-1] == i)[0][0]+1}"} |
| for i in top5] |
| st.markdown("**Most mobile residues [PREDICTION]:**") |
| st.dataframe(top_data, use_container_width=True, hide_index=True) |
|
|
| return st.session_state.get("_active_mode_tab", 0) |
|
|
|
|
| |
| |
| |
|
|
| def render_rmsip_comparison(pred_modes: dict, gt: dict, n_res: int): |
| """Full RMSIP dashboard β global score + per-mode overlap matrix. |
| |
| Compares predicted subspace against NMA ground truth. |
| """ |
| gt_modes = _parse_gt_modes(gt, n_res) |
| if not gt_modes: |
| st.info("No ground truth eigenvectors available for RMSIP computation.") |
| return |
|
|
| n_pred = len(pred_modes) |
| n_gt = len(gt_modes) |
|
|
| |
| pred_list = [pred_modes[k] for k in sorted(pred_modes.keys())] |
| gt_list = [gt_modes[k] for k in sorted(gt_modes.keys())] |
| rmsip = _rmsip(pred_list, gt_list) |
|
|
| st.markdown(f""" |
| <div style="background: linear-gradient(135deg, #1e1b4b, #312e81); border: 1px solid #6366f1; |
| border-radius: 12px; padding: 16px 24px; margin-bottom: 16px;"> |
| <div style="font-size: 0.85rem; color: #a5b4fc; text-transform: uppercase; letter-spacing: 0.05em;"> |
| RMSIP β Root Mean Square Inner Product |
| </div> |
| <div style="font-size: 2.2rem; font-weight: 800; color: {'#10b981' if rmsip > 0.7 else '#f59e0b' if rmsip > 0.4 else '#ef4444'};"> |
| {rmsip:.4f} |
| </div> |
| <div style="font-size: 0.8rem; color: #94a3b8; margin-top: 4px;"> |
| {'π’ Excellent subspace alignment' if rmsip > 0.7 else 'π‘ Partial alignment' if rmsip > 0.4 else 'π΄ Low overlap β modes differ significantly'} |
| </div> |
| </div> |
| """, unsafe_allow_html=True) |
|
|
| with st.expander("βΉοΈ What is RMSIP?", expanded=False): |
| st.markdown(""" |
| **RMSIP** measures how well two sets of normal modes span the same subspace. |
| `RMSIP = sqrt( mean(|<Ο_pred | Ο_gt>|Β²) )` over all pairs of modes. |
| - **RMSIP = 1** β predicted modes are identical to NMA ground truth |
| - **RMSIP = 0** β completely orthogonal subspaces (no overlap) |
| - **> 0.7** is considered excellent for NMA predictions in the literature |
| """) |
|
|
| |
| mat = _per_mode_overlap(pred_modes, gt_modes) |
|
|
| fig = go.Figure(go.Heatmap( |
| z=mat, |
| x=[f"GT M{k}" for k in sorted(gt_modes.keys())], |
| y=[f"Pred M{k}" for k in sorted(pred_modes.keys())], |
| colorscale="Viridis", |
| zmin=0, zmax=1, |
| text=np.round(mat, 3), |
| texttemplate="%{text:.2f}", |
| textfont={"size": 11}, |
| hovertemplate="<b>%{y}</b> vs <b>%{x}</b><br>|cos ΞΈ| = %{z:.4f}<extra></extra>", |
| )) |
| fig.update_layout( |
| title="[PREDICTION vs GT] Per-mode overlap matrix |cos ΞΈ|", |
| template="plotly_dark", |
| height=300, |
| paper_bgcolor="rgba(0,0,0,0)", |
| plot_bgcolor="rgba(30,27,75,0.5)", |
| margin=dict(l=80, r=20, t=50, b=60), |
| xaxis_title="Ground Truth Modes", |
| yaxis_title="Predicted Modes", |
| ) |
| st.plotly_chart(fig, use_container_width=True) |
|
|
|
|
| |
| |
| |
|
|
| def render_pred_vs_gt_displacement(pred_modes: dict, gt: dict, n_res: int, seq: str = ""): |
| """Per-residue displacement magnitude: Prediction vs Ground Truth for each mode.""" |
| gt_modes = _parse_gt_modes(gt, n_res) |
| if not gt_modes: |
| st.info("No ground truth eigenvectors for comparison.") |
| return |
|
|
| residues = list(range(1, n_res + 1)) |
| hover = [f"{seq[i] if i < len(seq) else '?'}{i+1}" for i in range(n_res)] |
|
|
| pred_colors = ["#6366f1", "#818cf8", "#a5b4fc", "#c7d2fe"] |
| gt_colors = ["#10b981", "#34d399", "#6ee7b7", "#a7f3d0"] |
|
|
| n_cols = min(len(pred_modes), len(gt_modes), 2) |
| cols = st.columns(n_cols) |
|
|
| for k, col in enumerate(cols): |
| if k not in pred_modes or k not in gt_modes: |
| continue |
| with col: |
| pred_mags = np.linalg.norm(pred_modes[k], axis=1) |
| gt_mags = np.linalg.norm(gt_modes[k], axis=1) |
|
|
| |
| pred_norm = pred_mags / (pred_mags.max() + 1e-8) |
| gt_norm = gt_mags / (gt_mags.max() + 1e-8) |
|
|
| |
| corr = float(np.corrcoef(pred_norm, gt_norm)[0, 1]) |
|
|
| fig = go.Figure() |
| fig.add_trace(go.Scatter( |
| x=residues, y=pred_norm, |
| name="PETIMOT Pred", |
| mode="lines", |
| line=dict(color=pred_colors[k % len(pred_colors)], width=2), |
| text=hover, |
| hovertemplate="%{text}<br>Pred: %{y:.3f}<extra></extra>", |
| )) |
| fig.add_trace(go.Scatter( |
| x=residues, y=gt_norm, |
| name="NMA Ground Truth", |
| mode="lines", |
| line=dict(color=gt_colors[k % len(gt_colors)], width=2, dash="dot"), |
| text=hover, |
| hovertemplate="%{text}<br>GT: %{y:.3f}<extra></extra>", |
| )) |
| fig.update_layout( |
| title=f"Mode {k} [r = {corr:.3f}]", |
| template="plotly_dark", |
| height=280, |
| paper_bgcolor="rgba(0,0,0,0)", |
| plot_bgcolor="rgba(30,27,75,0.3)", |
| xaxis_title="Residue", |
| yaxis_title="Normalised Ξ", |
| legend=dict(orientation="h", y=1.15), |
| margin=dict(l=50, r=10, t=50, b=40), |
| ) |
| st.plotly_chart(fig, use_container_width=True) |
|
|
|
|
| |
| |
| |
|
|
| def render_mode_correlation(modes: dict): |
| """Residue-level displacement correlation between modes (heatmap).""" |
| n_modes = len(modes) |
| if n_modes < 2: |
| return |
| profiles = [np.linalg.norm(modes[k], axis=1) for k in sorted(modes.keys())] |
| corr = np.corrcoef(profiles) |
|
|
| fig = go.Figure(go.Heatmap( |
| z=corr, |
| x=[f"M{k}" for k in range(n_modes)], |
| y=[f"M{k}" for k in range(n_modes)], |
| colorscale="RdBu_r", zmin=-1, zmax=1, |
| text=np.round(corr, 2), texttemplate="%{text:.2f}", |
| textfont={"size": 12}, |
| )) |
| fig.update_layout( |
| title="Mode Displacement Correlation [PREDICTION]", |
| template="plotly_dark", height=300, width=300, |
| paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(30,27,75,0.5)", |
| margin=dict(l=30, r=30, t=40, b=30), |
| ) |
| st.plotly_chart(fig, use_container_width=False) |
|
|
|
|
| |
| |
| |
|
|
| def render_eigenvalue_spectrum(eigenvalues: np.ndarray): |
| """Eigenvalue bar chart with cumulative variance line.""" |
| if eigenvalues is None or len(eigenvalues) == 0: |
| return |
|
|
| fig = go.Figure() |
| fig.add_trace(go.Bar( |
| x=[f"Ξ»{k+1}" for k in range(len(eigenvalues))], |
| y=eigenvalues, marker_color="#6366f1", name="Eigenvalue", |
| )) |
| cum = np.cumsum(eigenvalues) / eigenvalues.sum() * 100 |
| fig.add_trace(go.Scatter( |
| x=[f"Ξ»{k+1}" for k in range(len(eigenvalues))], |
| y=cum, mode="lines+markers", name="Cumul. variance %", |
| marker=dict(color="#ef4444", size=6), |
| line=dict(color="#ef4444", width=2), yaxis="y2", |
| )) |
| fig.update_layout( |
| title="Eigenvalue Spectrum [GROUND TRUTH]", |
| template="plotly_dark", height=250, |
| paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(30,27,75,0.5)", |
| yaxis=dict(title="Eigenvalue"), |
| yaxis2=dict(title="Cumul. %", overlaying="y", side="right", range=[0, 105]), |
| legend=dict(orientation="h", y=1.15), |
| margin=dict(l=40, r=40, t=40, b=30), |
| ) |
| st.plotly_chart(fig, use_container_width=True) |
|
|
|
|
| |
| |
| |
|
|
| def render_mode_dotproduct(modes: dict, seq: str = ""): |
| """Per-residue directional agreement between predicted modes (cos ΞΈ).""" |
| n_modes = len(modes) |
| if n_modes < 2: |
| st.info("Need β₯2 modes for directional agreement analysis.") |
| return |
|
|
| n_res = len(list(modes.values())[0]) |
| residues = np.arange(1, n_res + 1) |
|
|
| |
| global_dp = np.zeros((n_modes, n_modes)) |
| for i in range(n_modes): |
| for j in range(n_modes): |
| vi, vj = modes[i].flatten(), modes[j].flatten() |
| ni, nj = np.linalg.norm(vi), np.linalg.norm(vj) |
| if ni > 1e-8 and nj > 1e-8: |
| global_dp[i, j] = np.dot(vi, vj) / (ni * nj) |
|
|
| col1, col2 = st.columns([1, 2]) |
| with col1: |
| fig = go.Figure(go.Heatmap( |
| z=global_dp, |
| x=[f"M{k}" for k in range(n_modes)], |
| y=[f"M{k}" for k in range(n_modes)], |
| colorscale="RdBu_r", zmin=-1, zmax=1, |
| text=np.round(global_dp, 3), texttemplate="%{text:.3f}", |
| textfont={"size": 11}, |
| )) |
| fig.update_layout( |
| title="Global 3D Dot Product (cos ΞΈ) [PREDICTION]", |
| template="plotly_dark", height=300, width=300, |
| paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(30,27,75,0.5)", |
| margin=dict(l=30, r=30, t=40, b=30), |
| ) |
| st.plotly_chart(fig, use_container_width=False) |
| st.markdown("> **Β±1** = parallel modes Β· **0** = orthogonal (independent)") |
|
|
| with col2: |
| st.markdown("**Per-residue directional agreement M0Β·M1 [PREDICTION]:**") |
| per_res = np.array([ |
| np.dot(modes[0][r], modes[1][r]) / |
| (np.linalg.norm(modes[0][r]) * np.linalg.norm(modes[1][r]) + 1e-12) |
| for r in range(n_res) |
| ]) |
| hover = [f"{seq[i] if i < len(seq) else '?'}{i+1}" for i in range(n_res)] |
| fig2 = go.Figure(go.Bar( |
| x=residues, y=per_res, |
| marker_color=["#ef4444" if abs(c) > 0.7 else "#f59e0b" if abs(c) > 0.3 else "#10b981" |
| for c in per_res], |
| text=hover, |
| hovertemplate="%{text}<br>cos(ΞΈ): %{y:.3f}<extra></extra>", |
| )) |
| fig2.add_hline(y=0, line_dash="dash", line_color="#64748b") |
| fig2.update_layout( |
| template="plotly_dark", height=300, |
| paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(30,27,75,0.3)", |
| xaxis_title="Residue", yaxis_title="cos(ΞΈ) M0Β·M1", |
| yaxis_range=[-1.05, 1.05], margin=dict(l=50, r=20, t=10, b=40), |
| ) |
| st.plotly_chart(fig2, use_container_width=True) |
|
|
|
|
| |
| |
| |
|
|
| def render_mode_overlay(modes: dict, seq: str = ""): |
| """All modes overlaid on one displacement magnitude chart.""" |
| n_modes = len(modes) |
| if n_modes == 0: |
| return |
| n_res = len(list(modes.values())[0]) |
| colors = ["#6366f1", "#ef4444", "#10b981", "#f59e0b", "#ec4899", "#8b5cf6"] |
| hover = [f"{seq[i] if i < len(seq) else '?'}{i+1}" for i in range(n_res)] |
|
|
| fig = go.Figure() |
| for k in range(min(n_modes, 6)): |
| mags = np.linalg.norm(modes[k], axis=1) |
| fig.add_trace(go.Scatter( |
| x=list(range(1, n_res + 1)), y=mags, |
| mode="lines", name=f"Mode {k}", |
| line=dict(color=colors[k % len(colors)], width=2), |
| text=hover, |
| hovertemplate="%{text}<br>%{y:.3f} Γ
<extra>Mode " + str(k) + "</extra>", |
| )) |
| fig.update_layout( |
| title="All Modes β Displacement Magnitude Overlay [PREDICTION]", |
| template="plotly_dark", height=350, |
| paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(30,27,75,0.3)", |
| xaxis_title="Residue", yaxis_title="Displacement (Γ
)", |
| legend=dict(orientation="h", y=1.12), |
| margin=dict(l=50, r=20, t=50, b=40), |
| ) |
| st.plotly_chart(fig, use_container_width=True) |
|
|