File size: 13,367 Bytes
b47954d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee42d0e
 
 
 
 
 
b47954d
 
 
 
 
ee42d0e
b47954d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
"""Enhanced prediction analysis β€” sign-invariant modes and per-residue normalization."""
import numpy as np
import streamlit as st
import plotly.graph_objects as go
from plotly.subplots import make_subplots


def canonicalize_sign(modes: dict) -> dict:
    """Make eigenvectors sign-consistent.

    Eigenvectors are defined up to Β±1 global sign. We canonicalize by choosing
    the sign such that the component with the largest absolute value is positive.
    This ensures consistent visualization across different runs/proteins.
    """
    canonical = {}
    for k, vecs in modes.items():
        # Flatten to (3N,), find component with max absolute value
        flat = vecs.flatten()
        max_idx = np.argmax(np.abs(flat))
        if flat[max_idx] < 0:
            canonical[k] = -vecs  # Flip sign
        else:
            canonical[k] = vecs.copy()
    return canonical


def per_residue_relative_norm(vecs: np.ndarray) -> np.ndarray:
    """Normalize displacement magnitudes to [0, 1] relative to max.

    Args:
        vecs: (N, 3) displacement vectors

    Returns:
        (N,) relative magnitudes in [0, 1]
    """
    mags = np.linalg.norm(vecs, axis=1)
    max_m = mags.max()
    return mags / max_m if max_m > 1e-12 else mags


def per_residue_direction(vecs: np.ndarray, ca_coords: np.ndarray) -> np.ndarray:
    """Compute relative direction of displacement vs protein backbone.

    Projects displacement onto local backbone direction (CA_i β†’ CA_{i+1}).
    Returns signed projection: positive = along backbone, negative = against.

    Args:
        vecs: (N, 3) displacement vectors
        ca_coords: (N, 3) CA coordinates

    Returns:
        (N,) signed projections normalized by displacement magnitude
    """
    n = len(vecs)
    projections = np.zeros(n)

    for i in range(n):
        # Local backbone direction
        if i < n - 1:
            backbone = ca_coords[i + 1] - ca_coords[i]
        else:
            backbone = ca_coords[i] - ca_coords[i - 1]

        bb_norm = np.linalg.norm(backbone)
        if bb_norm < 1e-8:
            continue

        disp_mag = np.linalg.norm(vecs[i])
        if disp_mag < 1e-8:
            continue

        # Cosine angle between displacement and backbone direction
        projections[i] = np.dot(vecs[i], backbone) / (disp_mag * bb_norm)

    return projections


