Petimot / app /components /prediction_analysis.py
Valmbd's picture
Fix: rgba fillcolor bug + replace statsmodels lowess with scipy binned trendline
ee42d0e verified
"""Enhanced prediction analysis β€” sign-invariant modes and per-residue normalization."""
import numpy as np
import streamlit as st
import plotly.graph_objects as go
from plotly.subplots import make_subplots
def canonicalize_sign(modes: dict) -> dict:
"""Make eigenvectors sign-consistent.
Eigenvectors are defined up to Β±1 global sign. We canonicalize by choosing
the sign such that the component with the largest absolute value is positive.
This ensures consistent visualization across different runs/proteins.
"""
canonical = {}
for k, vecs in modes.items():
# Flatten to (3N,), find component with max absolute value
flat = vecs.flatten()
max_idx = np.argmax(np.abs(flat))
if flat[max_idx] < 0:
canonical[k] = -vecs # Flip sign
else:
canonical[k] = vecs.copy()
return canonical
def per_residue_relative_norm(vecs: np.ndarray) -> np.ndarray:
"""Normalize displacement magnitudes to [0, 1] relative to max.
Args:
vecs: (N, 3) displacement vectors
Returns:
(N,) relative magnitudes in [0, 1]
"""
mags = np.linalg.norm(vecs, axis=1)
max_m = mags.max()
return mags / max_m if max_m > 1e-12 else mags
def per_residue_direction(vecs: np.ndarray, ca_coords: np.ndarray) -> np.ndarray:
"""Compute relative direction of displacement vs protein backbone.
Projects displacement onto local backbone direction (CA_i β†’ CA_{i+1}).
Returns signed projection: positive = along backbone, negative = against.
Args:
vecs: (N, 3) displacement vectors
ca_coords: (N, 3) CA coordinates
Returns:
(N,) signed projections normalized by displacement magnitude
"""
n = len(vecs)
projections = np.zeros(n)
for i in range(n):
# Local backbone direction
if i < n - 1:
backbone = ca_coords[i + 1] - ca_coords[i]
else:
backbone = ca_coords[i] - ca_coords[i - 1]
bb_norm = np.linalg.norm(backbone)
if bb_norm < 1e-8:
continue
disp_mag = np.linalg.norm(vecs[i])
if disp_mag < 1e-8:
continue
# Cosine angle between displacement and backbone direction
projections[i] = np.dot(vecs[i], backbone) / (disp_mag * bb_norm)
return projections
def render_prediction_analysis(
modes: dict,
seq: str,
ca_coords: np.ndarray = None,
coverage: np.ndarray = None,
eigenvalues: np.ndarray = None,
gt_modes: dict = None,
protein_name: str = "",
):
"""Comprehensive prediction analysis panel.
Shows:
1. Normalized displacement heatmap (all modes Γ— residues)
2. Sign-canonical direction analysis
3. Prediction vs ground truth comparison (if available)
4. Per-residue statistics table
"""
# Canonicalize signs
modes_c = canonicalize_sign(modes)
n_modes = len(modes_c)
n_res = len(list(modes_c.values())[0])
if coverage is None:
coverage = np.ones(n_res)
# ── Tab layout ──
tab_norm, tab_dir, tab_compare, tab_table = st.tabs([
"πŸ“Š Normalized Displacement", "🧭 Direction Analysis",
"βš–οΈ Pred vs GT", "πŸ“‹ Per-Residue Table"
])
# ═══════════════════════════════════════════
# Tab 1: Normalized displacement heatmap
# ═══════════════════════════════════════════
with tab_norm:
# Compute relative norms for all modes
rel_norms = np.zeros((n_modes, n_res))
abs_mags = np.zeros((n_modes, n_res))
for k in range(n_modes):
abs_mags[k] = np.linalg.norm(modes_c[k], axis=1)
rel_norms[k] = per_residue_relative_norm(modes_c[k])
# Hover text with sequence
hover = [[f"{seq[j] if j < len(seq) else '?'}{j+1}<br>"
f"Abs: {abs_mags[k][j]:.3f}Γ…<br>"
f"Rel: {rel_norms[k][j]:.2%}<br>"
f"Cov: {coverage[j]:.2f}"
for j in range(n_res)] for k in range(n_modes)]
fig = make_subplots(rows=3, cols=1, row_heights=[0.4, 0.4, 0.2],
shared_xaxes=True, vertical_spacing=0.06,
subplot_titles=["Absolute Displacement (Γ…)",
"Relative Displacement (0-1)",
"Coverage"])
# Absolute heatmap
fig.add_trace(go.Heatmap(
z=abs_mags, colorscale="YlOrRd",
y=[f"Mode {k}" for k in range(n_modes)],
text=hover, hovertemplate="%{text}<extra></extra>",
colorbar=dict(title="Γ…", x=1.01, len=0.35, y=0.85),
), row=1, col=1)
# Relative heatmap
fig.add_trace(go.Heatmap(
z=rel_norms, colorscale="Viridis", zmin=0, zmax=1,
y=[f"Mode {k}" for k in range(n_modes)],
text=hover, hovertemplate="%{text}<extra></extra>",
colorbar=dict(title="Rel", x=1.08, len=0.35, y=0.5),
), row=2, col=1)
# Coverage bar
fig.add_trace(go.Bar(
x=list(range(n_res)), y=coverage[:n_res],
marker_color=["#10b981" if c > 0.5 else "#ef4444" for c in coverage[:n_res]],
hovertemplate="Res %{x}<br>Coverage: %{y:.3f}<extra></extra>",
showlegend=False,
), row=3, col=1)
# Sequence ticks
step = max(1, n_res // 50)
tick_vals = list(range(0, n_res, step))
tick_text = [f"{seq[i] if i < len(seq) else '?'}{i+1}" for i in tick_vals]
fig.update_xaxes(tickvals=tick_vals, ticktext=tick_text, tickangle=45,
tickfont=dict(size=8), row=3, col=1)
fig.update_layout(
template="plotly_dark", height=550,
paper_bgcolor="rgba(0,0,0,0)",
plot_bgcolor="rgba(30,27,75,0.3)",
margin=dict(l=60, r=80, t=30, b=50),
)
st.plotly_chart(fig, use_container_width=True)
# Key insight
for k in range(min(n_modes, 4)):
top3 = np.argsort(abs_mags[k])[-3:][::-1]
top_str = ", ".join([f"**{seq[i] if i<len(seq) else '?'}{i+1}** ({abs_mags[k][i]:.2f}Γ…)"
for i in top3])
st.markdown(f"Mode {k} hotspots: {top_str}")
# ═══════════════════════════════════════════
# Tab 2: Direction analysis
# ═══════════════════════════════════════════
with tab_dir:
if ca_coords is not None and len(ca_coords) == n_res:
st.markdown("""
**Direction Analysis**: Projects displacement onto the local backbone direction (CA→CA).
- πŸ”΅ **Blue** = motion along backbone (stretching/compressing)
- πŸ”΄ **Red** = motion perpendicular to backbone (lateral/hinge)
- Sign is arbitrary for eigenvectors β†’ we show absolute cosine similarity
""")
fig_dir = go.Figure()
colors = ["#6366f1", "#ef4444", "#10b981", "#f59e0b"]
for k in range(min(n_modes, 4)):
proj = per_residue_direction(modes_c[k], ca_coords)
# Show absolute cosine (sign-invariant)
abs_proj = np.abs(proj)
_fill_map = {
"#6366f1": "rgba(99,102,241,0.12)",
"#ef4444": "rgba(239,68,68,0.12)",
"#10b981": "rgba(16,185,129,0.12)",
"#f59e0b": "rgba(245,158,11,0.12)",
}
fig_dir.add_trace(go.Scatter(
x=list(range(1, n_res + 1)), y=abs_proj,
mode="lines", name=f"Mode {k}",
line=dict(color=colors[k], width=1.5),
fill="tozeroy",
fillcolor=_fill_map.get(colors[k], "rgba(99,102,241,0.12)"),
hovertemplate="Res %{x}<br>|cos ΞΈ|: %{y:.3f}<extra>Mode " + str(k) + "</extra>",
))
fig_dir.add_hline(y=0.5, line_dash="dash", line_color="#94a3b8",
annotation_text="isotropic threshold")
fig_dir.update_layout(
template="plotly_dark", height=350,
paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(30,27,75,0.3)",
xaxis_title="Residue", yaxis_title="|cos ΞΈ| (backbone projection)",
yaxis_range=[0, 1.05],
margin=dict(l=50, r=20, t=30, b=50),
)
st.plotly_chart(fig_dir, use_container_width=True)
# Direction heatmap
st.markdown("**Per-residue Γ— mode direction matrix:**")
dir_matrix = np.zeros((n_modes, n_res))
for k in range(n_modes):
dir_matrix[k] = np.abs(per_residue_direction(modes_c[k], ca_coords))
fig_dh = go.Figure(go.Heatmap(
z=dir_matrix, colorscale="RdBu_r", zmin=0, zmax=1,
y=[f"Mode {k}" for k in range(n_modes)],
colorbar=dict(title="|cos ΞΈ|"),
))
fig_dh.update_layout(
template="plotly_dark", height=200,
paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(30,27,75,0.3)",
margin=dict(l=60, r=60, t=10, b=30),
)
st.plotly_chart(fig_dh, use_container_width=True)
else:
st.info("Direction analysis requires CA coordinates (ground truth or PDB needed)")
# ═══════════════════════════════════════════
# Tab 3: Prediction vs Ground Truth
# ═══════════════════════════════════════════
with tab_compare:
if gt_modes is not None and len(gt_modes) > 0:
gt_c = canonicalize_sign(gt_modes)
n_gt = len(gt_c)
st.markdown("**Pred vs GT displacement profiles (sign-canonicalized):**")
for k in range(min(n_modes, n_gt, 4)):
pred_mag = np.linalg.norm(modes_c[k], axis=1)
gt_mag = np.linalg.norm(gt_c[k], axis=1)
# Normalize both to [0, 1]
pred_rel = pred_mag / (pred_mag.max() + 1e-12)
gt_rel = gt_mag / (gt_mag.max() + 1e-12)
fig_cmp = go.Figure()
fig_cmp.add_trace(go.Scatter(
x=list(range(1, n_res + 1)), y=gt_rel,
mode="lines", name="Ground Truth",
line=dict(color="#10b981", width=2),
))
fig_cmp.add_trace(go.Scatter(
x=list(range(1, n_res + 1)), y=pred_rel,
mode="lines", name="Prediction",
line=dict(color="#6366f1", width=2, dash="dot"),
))
# Correlation
corr = np.corrcoef(pred_rel, gt_rel)[0, 1]
rmse = np.sqrt(np.mean((pred_rel - gt_rel) ** 2))
fig_cmp.update_layout(
template="plotly_dark", height=200,
title=f"Mode {k} β€” r={corr:.3f}, RMSE={rmse:.3f}",
paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(30,27,75,0.3)",
margin=dict(l=40, r=20, t=40, b=30),
legend=dict(orientation="h", y=1.15),
)
st.plotly_chart(fig_cmp, use_container_width=True)
else:
st.info("No ground truth available for comparison. "
"Ground truth is only available for proteins in the training database.")
# ═══════════════════════════════════════════
# Tab 4: Per-residue table
# ═══════════════════════════════════════════
with tab_table:
import pandas as pd
rows = []
for i in range(n_res):
row = {
"Residue": i + 1,
"AA": seq[i] if i < len(seq) else "?",
"Coverage": f"{coverage[i]:.3f}" if i < len(coverage) else "β€”",
}
for k in range(min(n_modes, 4)):
mag = np.linalg.norm(modes_c[k][i])
rel = per_residue_relative_norm(modes_c[k])[i]
row[f"M{k} (Γ…)"] = f"{mag:.3f}"
row[f"M{k} rel"] = f"{rel:.2%}"
rows.append(row)
df = pd.DataFrame(rows)
st.dataframe(df, use_container_width=True, height=500,
column_config={
"Residue": st.column_config.NumberColumn(width="small"),
"AA": st.column_config.TextColumn(width="small"),
})
# Download CSV
csv = df.to_csv(index=False)
st.download_button("πŸ“₯ Download CSV", csv,
f"{protein_name}_analysis.csv", "text/csv")