Petimot / app /components /mode_panel.py
Valmbd's picture
fix: correct GT eigvects parsing + add pt_bb_to_pdb for offline 3D viz
128b46f
"""Mode panel β€” per-mode stats, RMSIP, alignment, dot-product, overlay."""
import streamlit as st
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
# ═══════════════════════════════════════════════════════════════════
# Core metrics
# ═══════════════════════════════════════════════════════════════════
def _rmsip(pred_modes: np.ndarray, gt_modes: np.ndarray) -> float:
"""Root Mean Square Inner Product between two subspaces.
RMSIP ∈ [0,1]: 1 = perfect alignment, 0 = orthogonal subspaces.
Args:
pred_modes: (N, 3*m) or list of (N, 3) arrays β€” predicted modes
gt_modes: (N, 3*m) or list of (N, 3) arrays β€” ground truth modes
"""
# Flatten to (3N, m) column vectors
if isinstance(pred_modes, list):
P = np.column_stack([v.flatten() for v in pred_modes])
else:
P = pred_modes
if isinstance(gt_modes, list):
G = np.column_stack([v.flatten() for v in gt_modes])
else:
G = gt_modes
# Normalize
P = P / (np.linalg.norm(P, axis=0, keepdims=True) + 1e-12)
G = G / (np.linalg.norm(G, axis=0, keepdims=True) + 1e-12)
dot = P.T @ G # (m, n) matrix of dot products
return float(np.sqrt((dot ** 2).mean()))
def _per_mode_overlap(pred_modes: dict, gt_modes: dict) -> np.ndarray:
"""Per-mode overlap (|cos ΞΈ|) matrix between pred and GT modes.
Returns (n_pred, n_gt) matrix.
"""
keys_pred = sorted(pred_modes.keys())
keys_gt = sorted(gt_modes.keys())
mat = np.zeros((len(keys_pred), len(keys_gt)))
for i, kp in enumerate(keys_pred):
vp = pred_modes[kp].flatten()
np_ = np.linalg.norm(vp)
for j, kg in enumerate(keys_gt):
vg = gt_modes[kg].flatten()
ng = np.linalg.norm(vg)
if np_ > 1e-8 and ng > 1e-8:
mat[i, j] = abs(np.dot(vp, vg)) / (np_ * ng)
return mat
def _parse_gt_modes(gt: dict, n_res: int) -> dict:
"""Extract per-mode vectors from GT .pt data dict.
GT format: eigvects shape (3N, n_modes) or (N, 3, n_modes).
Returns {mode_idx: (N, 3) array}.
"""
if gt is None:
return {}
ev = gt.get("eigvects", None)
if ev is None:
return {}
if hasattr(ev, "numpy"):
ev = ev.numpy()
ev = np.array(ev)
gt_modes = {}
if ev.ndim == 2:
# shape (3N, K) β€” standard NMA format
n_modes = ev.shape[1]
if ev.shape[0] == 3 * n_res:
for k in range(n_modes):
gt_modes[k] = ev[:, k].reshape(n_res, 3)
elif ev.shape[0] == n_res:
# shape (N, K) β€” magnitude only
for k in range(n_modes):
col = ev[:, k]
vec = np.zeros((n_res, 3))
vec[:, 0] = col
gt_modes[k] = vec
elif ev.ndim == 3:
# shape (N, 3, K)
for k in range(ev.shape[-1]):
gt_modes[k] = ev[:, :, k]
return gt_modes
# ═══════════════════════════════════════════════════════════════════
# Mode selector panel
# ═══════════════════════════════════════════════════════════════════
def render_mode_panel(
modes: dict,
seq: str = "",
eigenvalues: np.ndarray = None,
gt: dict = None,
) -> int:
"""Render mode selector with per-mode stats. Returns selected mode index."""
n_modes = len(modes)
if n_modes == 0:
st.warning("No modes available")
return 0
n_res = len(list(modes.values())[0])
gt_modes = _parse_gt_modes(gt, n_res)
tabs = st.tabs([f"Mode {k}" for k in range(n_modes)])
selected = 0
for k in range(n_modes):
with tabs[k]:
vecs = modes[k]
mags = np.linalg.norm(vecs, axis=1)
c1, c2, c3, c4, c5 = st.columns(5)
c1.metric("Mean Ξ”", f"{mags.mean():.3f} Γ…")
c2.metric("Max Ξ”", f"{mags.max():.3f} Γ…")
c3.metric("Std Ξ”", f"{mags.std():.3f} Γ…")
if eigenvalues is not None and k < len(eigenvalues):
c4.metric("Ξ» (GT)", f"{eigenvalues[k]:.4f}")
else:
c4.metric("Residues", f"{n_res}")
# RMSIP for this mode vs GT
if k in gt_modes:
gm = gt_modes[k].flatten()
pm = vecs.flatten()
cos = abs(np.dot(pm, gm)) / (np.linalg.norm(pm) * np.linalg.norm(gm) + 1e-12)
c5.metric("Overlap GT", f"{cos:.3f}", help="|cos ΞΈ| = directional agreement with ground truth mode")
# Top mobile residues
top5 = np.argsort(mags)[-5:][::-1]
top_data = [{"Residue": f"{seq[i] if i < len(seq) else '?'}{i+1}",
"Ξ” (Γ…)": f"{mags[i]:.3f}",
"Rank": f"#{np.where(np.argsort(mags)[::-1] == i)[0][0]+1}"}
for i in top5]
st.markdown("**Most mobile residues [PREDICTION]:**")
st.dataframe(top_data, use_container_width=True, hide_index=True)
return st.session_state.get("_active_mode_tab", 0)
# ═══════════════════════════════════════════════════════════════════
# RMSIP comparison
# ═══════════════════════════════════════════════════════════════════
def render_rmsip_comparison(pred_modes: dict, gt: dict, n_res: int):
"""Full RMSIP dashboard β€” global score + per-mode overlap matrix.
Compares predicted subspace against NMA ground truth.
"""
gt_modes = _parse_gt_modes(gt, n_res)
if not gt_modes:
st.info("No ground truth eigenvectors available for RMSIP computation.")
return
n_pred = len(pred_modes)
n_gt = len(gt_modes)
# Global RMSIP
pred_list = [pred_modes[k] for k in sorted(pred_modes.keys())]
gt_list = [gt_modes[k] for k in sorted(gt_modes.keys())]
rmsip = _rmsip(pred_list, gt_list)
st.markdown(f"""
<div style="background: linear-gradient(135deg, #1e1b4b, #312e81); border: 1px solid #6366f1;
border-radius: 12px; padding: 16px 24px; margin-bottom: 16px;">
<div style="font-size: 0.85rem; color: #a5b4fc; text-transform: uppercase; letter-spacing: 0.05em;">
RMSIP β€” Root Mean Square Inner Product
</div>
<div style="font-size: 2.2rem; font-weight: 800; color: {'#10b981' if rmsip > 0.7 else '#f59e0b' if rmsip > 0.4 else '#ef4444'};">
{rmsip:.4f}
</div>
<div style="font-size: 0.8rem; color: #94a3b8; margin-top: 4px;">
{'🟒 Excellent subspace alignment' if rmsip > 0.7 else '🟑 Partial alignment' if rmsip > 0.4 else 'πŸ”΄ Low overlap β€” modes differ significantly'}
</div>
</div>
""", unsafe_allow_html=True)
with st.expander("ℹ️ What is RMSIP?", expanded=False):
st.markdown("""
**RMSIP** measures how well two sets of normal modes span the same subspace.
`RMSIP = sqrt( mean(|<Ο†_pred | Ο†_gt>|Β²) )` over all pairs of modes.
- **RMSIP = 1** β†’ predicted modes are identical to NMA ground truth
- **RMSIP = 0** β†’ completely orthogonal subspaces (no overlap)
- **> 0.7** is considered excellent for NMA predictions in the literature
""")
# Per-mode overlap matrix
mat = _per_mode_overlap(pred_modes, gt_modes)
fig = go.Figure(go.Heatmap(
z=mat,
x=[f"GT M{k}" for k in sorted(gt_modes.keys())],
y=[f"Pred M{k}" for k in sorted(pred_modes.keys())],
colorscale="Viridis",
zmin=0, zmax=1,
text=np.round(mat, 3),
texttemplate="%{text:.2f}",
textfont={"size": 11},
hovertemplate="<b>%{y}</b> vs <b>%{x}</b><br>|cos ΞΈ| = %{z:.4f}<extra></extra>",
))
fig.update_layout(
title="[PREDICTION vs GT] Per-mode overlap matrix |cos ΞΈ|",
template="plotly_dark",
height=300,
paper_bgcolor="rgba(0,0,0,0)",
plot_bgcolor="rgba(30,27,75,0.5)",
margin=dict(l=80, r=20, t=50, b=60),
xaxis_title="Ground Truth Modes",
yaxis_title="Predicted Modes",
)
st.plotly_chart(fig, use_container_width=True)
# ═══════════════════════════════════════════════════════════════════
# Per-residue Pred vs GT comparison
# ═══════════════════════════════════════════════════════════════════
def render_pred_vs_gt_displacement(pred_modes: dict, gt: dict, n_res: int, seq: str = ""):
"""Per-residue displacement magnitude: Prediction vs Ground Truth for each mode."""
gt_modes = _parse_gt_modes(gt, n_res)
if not gt_modes:
st.info("No ground truth eigenvectors for comparison.")
return
residues = list(range(1, n_res + 1))
hover = [f"{seq[i] if i < len(seq) else '?'}{i+1}" for i in range(n_res)]
pred_colors = ["#6366f1", "#818cf8", "#a5b4fc", "#c7d2fe"]
gt_colors = ["#10b981", "#34d399", "#6ee7b7", "#a7f3d0"]
n_cols = min(len(pred_modes), len(gt_modes), 2)
cols = st.columns(n_cols)
for k, col in enumerate(cols):
if k not in pred_modes or k not in gt_modes:
continue
with col:
pred_mags = np.linalg.norm(pred_modes[k], axis=1)
gt_mags = np.linalg.norm(gt_modes[k], axis=1)
# Normalize both to the same scale for comparison
pred_norm = pred_mags / (pred_mags.max() + 1e-8)
gt_norm = gt_mags / (gt_mags.max() + 1e-8)
# Pearson correlation
corr = float(np.corrcoef(pred_norm, gt_norm)[0, 1])
fig = go.Figure()
fig.add_trace(go.Scatter(
x=residues, y=pred_norm,
name="PETIMOT Pred",
mode="lines",
line=dict(color=pred_colors[k % len(pred_colors)], width=2),
text=hover,
hovertemplate="%{text}<br>Pred: %{y:.3f}<extra></extra>",
))
fig.add_trace(go.Scatter(
x=residues, y=gt_norm,
name="NMA Ground Truth",
mode="lines",
line=dict(color=gt_colors[k % len(gt_colors)], width=2, dash="dot"),
text=hover,
hovertemplate="%{text}<br>GT: %{y:.3f}<extra></extra>",
))
fig.update_layout(
title=f"Mode {k} [r = {corr:.3f}]",
template="plotly_dark",
height=280,
paper_bgcolor="rgba(0,0,0,0)",
plot_bgcolor="rgba(30,27,75,0.3)",
xaxis_title="Residue",
yaxis_title="Normalised Ξ”",
legend=dict(orientation="h", y=1.15),
margin=dict(l=50, r=10, t=50, b=40),
)
st.plotly_chart(fig, use_container_width=True)
# ═══════════════════════════════════════════════════════════════════
# Mode correlation matrix
# ═══════════════════════════════════════════════════════════════════
def render_mode_correlation(modes: dict):
"""Residue-level displacement correlation between modes (heatmap)."""
n_modes = len(modes)
if n_modes < 2:
return
profiles = [np.linalg.norm(modes[k], axis=1) for k in sorted(modes.keys())]
corr = np.corrcoef(profiles)
fig = go.Figure(go.Heatmap(
z=corr,
x=[f"M{k}" for k in range(n_modes)],
y=[f"M{k}" for k in range(n_modes)],
colorscale="RdBu_r", zmin=-1, zmax=1,
text=np.round(corr, 2), texttemplate="%{text:.2f}",
textfont={"size": 12},
))
fig.update_layout(
title="Mode Displacement Correlation [PREDICTION]",
template="plotly_dark", height=300, width=300,
paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(30,27,75,0.5)",
margin=dict(l=30, r=30, t=40, b=30),
)
st.plotly_chart(fig, use_container_width=False)
# ═══════════════════════════════════════════════════════════════════
# Eigenvalue spectrum
# ═══════════════════════════════════════════════════════════════════
def render_eigenvalue_spectrum(eigenvalues: np.ndarray):
"""Eigenvalue bar chart with cumulative variance line."""
if eigenvalues is None or len(eigenvalues) == 0:
return
fig = go.Figure()
fig.add_trace(go.Bar(
x=[f"Ξ»{k+1}" for k in range(len(eigenvalues))],
y=eigenvalues, marker_color="#6366f1", name="Eigenvalue",
))
cum = np.cumsum(eigenvalues) / eigenvalues.sum() * 100
fig.add_trace(go.Scatter(
x=[f"Ξ»{k+1}" for k in range(len(eigenvalues))],
y=cum, mode="lines+markers", name="Cumul. variance %",
marker=dict(color="#ef4444", size=6),
line=dict(color="#ef4444", width=2), yaxis="y2",
))
fig.update_layout(
title="Eigenvalue Spectrum [GROUND TRUTH]",
template="plotly_dark", height=250,
paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(30,27,75,0.5)",
yaxis=dict(title="Eigenvalue"),
yaxis2=dict(title="Cumul. %", overlaying="y", side="right", range=[0, 105]),
legend=dict(orientation="h", y=1.15),
margin=dict(l=40, r=40, t=40, b=30),
)
st.plotly_chart(fig, use_container_width=True)
# ═══════════════════════════════════════════════════════════════════
# Dot-product directional agreement
# ═══════════════════════════════════════════════════════════════════
def render_mode_dotproduct(modes: dict, seq: str = ""):
"""Per-residue directional agreement between predicted modes (cos ΞΈ)."""
n_modes = len(modes)
if n_modes < 2:
st.info("Need β‰₯2 modes for directional agreement analysis.")
return
n_res = len(list(modes.values())[0])
residues = np.arange(1, n_res + 1)
# Global dot-product matrix
global_dp = np.zeros((n_modes, n_modes))
for i in range(n_modes):
for j in range(n_modes):
vi, vj = modes[i].flatten(), modes[j].flatten()
ni, nj = np.linalg.norm(vi), np.linalg.norm(vj)
if ni > 1e-8 and nj > 1e-8:
global_dp[i, j] = np.dot(vi, vj) / (ni * nj)
col1, col2 = st.columns([1, 2])
with col1:
fig = go.Figure(go.Heatmap(
z=global_dp,
x=[f"M{k}" for k in range(n_modes)],
y=[f"M{k}" for k in range(n_modes)],
colorscale="RdBu_r", zmin=-1, zmax=1,
text=np.round(global_dp, 3), texttemplate="%{text:.3f}",
textfont={"size": 11},
))
fig.update_layout(
title="Global 3D Dot Product (cos ΞΈ) [PREDICTION]",
template="plotly_dark", height=300, width=300,
paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(30,27,75,0.5)",
margin=dict(l=30, r=30, t=40, b=30),
)
st.plotly_chart(fig, use_container_width=False)
st.markdown("> **Β±1** = parallel modes Β· **0** = orthogonal (independent)")
with col2:
st.markdown("**Per-residue directional agreement M0Β·M1 [PREDICTION]:**")
per_res = np.array([
np.dot(modes[0][r], modes[1][r]) /
(np.linalg.norm(modes[0][r]) * np.linalg.norm(modes[1][r]) + 1e-12)
for r in range(n_res)
])
hover = [f"{seq[i] if i < len(seq) else '?'}{i+1}" for i in range(n_res)]
fig2 = go.Figure(go.Bar(
x=residues, y=per_res,
marker_color=["#ef4444" if abs(c) > 0.7 else "#f59e0b" if abs(c) > 0.3 else "#10b981"
for c in per_res],
text=hover,
hovertemplate="%{text}<br>cos(ΞΈ): %{y:.3f}<extra></extra>",
))
fig2.add_hline(y=0, line_dash="dash", line_color="#64748b")
fig2.update_layout(
template="plotly_dark", height=300,
paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(30,27,75,0.3)",
xaxis_title="Residue", yaxis_title="cos(ΞΈ) M0Β·M1",
yaxis_range=[-1.05, 1.05], margin=dict(l=50, r=20, t=10, b=40),
)
st.plotly_chart(fig2, use_container_width=True)
# ═══════════════════════════════════════════════════════════════════
# Mode overlay chart
# ═══════════════════════════════════════════════════════════════════
def render_mode_overlay(modes: dict, seq: str = ""):
"""All modes overlaid on one displacement magnitude chart."""
n_modes = len(modes)
if n_modes == 0:
return
n_res = len(list(modes.values())[0])
colors = ["#6366f1", "#ef4444", "#10b981", "#f59e0b", "#ec4899", "#8b5cf6"]
hover = [f"{seq[i] if i < len(seq) else '?'}{i+1}" for i in range(n_res)]
fig = go.Figure()
for k in range(min(n_modes, 6)):
mags = np.linalg.norm(modes[k], axis=1)
fig.add_trace(go.Scatter(
x=list(range(1, n_res + 1)), y=mags,
mode="lines", name=f"Mode {k}",
line=dict(color=colors[k % len(colors)], width=2),
text=hover,
hovertemplate="%{text}<br>%{y:.3f} Γ…<extra>Mode " + str(k) + "</extra>",
))
fig.update_layout(
title="All Modes β€” Displacement Magnitude Overlay [PREDICTION]",
template="plotly_dark", height=350,
paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(30,27,75,0.3)",
xaxis_title="Residue", yaxis_title="Displacement (Γ…)",
legend=dict(orientation="h", y=1.12),
margin=dict(l=50, r=20, t=50, b=40),
)
st.plotly_chart(fig, use_container_width=True)