Petimot / app /components /viewer_3d.py
Valmbd's picture
Update: Benchmark page, mode comparison, RMSIP analysis, 36k predictions
5e3ee28 verified
"""3D protein motion visualization β€” PyMOL-grade viewer using py3Dmol.
Features:
- Cartoon / Stick / Sphere / Surface / Line representations
- Displacement arrows with cone arrowheads
- Secondary structure-aware coloring (helix=magenta, sheet=yellow, coil=white)
- Animated mode oscillation (generates multi-frame PDB)
- Ground Truth vs Prediction comparison viewer
- Deformed structure visualization
- Dark background (#0f0b25)
"""
import streamlit as st
import numpy as np
import streamlit.components.v1 as components
try:
from stmol import showmol
import py3Dmol
HAS_STMOL = True
except ImportError:
HAS_STMOL = False
# ═══════════════════════════════════════════════════
# Color palettes
# ═══════════════════════════════════════════════════
AA_COLORS = {
"A": "#f59e0b", "I": "#f59e0b", "L": "#f59e0b", "M": "#f59e0b",
"F": "#f59e0b", "W": "#f59e0b", "V": "#f59e0b", "P": "#eab308",
"D": "#ef4444", "E": "#ef4444", "K": "#3b82f6", "R": "#3b82f6", "H": "#818cf8",
"S": "#10b981", "T": "#10b981", "N": "#10b981", "Q": "#10b981",
"C": "#facc15", "Y": "#10b981", "G": "#94a3b8", "X": "#64748b",
}
SS_COLORS = {
"H": "#ec4899", # Helix β†’ magenta
"E": "#f59e0b", # Sheet β†’ amber
"C": "#94a3b8", # Coil β†’ grey
}
BG_COLOR = "#0f0b25"
def _mag_to_color(t: float) -> str:
"""Blue β†’ Indigo β†’ Magenta β†’ Red gradient based on normalized magnitude [0,1]."""
t = max(0.0, min(1.0, t))
# Smooth 4-stop gradient: blue β†’ indigo β†’ magenta β†’ red
stops = [
(0.00, (59, 130, 246)), # blue-500
(0.33, (99, 102, 241)), # indigo-500
(0.66, (168, 85, 247)), # violet-500
(1.00, (239, 68, 68)), # red-500
]
for i in range(len(stops) - 1):
t0, c0 = stops[i]
t1, c1 = stops[i + 1]
if t <= t1:
f = (t - t0) / (t1 - t0) if t1 > t0 else 0
r = int(c0[0] + (c1[0] - c0[0]) * f)
g = int(c0[1] + (c1[1] - c0[1]) * f)
b = int(c0[2] + (c1[2] - c0[2]) * f)
return f"rgb({r},{g},{b})"
return f"rgb({stops[-1][1][0]},{stops[-1][1][1]},{stops[-1][1][2]})"
def _get_color(idx: int, intensity: float, n_res: int, seq: str, scheme: str) -> str:
"""Get color for a residue based on the color scheme."""
if scheme == "magnitude":
return _mag_to_color(intensity)
elif scheme == "rainbow":
import colorsys
h = idx / max(n_res - 1, 1)
r, g, b = [int(255 * c) for c in colorsys.hsv_to_rgb(h, 0.85, 0.92)]
return f"rgb({r},{g},{b})"
elif scheme == "residue_type":
aa = seq[idx] if idx < len(seq) else "X"
return AA_COLORS.get(aa, "#94a3b8")
elif scheme == "bfactor":
return _mag_to_color(intensity)
elif scheme == "secondary":
return "#c7d2fe" # Will be handled at cartoon level
else:
return "#6366f1"
def _parse_ss_from_pdb(pdb_text: str) -> dict:
"""Parse secondary structure from HELIX/SHEET records in PDB."""
ss = {} # residue_num -> 'H' or 'E'
for line in pdb_text.split("\n"):
if line.startswith("HELIX"):
try:
start = int(line[21:25].strip())
end = int(line[33:37].strip())
for r in range(start, end + 1):
ss[r] = "H"
except (ValueError, IndexError):
pass
elif line.startswith("SHEET"):
try:
start = int(line[22:26].strip())
end = int(line[33:37].strip())
for r in range(start, end + 1):
ss[r] = "E"
except (ValueError, IndexError):
pass
return ss
def _generate_deformed_pdb(pdb_text: str, ca_coords: np.ndarray,
mode_vecs: np.ndarray, amplitude: float,
phase: float = 1.0) -> str:
"""Generate a PDB with CA atoms displaced along the mode vector.
phase: -1 to 1, controls direction of displacement.
"""
lines = pdb_text.split("\n")
ca_idx = 0
new_lines = []
for line in lines:
if line.startswith("ATOM") and " CA " in line:
if ca_idx < len(ca_coords):
disp = mode_vecs[ca_idx] * amplitude * phase
x = ca_coords[ca_idx][0] + disp[0]
y = ca_coords[ca_idx][1] + disp[1]
z = ca_coords[ca_idx][2] + disp[2]
# Replace XYZ coordinates (cols 30-54 in PDB format)
new_line = line[:30] + f"{x:8.3f}{y:8.3f}{z:8.3f}" + line[54:]
new_lines.append(new_line)
ca_idx += 1
else:
new_lines.append(line)
else:
new_lines.append(line)
return "\n".join(new_lines)
# ═══════════════════════════════════════════════════
# Main viewer
# ═══════════════════════════════════════════════════
def render_motion_viewer(
pdb_text: str,
ca_coords: np.ndarray,
mode_vecs: np.ndarray,
seq: str = "",
amplitude: float = 3.0,
arrow_scale: float = 1.0,
color_scheme: str = "magnitude",
show_cartoon: bool = True,
show_labels: bool = True,
min_displacement: float = 0.01,
width: int = 800,
height: int = 500,
key: str = "viewer",
style: str = "cartoon",
show_surface: bool = False,
surface_opacity: float = 0.15,
):
"""Render PyMOL-style interactive 3D viewer with displacement arrows."""
if not HAS_STMOL:
st.error("Install `stmol` and `py3Dmol`: `pip install stmol py3Dmol`")
return
n_res = len(ca_coords)
mags = np.linalg.norm(mode_vecs, axis=1)
max_mag = mags.max() + 1e-8
view = py3Dmol.view(width=width, height=height)
view.addModel(pdb_text, "pdb")
# ── Backbone representation ──
if style == "cartoon":
if color_scheme == "magnitude":
view.setStyle({"cartoon": {"color": "white", "opacity": 0.4}})
for i in range(n_res):
t = mags[i] / max_mag
col = _mag_to_color(t)
view.addStyle({"resi": i + 1}, {"cartoon": {"color": col, "opacity": 0.85}})
elif color_scheme == "secondary":
ss = _parse_ss_from_pdb(pdb_text)
view.setStyle({"cartoon": {"color": "#94a3b8", "opacity": 0.85}})
for i in range(n_res):
ss_type = ss.get(i + 1, "C")
col = SS_COLORS.get(ss_type, "#94a3b8")
view.addStyle({"resi": i + 1}, {"cartoon": {"color": col, "opacity": 0.9}})
elif color_scheme == "rainbow":
view.setStyle({"cartoon": {"color": "spectrum", "opacity": 0.85}})
elif color_scheme == "residue_type":
view.setStyle({"cartoon": {"color": "white", "opacity": 0.4}})
for i in range(min(n_res, len(seq))):
col = AA_COLORS.get(seq[i], "#94a3b8")
view.addStyle({"resi": i + 1}, {"cartoon": {"color": col, "opacity": 0.85}})
else:
view.setStyle({"cartoon": {"color": "#c7d2fe", "opacity": 0.7}})
elif style == "stick":
view.setStyle({"stick": {"radius": 0.12, "colorscheme": "Jmol"}})
elif style == "sphere":
view.setStyle({"sphere": {"scale": 0.3, "colorscheme": "Jmol"}})
elif style == "line":
view.setStyle({"line": {"colorscheme": "Jmol"}})
# ── Surface overlay ──
if show_surface:
if color_scheme == "magnitude":
# Color surface by displacement
view.addSurface(py3Dmol.VDW, {
"opacity": surface_opacity,
"color": "white",
})
for i in range(n_res):
t = mags[i] / max_mag
col = _mag_to_color(t)
view.addSurface(py3Dmol.VDW, {
"opacity": surface_opacity,
"color": col,
}, {"resi": i + 1})
else:
view.addSurface(py3Dmol.VDW, {
"opacity": surface_opacity,
"color": "#6366f1",
})
# ── Displacement arrows ──
for i in range(n_res):
if mags[i] < min_displacement:
continue
s = ca_coords[i]
d = mode_vecs[i] * amplitude
e = s + d
t = mags[i] / max_mag
col = _get_color(i, t, n_res, seq, color_scheme)
# Shaft
base_r = 0.06 * arrow_scale
shaft_r = base_r + base_r * 1.5 * t
view.addCylinder({
"start": {"x": float(s[0]), "y": float(s[1]), "z": float(s[2])},
"end": {"x": float(e[0]), "y": float(e[1]), "z": float(e[2])},
"radius": shaft_r,
"color": col,
"fromCap": True,
})
# Arrowhead
dn = d / (np.linalg.norm(d) + 1e-8)
tip = e + dn * 0.3 * amplitude * arrow_scale
tip_r = shaft_r * 2.5
view.addCylinder({
"start": {"x": float(e[0]), "y": float(e[1]), "z": float(e[2])},
"end": {"x": float(tip[0]), "y": float(tip[1]), "z": float(tip[2])},
"radius": tip_r,
"color": col,
"toCap": True,
})
# ── Labels ──
if show_labels and n_res > 0:
top_n = min(5, n_res)
top_idx = np.argsort(mags)[-top_n:][::-1]
for rank, idx in enumerate(top_idx):
if mags[idx] < min_displacement:
continue
pos = ca_coords[idx] + mode_vecs[idx] * amplitude * 0.5
aa = seq[idx] if idx < len(seq) else "?"
view.addLabel(
f"{aa}{idx + 1} ({mags[idx]:.2f}Γ…)",
{
"position": {"x": float(pos[0]), "y": float(pos[1] + 1.5), "z": float(pos[2])},
"fontSize": 11,
"fontColor": "white",
"backgroundColor": "#1e1b4b",
"backgroundOpacity": 0.9,
"borderColor": "#6366f1",
"borderThickness": 1.5,
},
)
view.setBackgroundColor(BG_COLOR)
view.zoomTo()
showmol(view, height=height, width=width)
# ═══════════════════════════════════════════════════
# Deformed structure viewer
# ═══════════════════════════════════════════════════
def render_deformed_viewer(
pdb_text: str,
ca_coords: np.ndarray,
mode_vecs: np.ndarray,
seq: str = "",
amplitude: float = 3.0,
width: int = 800,
height: int = 450,
):
"""Show three structures: original, +deformed, -deformed superimposed."""
if not HAS_STMOL:
return
n_res = len(ca_coords)
mags = np.linalg.norm(mode_vecs, axis=1)
max_mag = mags.max() + 1e-8
view = py3Dmol.view(width=width, height=height)
# Original structure (white)
view.addModel(pdb_text, "pdb")
view.setStyle({"model": 0}, {"cartoon": {"color": "#64748b", "opacity": 0.4}})
# +Deformed (blue)
pdb_plus = _generate_deformed_pdb(pdb_text, ca_coords, mode_vecs, amplitude, +1.0)
view.addModel(pdb_plus, "pdb")
view.setStyle({"model": 1}, {"cartoon": {"color": "#6366f1", "opacity": 0.7}})
# -Deformed (red)
pdb_minus = _generate_deformed_pdb(pdb_text, ca_coords, mode_vecs, amplitude, -1.0)
view.addModel(pdb_minus, "pdb")
view.setStyle({"model": 2}, {"cartoon": {"color": "#ef4444", "opacity": 0.7}})
view.setBackgroundColor(BG_COLOR)
view.zoomTo()
showmol(view, height=height, width=width)
# ═══════════════════════════════════════════════════
# Mode animation (oscillating structure)
# ═══════════════════════════════════════════════════
def render_animated_viewer(
pdb_text: str,
ca_coords: np.ndarray,
mode_vecs: np.ndarray,
amplitude: float = 3.0,
n_frames: int = 20,
width: int = 800,
height: int = 450,
key: str = "anim",
):
"""Render an animated viewer that oscillates the structure along the mode vector.
Uses multi-model PDB + 3Dmol.js animate() to show smooth mode oscillation.
"""
if not HAS_STMOL:
return
n_res = len(ca_coords)
mags = np.linalg.norm(mode_vecs, axis=1)
max_mag = mags.max() + 1e-8
# Build multi-model PDB for animation frames
# Oscillate: phase goes from 0 β†’ +1 β†’ 0 β†’ -1 β†’ 0
phases = np.sin(np.linspace(0, 2 * np.pi, n_frames, endpoint=False))
multi_pdb = ""
for frame, phase in enumerate(phases):
deformed = _generate_deformed_pdb(pdb_text, ca_coords, mode_vecs, amplitude, phase)
multi_pdb += f"MODEL {frame + 1}\n"
for line in deformed.split("\n"):
if line.startswith(("ATOM", "HETATM", "TER")):
multi_pdb += line + "\n"
multi_pdb += "ENDMDL\n"
# Use raw HTML/JS for smooth animation
html = f"""
<div id="viewport_{key}" style="width:{width}px; height:{height}px; position:relative; border-radius: 12px; overflow: hidden;">
</div>
<script src="https://3Dmol.org/build/3Dmol-min.js"></script>
<script>
(function() {{
let viewer = $3Dmol.createViewer("viewport_{key}", {{
backgroundColor: "{BG_COLOR}",
antialias: true,
}});
let pdbData = `{multi_pdb}`;
viewer.addModelsAsFrames(pdbData, "pdb");
// Color by displacement magnitude
viewer.setStyle({{}}, {{cartoon: {{color: "spectrum", opacity: 0.85}}}});
viewer.zoomTo();
viewer.animate({{loop: "forward", reps: 0, interval: 80}});
viewer.render();
}})();
</script>
"""
components.html(html, height=height + 10, width=width + 10)
# ═══════════════════════════════════════════════════
# GT vs Prediction comparison viewer
# ═══════════════════════════════════════════════════
def render_pred_vs_gt_viewer(
pdb_text: str,
ca_coords: np.ndarray,
pred_vecs: np.ndarray,
gt_vecs: np.ndarray,
seq: str = "",
amplitude: float = 3.0,
arrow_scale: float = 1.0,
width: int = 900,
height: int = 400,
):
"""Side-by-side Prediction vs Ground Truth 3D viewer."""
if not HAS_STMOL:
return
n_res = len(ca_coords)
pred_mags = np.linalg.norm(pred_vecs, axis=1)
gt_mags = np.linalg.norm(gt_vecs, axis=1)
col_pred, col_gt = st.columns(2)
with col_pred:
st.caption(f"**PETIMOT Prediction** Β· ΞΌ={pred_mags.mean():.3f}Γ… Β· max={pred_mags.max():.3f}Γ…")
max_m = pred_mags.max() + 1e-8
vw = py3Dmol.view(width=width // 2, height=height)
vw.addModel(pdb_text, "pdb")
vw.setStyle({"cartoon": {"color": "white", "opacity": 0.3}})
for i in range(n_res):
t = pred_mags[i] / max_m
vw.addStyle({"resi": i + 1}, {"cartoon": {"color": _mag_to_color(t), "opacity": 0.8}})
for i in range(n_res):
if pred_mags[i] < 0.01:
continue
s = ca_coords[i]
d = pred_vecs[i] * amplitude
e = s + d
t = pred_mags[i] / max_m
vw.addCylinder({
"start": {"x": float(s[0]), "y": float(s[1]), "z": float(s[2])},
"end": {"x": float(e[0]), "y": float(e[1]), "z": float(e[2])},
"radius": 0.06 * arrow_scale + 0.06 * t * arrow_scale,
"color": "#6366f1",
})
vw.setBackgroundColor(BG_COLOR)
vw.zoomTo()
showmol(vw, height=height, width=width // 2)
with col_gt:
st.caption(f"**Ground Truth (NMA)** Β· ΞΌ={gt_mags.mean():.3f}Γ… Β· max={gt_mags.max():.3f}Γ…")
max_m = gt_mags.max() + 1e-8
vw2 = py3Dmol.view(width=width // 2, height=height)
vw2.addModel(pdb_text, "pdb")
vw2.setStyle({"cartoon": {"color": "white", "opacity": 0.3}})
for i in range(n_res):
t = gt_mags[i] / max_m
vw2.addStyle({"resi": i + 1}, {"cartoon": {"color": _mag_to_color(t), "opacity": 0.8}})
for i in range(n_res):
if gt_mags[i] < 0.01:
continue
s = ca_coords[i]
d = gt_vecs[i] * amplitude
e = s + d
t = gt_mags[i] / max_m
vw2.addCylinder({
"start": {"x": float(s[0]), "y": float(s[1]), "z": float(s[2])},
"end": {"x": float(e[0]), "y": float(e[1]), "z": float(e[2])},
"radius": 0.06 * arrow_scale + 0.06 * t * arrow_scale,
"color": "#10b981",
})
vw2.setBackgroundColor(BG_COLOR)
vw2.zoomTo()
showmol(vw2, height=height, width=width // 2)
# ═══════════════════════════════════════════════════
# Mode comparison grid
# ═══════════════════════════════════════════════════
def render_mode_comparison(
pdb_text: str,
ca_coords: np.ndarray,
modes: dict,
seq: str = "",
amplitude: float = 3.0,
arrow_scale: float = 1.0,
width: int = 900,
height: int = 350,
):
"""Render side-by-side mode comparison grid with PyMOL-style."""
if not HAS_STMOL:
return
n_modes = min(4, len(modes))
if n_modes == 0:
st.warning("No modes to display")
return
mode_colors = ["#6366f1", "#ef4444", "#10b981", "#f59e0b"]
cols = st.columns(n_modes)
for k in range(n_modes):
with cols[k]:
vecs = modes[k]
mags = np.linalg.norm(vecs, axis=1)
max_m = mags.max() + 1e-8
st.caption(f"**Mode {k}** Β· ΞΌ={mags.mean():.3f}Γ… Β· max={mags.max():.3f}Γ…")
per_col_w = width // n_modes
view = py3Dmol.view(width=per_col_w, height=height)
view.addModel(pdb_text, "pdb")
view.setStyle({"cartoon": {"color": "white", "opacity": 0.3}})
for i in range(len(ca_coords)):
t = mags[i] / max_m
view.addStyle({"resi": i + 1}, {"cartoon": {"color": _mag_to_color(t), "opacity": 0.7}})
for i in range(len(ca_coords)):
if mags[i] < 0.01:
continue
s = ca_coords[i]
d = vecs[i] * amplitude
e = s + d
t = mags[i] / max_m
view.addCylinder({
"start": {"x": float(s[0]), "y": float(s[1]), "z": float(s[2])},
"end": {"x": float(e[0]), "y": float(e[1]), "z": float(e[2])},
"radius": 0.06 * arrow_scale + 0.06 * t * arrow_scale,
"color": mode_colors[k],
})
view.setBackgroundColor(BG_COLOR)
view.zoomTo()
showmol(view, height=height, width=per_col_w)