| """3D protein motion visualization β PyMOL-grade viewer using py3Dmol. |
| |
| Features: |
| - Cartoon / Stick / Sphere / Surface / Line representations |
| - Displacement arrows with cone arrowheads |
| - Secondary structure-aware coloring (helix=magenta, sheet=yellow, coil=white) |
| - Animated mode oscillation (generates multi-frame PDB) |
| - Ground Truth vs Prediction comparison viewer |
| - Deformed structure visualization |
| - Dark background (#0f0b25) |
| """ |
| import streamlit as st |
| import numpy as np |
| import streamlit.components.v1 as components |
|
|
| try: |
| from stmol import showmol |
| import py3Dmol |
| HAS_STMOL = True |
| except ImportError: |
| HAS_STMOL = False |
|
|
|
|
| |
| |
| |
| AA_COLORS = { |
| "A": "#f59e0b", "I": "#f59e0b", "L": "#f59e0b", "M": "#f59e0b", |
| "F": "#f59e0b", "W": "#f59e0b", "V": "#f59e0b", "P": "#eab308", |
| "D": "#ef4444", "E": "#ef4444", "K": "#3b82f6", "R": "#3b82f6", "H": "#818cf8", |
| "S": "#10b981", "T": "#10b981", "N": "#10b981", "Q": "#10b981", |
| "C": "#facc15", "Y": "#10b981", "G": "#94a3b8", "X": "#64748b", |
| } |
|
|
| SS_COLORS = { |
| "H": "#ec4899", |
| "E": "#f59e0b", |
| "C": "#94a3b8", |
| } |
|
|
| BG_COLOR = "#0f0b25" |
|
|
|
|
| def _mag_to_color(t: float) -> str: |
| """Blue β Indigo β Magenta β Red gradient based on normalized magnitude [0,1].""" |
| t = max(0.0, min(1.0, t)) |
| |
| stops = [ |
| (0.00, (59, 130, 246)), |
| (0.33, (99, 102, 241)), |
| (0.66, (168, 85, 247)), |
| (1.00, (239, 68, 68)), |
| ] |
| for i in range(len(stops) - 1): |
| t0, c0 = stops[i] |
| t1, c1 = stops[i + 1] |
| if t <= t1: |
| f = (t - t0) / (t1 - t0) if t1 > t0 else 0 |
| r = int(c0[0] + (c1[0] - c0[0]) * f) |
| g = int(c0[1] + (c1[1] - c0[1]) * f) |
| b = int(c0[2] + (c1[2] - c0[2]) * f) |
| return f"rgb({r},{g},{b})" |
| return f"rgb({stops[-1][1][0]},{stops[-1][1][1]},{stops[-1][1][2]})" |
|
|
|
|
| def _get_color(idx: int, intensity: float, n_res: int, seq: str, scheme: str) -> str: |
| """Get color for a residue based on the color scheme.""" |
| if scheme == "magnitude": |
| return _mag_to_color(intensity) |
| elif scheme == "rainbow": |
| import colorsys |
| h = idx / max(n_res - 1, 1) |
| r, g, b = [int(255 * c) for c in colorsys.hsv_to_rgb(h, 0.85, 0.92)] |
| return f"rgb({r},{g},{b})" |
| elif scheme == "residue_type": |
| aa = seq[idx] if idx < len(seq) else "X" |
| return AA_COLORS.get(aa, "#94a3b8") |
| elif scheme == "bfactor": |
| return _mag_to_color(intensity) |
| elif scheme == "secondary": |
| return "#c7d2fe" |
| else: |
| return "#6366f1" |
|
|
|
|
| def _parse_ss_from_pdb(pdb_text: str) -> dict: |
| """Parse secondary structure from HELIX/SHEET records in PDB.""" |
| ss = {} |
| for line in pdb_text.split("\n"): |
| if line.startswith("HELIX"): |
| try: |
| start = int(line[21:25].strip()) |
| end = int(line[33:37].strip()) |
| for r in range(start, end + 1): |
| ss[r] = "H" |
| except (ValueError, IndexError): |
| pass |
| elif line.startswith("SHEET"): |
| try: |
| start = int(line[22:26].strip()) |
| end = int(line[33:37].strip()) |
| for r in range(start, end + 1): |
| ss[r] = "E" |
| except (ValueError, IndexError): |
| pass |
| return ss |
|
|
|
|
| def _generate_deformed_pdb(pdb_text: str, ca_coords: np.ndarray, |
| mode_vecs: np.ndarray, amplitude: float, |
| phase: float = 1.0) -> str: |
| """Generate a PDB with CA atoms displaced along the mode vector. |
| |
| phase: -1 to 1, controls direction of displacement. |
| """ |
| lines = pdb_text.split("\n") |
| ca_idx = 0 |
| new_lines = [] |
| for line in lines: |
| if line.startswith("ATOM") and " CA " in line: |
| if ca_idx < len(ca_coords): |
| disp = mode_vecs[ca_idx] * amplitude * phase |
| x = ca_coords[ca_idx][0] + disp[0] |
| y = ca_coords[ca_idx][1] + disp[1] |
| z = ca_coords[ca_idx][2] + disp[2] |
| |
| new_line = line[:30] + f"{x:8.3f}{y:8.3f}{z:8.3f}" + line[54:] |
| new_lines.append(new_line) |
| ca_idx += 1 |
| else: |
| new_lines.append(line) |
| else: |
| new_lines.append(line) |
| return "\n".join(new_lines) |
|
|
|
|
| |
| |
| |
| def render_motion_viewer( |
| pdb_text: str, |
| ca_coords: np.ndarray, |
| mode_vecs: np.ndarray, |
| seq: str = "", |
| amplitude: float = 3.0, |
| arrow_scale: float = 1.0, |
| color_scheme: str = "magnitude", |
| show_cartoon: bool = True, |
| show_labels: bool = True, |
| min_displacement: float = 0.01, |
| width: int = 800, |
| height: int = 500, |
| key: str = "viewer", |
| style: str = "cartoon", |
| show_surface: bool = False, |
| surface_opacity: float = 0.15, |
| ): |
| """Render PyMOL-style interactive 3D viewer with displacement arrows.""" |
| if not HAS_STMOL: |
| st.error("Install `stmol` and `py3Dmol`: `pip install stmol py3Dmol`") |
| return |
|
|
| n_res = len(ca_coords) |
| mags = np.linalg.norm(mode_vecs, axis=1) |
| max_mag = mags.max() + 1e-8 |
|
|
| view = py3Dmol.view(width=width, height=height) |
| view.addModel(pdb_text, "pdb") |
|
|
| |
| if style == "cartoon": |
| if color_scheme == "magnitude": |
| view.setStyle({"cartoon": {"color": "white", "opacity": 0.4}}) |
| for i in range(n_res): |
| t = mags[i] / max_mag |
| col = _mag_to_color(t) |
| view.addStyle({"resi": i + 1}, {"cartoon": {"color": col, "opacity": 0.85}}) |
| elif color_scheme == "secondary": |
| ss = _parse_ss_from_pdb(pdb_text) |
| view.setStyle({"cartoon": {"color": "#94a3b8", "opacity": 0.85}}) |
| for i in range(n_res): |
| ss_type = ss.get(i + 1, "C") |
| col = SS_COLORS.get(ss_type, "#94a3b8") |
| view.addStyle({"resi": i + 1}, {"cartoon": {"color": col, "opacity": 0.9}}) |
| elif color_scheme == "rainbow": |
| view.setStyle({"cartoon": {"color": "spectrum", "opacity": 0.85}}) |
| elif color_scheme == "residue_type": |
| view.setStyle({"cartoon": {"color": "white", "opacity": 0.4}}) |
| for i in range(min(n_res, len(seq))): |
| col = AA_COLORS.get(seq[i], "#94a3b8") |
| view.addStyle({"resi": i + 1}, {"cartoon": {"color": col, "opacity": 0.85}}) |
| else: |
| view.setStyle({"cartoon": {"color": "#c7d2fe", "opacity": 0.7}}) |
|
|
| elif style == "stick": |
| view.setStyle({"stick": {"radius": 0.12, "colorscheme": "Jmol"}}) |
| elif style == "sphere": |
| view.setStyle({"sphere": {"scale": 0.3, "colorscheme": "Jmol"}}) |
| elif style == "line": |
| view.setStyle({"line": {"colorscheme": "Jmol"}}) |
|
|
| |
| if show_surface: |
| if color_scheme == "magnitude": |
| |
| view.addSurface(py3Dmol.VDW, { |
| "opacity": surface_opacity, |
| "color": "white", |
| }) |
| for i in range(n_res): |
| t = mags[i] / max_mag |
| col = _mag_to_color(t) |
| view.addSurface(py3Dmol.VDW, { |
| "opacity": surface_opacity, |
| "color": col, |
| }, {"resi": i + 1}) |
| else: |
| view.addSurface(py3Dmol.VDW, { |
| "opacity": surface_opacity, |
| "color": "#6366f1", |
| }) |
|
|
| |
| for i in range(n_res): |
| if mags[i] < min_displacement: |
| continue |
|
|
| s = ca_coords[i] |
| d = mode_vecs[i] * amplitude |
| e = s + d |
| t = mags[i] / max_mag |
|
|
| col = _get_color(i, t, n_res, seq, color_scheme) |
|
|
| |
| base_r = 0.06 * arrow_scale |
| shaft_r = base_r + base_r * 1.5 * t |
| view.addCylinder({ |
| "start": {"x": float(s[0]), "y": float(s[1]), "z": float(s[2])}, |
| "end": {"x": float(e[0]), "y": float(e[1]), "z": float(e[2])}, |
| "radius": shaft_r, |
| "color": col, |
| "fromCap": True, |
| }) |
|
|
| |
| dn = d / (np.linalg.norm(d) + 1e-8) |
| tip = e + dn * 0.3 * amplitude * arrow_scale |
| tip_r = shaft_r * 2.5 |
| view.addCylinder({ |
| "start": {"x": float(e[0]), "y": float(e[1]), "z": float(e[2])}, |
| "end": {"x": float(tip[0]), "y": float(tip[1]), "z": float(tip[2])}, |
| "radius": tip_r, |
| "color": col, |
| "toCap": True, |
| }) |
|
|
| |
| if show_labels and n_res > 0: |
| top_n = min(5, n_res) |
| top_idx = np.argsort(mags)[-top_n:][::-1] |
| for rank, idx in enumerate(top_idx): |
| if mags[idx] < min_displacement: |
| continue |
| pos = ca_coords[idx] + mode_vecs[idx] * amplitude * 0.5 |
| aa = seq[idx] if idx < len(seq) else "?" |
| view.addLabel( |
| f"{aa}{idx + 1} ({mags[idx]:.2f}Γ
)", |
| { |
| "position": {"x": float(pos[0]), "y": float(pos[1] + 1.5), "z": float(pos[2])}, |
| "fontSize": 11, |
| "fontColor": "white", |
| "backgroundColor": "#1e1b4b", |
| "backgroundOpacity": 0.9, |
| "borderColor": "#6366f1", |
| "borderThickness": 1.5, |
| }, |
| ) |
|
|
| view.setBackgroundColor(BG_COLOR) |
| view.zoomTo() |
| showmol(view, height=height, width=width) |
|
|
|
|
| |
| |
| |
| def render_deformed_viewer( |
| pdb_text: str, |
| ca_coords: np.ndarray, |
| mode_vecs: np.ndarray, |
| seq: str = "", |
| amplitude: float = 3.0, |
| width: int = 800, |
| height: int = 450, |
| ): |
| """Show three structures: original, +deformed, -deformed superimposed.""" |
| if not HAS_STMOL: |
| return |
|
|
| n_res = len(ca_coords) |
| mags = np.linalg.norm(mode_vecs, axis=1) |
| max_mag = mags.max() + 1e-8 |
|
|
| view = py3Dmol.view(width=width, height=height) |
|
|
| |
| view.addModel(pdb_text, "pdb") |
| view.setStyle({"model": 0}, {"cartoon": {"color": "#64748b", "opacity": 0.4}}) |
|
|
| |
| pdb_plus = _generate_deformed_pdb(pdb_text, ca_coords, mode_vecs, amplitude, +1.0) |
| view.addModel(pdb_plus, "pdb") |
| view.setStyle({"model": 1}, {"cartoon": {"color": "#6366f1", "opacity": 0.7}}) |
|
|
| |
| pdb_minus = _generate_deformed_pdb(pdb_text, ca_coords, mode_vecs, amplitude, -1.0) |
| view.addModel(pdb_minus, "pdb") |
| view.setStyle({"model": 2}, {"cartoon": {"color": "#ef4444", "opacity": 0.7}}) |
|
|
| view.setBackgroundColor(BG_COLOR) |
| view.zoomTo() |
| showmol(view, height=height, width=width) |
|
|
|
|
| |
| |
| |
| def render_animated_viewer( |
| pdb_text: str, |
| ca_coords: np.ndarray, |
| mode_vecs: np.ndarray, |
| amplitude: float = 3.0, |
| n_frames: int = 20, |
| width: int = 800, |
| height: int = 450, |
| key: str = "anim", |
| ): |
| """Render an animated viewer that oscillates the structure along the mode vector. |
| |
| Uses multi-model PDB + 3Dmol.js animate() to show smooth mode oscillation. |
| """ |
| if not HAS_STMOL: |
| return |
|
|
| n_res = len(ca_coords) |
| mags = np.linalg.norm(mode_vecs, axis=1) |
| max_mag = mags.max() + 1e-8 |
|
|
| |
| |
| phases = np.sin(np.linspace(0, 2 * np.pi, n_frames, endpoint=False)) |
| |
| multi_pdb = "" |
| for frame, phase in enumerate(phases): |
| deformed = _generate_deformed_pdb(pdb_text, ca_coords, mode_vecs, amplitude, phase) |
| multi_pdb += f"MODEL {frame + 1}\n" |
| for line in deformed.split("\n"): |
| if line.startswith(("ATOM", "HETATM", "TER")): |
| multi_pdb += line + "\n" |
| multi_pdb += "ENDMDL\n" |
|
|
| |
| html = f""" |
| <div id="viewport_{key}" style="width:{width}px; height:{height}px; position:relative; border-radius: 12px; overflow: hidden;"> |
| </div> |
| <script src="https://3Dmol.org/build/3Dmol-min.js"></script> |
| <script> |
| (function() {{ |
| let viewer = $3Dmol.createViewer("viewport_{key}", {{ |
| backgroundColor: "{BG_COLOR}", |
| antialias: true, |
| }}); |
| |
| let pdbData = `{multi_pdb}`; |
| viewer.addModelsAsFrames(pdbData, "pdb"); |
| |
| // Color by displacement magnitude |
| viewer.setStyle({{}}, {{cartoon: {{color: "spectrum", opacity: 0.85}}}}); |
| |
| viewer.zoomTo(); |
| viewer.animate({{loop: "forward", reps: 0, interval: 80}}); |
| viewer.render(); |
| }})(); |
| </script> |
| """ |
| components.html(html, height=height + 10, width=width + 10) |
|
|
|
|
| |
| |
| |
| def render_pred_vs_gt_viewer( |
| pdb_text: str, |
| ca_coords: np.ndarray, |
| pred_vecs: np.ndarray, |
| gt_vecs: np.ndarray, |
| seq: str = "", |
| amplitude: float = 3.0, |
| arrow_scale: float = 1.0, |
| width: int = 900, |
| height: int = 400, |
| ): |
| """Side-by-side Prediction vs Ground Truth 3D viewer.""" |
| if not HAS_STMOL: |
| return |
|
|
| n_res = len(ca_coords) |
| pred_mags = np.linalg.norm(pred_vecs, axis=1) |
| gt_mags = np.linalg.norm(gt_vecs, axis=1) |
|
|
| col_pred, col_gt = st.columns(2) |
|
|
| with col_pred: |
| st.caption(f"**PETIMOT Prediction** Β· ΞΌ={pred_mags.mean():.3f}Γ
Β· max={pred_mags.max():.3f}Γ
") |
| max_m = pred_mags.max() + 1e-8 |
| vw = py3Dmol.view(width=width // 2, height=height) |
| vw.addModel(pdb_text, "pdb") |
| vw.setStyle({"cartoon": {"color": "white", "opacity": 0.3}}) |
| for i in range(n_res): |
| t = pred_mags[i] / max_m |
| vw.addStyle({"resi": i + 1}, {"cartoon": {"color": _mag_to_color(t), "opacity": 0.8}}) |
| for i in range(n_res): |
| if pred_mags[i] < 0.01: |
| continue |
| s = ca_coords[i] |
| d = pred_vecs[i] * amplitude |
| e = s + d |
| t = pred_mags[i] / max_m |
| vw.addCylinder({ |
| "start": {"x": float(s[0]), "y": float(s[1]), "z": float(s[2])}, |
| "end": {"x": float(e[0]), "y": float(e[1]), "z": float(e[2])}, |
| "radius": 0.06 * arrow_scale + 0.06 * t * arrow_scale, |
| "color": "#6366f1", |
| }) |
| vw.setBackgroundColor(BG_COLOR) |
| vw.zoomTo() |
| showmol(vw, height=height, width=width // 2) |
|
|
| with col_gt: |
| st.caption(f"**Ground Truth (NMA)** Β· ΞΌ={gt_mags.mean():.3f}Γ
Β· max={gt_mags.max():.3f}Γ
") |
| max_m = gt_mags.max() + 1e-8 |
| vw2 = py3Dmol.view(width=width // 2, height=height) |
| vw2.addModel(pdb_text, "pdb") |
| vw2.setStyle({"cartoon": {"color": "white", "opacity": 0.3}}) |
| for i in range(n_res): |
| t = gt_mags[i] / max_m |
| vw2.addStyle({"resi": i + 1}, {"cartoon": {"color": _mag_to_color(t), "opacity": 0.8}}) |
| for i in range(n_res): |
| if gt_mags[i] < 0.01: |
| continue |
| s = ca_coords[i] |
| d = gt_vecs[i] * amplitude |
| e = s + d |
| t = gt_mags[i] / max_m |
| vw2.addCylinder({ |
| "start": {"x": float(s[0]), "y": float(s[1]), "z": float(s[2])}, |
| "end": {"x": float(e[0]), "y": float(e[1]), "z": float(e[2])}, |
| "radius": 0.06 * arrow_scale + 0.06 * t * arrow_scale, |
| "color": "#10b981", |
| }) |
| vw2.setBackgroundColor(BG_COLOR) |
| vw2.zoomTo() |
| showmol(vw2, height=height, width=width // 2) |
|
|
|
|
| |
| |
| |
| def render_mode_comparison( |
| pdb_text: str, |
| ca_coords: np.ndarray, |
| modes: dict, |
| seq: str = "", |
| amplitude: float = 3.0, |
| arrow_scale: float = 1.0, |
| width: int = 900, |
| height: int = 350, |
| ): |
| """Render side-by-side mode comparison grid with PyMOL-style.""" |
| if not HAS_STMOL: |
| return |
|
|
| n_modes = min(4, len(modes)) |
| if n_modes == 0: |
| st.warning("No modes to display") |
| return |
|
|
| mode_colors = ["#6366f1", "#ef4444", "#10b981", "#f59e0b"] |
|
|
| cols = st.columns(n_modes) |
| for k in range(n_modes): |
| with cols[k]: |
| vecs = modes[k] |
| mags = np.linalg.norm(vecs, axis=1) |
| max_m = mags.max() + 1e-8 |
| st.caption(f"**Mode {k}** Β· ΞΌ={mags.mean():.3f}Γ
Β· max={mags.max():.3f}Γ
") |
|
|
| per_col_w = width // n_modes |
| view = py3Dmol.view(width=per_col_w, height=height) |
| view.addModel(pdb_text, "pdb") |
|
|
| view.setStyle({"cartoon": {"color": "white", "opacity": 0.3}}) |
| for i in range(len(ca_coords)): |
| t = mags[i] / max_m |
| view.addStyle({"resi": i + 1}, {"cartoon": {"color": _mag_to_color(t), "opacity": 0.7}}) |
|
|
| for i in range(len(ca_coords)): |
| if mags[i] < 0.01: |
| continue |
| s = ca_coords[i] |
| d = vecs[i] * amplitude |
| e = s + d |
| t = mags[i] / max_m |
| view.addCylinder({ |
| "start": {"x": float(s[0]), "y": float(s[1]), "z": float(s[2])}, |
| "end": {"x": float(e[0]), "y": float(e[1]), "z": float(e[2])}, |
| "radius": 0.06 * arrow_scale + 0.06 * t * arrow_scale, |
| "color": mode_colors[k], |
| }) |
|
|
| view.setBackgroundColor(BG_COLOR) |
| view.zoomTo() |
| showmol(view, height=height, width=per_col_w) |
|
|