"""Enhanced prediction analysis — sign-invariant modes and per-residue normalization.""" import numpy as np import streamlit as st import plotly.graph_objects as go from plotly.subplots import make_subplots def canonicalize_sign(modes: dict) -> dict: """Make eigenvectors sign-consistent. Eigenvectors are defined up to ±1 global sign. We canonicalize by choosing the sign such that the component with the largest absolute value is positive. This ensures consistent visualization across different runs/proteins. """ canonical = {} for k, vecs in modes.items(): # Flatten to (3N,), find component with max absolute value flat = vecs.flatten() max_idx = np.argmax(np.abs(flat)) if flat[max_idx] < 0: canonical[k] = -vecs # Flip sign else: canonical[k] = vecs.copy() return canonical def per_residue_relative_norm(vecs: np.ndarray) -> np.ndarray: """Normalize displacement magnitudes to [0, 1] relative to max. Args: vecs: (N, 3) displacement vectors Returns: (N,) relative magnitudes in [0, 1] """ mags = np.linalg.norm(vecs, axis=1) max_m = mags.max() return mags / max_m if max_m > 1e-12 else mags def per_residue_direction(vecs: np.ndarray, ca_coords: np.ndarray) -> np.ndarray: """Compute relative direction of displacement vs protein backbone. Projects displacement onto local backbone direction (CA_i → CA_{i+1}). Returns signed projection: positive = along backbone, negative = against. Args: vecs: (N, 3) displacement vectors ca_coords: (N, 3) CA coordinates Returns: (N,) signed projections normalized by displacement magnitude """ n = len(vecs) projections = np.zeros(n) for i in range(n): # Local backbone direction if i < n - 1: backbone = ca_coords[i + 1] - ca_coords[i] else: backbone = ca_coords[i] - ca_coords[i - 1] bb_norm = np.linalg.norm(backbone) if bb_norm < 1e-8: continue disp_mag = np.linalg.norm(vecs[i]) if disp_mag < 1e-8: continue # Cosine angle between displacement and backbone direction projections[i] = np.dot(vecs[i], backbone) / (disp_mag * bb_norm) return projections def render_prediction_analysis( modes: dict, seq: str, ca_coords: np.ndarray = None, coverage: np.ndarray = None, eigenvalues: np.ndarray = None, gt_modes: dict = None, protein_name: str = "", ): """Comprehensive prediction analysis panel. Shows: 1. Normalized displacement heatmap (all modes × residues) 2. Sign-canonical direction analysis 3. Prediction vs ground truth comparison (if available) 4. Per-residue statistics table """ # Canonicalize signs modes_c = canonicalize_sign(modes) n_modes = len(modes_c) n_res = len(list(modes_c.values())[0]) if coverage is None: coverage = np.ones(n_res) # ── Tab layout ── tab_norm, tab_dir, tab_compare, tab_table = st.tabs([ "📊 Normalized Displacement", "🧭 Direction Analysis", "⚖️ Pred vs GT", "📋 Per-Residue Table" ]) # ═══════════════════════════════════════════ # Tab 1: Normalized displacement heatmap # ═══════════════════════════════════════════ with tab_norm: # Compute relative norms for all modes rel_norms = np.zeros((n_modes, n_res)) abs_mags = np.zeros((n_modes, n_res)) for k in range(n_modes): abs_mags[k] = np.linalg.norm(modes_c[k], axis=1) rel_norms[k] = per_residue_relative_norm(modes_c[k]) # Hover text with sequence hover = [[f"{seq[j] if j < len(seq) else '?'}{j+1}
" f"Abs: {abs_mags[k][j]:.3f}Å
" f"Rel: {rel_norms[k][j]:.2%}
" f"Cov: {coverage[j]:.2f}" for j in range(n_res)] for k in range(n_modes)] fig = make_subplots(rows=3, cols=1, row_heights=[0.4, 0.4, 0.2], shared_xaxes=True, vertical_spacing=0.06, subplot_titles=["Absolute Displacement (Å)", "Relative Displacement (0-1)", "Coverage"]) # Absolute heatmap fig.add_trace(go.Heatmap( z=abs_mags, colorscale="YlOrRd", y=[f"Mode {k}" for k in range(n_modes)], text=hover, hovertemplate="%{text}", colorbar=dict(title="Å", x=1.01, len=0.35, y=0.85), ), row=1, col=1) # Relative heatmap fig.add_trace(go.Heatmap( z=rel_norms, colorscale="Viridis", zmin=0, zmax=1, y=[f"Mode {k}" for k in range(n_modes)], text=hover, hovertemplate="%{text}", colorbar=dict(title="Rel", x=1.08, len=0.35, y=0.5), ), row=2, col=1) # Coverage bar fig.add_trace(go.Bar( x=list(range(n_res)), y=coverage[:n_res], marker_color=["#10b981" if c > 0.5 else "#ef4444" for c in coverage[:n_res]], hovertemplate="Res %{x}
Coverage: %{y:.3f}", showlegend=False, ), row=3, col=1) # Sequence ticks step = max(1, n_res // 50) tick_vals = list(range(0, n_res, step)) tick_text = [f"{seq[i] if i < len(seq) else '?'}{i+1}" for i in tick_vals] fig.update_xaxes(tickvals=tick_vals, ticktext=tick_text, tickangle=45, tickfont=dict(size=8), row=3, col=1) fig.update_layout( template="plotly_dark", height=550, paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(30,27,75,0.3)", margin=dict(l=60, r=80, t=30, b=50), ) st.plotly_chart(fig, use_container_width=True) # Key insight for k in range(min(n_modes, 4)): top3 = np.argsort(abs_mags[k])[-3:][::-1] top_str = ", ".join([f"**{seq[i] if i", )) fig_dir.add_hline(y=0.5, line_dash="dash", line_color="#94a3b8", annotation_text="isotropic threshold") fig_dir.update_layout( template="plotly_dark", height=350, paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(30,27,75,0.3)", xaxis_title="Residue", yaxis_title="|cos θ| (backbone projection)", yaxis_range=[0, 1.05], margin=dict(l=50, r=20, t=30, b=50), ) st.plotly_chart(fig_dir, use_container_width=True) # Direction heatmap st.markdown("**Per-residue × mode direction matrix:**") dir_matrix = np.zeros((n_modes, n_res)) for k in range(n_modes): dir_matrix[k] = np.abs(per_residue_direction(modes_c[k], ca_coords)) fig_dh = go.Figure(go.Heatmap( z=dir_matrix, colorscale="RdBu_r", zmin=0, zmax=1, y=[f"Mode {k}" for k in range(n_modes)], colorbar=dict(title="|cos θ|"), )) fig_dh.update_layout( template="plotly_dark", height=200, paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(30,27,75,0.3)", margin=dict(l=60, r=60, t=10, b=30), ) st.plotly_chart(fig_dh, use_container_width=True) else: st.info("Direction analysis requires CA coordinates (ground truth or PDB needed)") # ═══════════════════════════════════════════ # Tab 3: Prediction vs Ground Truth # ═══════════════════════════════════════════ with tab_compare: if gt_modes is not None and len(gt_modes) > 0: gt_c = canonicalize_sign(gt_modes) n_gt = len(gt_c) st.markdown("**Pred vs GT displacement profiles (sign-canonicalized):**") for k in range(min(n_modes, n_gt, 4)): pred_mag = np.linalg.norm(modes_c[k], axis=1) gt_mag = np.linalg.norm(gt_c[k], axis=1) # Normalize both to [0, 1] pred_rel = pred_mag / (pred_mag.max() + 1e-12) gt_rel = gt_mag / (gt_mag.max() + 1e-12) fig_cmp = go.Figure() fig_cmp.add_trace(go.Scatter( x=list(range(1, n_res + 1)), y=gt_rel, mode="lines", name="Ground Truth", line=dict(color="#10b981", width=2), )) fig_cmp.add_trace(go.Scatter( x=list(range(1, n_res + 1)), y=pred_rel, mode="lines", name="Prediction", line=dict(color="#6366f1", width=2, dash="dot"), )) # Correlation corr = np.corrcoef(pred_rel, gt_rel)[0, 1] rmse = np.sqrt(np.mean((pred_rel - gt_rel) ** 2)) fig_cmp.update_layout( template="plotly_dark", height=200, title=f"Mode {k} — r={corr:.3f}, RMSE={rmse:.3f}", paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(30,27,75,0.3)", margin=dict(l=40, r=20, t=40, b=30), legend=dict(orientation="h", y=1.15), ) st.plotly_chart(fig_cmp, use_container_width=True) else: st.info("No ground truth available for comparison. " "Ground truth is only available for proteins in the training database.") # ═══════════════════════════════════════════ # Tab 4: Per-residue table # ═══════════════════════════════════════════ with tab_table: import pandas as pd rows = [] for i in range(n_res): row = { "Residue": i + 1, "AA": seq[i] if i < len(seq) else "?", "Coverage": f"{coverage[i]:.3f}" if i < len(coverage) else "—", } for k in range(min(n_modes, 4)): mag = np.linalg.norm(modes_c[k][i]) rel = per_residue_relative_norm(modes_c[k])[i] row[f"M{k} (Å)"] = f"{mag:.3f}" row[f"M{k} rel"] = f"{rel:.2%}" rows.append(row) df = pd.DataFrame(rows) st.dataframe(df, use_container_width=True, height=500, column_config={ "Residue": st.column_config.NumberColumn(width="small"), "AA": st.column_config.TextColumn(width="small"), }) # Download CSV csv = df.to_csv(index=False) st.download_button("📥 Download CSV", csv, f"{protein_name}_analysis.csv", "text/csv")