"""3D protein motion visualization — PyMOL-grade viewer using py3Dmol. Features: - Cartoon / Stick / Sphere / Surface / Line representations - Displacement arrows with cone arrowheads - Secondary structure-aware coloring (helix=magenta, sheet=yellow, coil=white) - Animated mode oscillation (generates multi-frame PDB) - Ground Truth vs Prediction comparison viewer - Deformed structure visualization - Dark background (#0f0b25) """ import streamlit as st import numpy as np import streamlit.components.v1 as components try: from stmol import showmol import py3Dmol HAS_STMOL = True except ImportError: HAS_STMOL = False # ═══════════════════════════════════════════════════ # Color palettes # ═══════════════════════════════════════════════════ AA_COLORS = { "A": "#f59e0b", "I": "#f59e0b", "L": "#f59e0b", "M": "#f59e0b", "F": "#f59e0b", "W": "#f59e0b", "V": "#f59e0b", "P": "#eab308", "D": "#ef4444", "E": "#ef4444", "K": "#3b82f6", "R": "#3b82f6", "H": "#818cf8", "S": "#10b981", "T": "#10b981", "N": "#10b981", "Q": "#10b981", "C": "#facc15", "Y": "#10b981", "G": "#94a3b8", "X": "#64748b", } SS_COLORS = { "H": "#ec4899", # Helix → magenta "E": "#f59e0b", # Sheet → amber "C": "#94a3b8", # Coil → grey } BG_COLOR = "#0f0b25" def _mag_to_color(t: float) -> str: """Blue → Indigo → Magenta → Red gradient based on normalized magnitude [0,1].""" t = max(0.0, min(1.0, t)) # Smooth 4-stop gradient: blue → indigo → magenta → red stops = [ (0.00, (59, 130, 246)), # blue-500 (0.33, (99, 102, 241)), # indigo-500 (0.66, (168, 85, 247)), # violet-500 (1.00, (239, 68, 68)), # red-500 ] for i in range(len(stops) - 1): t0, c0 = stops[i] t1, c1 = stops[i + 1] if t <= t1: f = (t - t0) / (t1 - t0) if t1 > t0 else 0 r = int(c0[0] + (c1[0] - c0[0]) * f) g = int(c0[1] + (c1[1] - c0[1]) * f) b = int(c0[2] + (c1[2] - c0[2]) * f) return f"rgb({r},{g},{b})" return f"rgb({stops[-1][1][0]},{stops[-1][1][1]},{stops[-1][1][2]})" def _get_color(idx: int, intensity: float, n_res: int, seq: str, scheme: str) -> str: """Get color for a residue based on the color scheme.""" if scheme == "magnitude": return _mag_to_color(intensity) elif scheme == "rainbow": import colorsys h = idx / max(n_res - 1, 1) r, g, b = [int(255 * c) for c in colorsys.hsv_to_rgb(h, 0.85, 0.92)] return f"rgb({r},{g},{b})" elif scheme == "residue_type": aa = seq[idx] if idx < len(seq) else "X" return AA_COLORS.get(aa, "#94a3b8") elif scheme == "bfactor": return _mag_to_color(intensity) elif scheme == "secondary": return "#c7d2fe" # Will be handled at cartoon level else: return "#6366f1" def _parse_ss_from_pdb(pdb_text: str) -> dict: """Parse secondary structure from HELIX/SHEET records in PDB.""" ss = {} # residue_num -> 'H' or 'E' for line in pdb_text.split("\n"): if line.startswith("HELIX"): try: start = int(line[21:25].strip()) end = int(line[33:37].strip()) for r in range(start, end + 1): ss[r] = "H" except (ValueError, IndexError): pass elif line.startswith("SHEET"): try: start = int(line[22:26].strip()) end = int(line[33:37].strip()) for r in range(start, end + 1): ss[r] = "E" except (ValueError, IndexError): pass return ss def _generate_deformed_pdb(pdb_text: str, ca_coords: np.ndarray, mode_vecs: np.ndarray, amplitude: float, phase: float = 1.0) -> str: """Generate a PDB with CA atoms displaced along the mode vector. phase: -1 to 1, controls direction of displacement. """ lines = pdb_text.split("\n") ca_idx = 0 new_lines = [] for line in lines: if line.startswith("ATOM") and " CA " in line: if ca_idx < len(ca_coords): disp = mode_vecs[ca_idx] * amplitude * phase x = ca_coords[ca_idx][0] + disp[0] y = ca_coords[ca_idx][1] + disp[1] z = ca_coords[ca_idx][2] + disp[2] # Replace XYZ coordinates (cols 30-54 in PDB format) new_line = line[:30] + f"{x:8.3f}{y:8.3f}{z:8.3f}" + line[54:] new_lines.append(new_line) ca_idx += 1 else: new_lines.append(line) else: new_lines.append(line) return "\n".join(new_lines) # ═══════════════════════════════════════════════════ # Main viewer # ═══════════════════════════════════════════════════ def render_motion_viewer( pdb_text: str, ca_coords: np.ndarray, mode_vecs: np.ndarray, seq: str = "", amplitude: float = 3.0, arrow_scale: float = 1.0, color_scheme: str = "magnitude", show_cartoon: bool = True, show_labels: bool = True, min_displacement: float = 0.01, width: int = 800, height: int = 500, key: str = "viewer", style: str = "cartoon", show_surface: bool = False, surface_opacity: float = 0.15, ): """Render PyMOL-style interactive 3D viewer with displacement arrows.""" if not HAS_STMOL: st.error("Install `stmol` and `py3Dmol`: `pip install stmol py3Dmol`") return n_res = len(ca_coords) mags = np.linalg.norm(mode_vecs, axis=1) max_mag = mags.max() + 1e-8 view = py3Dmol.view(width=width, height=height) view.addModel(pdb_text, "pdb") # ── Backbone representation ── if style == "cartoon": if color_scheme == "magnitude": view.setStyle({"cartoon": {"color": "white", "opacity": 0.4}}) for i in range(n_res): t = mags[i] / max_mag col = _mag_to_color(t) view.addStyle({"resi": i + 1}, {"cartoon": {"color": col, "opacity": 0.85}}) elif color_scheme == "secondary": ss = _parse_ss_from_pdb(pdb_text) view.setStyle({"cartoon": {"color": "#94a3b8", "opacity": 0.85}}) for i in range(n_res): ss_type = ss.get(i + 1, "C") col = SS_COLORS.get(ss_type, "#94a3b8") view.addStyle({"resi": i + 1}, {"cartoon": {"color": col, "opacity": 0.9}}) elif color_scheme == "rainbow": view.setStyle({"cartoon": {"color": "spectrum", "opacity": 0.85}}) elif color_scheme == "residue_type": view.setStyle({"cartoon": {"color": "white", "opacity": 0.4}}) for i in range(min(n_res, len(seq))): col = AA_COLORS.get(seq[i], "#94a3b8") view.addStyle({"resi": i + 1}, {"cartoon": {"color": col, "opacity": 0.85}}) else: view.setStyle({"cartoon": {"color": "#c7d2fe", "opacity": 0.7}}) elif style == "stick": view.setStyle({"stick": {"radius": 0.12, "colorscheme": "Jmol"}}) elif style == "sphere": view.setStyle({"sphere": {"scale": 0.3, "colorscheme": "Jmol"}}) elif style == "line": view.setStyle({"line": {"colorscheme": "Jmol"}}) # ── Surface overlay ── if show_surface: if color_scheme == "magnitude": # Color surface by displacement view.addSurface(py3Dmol.VDW, { "opacity": surface_opacity, "color": "white", }) for i in range(n_res): t = mags[i] / max_mag col = _mag_to_color(t) view.addSurface(py3Dmol.VDW, { "opacity": surface_opacity, "color": col, }, {"resi": i + 1}) else: view.addSurface(py3Dmol.VDW, { "opacity": surface_opacity, "color": "#6366f1", }) # ── Displacement arrows ── for i in range(n_res): if mags[i] < min_displacement: continue s = ca_coords[i] d = mode_vecs[i] * amplitude e = s + d t = mags[i] / max_mag col = _get_color(i, t, n_res, seq, color_scheme) # Shaft base_r = 0.06 * arrow_scale shaft_r = base_r + base_r * 1.5 * t view.addCylinder({ "start": {"x": float(s[0]), "y": float(s[1]), "z": float(s[2])}, "end": {"x": float(e[0]), "y": float(e[1]), "z": float(e[2])}, "radius": shaft_r, "color": col, "fromCap": True, }) # Arrowhead dn = d / (np.linalg.norm(d) + 1e-8) tip = e + dn * 0.3 * amplitude * arrow_scale tip_r = shaft_r * 2.5 view.addCylinder({ "start": {"x": float(e[0]), "y": float(e[1]), "z": float(e[2])}, "end": {"x": float(tip[0]), "y": float(tip[1]), "z": float(tip[2])}, "radius": tip_r, "color": col, "toCap": True, }) # ── Labels ── if show_labels and n_res > 0: top_n = min(5, n_res) top_idx = np.argsort(mags)[-top_n:][::-1] for rank, idx in enumerate(top_idx): if mags[idx] < min_displacement: continue pos = ca_coords[idx] + mode_vecs[idx] * amplitude * 0.5 aa = seq[idx] if idx < len(seq) else "?" view.addLabel( f"{aa}{idx + 1} ({mags[idx]:.2f}Å)", { "position": {"x": float(pos[0]), "y": float(pos[1] + 1.5), "z": float(pos[2])}, "fontSize": 11, "fontColor": "white", "backgroundColor": "#1e1b4b", "backgroundOpacity": 0.9, "borderColor": "#6366f1", "borderThickness": 1.5, }, ) view.setBackgroundColor(BG_COLOR) view.zoomTo() showmol(view, height=height, width=width) # ═══════════════════════════════════════════════════ # Deformed structure viewer # ═══════════════════════════════════════════════════ def render_deformed_viewer( pdb_text: str, ca_coords: np.ndarray, mode_vecs: np.ndarray, seq: str = "", amplitude: float = 3.0, width: int = 800, height: int = 450, ): """Show three structures: original, +deformed, -deformed superimposed.""" if not HAS_STMOL: return n_res = len(ca_coords) mags = np.linalg.norm(mode_vecs, axis=1) max_mag = mags.max() + 1e-8 view = py3Dmol.view(width=width, height=height) # Original structure (white) view.addModel(pdb_text, "pdb") view.setStyle({"model": 0}, {"cartoon": {"color": "#64748b", "opacity": 0.4}}) # +Deformed (blue) pdb_plus = _generate_deformed_pdb(pdb_text, ca_coords, mode_vecs, amplitude, +1.0) view.addModel(pdb_plus, "pdb") view.setStyle({"model": 1}, {"cartoon": {"color": "#6366f1", "opacity": 0.7}}) # -Deformed (red) pdb_minus = _generate_deformed_pdb(pdb_text, ca_coords, mode_vecs, amplitude, -1.0) view.addModel(pdb_minus, "pdb") view.setStyle({"model": 2}, {"cartoon": {"color": "#ef4444", "opacity": 0.7}}) view.setBackgroundColor(BG_COLOR) view.zoomTo() showmol(view, height=height, width=width) # ═══════════════════════════════════════════════════ # Mode animation (oscillating structure) # ═══════════════════════════════════════════════════ def render_animated_viewer( pdb_text: str, ca_coords: np.ndarray, mode_vecs: np.ndarray, amplitude: float = 3.0, n_frames: int = 20, width: int = 800, height: int = 450, key: str = "anim", ): """Render an animated viewer that oscillates the structure along the mode vector. Uses multi-model PDB + 3Dmol.js animate() to show smooth mode oscillation. """ if not HAS_STMOL: return n_res = len(ca_coords) mags = np.linalg.norm(mode_vecs, axis=1) max_mag = mags.max() + 1e-8 # Build multi-model PDB for animation frames # Oscillate: phase goes from 0 → +1 → 0 → -1 → 0 phases = np.sin(np.linspace(0, 2 * np.pi, n_frames, endpoint=False)) multi_pdb = "" for frame, phase in enumerate(phases): deformed = _generate_deformed_pdb(pdb_text, ca_coords, mode_vecs, amplitude, phase) multi_pdb += f"MODEL {frame + 1}\n" for line in deformed.split("\n"): if line.startswith(("ATOM", "HETATM", "TER")): multi_pdb += line + "\n" multi_pdb += "ENDMDL\n" # Use raw HTML/JS for smooth animation html = f"""
""" components.html(html, height=height + 10, width=width + 10) # ═══════════════════════════════════════════════════ # GT vs Prediction comparison viewer # ═══════════════════════════════════════════════════ def render_pred_vs_gt_viewer( pdb_text: str, ca_coords: np.ndarray, pred_vecs: np.ndarray, gt_vecs: np.ndarray, seq: str = "", amplitude: float = 3.0, arrow_scale: float = 1.0, width: int = 900, height: int = 400, ): """Side-by-side Prediction vs Ground Truth 3D viewer.""" if not HAS_STMOL: return n_res = len(ca_coords) pred_mags = np.linalg.norm(pred_vecs, axis=1) gt_mags = np.linalg.norm(gt_vecs, axis=1) col_pred, col_gt = st.columns(2) with col_pred: st.caption(f"**PETIMOT Prediction** · μ={pred_mags.mean():.3f}Å · max={pred_mags.max():.3f}Å") max_m = pred_mags.max() + 1e-8 vw = py3Dmol.view(width=width // 2, height=height) vw.addModel(pdb_text, "pdb") vw.setStyle({"cartoon": {"color": "white", "opacity": 0.3}}) for i in range(n_res): t = pred_mags[i] / max_m vw.addStyle({"resi": i + 1}, {"cartoon": {"color": _mag_to_color(t), "opacity": 0.8}}) for i in range(n_res): if pred_mags[i] < 0.01: continue s = ca_coords[i] d = pred_vecs[i] * amplitude e = s + d t = pred_mags[i] / max_m vw.addCylinder({ "start": {"x": float(s[0]), "y": float(s[1]), "z": float(s[2])}, "end": {"x": float(e[0]), "y": float(e[1]), "z": float(e[2])}, "radius": 0.06 * arrow_scale + 0.06 * t * arrow_scale, "color": "#6366f1", }) vw.setBackgroundColor(BG_COLOR) vw.zoomTo() showmol(vw, height=height, width=width // 2) with col_gt: st.caption(f"**Ground Truth (NMA)** · μ={gt_mags.mean():.3f}Å · max={gt_mags.max():.3f}Å") max_m = gt_mags.max() + 1e-8 vw2 = py3Dmol.view(width=width // 2, height=height) vw2.addModel(pdb_text, "pdb") vw2.setStyle({"cartoon": {"color": "white", "opacity": 0.3}}) for i in range(n_res): t = gt_mags[i] / max_m vw2.addStyle({"resi": i + 1}, {"cartoon": {"color": _mag_to_color(t), "opacity": 0.8}}) for i in range(n_res): if gt_mags[i] < 0.01: continue s = ca_coords[i] d = gt_vecs[i] * amplitude e = s + d t = gt_mags[i] / max_m vw2.addCylinder({ "start": {"x": float(s[0]), "y": float(s[1]), "z": float(s[2])}, "end": {"x": float(e[0]), "y": float(e[1]), "z": float(e[2])}, "radius": 0.06 * arrow_scale + 0.06 * t * arrow_scale, "color": "#10b981", }) vw2.setBackgroundColor(BG_COLOR) vw2.zoomTo() showmol(vw2, height=height, width=width // 2) # ═══════════════════════════════════════════════════ # Mode comparison grid # ═══════════════════════════════════════════════════ def render_mode_comparison( pdb_text: str, ca_coords: np.ndarray, modes: dict, seq: str = "", amplitude: float = 3.0, arrow_scale: float = 1.0, width: int = 900, height: int = 350, ): """Render side-by-side mode comparison grid with PyMOL-style.""" if not HAS_STMOL: return n_modes = min(4, len(modes)) if n_modes == 0: st.warning("No modes to display") return mode_colors = ["#6366f1", "#ef4444", "#10b981", "#f59e0b"] cols = st.columns(n_modes) for k in range(n_modes): with cols[k]: vecs = modes[k] mags = np.linalg.norm(vecs, axis=1) max_m = mags.max() + 1e-8 st.caption(f"**Mode {k}** · μ={mags.mean():.3f}Å · max={mags.max():.3f}Å") per_col_w = width // n_modes view = py3Dmol.view(width=per_col_w, height=height) view.addModel(pdb_text, "pdb") view.setStyle({"cartoon": {"color": "white", "opacity": 0.3}}) for i in range(len(ca_coords)): t = mags[i] / max_m view.addStyle({"resi": i + 1}, {"cartoon": {"color": _mag_to_color(t), "opacity": 0.7}}) for i in range(len(ca_coords)): if mags[i] < 0.01: continue s = ca_coords[i] d = vecs[i] * amplitude e = s + d t = mags[i] / max_m view.addCylinder({ "start": {"x": float(s[0]), "y": float(s[1]), "z": float(s[2])}, "end": {"x": float(e[0]), "y": float(e[1]), "z": float(e[2])}, "radius": 0.06 * arrow_scale + 0.06 * t * arrow_scale, "color": mode_colors[k], }) view.setBackgroundColor(BG_COLOR) view.zoomTo() showmol(view, height=height, width=per_col_w)