def render_prediction_analysis(
    modes: dict,
    seq: str,
    ca_coords: np.ndarray = None,
    coverage: np.ndarray = None,
    eigenvalues: np.ndarray = None,
    gt_modes: dict = None,
    protein_name: str = "",
):
    """Comprehensive prediction analysis panel.

    Shows:
    1. Normalized displacement heatmap (all modes Γ— residues)
    2. Sign-canonical direction analysis
    3. Prediction vs ground truth comparison (if available)
    4. Per-residue statistics table
    """
    # Canonicalize signs
    modes_c = canonicalize_sign(modes)
    n_modes = len(modes_c)
    n_res = len(list(modes_c.values())[0])

    if coverage is None:
        coverage = np.ones(n_res)

    # ── Tab layout ──
    tab_norm, tab_dir, tab_compare, tab_table = st.tabs([
        "πŸ“Š Normalized Displacement", "🧭 Direction Analysis",
        "βš–οΈ Pred vs GT", "πŸ“‹ Per-Residue Table"
    ])

    # ═══════════════════════════════════════════
    # Tab 1: Normalized displacement heatmap
    # ═══════════════════════════════════════════
    with tab_norm:
        # Compute relative norms for all modes
        rel_norms = np.zeros((n_modes, n_res))
        abs_mags = np.zeros((n_modes, n_res))
        for k in range(n_modes):
            abs_mags[k] = np.linalg.norm(modes_c[k], axis=1)
            rel_norms[k] = per_residue_relative_norm(modes_c[k])

        # Hover text with sequence
        hover = [[f"{seq[j] if j < len(seq) else '?'}{j+1}<br>"
                   f"Abs: {abs_mags[k][j]:.3f}Γ…<br>"
                   f"Rel: {rel_norms[k][j]:.2%}<br>"
                   f"Cov: {coverage[j]:.2f}"
                   for j in range(n_res)] for k in range(n_modes)]

        fig = make_subplots(rows=3, cols=1, row_heights=[0.4, 0.4, 0.2],
                            shared_xaxes=True, vertical_spacing=0.06,
                            subplot_titles=["Absolute Displacement (Γ…)",
                                            "Relative Displacement (0-1)",
                                            "Coverage"])

        # Absolute heatmap
        fig.add_trace(go.Heatmap(
            z=abs_mags, colorscale="YlOrRd",
            y=[f"Mode {k}" for k in range(n_modes)],
            text=hover, hovertemplate="%{text}<extra></extra>",
            colorbar=dict(title="Γ…", x=1.01, len=0.35, y=0.85),
        ), row=1, col=1)

        # Relative heatmap
        fig.add_trace(go.Heatmap(
            z=rel_norms, colorscale="Viridis", zmin=0, zmax=1,
            y=[f"Mode {k}" for k in range(n_modes)],
            text=hover, hovertemplate="%{text}<extra></extra>",
            colorbar=dict(title="Rel", x=1.08, len=0.35, y=0.5),
        ), row=2, col=1)

        # Coverage bar
        fig.add_trace(go.Bar(
            x=list(range(n_res)), y=coverage[:n_res],
            marker_color=["#10b981" if c > 0.5 else "#ef4444" for c in coverage[:n_res]],
            hovertemplate="Res %{x}<br>Coverage: %{y:.3f}<extra></extra>",
            showlegend=False,
        ), row=3, col=1)

        # Sequence ticks
        step = max(1, n_res // 50)
        tick_vals = list(range(0, n_res, step))
        tick_text = [f"{seq[i] if i < len(seq) else '?'}{i+1}" for i in tick_vals]
        fig.update_xaxes(tickvals=tick_vals, ticktext=tick_text, tickangle=45,
                          tickfont=dict(size=8), row=3, col=1)

        fig.update_layout(
            template="plotly_dark", height=550,
            paper_bgcolor="rgba(0,0,0,0)",
            plot_bgcolor="rgba(30,27,75,0.3)",
            margin=dict(l=60, r=80, t=30, b=50),
        )
        st.plotly_chart(fig, use_container_width=True)

        # Key insight
        for k in range(min(n_modes, 4)):
            top3 = np.argsort(abs_mags[k])[-3:][::-1]
            top_str = ", ".join([f"**{seq[i] if i<len(seq) else '?'}{i+1}** ({abs_mags[k][i]:.2f}Γ…)"
                                  for i in top3])
            st.markdown(f"Mode {k} hotspots: {top_str}")

    # ═══════════════════════════════════════════
    # Tab 2: Direction analysis
    # ═══════════════════════════════════════════
    with tab_dir:
        if ca_coords is not None and len(ca_coords) == n_res:
            st.markdown("""
            **Direction Analysis**: Projects displacement onto the local backbone direction (CA→CA).
            - πŸ”΅ **Blue** = motion along backbone (stretching/compressing)
            - πŸ”΄ **Red** = motion perpendicular to backbone (lateral/hinge)
            - Sign is arbitrary for eigenvectors β†’ we show absolute cosine similarity
            """)

            fig_dir = go.Figure()
            colors = ["#6366f1", "#ef4444", "#10b981", "#f59e0b"]

            for k in range(min(n_modes, 4)):
                proj = per_residue_direction(modes_c[k], ca_coords)
                # Show absolute cosine (sign-invariant)
                abs_proj = np.abs(proj)

                _fill_map = {
                    "#6366f1": "rgba(99,102,241,0.12)",
                    "#ef4444": "rgba(239,68,68,0.12)",
                    "#10b981": "rgba(16,185,129,0.12)",
                    "#f59e0b": "rgba(245,158,11,0.12)",
                }
                fig_dir.add_trace(go.Scatter(
                    x=list(range(1, n_res + 1)), y=abs_proj,
                    mode="lines", name=f"Mode {k}",
                    line=dict(color=colors[k], width=1.5),
                    fill="tozeroy",
                    fillcolor=_fill_map.get(colors[k], "rgba(99,102,241,0.12)"),
                    hovertemplate="Res %{x}<br>|cos ΞΈ|: %{y:.3f}<extra>Mode " + str(k) + "</extra>",
                ))

            fig_dir.add_hline(y=0.5, line_dash="dash", line_color="#94a3b8",
                               annotation_text="isotropic threshold")

            fig_dir.update_layout(
                template="plotly_dark", height=350,
                paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(30,27,75,0.3)",
                xaxis_title="Residue", yaxis_title="|cos ΞΈ| (backbone projection)",
                yaxis_range=[0, 1.05],
                margin=dict(l=50, r=20, t=30, b=50),
            )
            st.plotly_chart(fig_dir, use_container_width=True)

            # Direction heatmap
            st.markdown("**Per-residue Γ— mode direction matrix:**")
            dir_matrix = np.zeros((n_modes, n_res))
            for k in range(n_modes):
                dir_matrix[k] = np.abs(per_residue_direction(modes_c[k], ca_coords))

            fig_dh = go.Figure(go.Heatmap(
                z=dir_matrix, colorscale="RdBu_r", zmin=0, zmax=1,
                y=[f"Mode {k}" for k in range(n_modes)],
                colorbar=dict(title="|cos ΞΈ|"),
            ))
            fig_dh.update_layout(
                template="plotly_dark", height=200,
                paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(30,27,75,0.3)",
                margin=dict(l=60, r=60, t=10, b=30),
            )
            st.plotly_chart(fig_dh, use_container_width=True)
        else:
            st.info("Direction analysis requires CA coordinates (ground truth or PDB needed)")

    # ═══════════════════════════════════════════
    # Tab 3: Prediction vs Ground Truth
    # ═══════════════════════════════════════════
    with tab_compare:
        if gt_modes is not None and len(gt_modes) > 0:
            gt_c = canonicalize_sign(gt_modes)
            n_gt = len(gt_c)

            st.markdown("**Pred vs GT displacement profiles (sign-canonicalized):**")

            for k in range(min(n_modes, n_gt, 4)):
                pred_mag = np.linalg.norm(modes_c[k], axis=1)
                gt_mag = np.linalg.norm(gt_c[k], axis=1)

                # Normalize both to [0, 1]
                pred_rel = pred_mag / (pred_mag.max() + 1e-12)
                gt_rel = gt_mag / (gt_mag.max() + 1e-12)

                fig_cmp = go.Figure()
                fig_cmp.add_trace(go.Scatter(
                    x=list(range(1, n_res + 1)), y=gt_rel,
                    mode="lines", name="Ground Truth",
                    line=dict(color="#10b981", width=2),
                ))
                fig_cmp.add_trace(go.Scatter(
                    x=list(range(1, n_res + 1)), y=pred_rel,
                    mode="lines", name="Prediction",
                    line=dict(color="#6366f1", width=2, dash="dot"),
                ))

                # Correlation
                corr = np.corrcoef(pred_rel, gt_rel)[0, 1]
                rmse = np.sqrt(np.mean((pred_rel - gt_rel) ** 2))

                fig_cmp.update_layout(
                    template="plotly_dark", height=200,
                    title=f"Mode {k} β€” r={corr:.3f}, RMSE={rmse:.3f}",
                    paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(30,27,75,0.3)",
                    margin=dict(l=40, r=20, t=40, b=30),
                    legend=dict(orientation="h", y=1.15),
                )
                st.plotly_chart(fig_cmp, use_container_width=True)
        else:
            st.info("No ground truth available for comparison. "
                    "Ground truth is only available for proteins in the training database.")

    # ═══════════════════════════════════════════
    # Tab 4: Per-residue table
    # ═══════════════════════════════════════════
    with tab_table:
        import pandas as pd

        rows = []
        for i in range(n_res):
            row = {
                "Residue": i + 1,
                "AA": seq[i] if i < len(seq) else "?",
                "Coverage": f"{coverage[i]:.3f}" if i < len(coverage) else "β€”",
            }
            for k in range(min(n_modes, 4)):
                mag = np.linalg.norm(modes_c[k][i])
                rel = per_residue_relative_norm(modes_c[k])[i]
                row[f"M{k} (Γ…)"] = f"{mag:.3f}"
                row[f"M{k} rel"] = f"{rel:.2%}"
            rows.append(row)

        df = pd.DataFrame(rows)
        st.dataframe(df, use_container_width=True, height=500,
                      column_config={
                          "Residue": st.column_config.NumberColumn(width="small"),
                          "AA": st.column_config.TextColumn(width="small"),
                      })

        # Download CSV
        csv = df.to_csv(index=False)
        st.download_button("πŸ“₯ Download CSV", csv,
                            f"{protein_name}_analysis.csv", "text/csv")