| """Enhanced prediction analysis β sign-invariant modes and per-residue normalization.""" |
| import numpy as np |
| import streamlit as st |
| import plotly.graph_objects as go |
| from plotly.subplots import make_subplots |
|
|
|
|
| def canonicalize_sign(modes: dict) -> dict: |
| """Make eigenvectors sign-consistent. |
| |
| Eigenvectors are defined up to Β±1 global sign. We canonicalize by choosing |
| the sign such that the component with the largest absolute value is positive. |
| This ensures consistent visualization across different runs/proteins. |
| """ |
| canonical = {} |
| for k, vecs in modes.items(): |
| |
| flat = vecs.flatten() |
| max_idx = np.argmax(np.abs(flat)) |
| if flat[max_idx] < 0: |
| canonical[k] = -vecs |
| else: |
| canonical[k] = vecs.copy() |
| return canonical |
|
|
|
|
| def per_residue_relative_norm(vecs: np.ndarray) -> np.ndarray: |
| """Normalize displacement magnitudes to [0, 1] relative to max. |
| |
| Args: |
| vecs: (N, 3) displacement vectors |
| |
| Returns: |
| (N,) relative magnitudes in [0, 1] |
| """ |
| mags = np.linalg.norm(vecs, axis=1) |
| max_m = mags.max() |
| return mags / max_m if max_m > 1e-12 else mags |
|
|
|
|
| def per_residue_direction(vecs: np.ndarray, ca_coords: np.ndarray) -> np.ndarray: |
| """Compute relative direction of displacement vs protein backbone. |
| |
| Projects displacement onto local backbone direction (CA_i β CA_{i+1}). |
| Returns signed projection: positive = along backbone, negative = against. |
| |
| Args: |
| vecs: (N, 3) displacement vectors |
| ca_coords: (N, 3) CA coordinates |
| |
| Returns: |
| (N,) signed projections normalized by displacement magnitude |
| """ |
| n = len(vecs) |
| projections = np.zeros(n) |
|
|
| for i in range(n): |
| |
| if i < n - 1: |
| backbone = ca_coords[i + 1] - ca_coords[i] |
| else: |
| backbone = ca_coords[i] - ca_coords[i - 1] |
|
|
| bb_norm = np.linalg.norm(backbone) |
| if bb_norm < 1e-8: |
| continue |
|
|
| disp_mag = np.linalg.norm(vecs[i]) |
| if disp_mag < 1e-8: |
| continue |
|
|
| |
| projections[i] = np.dot(vecs[i], backbone) / (disp_mag * bb_norm) |
|
|
| return projections |
|
|
|
|
| def render_prediction_analysis( |
| modes: dict, |
| seq: str, |
| ca_coords: np.ndarray = None, |
| coverage: np.ndarray = None, |
| eigenvalues: np.ndarray = None, |
| gt_modes: dict = None, |
| protein_name: str = "", |
| ): |
| """Comprehensive prediction analysis panel. |
| |
| Shows: |
| 1. Normalized displacement heatmap (all modes Γ residues) |
| 2. Sign-canonical direction analysis |
| 3. Prediction vs ground truth comparison (if available) |
| 4. Per-residue statistics table |
| """ |
| |
| modes_c = canonicalize_sign(modes) |
| n_modes = len(modes_c) |
| n_res = len(list(modes_c.values())[0]) |
|
|
| if coverage is None: |
| coverage = np.ones(n_res) |
|
|
| |
| tab_norm, tab_dir, tab_compare, tab_table = st.tabs([ |
| "π Normalized Displacement", "π§ Direction Analysis", |
| "βοΈ Pred vs GT", "π Per-Residue Table" |
| ]) |
|
|
| |
| |
| |
| with tab_norm: |
| |
| rel_norms = np.zeros((n_modes, n_res)) |
| abs_mags = np.zeros((n_modes, n_res)) |
| for k in range(n_modes): |
| abs_mags[k] = np.linalg.norm(modes_c[k], axis=1) |
| rel_norms[k] = per_residue_relative_norm(modes_c[k]) |
|
|
| |
| hover = [[f"{seq[j] if j < len(seq) else '?'}{j+1}<br>" |
| f"Abs: {abs_mags[k][j]:.3f}Γ
<br>" |
| f"Rel: {rel_norms[k][j]:.2%}<br>" |
| f"Cov: {coverage[j]:.2f}" |
| for j in range(n_res)] for k in range(n_modes)] |
|
|
| fig = make_subplots(rows=3, cols=1, row_heights=[0.4, 0.4, 0.2], |
| shared_xaxes=True, vertical_spacing=0.06, |
| subplot_titles=["Absolute Displacement (Γ
)", |
| "Relative Displacement (0-1)", |
| "Coverage"]) |
|
|
| |
| fig.add_trace(go.Heatmap( |
| z=abs_mags, colorscale="YlOrRd", |
| y=[f"Mode {k}" for k in range(n_modes)], |
| text=hover, hovertemplate="%{text}<extra></extra>", |
| colorbar=dict(title="Γ
", x=1.01, len=0.35, y=0.85), |
| ), row=1, col=1) |
|
|
| |
| fig.add_trace(go.Heatmap( |
| z=rel_norms, colorscale="Viridis", zmin=0, zmax=1, |
| y=[f"Mode {k}" for k in range(n_modes)], |
| text=hover, hovertemplate="%{text}<extra></extra>", |
| colorbar=dict(title="Rel", x=1.08, len=0.35, y=0.5), |
| ), row=2, col=1) |
|
|
| |
| fig.add_trace(go.Bar( |
| x=list(range(n_res)), y=coverage[:n_res], |
| marker_color=["#10b981" if c > 0.5 else "#ef4444" for c in coverage[:n_res]], |
| hovertemplate="Res %{x}<br>Coverage: %{y:.3f}<extra></extra>", |
| showlegend=False, |
| ), row=3, col=1) |
|
|
| |
| step = max(1, n_res // 50) |
| tick_vals = list(range(0, n_res, step)) |
| tick_text = [f"{seq[i] if i < len(seq) else '?'}{i+1}" for i in tick_vals] |
| fig.update_xaxes(tickvals=tick_vals, ticktext=tick_text, tickangle=45, |
| tickfont=dict(size=8), row=3, col=1) |
|
|
| fig.update_layout( |
| template="plotly_dark", height=550, |
| paper_bgcolor="rgba(0,0,0,0)", |
| plot_bgcolor="rgba(30,27,75,0.3)", |
| margin=dict(l=60, r=80, t=30, b=50), |
| ) |
| st.plotly_chart(fig, use_container_width=True) |
|
|
| |
| for k in range(min(n_modes, 4)): |
| top3 = np.argsort(abs_mags[k])[-3:][::-1] |
| top_str = ", ".join([f"**{seq[i] if i<len(seq) else '?'}{i+1}** ({abs_mags[k][i]:.2f}Γ
)" |
| for i in top3]) |
| st.markdown(f"Mode {k} hotspots: {top_str}") |
|
|
| |
| |
| |
| with tab_dir: |
| if ca_coords is not None and len(ca_coords) == n_res: |
| st.markdown(""" |
| **Direction Analysis**: Projects displacement onto the local backbone direction (CAβCA). |
| - π΅ **Blue** = motion along backbone (stretching/compressing) |
| - π΄ **Red** = motion perpendicular to backbone (lateral/hinge) |
| - Sign is arbitrary for eigenvectors β we show absolute cosine similarity |
| """) |
|
|
| fig_dir = go.Figure() |
| colors = ["#6366f1", "#ef4444", "#10b981", "#f59e0b"] |
|
|
| for k in range(min(n_modes, 4)): |
| proj = per_residue_direction(modes_c[k], ca_coords) |
| |
| abs_proj = np.abs(proj) |
|
|
| _fill_map = { |
| "#6366f1": "rgba(99,102,241,0.12)", |
| "#ef4444": "rgba(239,68,68,0.12)", |
| "#10b981": "rgba(16,185,129,0.12)", |
| "#f59e0b": "rgba(245,158,11,0.12)", |
| } |
| fig_dir.add_trace(go.Scatter( |
| x=list(range(1, n_res + 1)), y=abs_proj, |
| mode="lines", name=f"Mode {k}", |
| line=dict(color=colors[k], width=1.5), |
| fill="tozeroy", |
| fillcolor=_fill_map.get(colors[k], "rgba(99,102,241,0.12)"), |
| hovertemplate="Res %{x}<br>|cos ΞΈ|: %{y:.3f}<extra>Mode " + str(k) + "</extra>", |
| )) |
|
|
| fig_dir.add_hline(y=0.5, line_dash="dash", line_color="#94a3b8", |
| annotation_text="isotropic threshold") |
|
|
| fig_dir.update_layout( |
| template="plotly_dark", height=350, |
| paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(30,27,75,0.3)", |
| xaxis_title="Residue", yaxis_title="|cos ΞΈ| (backbone projection)", |
| yaxis_range=[0, 1.05], |
| margin=dict(l=50, r=20, t=30, b=50), |
| ) |
| st.plotly_chart(fig_dir, use_container_width=True) |
|
|
| |
| st.markdown("**Per-residue Γ mode direction matrix:**") |
| dir_matrix = np.zeros((n_modes, n_res)) |
| for k in range(n_modes): |
| dir_matrix[k] = np.abs(per_residue_direction(modes_c[k], ca_coords)) |
|
|
| fig_dh = go.Figure(go.Heatmap( |
| z=dir_matrix, colorscale="RdBu_r", zmin=0, zmax=1, |
| y=[f"Mode {k}" for k in range(n_modes)], |
| colorbar=dict(title="|cos ΞΈ|"), |
| )) |
| fig_dh.update_layout( |
| template="plotly_dark", height=200, |
| paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(30,27,75,0.3)", |
| margin=dict(l=60, r=60, t=10, b=30), |
| ) |
| st.plotly_chart(fig_dh, use_container_width=True) |
| else: |
| st.info("Direction analysis requires CA coordinates (ground truth or PDB needed)") |
|
|
| |
| |
| |
| with tab_compare: |
| if gt_modes is not None and len(gt_modes) > 0: |
| gt_c = canonicalize_sign(gt_modes) |
| n_gt = len(gt_c) |
|
|
| st.markdown("**Pred vs GT displacement profiles (sign-canonicalized):**") |
|
|
| for k in range(min(n_modes, n_gt, 4)): |
| pred_mag = np.linalg.norm(modes_c[k], axis=1) |
| gt_mag = np.linalg.norm(gt_c[k], axis=1) |
|
|
| |
| pred_rel = pred_mag / (pred_mag.max() + 1e-12) |
| gt_rel = gt_mag / (gt_mag.max() + 1e-12) |
|
|
| fig_cmp = go.Figure() |
| fig_cmp.add_trace(go.Scatter( |
| x=list(range(1, n_res + 1)), y=gt_rel, |
| mode="lines", name="Ground Truth", |
| line=dict(color="#10b981", width=2), |
| )) |
| fig_cmp.add_trace(go.Scatter( |
| x=list(range(1, n_res + 1)), y=pred_rel, |
| mode="lines", name="Prediction", |
| line=dict(color="#6366f1", width=2, dash="dot"), |
| )) |
|
|
| |
| corr = np.corrcoef(pred_rel, gt_rel)[0, 1] |
| rmse = np.sqrt(np.mean((pred_rel - gt_rel) ** 2)) |
|
|
| fig_cmp.update_layout( |
| template="plotly_dark", height=200, |
| title=f"Mode {k} β r={corr:.3f}, RMSE={rmse:.3f}", |
| paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(30,27,75,0.3)", |
| margin=dict(l=40, r=20, t=40, b=30), |
| legend=dict(orientation="h", y=1.15), |
| ) |
| st.plotly_chart(fig_cmp, use_container_width=True) |
| else: |
| st.info("No ground truth available for comparison. " |
| "Ground truth is only available for proteins in the training database.") |
|
|
| |
| |
| |
| with tab_table: |
| import pandas as pd |
|
|
| rows = [] |
| for i in range(n_res): |
| row = { |
| "Residue": i + 1, |
| "AA": seq[i] if i < len(seq) else "?", |
| "Coverage": f"{coverage[i]:.3f}" if i < len(coverage) else "β", |
| } |
| for k in range(min(n_modes, 4)): |
| mag = np.linalg.norm(modes_c[k][i]) |
| rel = per_residue_relative_norm(modes_c[k])[i] |
| row[f"M{k} (Γ
)"] = f"{mag:.3f}" |
| row[f"M{k} rel"] = f"{rel:.2%}" |
| rows.append(row) |
|
|
| df = pd.DataFrame(rows) |
| st.dataframe(df, use_container_width=True, height=500, |
| column_config={ |
| "Residue": st.column_config.NumberColumn(width="small"), |
| "AA": st.column_config.TextColumn(width="small"), |
| }) |
|
|
| |
| csv = df.to_csv(index=False) |
| st.download_button("π₯ Download CSV", csv, |
| f"{protein_name}_analysis.csv", "text/csv") |
|
|