Valmbd commited on
Commit
b47954d
ยท
1 Parent(s): 442fb8f

Add Streamlit explorer app with Docker deployment

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
Dockerfile ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # System deps
6
+ RUN apt-get update && apt-get install -y --no-install-recommends \
7
+ build-essential git && \
8
+ rm -rf /var/lib/apt/lists/*
9
+
10
+ # Python deps
11
+ COPY requirements.txt .
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
+
14
+ # Copy app code
15
+ COPY . .
16
+
17
+ # Install PETIMOT package
18
+ RUN pip install --no-cache-dir -e .
19
+
20
+ # Streamlit config
21
+ RUN mkdir -p /root/.streamlit
22
+ RUN echo '[server]\nheadless = true\nport = 7860\nenableCORS = false\nenableXsrfProtection = false\n' > /root/.streamlit/config.toml
23
+
24
+ EXPOSE 7860
25
+
26
+ HEALTHCHECK CMD curl --fail http://localhost:7860/_stcore/health
27
+
28
+ ENTRYPOINT ["streamlit", "run", "app/app.py", "--server.port=7860", "--server.address=0.0.0.0"]
README.md CHANGED
@@ -1,71 +1,46 @@
1
- # PETIMOT: Protein Motion Inference from Sparse Data
2
-
3
- PETIMOT (Protein sEquence and sTructure-based Inference of MOTions) predicts protein conformational changes using SE(3)-equivariant graph neural networks and pre-trained protein language models.
4
-
5
- ## Installation
6
-
7
- ```bash
8
- # Create and activate conda environment
9
- conda create -n petimot python=3.9
10
- conda activate petimot
11
-
12
- # Clone and install
13
- git clone https://github.com/PhyloSofS-Team/PETIMOT.git
14
- cd petimot
15
- pip install -r requirements.txt
16
- ```
17
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  ## Usage
20
 
21
- ### Reproduce paper results
22
-
23
- 1. Download resources from [Figshare](https://figshare.com/s/ab400d852b4669a83b64):
24
- - Download `default_2025-02-07_21-54-02_epoch_33.pt` into the `weights/` directory
25
- - Download and extract `ground_truth.zip` into the `ground_truth/` directory
26
-
27
- 2. Run inference and evaluation:
28
- ```bash
29
- python -m petimot infer_and_evaluate \
30
- --model-path weights/default_2025-02-07_21-54-02_epoch_33.pt \
31
- --list-path eval_list.txt \
32
- --ground-truth-path ground_truth/ \
33
- --prediction-path predictions/ \
34
- --evaluation-path evaluation/
35
- ```
36
-
37
- ### Compare with baseline methods
38
-
39
- 1. Download baseline predictions from [Figshare](https://figshare.com/s/ab400d852b4669a83b64) :
40
- - Download and extract `baseline_predictions.zip` into the `baselines/` directory
41
-
42
- 2. Run evaluation:
43
  ```bash
44
- python -m petimot evaluate \
45
- --prediction-path baselines/alphaflow_pdb_distilled/ \
46
- --ground-truth-path ground_truth/ \
47
- --output-path evaluation/
48
- ```
49
-
50
- Available baseline predictions:
51
- - AlphaFlow (distilled)
52
- - ESMFlow (distilled)
53
- - Normal Mode Analysis
54
-
55
-
56
-
57
- ### Predict motions for your own PDB files
58
-
59
- ```bash
60
- # Single PDB structure
61
- python -m petimot infer \
62
- --model-path weights/default_2025-02-07_21-54-02_epoch_33.pt \
63
- --list-path protein.pdb \
64
- --output-path predictions/
65
-
66
- # Multiple structures (provide paths in a text file)
67
- python -m petimot infer \
68
- --model-path weights/default_2025-02-07_21-54-02_epoch_33.pt \
69
- --list-path protein_list.txt \
70
- --output-path predictions/
71
  ```
 
1
+ ---
2
+ title: PETIMOT Explorer
3
+ emoji: ๐Ÿงฌ
4
+ colorFrom: indigo
5
+ colorTo: purple
6
+ sdk: streamlit
7
+ sdk_version: "1.40.0"
8
+ app_file: app/app.py
9
+ pinned: true
10
+ license: gpl-3.0
11
+ tags:
12
+ - protein
13
+ - motion
14
+ - GNN
15
+ - bioinformatics
16
+ - structural-biology
17
+ short_description: "Explore SE(3)-equivariant protein motion predictions"
18
+ ---
19
+
20
+ # ๐Ÿงฌ PETIMOT Explorer
21
+
22
+ **Protein Motion Inference from Sparse Data**
23
+
24
+ Interactive explorer for protein motion predictions using SE(3)-equivariant Graph Neural Networks.
25
+
26
+ ## Features
27
+
28
+ - ๐Ÿ” **Explorer** โ€” Browse pre-computed predictions for ~36K proteins
29
+ - ๐Ÿ”ฎ **Inference** โ€” Predict motion for any protein (PDB ID or upload)
30
+ - ๐Ÿ“Š **Statistics** โ€” Dataset-wide analysis and distributions
31
+ - ๐ŸŽจ **3D Viewer** โ€” Interactive motion visualization with displacement arrows
32
+ - ๐Ÿงฌ **Sequence View** โ€” Per-residue displacement heatmap with coverage overlay
33
+
34
+ ## Paper
35
+
36
+ > Lombard, Grudinin & Laine โ€” *PETIMOT: SE(3)-Equivariant GNNs for Protein Motion Prediction*
37
+ > [arXiv 2504.02839](https://arxiv.org/abs/2504.02839)
38
 
39
  ## Usage
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  ```bash
42
+ git clone https://github.com/PhyloSofS-Team/PETIMOT
43
+ cd PETIMOT
44
+ pip install -r app/requirements.txt
45
+ streamlit run app/app.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  ```
app/app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os, sys
3
+
4
+ # โ”€โ”€ Page Config โ”€โ”€
5
+ st.set_page_config(
6
+ page_title="PETIMOT Explorer",
7
+ page_icon="๐Ÿงฌ",
8
+ layout="wide",
9
+ initial_sidebar_state="expanded",
10
+ )
11
+
12
+ # โ”€โ”€ Ensure PETIMOT is importable โ”€โ”€
13
+ PETIMOT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
14
+ if PETIMOT_ROOT not in sys.path:
15
+ sys.path.insert(0, PETIMOT_ROOT)
16
+
17
+ # โ”€โ”€ Custom CSS โ”€โ”€
18
+ st.markdown("""
19
+ <style>
20
+ /* Dark theme overrides */
21
+ .stApp { background-color: #0f0d1a; }
22
+ .block-container { padding-top: 1rem; }
23
+
24
+ /* Sidebar styling */
25
+ section[data-testid="stSidebar"] {
26
+ background-color: #1a1730;
27
+ border-right: 1px solid #2d2b55;
28
+ }
29
+
30
+ /* Headers */
31
+ h1, h2, h3 { color: #c4b5fd !important; }
32
+
33
+ /* Metric cards */
34
+ [data-testid="stMetric"] {
35
+ background-color: #1e1b4b;
36
+ border: 1px solid #312e81;
37
+ border-radius: 12px;
38
+ padding: 12px 16px;
39
+ }
40
+ [data-testid="stMetricLabel"] { color: #a5b4fc !important; }
41
+ [data-testid="stMetricValue"] { color: #e0e7ff !important; }
42
+
43
+ /* Dataframe */
44
+ .stDataFrame { border-radius: 8px; overflow: hidden; }
45
+
46
+ /* Tabs */
47
+ .stTabs [data-baseweb="tab"] {
48
+ background-color: #1e1b4b;
49
+ border-radius: 8px 8px 0 0;
50
+ color: #a5b4fc;
51
+ }
52
+ .stTabs [data-baseweb="tab"][aria-selected="true"] {
53
+ background-color: #312e81;
54
+ color: white;
55
+ }
56
+ </style>
57
+ """, unsafe_allow_html=True)
58
+
59
+ # โ”€โ”€ Sidebar โ”€โ”€
60
+ with st.sidebar:
61
+ st.image("https://raw.githubusercontent.com/PhyloSofS-Team/PETIMOT/main/logo.png",
62
+ use_container_width=True)
63
+ st.markdown("# ๐Ÿงฌ PETIMOT")
64
+ st.markdown("**Protein Motion from Sparse Data**")
65
+ st.markdown("SE(3)-Equivariant GNNs")
66
+ st.divider()
67
+
68
+ # Global settings
69
+ st.markdown("### โš™๏ธ Settings")
70
+
71
+ weights_dir = os.path.join(PETIMOT_ROOT, "weights")
72
+ pt_files = []
73
+ if os.path.isdir(weights_dir):
74
+ for root, dirs, files in os.walk(weights_dir):
75
+ for f in files:
76
+ if f.endswith(".pt"):
77
+ pt_files.append(os.path.join(root, f))
78
+
79
+ if pt_files:
80
+ selected_weights = st.selectbox(
81
+ "Model weights",
82
+ pt_files,
83
+ format_func=lambda x: os.path.basename(x),
84
+ key="weights"
85
+ )
86
+ else:
87
+ selected_weights = None
88
+ st.warning("No weights found in `weights/`")
89
+
90
+ st.divider()
91
+ st.markdown("""
92
+ **Links**
93
+ - [Paper](https://arxiv.org/abs/2504.02839)
94
+ - [GitHub](https://github.com/PhyloSofS-Team/PETIMOT)
95
+ - [Data](https://figshare.com/s/ab400d852b4669a83b64)
96
+ """)
97
+ st.caption("GPL-3.0 ยท Lombard, Grudinin & Laine")
98
+
99
+ # โ”€โ”€ Main Page โ”€โ”€
100
+ st.title("๐Ÿงฌ PETIMOT Explorer")
101
+ st.markdown("""
102
+ Explore protein motion predictions from the PETIMOT framework.
103
+ Navigate using the sidebar pages:
104
+
105
+ | Page | Description |
106
+ |------|-------------|
107
+ | ๐Ÿ” **Explorer** | Browse pre-computed predictions for ~36K proteins |
108
+ | ๐Ÿ”ฎ **Inference** | Predict motion for a new protein (PDB ID or upload) |
109
+ | ๐Ÿ“Š **Statistics** | Dataset-wide analysis and distributions |
110
+ """)
111
+
112
+ # โ”€โ”€ Data Status โ”€โ”€
113
+ from app.utils.download import check_data_status, ensure_weights
114
+
115
+ status = check_data_status(PETIMOT_ROOT)
116
+
117
+ col1, col2, col3 = st.columns(3)
118
+ with col1:
119
+ st.metric("Ground Truth", f"{status['ground_truth']:,}",
120
+ delta="โœ…" if status['has_gt'] else "Missing")
121
+ with col2:
122
+ st.metric("Predictions", f"{status['predictions']:,}",
123
+ delta="โœ…" if status['has_predictions'] else "Not yet computed")
124
+ with col3:
125
+ st.metric("Model Weights", "4.7M params",
126
+ delta="โœ…" if status['has_weights'] else "Missing")
127
+
128
+ # Auto-download if missing
129
+ if not status['has_weights']:
130
+ st.divider()
131
+ st.warning("โš ๏ธ Model weights not found.")
132
+ if st.button("โฌ‡๏ธ Download weights from Figshare (18 MB)", type="primary"):
133
+ with st.spinner("Downloading..."):
134
+ wt = ensure_weights(PETIMOT_ROOT)
135
+ if wt:
136
+ st.success(f"โœ… Weights downloaded: {os.path.basename(wt)}")
137
+ st.rerun()
138
+ else:
139
+ st.error("Download failed. Please manually download from "
140
+ "[Figshare](https://figshare.com/s/ab400d852b4669a83b64) "
141
+ "and place in `weights/`")
142
+
143
+ if not status['has_predictions'] and status['has_weights']:
144
+ st.info("๐Ÿ’ก No pre-computed predictions yet. Use the **Inference** page to predict "
145
+ "individual proteins, or run batch inference from the Colab notebook.")
app/components/__init__.py ADDED
File without changes
app/components/mode_panel.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Mode selection panel with per-mode statistics."""
2
+ import streamlit as st
3
+ import numpy as np
4
+ import plotly.graph_objects as go
5
+
6
+
7
+ def render_mode_panel(
8
+ modes: dict,
9
+ seq: str = "",
10
+ eigenvalues: np.ndarray = None,
11
+ ) -> int:
12
+ """Render mode selector with per-mode stats. Returns selected mode index.
13
+
14
+ Args:
15
+ modes: {k: np.ndarray (N, 3)} displacement vectors per mode
16
+ seq: Amino acid sequence
17
+ eigenvalues: Ground truth eigenvalues (if available)
18
+
19
+ Returns:
20
+ Selected mode index
21
+ """
22
+ n_modes = len(modes)
23
+ if n_modes == 0:
24
+ st.warning("No modes available")
25
+ return 0
26
+
27
+ # Mode tabs
28
+ tabs = st.tabs([f"Mode {k}" for k in range(n_modes)])
29
+
30
+ selected = 0
31
+ for k in range(n_modes):
32
+ with tabs[k]:
33
+ vecs = modes[k]
34
+ mags = np.linalg.norm(vecs, axis=1)
35
+ n_res = len(mags)
36
+
37
+ # Stats columns
38
+ c1, c2, c3, c4 = st.columns(4)
39
+ c1.metric("Mean", f"{mags.mean():.3f} ร…")
40
+ c2.metric("Max", f"{mags.max():.3f} ร…")
41
+ c3.metric("Std", f"{mags.std():.3f} ร…")
42
+
43
+ if eigenvalues is not None and k < len(eigenvalues):
44
+ c4.metric("ฮป", f"{eigenvalues[k]:.4f}")
45
+ else:
46
+ c4.metric("Residues", f"{n_res}")
47
+
48
+ # Top mobile residues
49
+ top5 = np.argsort(mags)[-5:][::-1]
50
+ top_data = []
51
+ for idx in top5:
52
+ aa = seq[idx] if idx < len(seq) else "?"
53
+ top_data.append({
54
+ "Residue": f"{aa}{idx + 1}",
55
+ "Displacement": f"{mags[idx]:.3f} ร…",
56
+ "Rank": f"#{np.where(np.argsort(mags)[::-1] == idx)[0][0] + 1}",
57
+ })
58
+ st.markdown("**Most mobile residues:**")
59
+ st.dataframe(top_data, use_container_width=True, hide_index=True)
60
+
61
+ selected = st.session_state.get("_active_mode_tab", 0)
62
+ return selected
63
+
64
+
65
+ def render_mode_correlation(modes: dict):
66
+ """Render mode correlation matrix as Plotly heatmap."""
67
+ n_modes = len(modes)
68
+ if n_modes < 2:
69
+ return
70
+
71
+ # Compute displacement profile correlation
72
+ profiles = []
73
+ for k in sorted(modes.keys()):
74
+ mags = np.linalg.norm(modes[k], axis=1)
75
+ profiles.append(mags)
76
+
77
+ corr = np.corrcoef(profiles)
78
+
79
+ fig = go.Figure(go.Heatmap(
80
+ z=corr,
81
+ x=[f"M{k}" for k in range(n_modes)],
82
+ y=[f"M{k}" for k in range(n_modes)],
83
+ colorscale="RdBu_r",
84
+ zmin=-1, zmax=1,
85
+ text=np.round(corr, 2),
86
+ texttemplate="%{text:.2f}",
87
+ textfont={"size": 12},
88
+ ))
89
+
90
+ fig.update_layout(
91
+ title="Mode Displacement Correlation",
92
+ template="plotly_dark",
93
+ height=300, width=300,
94
+ paper_bgcolor="rgba(0,0,0,0)",
95
+ plot_bgcolor="rgba(30,27,75,0.5)",
96
+ margin=dict(l=30, r=30, t=40, b=30),
97
+ )
98
+ st.plotly_chart(fig, use_container_width=False)
99
+
100
+
101
+ def render_eigenvalue_spectrum(eigenvalues: np.ndarray):
102
+ """Render eigenvalue bar chart with cumulative variance line."""
103
+ if eigenvalues is None or len(eigenvalues) == 0:
104
+ return
105
+
106
+ fig = go.Figure()
107
+
108
+ # Bars
109
+ fig.add_trace(go.Bar(
110
+ x=[f"ฮป{k+1}" for k in range(len(eigenvalues))],
111
+ y=eigenvalues,
112
+ marker_color="#6366f1",
113
+ name="Eigenvalue",
114
+ ))
115
+
116
+ # Cumulative variance line
117
+ cum = np.cumsum(eigenvalues) / eigenvalues.sum() * 100
118
+ fig.add_trace(go.Scatter(
119
+ x=[f"ฮป{k+1}" for k in range(len(eigenvalues))],
120
+ y=cum,
121
+ mode="lines+markers",
122
+ name="Cumul. variance %",
123
+ marker=dict(color="#ef4444", size=6),
124
+ line=dict(color="#ef4444", width=2),
125
+ yaxis="y2",
126
+ ))
127
+
128
+ fig.update_layout(
129
+ title="Eigenvalue Spectrum",
130
+ template="plotly_dark",
131
+ height=250,
132
+ paper_bgcolor="rgba(0,0,0,0)",
133
+ plot_bgcolor="rgba(30,27,75,0.5)",
134
+ yaxis=dict(title="Eigenvalue"),
135
+ yaxis2=dict(title="Cumul. %", overlaying="y", side="right", range=[0, 105]),
136
+ legend=dict(orientation="h", y=1.15),
137
+ margin=dict(l=40, r=40, t=40, b=30),
138
+ )
139
+ st.plotly_chart(fig, use_container_width=True)
app/components/prediction_analysis.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Enhanced prediction analysis โ€” sign-invariant modes and per-residue normalization."""
2
+ import numpy as np
3
+ import streamlit as st
4
+ import plotly.graph_objects as go
5
+ from plotly.subplots import make_subplots
6
+
7
+
8
+ def canonicalize_sign(modes: dict) -> dict:
9
+ """Make eigenvectors sign-consistent.
10
+
11
+ Eigenvectors are defined up to ยฑ1 global sign. We canonicalize by choosing
12
+ the sign such that the component with the largest absolute value is positive.
13
+ This ensures consistent visualization across different runs/proteins.
14
+ """
15
+ canonical = {}
16
+ for k, vecs in modes.items():
17
+ # Flatten to (3N,), find component with max absolute value
18
+ flat = vecs.flatten()
19
+ max_idx = np.argmax(np.abs(flat))
20
+ if flat[max_idx] < 0:
21
+ canonical[k] = -vecs # Flip sign
22
+ else:
23
+ canonical[k] = vecs.copy()
24
+ return canonical
25
+
26
+
27
+ def per_residue_relative_norm(vecs: np.ndarray) -> np.ndarray:
28
+ """Normalize displacement magnitudes to [0, 1] relative to max.
29
+
30
+ Args:
31
+ vecs: (N, 3) displacement vectors
32
+
33
+ Returns:
34
+ (N,) relative magnitudes in [0, 1]
35
+ """
36
+ mags = np.linalg.norm(vecs, axis=1)
37
+ max_m = mags.max()
38
+ return mags / max_m if max_m > 1e-12 else mags
39
+
40
+
41
+ def per_residue_direction(vecs: np.ndarray, ca_coords: np.ndarray) -> np.ndarray:
42
+ """Compute relative direction of displacement vs protein backbone.
43
+
44
+ Projects displacement onto local backbone direction (CA_i โ†’ CA_{i+1}).
45
+ Returns signed projection: positive = along backbone, negative = against.
46
+
47
+ Args:
48
+ vecs: (N, 3) displacement vectors
49
+ ca_coords: (N, 3) CA coordinates
50
+
51
+ Returns:
52
+ (N,) signed projections normalized by displacement magnitude
53
+ """
54
+ n = len(vecs)
55
+ projections = np.zeros(n)
56
+
57
+ for i in range(n):
58
+ # Local backbone direction
59
+ if i < n - 1:
60
+ backbone = ca_coords[i + 1] - ca_coords[i]
61
+ else:
62
+ backbone = ca_coords[i] - ca_coords[i - 1]
63
+
64
+ bb_norm = np.linalg.norm(backbone)
65
+ if bb_norm < 1e-8:
66
+ continue
67
+
68
+ disp_mag = np.linalg.norm(vecs[i])
69
+ if disp_mag < 1e-8:
70
+ continue
71
+
72
+ # Cosine angle between displacement and backbone direction
73
+ projections[i] = np.dot(vecs[i], backbone) / (disp_mag * bb_norm)
74
+
75
+ return projections
76
+
77
+
78
+ def render_prediction_analysis(
79
+ modes: dict,
80
+ seq: str,
81
+ ca_coords: np.ndarray = None,
82
+ coverage: np.ndarray = None,
83
+ eigenvalues: np.ndarray = None,
84
+ gt_modes: dict = None,
85
+ protein_name: str = "",
86
+ ):
87
+ """Comprehensive prediction analysis panel.
88
+
89
+ Shows:
90
+ 1. Normalized displacement heatmap (all modes ร— residues)
91
+ 2. Sign-canonical direction analysis
92
+ 3. Prediction vs ground truth comparison (if available)
93
+ 4. Per-residue statistics table
94
+ """
95
+ # Canonicalize signs
96
+ modes_c = canonicalize_sign(modes)
97
+ n_modes = len(modes_c)
98
+ n_res = len(list(modes_c.values())[0])
99
+
100
+ if coverage is None:
101
+ coverage = np.ones(n_res)
102
+
103
+ # โ”€โ”€ Tab layout โ”€โ”€
104
+ tab_norm, tab_dir, tab_compare, tab_table = st.tabs([
105
+ "๐Ÿ“Š Normalized Displacement", "๐Ÿงญ Direction Analysis",
106
+ "โš–๏ธ Pred vs GT", "๐Ÿ“‹ Per-Residue Table"
107
+ ])
108
+
109
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
110
+ # Tab 1: Normalized displacement heatmap
111
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
112
+ with tab_norm:
113
+ # Compute relative norms for all modes
114
+ rel_norms = np.zeros((n_modes, n_res))
115
+ abs_mags = np.zeros((n_modes, n_res))
116
+ for k in range(n_modes):
117
+ abs_mags[k] = np.linalg.norm(modes_c[k], axis=1)
118
+ rel_norms[k] = per_residue_relative_norm(modes_c[k])
119
+
120
+ # Hover text with sequence
121
+ hover = [[f"{seq[j] if j < len(seq) else '?'}{j+1}<br>"
122
+ f"Abs: {abs_mags[k][j]:.3f}ร…<br>"
123
+ f"Rel: {rel_norms[k][j]:.2%}<br>"
124
+ f"Cov: {coverage[j]:.2f}"
125
+ for j in range(n_res)] for k in range(n_modes)]
126
+
127
+ fig = make_subplots(rows=3, cols=1, row_heights=[0.4, 0.4, 0.2],
128
+ shared_xaxes=True, vertical_spacing=0.06,
129
+ subplot_titles=["Absolute Displacement (ร…)",
130
+ "Relative Displacement (0-1)",
131
+ "Coverage"])
132
+
133
+ # Absolute heatmap
134
+ fig.add_trace(go.Heatmap(
135
+ z=abs_mags, colorscale="YlOrRd",
136
+ y=[f"Mode {k}" for k in range(n_modes)],
137
+ text=hover, hovertemplate="%{text}<extra></extra>",
138
+ colorbar=dict(title="ร…", x=1.01, len=0.35, y=0.85),
139
+ ), row=1, col=1)
140
+
141
+ # Relative heatmap
142
+ fig.add_trace(go.Heatmap(
143
+ z=rel_norms, colorscale="Viridis", zmin=0, zmax=1,
144
+ y=[f"Mode {k}" for k in range(n_modes)],
145
+ text=hover, hovertemplate="%{text}<extra></extra>",
146
+ colorbar=dict(title="Rel", x=1.08, len=0.35, y=0.5),
147
+ ), row=2, col=1)
148
+
149
+ # Coverage bar
150
+ fig.add_trace(go.Bar(
151
+ x=list(range(n_res)), y=coverage[:n_res],
152
+ marker_color=["#10b981" if c > 0.5 else "#ef4444" for c in coverage[:n_res]],
153
+ hovertemplate="Res %{x}<br>Coverage: %{y:.3f}<extra></extra>",
154
+ showlegend=False,
155
+ ), row=3, col=1)
156
+
157
+ # Sequence ticks
158
+ step = max(1, n_res // 50)
159
+ tick_vals = list(range(0, n_res, step))
160
+ tick_text = [f"{seq[i] if i < len(seq) else '?'}{i+1}" for i in tick_vals]
161
+ fig.update_xaxes(tickvals=tick_vals, ticktext=tick_text, tickangle=45,
162
+ tickfont=dict(size=8), row=3, col=1)
163
+
164
+ fig.update_layout(
165
+ template="plotly_dark", height=550,
166
+ paper_bgcolor="rgba(0,0,0,0)",
167
+ plot_bgcolor="rgba(30,27,75,0.3)",
168
+ margin=dict(l=60, r=80, t=30, b=50),
169
+ )
170
+ st.plotly_chart(fig, use_container_width=True)
171
+
172
+ # Key insight
173
+ for k in range(min(n_modes, 4)):
174
+ top3 = np.argsort(abs_mags[k])[-3:][::-1]
175
+ top_str = ", ".join([f"**{seq[i] if i<len(seq) else '?'}{i+1}** ({abs_mags[k][i]:.2f}ร…)"
176
+ for i in top3])
177
+ st.markdown(f"Mode {k} hotspots: {top_str}")
178
+
179
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
180
+ # Tab 2: Direction analysis
181
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
182
+ with tab_dir:
183
+ if ca_coords is not None and len(ca_coords) == n_res:
184
+ st.markdown("""
185
+ **Direction Analysis**: Projects displacement onto the local backbone direction (CAโ†’CA).
186
+ - ๐Ÿ”ต **Blue** = motion along backbone (stretching/compressing)
187
+ - ๐Ÿ”ด **Red** = motion perpendicular to backbone (lateral/hinge)
188
+ - Sign is arbitrary for eigenvectors โ†’ we show absolute cosine similarity
189
+ """)
190
+
191
+ fig_dir = go.Figure()
192
+ colors = ["#6366f1", "#ef4444", "#10b981", "#f59e0b"]
193
+
194
+ for k in range(min(n_modes, 4)):
195
+ proj = per_residue_direction(modes_c[k], ca_coords)
196
+ # Show absolute cosine (sign-invariant)
197
+ abs_proj = np.abs(proj)
198
+
199
+ fig_dir.add_trace(go.Scatter(
200
+ x=list(range(1, n_res + 1)), y=abs_proj,
201
+ mode="lines", name=f"Mode {k}",
202
+ line=dict(color=colors[k], width=1.5),
203
+ fill="tozeroy",
204
+ fillcolor=colors[k].replace(")", ",0.1)").replace("#", "rgba(").replace(
205
+ "rgba(6366f1", "rgba(99,102,241").replace(
206
+ "rgba(ef4444", "rgba(239,68,68").replace(
207
+ "rgba(10b981", "rgba(16,185,129").replace(
208
+ "rgba(f59e0b", "rgba(245,158,11"),
209
+ hovertemplate="Res %{x}<br>|cos ฮธ|: %{y:.3f}<extra>Mode " + str(k) + "</extra>",
210
+ ))
211
+
212
+ fig_dir.add_hline(y=0.5, line_dash="dash", line_color="#94a3b8",
213
+ annotation_text="isotropic threshold")
214
+
215
+ fig_dir.update_layout(
216
+ template="plotly_dark", height=350,
217
+ paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(30,27,75,0.3)",
218
+ xaxis_title="Residue", yaxis_title="|cos ฮธ| (backbone projection)",
219
+ yaxis_range=[0, 1.05],
220
+ margin=dict(l=50, r=20, t=30, b=50),
221
+ )
222
+ st.plotly_chart(fig_dir, use_container_width=True)
223
+
224
+ # Direction heatmap
225
+ st.markdown("**Per-residue ร— mode direction matrix:**")
226
+ dir_matrix = np.zeros((n_modes, n_res))
227
+ for k in range(n_modes):
228
+ dir_matrix[k] = np.abs(per_residue_direction(modes_c[k], ca_coords))
229
+
230
+ fig_dh = go.Figure(go.Heatmap(
231
+ z=dir_matrix, colorscale="RdBu_r", zmin=0, zmax=1,
232
+ y=[f"Mode {k}" for k in range(n_modes)],
233
+ colorbar=dict(title="|cos ฮธ|"),
234
+ ))
235
+ fig_dh.update_layout(
236
+ template="plotly_dark", height=200,
237
+ paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(30,27,75,0.3)",
238
+ margin=dict(l=60, r=60, t=10, b=30),
239
+ )
240
+ st.plotly_chart(fig_dh, use_container_width=True)
241
+ else:
242
+ st.info("Direction analysis requires CA coordinates (ground truth or PDB needed)")
243
+
244
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
245
+ # Tab 3: Prediction vs Ground Truth
246
+ # โ•โ•โ•โ•โ•๏ฟฝ๏ฟฝโ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
247
+ with tab_compare:
248
+ if gt_modes is not None and len(gt_modes) > 0:
249
+ gt_c = canonicalize_sign(gt_modes)
250
+ n_gt = len(gt_c)
251
+
252
+ st.markdown("**Pred vs GT displacement profiles (sign-canonicalized):**")
253
+
254
+ for k in range(min(n_modes, n_gt, 4)):
255
+ pred_mag = np.linalg.norm(modes_c[k], axis=1)
256
+ gt_mag = np.linalg.norm(gt_c[k], axis=1)
257
+
258
+ # Normalize both to [0, 1]
259
+ pred_rel = pred_mag / (pred_mag.max() + 1e-12)
260
+ gt_rel = gt_mag / (gt_mag.max() + 1e-12)
261
+
262
+ fig_cmp = go.Figure()
263
+ fig_cmp.add_trace(go.Scatter(
264
+ x=list(range(1, n_res + 1)), y=gt_rel,
265
+ mode="lines", name="Ground Truth",
266
+ line=dict(color="#10b981", width=2),
267
+ ))
268
+ fig_cmp.add_trace(go.Scatter(
269
+ x=list(range(1, n_res + 1)), y=pred_rel,
270
+ mode="lines", name="Prediction",
271
+ line=dict(color="#6366f1", width=2, dash="dot"),
272
+ ))
273
+
274
+ # Correlation
275
+ corr = np.corrcoef(pred_rel, gt_rel)[0, 1]
276
+ rmse = np.sqrt(np.mean((pred_rel - gt_rel) ** 2))
277
+
278
+ fig_cmp.update_layout(
279
+ template="plotly_dark", height=200,
280
+ title=f"Mode {k} โ€” r={corr:.3f}, RMSE={rmse:.3f}",
281
+ paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(30,27,75,0.3)",
282
+ margin=dict(l=40, r=20, t=40, b=30),
283
+ legend=dict(orientation="h", y=1.15),
284
+ )
285
+ st.plotly_chart(fig_cmp, use_container_width=True)
286
+ else:
287
+ st.info("No ground truth available for comparison. "
288
+ "Ground truth is only available for proteins in the training database.")
289
+
290
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
291
+ # Tab 4: Per-residue table
292
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
293
+ with tab_table:
294
+ import pandas as pd
295
+
296
+ rows = []
297
+ for i in range(n_res):
298
+ row = {
299
+ "Residue": i + 1,
300
+ "AA": seq[i] if i < len(seq) else "?",
301
+ "Coverage": f"{coverage[i]:.3f}" if i < len(coverage) else "โ€”",
302
+ }
303
+ for k in range(min(n_modes, 4)):
304
+ mag = np.linalg.norm(modes_c[k][i])
305
+ rel = per_residue_relative_norm(modes_c[k])[i]
306
+ row[f"M{k} (ร…)"] = f"{mag:.3f}"
307
+ row[f"M{k} rel"] = f"{rel:.2%}"
308
+ rows.append(row)
309
+
310
+ df = pd.DataFrame(rows)
311
+ st.dataframe(df, use_container_width=True, height=500,
312
+ column_config={
313
+ "Residue": st.column_config.NumberColumn(width="small"),
314
+ "AA": st.column_config.TextColumn(width="small"),
315
+ })
316
+
317
+ # Download CSV
318
+ csv = df.to_csv(index=False)
319
+ st.download_button("๐Ÿ“ฅ Download CSV", csv,
320
+ f"{protein_name}_analysis.csv", "text/csv")
app/components/sequence_viewer.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Interactive sequence viewer with per-residue displacement heatmap."""
2
+ import streamlit as st
3
+ import numpy as np
4
+
5
+
6
+ # Amino acid property classification
7
+ AA_PROPS = {
8
+ "A": "hydrophobic", "I": "hydrophobic", "L": "hydrophobic", "M": "hydrophobic",
9
+ "F": "hydrophobic", "W": "hydrophobic", "V": "hydrophobic", "P": "hydrophobic",
10
+ "D": "charged", "E": "charged", "K": "charged", "R": "charged", "H": "charged",
11
+ "S": "polar", "T": "polar", "N": "polar", "Q": "polar", "C": "polar", "Y": "polar",
12
+ "G": "special", "X": "unknown",
13
+ }
14
+
15
+ PROP_COLORS = {
16
+ "hydrophobic": "#f59e0b",
17
+ "charged": "#ef4444",
18
+ "polar": "#10b981",
19
+ "special": "#94a3b8",
20
+ "unknown": "#64748b",
21
+ }
22
+
23
+
24
+ def render_sequence_viewer(
25
+ seq: str,
26
+ displacements: np.ndarray,
27
+ coverage: np.ndarray = None,
28
+ mode_label: str = "Mode 0",
29
+ max_chars_per_row: int = 80,
30
+ ):
31
+ """Render interactive HTML sequence viewer with displacement coloring.
32
+
33
+ Each residue is displayed as a colored cell where:
34
+ - Background: displacement magnitude (white โ†’ red gradient)
35
+ - Border: coverage (thick = low coverage)
36
+ - Tooltip: residue info
37
+
38
+ Args:
39
+ seq: Amino acid sequence (1-letter codes)
40
+ displacements: Per-residue displacement magnitudes (N,)
41
+ coverage: Per-residue coverage (N,) in [0, 1]
42
+ mode_label: Label for the current mode
43
+ max_chars_per_row: Characters per row before wrapping
44
+ """
45
+ n = len(seq)
46
+ if coverage is None:
47
+ coverage = np.ones(n)
48
+
49
+ max_d = displacements.max() + 1e-8
50
+
51
+ html = f"""
52
+ <style>
53
+ .seq-container {{
54
+ font-family: 'Consolas', 'Monaco', monospace;
55
+ background: #1e1b4b;
56
+ border-radius: 8px;
57
+ padding: 12px;
58
+ margin: 8px 0;
59
+ }}
60
+ .seq-header {{
61
+ color: #a5b4fc;
62
+ font-size: 13px;
63
+ margin-bottom: 8px;
64
+ font-weight: bold;
65
+ }}
66
+ .seq-row {{
67
+ display: flex;
68
+ flex-wrap: wrap;
69
+ gap: 1px;
70
+ margin-bottom: 2px;
71
+ }}
72
+ .res {{
73
+ display: inline-flex;
74
+ align-items: center;
75
+ justify-content: center;
76
+ width: 14px;
77
+ height: 22px;
78
+ font-size: 10px;
79
+ font-weight: bold;
80
+ border-radius: 2px;
81
+ cursor: pointer;
82
+ transition: transform 0.1s;
83
+ position: relative;
84
+ }}
85
+ .res:hover {{
86
+ transform: scale(1.8);
87
+ z-index: 10;
88
+ box-shadow: 0 0 8px rgba(99, 102, 241, 0.8);
89
+ }}
90
+ .res:hover::after {{
91
+ content: attr(data-tooltip);
92
+ position: absolute;
93
+ top: -38px;
94
+ left: 50%;
95
+ transform: translateX(-50%);
96
+ background: #312e81;
97
+ color: white;
98
+ padding: 4px 8px;
99
+ border-radius: 4px;
100
+ font-size: 10px;
101
+ white-space: nowrap;
102
+ z-index: 100;
103
+ border: 1px solid #6366f1;
104
+ }}
105
+ .seq-ruler {{
106
+ display: flex;
107
+ gap: 1px;
108
+ margin-bottom: 1px;
109
+ }}
110
+ .ruler-mark {{
111
+ width: 14px;
112
+ font-size: 7px;
113
+ color: #64748b;
114
+ text-align: center;
115
+ }}
116
+ .legend {{
117
+ display: flex;
118
+ gap: 16px;
119
+ margin-top: 8px;
120
+ font-size: 11px;
121
+ color: #94a3b8;
122
+ }}
123
+ .legend-item {{
124
+ display: flex;
125
+ align-items: center;
126
+ gap: 4px;
127
+ }}
128
+ .legend-swatch {{
129
+ width: 12px;
130
+ height: 12px;
131
+ border-radius: 2px;
132
+ }}
133
+ </style>
134
+ <div class="seq-container">
135
+ <div class="seq-header">{mode_label} โ€” Per-residue displacement ({n} residues)</div>
136
+ """
137
+
138
+ # Build rows
139
+ for row_start in range(0, n, max_chars_per_row):
140
+ row_end = min(row_start + max_chars_per_row, n)
141
+
142
+ # Ruler
143
+ html += '<div class="seq-ruler">'
144
+ for i in range(row_start, row_end):
145
+ if (i + 1) % 10 == 0:
146
+ html += f'<div class="ruler-mark">{i + 1}</div>'
147
+ elif (i + 1) % 5 == 0:
148
+ html += '<div class="ruler-mark">ยท</div>'
149
+ else:
150
+ html += '<div class="ruler-mark"></div>'
151
+ html += "</div>"
152
+
153
+ # Residues
154
+ html += '<div class="seq-row">'
155
+ for i in range(row_start, row_end):
156
+ aa = seq[i] if i < len(seq) else "X"
157
+ d = displacements[i]
158
+ c = coverage[i] if i < len(coverage) else 1.0
159
+ t = d / max_d # Normalized displacement
160
+
161
+ # Background: displacement heatmap (dark purple โ†’ bright red)
162
+ r = int(30 + 225 * t)
163
+ g = int(27 + 20 * (1 - t))
164
+ b = int(75 - 50 * t)
165
+ bg = f"rgb({r},{g},{b})"
166
+
167
+ # Text color: white for high displacement, light for low
168
+ txt_color = "white" if t > 0.3 else "#a5b4fc"
169
+
170
+ # Border: thicker = lower coverage
171
+ border_w = max(0, int(3 * (1 - c)))
172
+ border = f"{border_w}px solid #ef4444" if border_w > 0 else "none"
173
+
174
+ prop = AA_PROPS.get(aa, "unknown")
175
+ tooltip = f"{aa}{i+1} | {d:.3f}ร… | cov={c:.2f} | {prop}"
176
+
177
+ html += (
178
+ f'<div class="res" style="background:{bg};color:{txt_color};'
179
+ f'border:{border}" data-tooltip="{tooltip}">{aa}</div>'
180
+ )
181
+ html += "</div>"
182
+
183
+ # Legend
184
+ html += """
185
+ <div class="legend">
186
+ <div class="legend-item">
187
+ <div class="legend-swatch" style="background:linear-gradient(90deg,#1e1b4b,#ff3030)"></div>
188
+ Low โ†’ High displacement
189
+ </div>
190
+ <div class="legend-item">
191
+ <div class="legend-swatch" style="border:2px solid #ef4444;background:none"></div>
192
+ Red border = low coverage
193
+ </div>
194
+ </div>
195
+ </div>
196
+ """
197
+
198
+ st.markdown(html, unsafe_allow_html=True)
199
+
200
+
201
+ def render_displacement_chart(
202
+ displacements: dict,
203
+ seq: str = "",
204
+ coverage: np.ndarray = None,
205
+ ):
206
+ """Render interactive displacement profile chart using Plotly.
207
+
208
+ Args:
209
+ displacements: {mode_idx: np.ndarray of per-residue magnitudes}
210
+ seq: Amino acid sequence
211
+ coverage: Per-residue coverage
212
+ """
213
+ import plotly.graph_objects as go
214
+ from plotly.subplots import make_subplots
215
+
216
+ n_modes = len(displacements)
217
+ n_res = len(list(displacements.values())[0])
218
+ residues = np.arange(1, n_res + 1)
219
+
220
+ # Hover text with AA identity
221
+ hover_text = [f"{seq[i] if i < len(seq) else '?'}{i+1}" for i in range(n_res)]
222
+
223
+ fig = make_subplots(
224
+ rows=2, cols=1, row_heights=[0.75, 0.25],
225
+ shared_xaxes=True, vertical_spacing=0.08,
226
+ subplot_titles=["Displacement by Mode", "Coverage"]
227
+ )
228
+
229
+ colors = ["#6366f1", "#ef4444", "#10b981", "#f59e0b", "#ec4899", "#8b5cf6"]
230
+
231
+ for k, d in displacements.items():
232
+ mags = np.linalg.norm(d, axis=1) if d.ndim == 2 else d
233
+ fig.add_trace(go.Scatter(
234
+ x=residues, y=mags,
235
+ mode="lines",
236
+ name=f"Mode {k} (ฮผ={mags.mean():.3f}ร…)",
237
+ line=dict(color=colors[k % len(colors)], width=1.5),
238
+ fill="tozeroy",
239
+ fillcolor=colors[k % len(colors)].replace(")", ",0.1)").replace("rgb", "rgba"),
240
+ text=hover_text,
241
+ hovertemplate="%{text}<br>Displacement: %{y:.3f}ร…<extra>Mode " + str(k) + "</extra>",
242
+ ), row=1, col=1)
243
+
244
+ # Coverage
245
+ if coverage is not None:
246
+ fig.add_trace(go.Scatter(
247
+ x=residues, y=coverage[:n_res],
248
+ mode="lines",
249
+ name="Coverage",
250
+ line=dict(color="#94a3b8", width=1.5),
251
+ fill="tozeroy",
252
+ fillcolor="rgba(148,163,184,0.15)",
253
+ hovertemplate="%{x}<br>Coverage: %{y:.3f}<extra></extra>",
254
+ ), row=2, col=1)
255
+
256
+ fig.update_layout(
257
+ template="plotly_dark",
258
+ height=400,
259
+ paper_bgcolor="rgba(0,0,0,0)",
260
+ plot_bgcolor="rgba(30,27,75,0.5)",
261
+ legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
262
+ margin=dict(l=50, r=20, t=40, b=30),
263
+ )
264
+ fig.update_xaxes(title_text="Residue", row=2, col=1)
265
+ fig.update_yaxes(title_text="Displacement (ร…)", row=1, col=1)
266
+ fig.update_yaxes(title_text="Coverage", range=[0, 1.1], row=2, col=1)
267
+
268
+ st.plotly_chart(fig, use_container_width=True, key=f"disp_chart")
app/components/viewer_3d.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """3D protein motion visualization using py3Dmol via stmol."""
2
+ import streamlit as st
3
+ import numpy as np
4
+
5
+ try:
6
+ from stmol import showmol
7
+ import py3Dmol
8
+ HAS_STMOL = True
9
+ except ImportError:
10
+ HAS_STMOL = False
11
+
12
+
13
+ def render_motion_viewer(
14
+ pdb_text: str,
15
+ ca_coords: np.ndarray,
16
+ mode_vecs: np.ndarray,
17
+ seq: str = "",
18
+ amplitude: float = 3.0,
19
+ arrow_scale: float = 1.0,
20
+ color_scheme: str = "magnitude",
21
+ show_cartoon: bool = True,
22
+ show_labels: bool = True,
23
+ min_displacement: float = 0.01,
24
+ width: int = 800,
25
+ height: int = 500,
26
+ key: str = "viewer",
27
+ ):
28
+ """Render interactive 3D motion viewer with displacement arrows.
29
+
30
+ Args:
31
+ pdb_text: PDB file content as string
32
+ ca_coords: CA coordinates (N, 3)
33
+ mode_vecs: Displacement vectors for one mode (N, 3)
34
+ seq: Amino acid sequence (1-letter)
35
+ amplitude: Arrow length multiplier
36
+ arrow_scale: Arrow thickness multiplier
37
+ color_scheme: "magnitude"|"rainbow"|"bfactor"|"chain"|"residue_type"
38
+ show_cartoon: Show cartoon backbone
39
+ show_labels: Label top mobile residues
40
+ min_displacement: Hide arrows below this threshold
41
+ width, height: Viewer dimensions
42
+ key: Streamlit widget key
43
+ """
44
+ if not HAS_STMOL:
45
+ st.error("Install `stmol`: `pip install stmol`")
46
+ return
47
+
48
+ n_res = len(ca_coords)
49
+ mags = np.linalg.norm(mode_vecs, axis=1)
50
+ max_mag = mags.max() + 1e-8
51
+
52
+ view = py3Dmol.view(width=width, height=height)
53
+ view.addModel(pdb_text, "pdb")
54
+
55
+ # Backbone style
56
+ if show_cartoon:
57
+ view.setStyle({"cartoon": {"color": "#e2e8f0", "opacity": 0.45}})
58
+ else:
59
+ view.setStyle({"stick": {"radius": 0.06, "color": "#94a3b8"}})
60
+
61
+ # Draw displacement arrows
62
+ for i in range(n_res):
63
+ if mags[i] < min_displacement:
64
+ continue
65
+
66
+ s = ca_coords[i]
67
+ d = mode_vecs[i] * amplitude
68
+ e = s + d
69
+ t = mags[i] / max_mag # Normalized intensity [0, 1]
70
+
71
+ # Color assignment
72
+ col = _get_color(i, t, n_res, seq, color_scheme)
73
+
74
+ # Arrow shaft โ€” radius proportional to displacement
75
+ base_r = 0.08 * arrow_scale
76
+ shaft_r = base_r + base_r * t
77
+ view.addCylinder({
78
+ "start": {"x": float(s[0]), "y": float(s[1]), "z": float(s[2])},
79
+ "end": {"x": float(e[0]), "y": float(e[1]), "z": float(e[2])},
80
+ "radius": shaft_r,
81
+ "color": col,
82
+ "fromCap": True,
83
+ })
84
+
85
+ # Arrow tip (cone-like)
86
+ dn = d / (np.linalg.norm(d) + 1e-8)
87
+ tip = e + dn * 0.25 * amplitude * arrow_scale
88
+ tip_r = shaft_r * 2.2
89
+ view.addCylinder({
90
+ "start": {"x": float(e[0]), "y": float(e[1]), "z": float(e[2])},
91
+ "end": {"x": float(tip[0]), "y": float(tip[1]), "z": float(tip[2])},
92
+ "radius": tip_r,
93
+ "color": col,
94
+ "toCap": True,
95
+ })
96
+
97
+ # Label top-5 mobile residues
98
+ if show_labels and n_res > 0:
99
+ top_n = min(5, n_res)
100
+ top_idx = np.argsort(mags)[-top_n:][::-1]
101
+ for idx in top_idx:
102
+ if mags[idx] < min_displacement:
103
+ continue
104
+ pos = ca_coords[idx]
105
+ aa = seq[idx] if idx < len(seq) else "?"
106
+ view.addLabel(
107
+ f"{aa}{idx + 1}\n{mags[idx]:.2f}ร…",
108
+ {
109
+ "position": {"x": float(pos[0]), "y": float(pos[1] + 2.5), "z": float(pos[2])},
110
+ "fontSize": 11,
111
+ "fontColor": "white",
112
+ "backgroundColor": "#312e81",
113
+ "backgroundOpacity": 0.85,
114
+ "borderColor": "#6366f1",
115
+ "borderThickness": 1,
116
+ },
117
+ )
118
+
119
+ view.zoomTo()
120
+ showmol(view, height=height, width=width)
121
+
122
+
123
+ def _get_color(idx: int, intensity: float, n_res: int, seq: str, scheme: str) -> str:
124
+ """Get color for a residue based on the color scheme."""
125
+ if scheme == "magnitude":
126
+ # Blue โ†’ Purple โ†’ Red gradient
127
+ r = int(99 + 156 * intensity)
128
+ g = int(102 - 62 * intensity)
129
+ b = int(241 - 180 * intensity)
130
+ return f"rgb({r},{g},{b})"
131
+
132
+ elif scheme == "rainbow":
133
+ import colorsys
134
+ h = idx / max(n_res - 1, 1)
135
+ r, g, b = [int(255 * c) for c in colorsys.hsv_to_rgb(h, 0.85, 0.92)]
136
+ return f"rgb({r},{g},{b})"
137
+
138
+ elif scheme == "residue_type":
139
+ aa = seq[idx] if idx < len(seq) else "X"
140
+ hydrophobic = "AILMFWVP"
141
+ charged = "DEKRH"
142
+ polar = "STNQCY"
143
+ if aa in hydrophobic: return "#f59e0b"
144
+ if aa in charged: return "#ef4444"
145
+ if aa in polar: return "#10b981"
146
+ return "#94a3b8"
147
+
148
+ elif scheme == "bfactor":
149
+ r = int(255 * intensity)
150
+ g = int(100 * (1 - intensity))
151
+ b = int(50 * (1 - intensity))
152
+ return f"rgb({r},{g},{b})"
153
+
154
+ else:
155
+ return "#6366f1"
156
+
157
+
158
+ def render_mode_comparison(
159
+ pdb_text: str,
160
+ ca_coords: np.ndarray,
161
+ modes: dict,
162
+ seq: str = "",
163
+ amplitude: float = 3.0,
164
+ arrow_scale: float = 1.0,
165
+ width: int = 900,
166
+ height: int = 350,
167
+ ):
168
+ """Render side-by-side mode comparison grid."""
169
+ if not HAS_STMOL:
170
+ st.error("Install `stmol`")
171
+ return
172
+
173
+ n_modes = min(4, len(modes))
174
+ if n_modes == 0:
175
+ st.warning("No modes to display")
176
+ return
177
+
178
+ colors = ["#6366f1", "#ef4444", "#10b981", "#f59e0b"]
179
+
180
+ cols = st.columns(n_modes)
181
+ for k in range(n_modes):
182
+ with cols[k]:
183
+ vecs = modes[k]
184
+ mags = np.linalg.norm(vecs, axis=1)
185
+ st.caption(f"**Mode {k}** ยท ฮผ={mags.mean():.3f}ร… ยท max={mags.max():.3f}ร…")
186
+
187
+ view = py3Dmol.view(width=width // n_modes, height=height)
188
+ view.addModel(pdb_text, "pdb")
189
+ view.setStyle({"cartoon": {"color": "#e2e8f0", "opacity": 0.35}})
190
+
191
+ max_m = mags.max() + 1e-8
192
+ for i in range(len(ca_coords)):
193
+ if mags[i] < 0.01: continue
194
+ s = ca_coords[i]; d = vecs[i] * amplitude; e = s + d
195
+ t = mags[i] / max_m
196
+ view.addCylinder({
197
+ "start": {"x": float(s[0]), "y": float(s[1]), "z": float(s[2])},
198
+ "end": {"x": float(e[0]), "y": float(e[1]), "z": float(e[2])},
199
+ "radius": 0.08 * arrow_scale + 0.05 * t * arrow_scale,
200
+ "color": colors[k],
201
+ })
202
+ view.zoomTo()
203
+ showmol(view, height=height, width=width // n_modes)
app/pages/1_๐Ÿ”_Explorer.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """๐Ÿ” Database Explorer โ€” Browse pre-computed PETIMOT predictions."""
2
+ import streamlit as st
3
+ import os, sys
4
+ import numpy as np
5
+
6
+ # Imports
7
+ ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
8
+ PETIMOT_ROOT = os.path.dirname(ROOT)
9
+ if PETIMOT_ROOT not in sys.path:
10
+ sys.path.insert(0, PETIMOT_ROOT)
11
+
12
+ from app.utils.data_loader import (
13
+ find_predictions_dir, load_prediction_index, load_modes, load_ground_truth
14
+ )
15
+ from app.components.viewer_3d import render_motion_viewer, render_mode_comparison
16
+ from app.components.sequence_viewer import render_sequence_viewer, render_displacement_chart
17
+ from app.components.mode_panel import render_mode_correlation, render_eigenvalue_spectrum
18
+ from app.components.prediction_analysis import render_prediction_analysis
19
+
20
+ st.header("๐Ÿ” Database Explorer")
21
+
22
+ # โ”€โ”€ Find predictions โ”€โ”€
23
+ pred_dir = find_predictions_dir(PETIMOT_ROOT)
24
+ gt_dir = os.path.join(PETIMOT_ROOT, "ground_truth")
25
+
26
+ if not pred_dir:
27
+ st.error("No predictions found. Run inference first (notebook cell 4.3 or Inference page).")
28
+ st.stop()
29
+
30
+ # โ”€โ”€ Load index โ”€โ”€
31
+ with st.spinner("Loading prediction index..."):
32
+ df = load_prediction_index(pred_dir)
33
+
34
+ if df.empty:
35
+ st.warning("No predictions in index.")
36
+ st.stop()
37
+
38
+ st.success(f"**{len(df):,}** proteins indexed from `{os.path.basename(pred_dir)}`")
39
+
40
+ # โ”€โ”€ Filters โ”€โ”€
41
+ with st.expander("๐Ÿ”ง Filters", expanded=False):
42
+ col1, col2, col3 = st.columns(3)
43
+ with col1:
44
+ search = st.text_input("๐Ÿ” Search by name", "", placeholder="e.g. 1ake")
45
+ with col2:
46
+ len_range = st.slider("Sequence length", int(df.seq_len.min()), int(df.seq_len.max()),
47
+ (int(df.seq_len.min()), int(df.seq_len.max())))
48
+ with col3:
49
+ sort_by = st.selectbox("Sort by", ["name", "seq_len", "mean_disp_m0", "max_disp_m0"],
50
+ index=2)
51
+
52
+ mask = (df.seq_len >= len_range[0]) & (df.seq_len <= len_range[1])
53
+ if search:
54
+ mask &= df.name.str.contains(search, case=False, na=False)
55
+ df_filtered = df[mask].sort_values(sort_by, ascending=(sort_by == "name"))
56
+ st.markdown(f"Showing **{len(df_filtered)}** / {len(df)} proteins")
57
+
58
+ # โ”€โ”€ Table โ”€โ”€
59
+ selected_idx = st.dataframe(
60
+ df_filtered[["name", "seq_len", "n_modes", "mean_disp_m0", "max_disp_m0", "top_residue"]].rename(
61
+ columns={"name": "Protein", "seq_len": "Length", "n_modes": "Modes",
62
+ "mean_disp_m0": "Mean ฮ” (M0)", "max_disp_m0": "Max ฮ” (M0)", "top_residue": "Top Res"}
63
+ ),
64
+ use_container_width=True, hide_index=True,
65
+ on_select="rerun", selection_mode="single-row", height=350,
66
+ )
67
+
68
+ # โ”€โ”€ Protein detail panel โ”€โ”€
69
+ selected_rows = selected_idx.selection.rows if selected_idx.selection.rows else []
70
+ if not selected_rows:
71
+ st.info("๐Ÿ‘† Click a row to view detailed analysis")
72
+ st.stop()
73
+
74
+ protein_name = df_filtered.iloc[selected_rows[0]]["name"]
75
+ st.divider()
76
+ st.subheader(f"๐Ÿงฌ {protein_name}")
77
+
78
+ # Load data
79
+ modes = load_modes(pred_dir, protein_name)
80
+ gt = load_ground_truth(gt_dir, protein_name)
81
+
82
+ if not modes:
83
+ st.error(f"No mode files found for {protein_name}")
84
+ st.stop()
85
+
86
+ n_res = len(list(modes.values())[0])
87
+ seq = gt.get("seq", "X" * n_res) if gt else "X" * n_res
88
+ ca = gt["bb"][:, 1] if gt and "bb" in gt else np.zeros((n_res, 3))
89
+ coverage = gt.get("coverage", np.ones(n_res)) if gt else np.ones(n_res)
90
+ eigenvalues = gt.get("eigvals", None) if gt else None
91
+ pdb_text = None
92
+
93
+ pdb_path = os.path.join(PETIMOT_ROOT, "pdbs", f"{protein_name}.pdb")
94
+ if os.path.exists(pdb_path):
95
+ with open(pdb_path) as f:
96
+ pdb_text = f.read()
97
+
98
+ # โ”€โ”€ Sidebar controls โ”€โ”€
99
+ with st.sidebar:
100
+ st.divider()
101
+ st.markdown(f"### ๐ŸŽ›๏ธ {protein_name}")
102
+ mode_idx = st.slider("Mode", 0, len(modes) - 1, 0, key="mode_sel")
103
+ amplitude = st.slider("Arrow amplitude", 0.5, 15.0, 3.0, 0.5, key="amp")
104
+ arrow_scale = st.slider("Arrow thickness", 0.3, 3.0, 1.0, 0.1, key="arrow_s")
105
+ color_scheme = st.selectbox("Color", ["magnitude", "rainbow", "residue_type", "bfactor"], key="col_s")
106
+ show_labels = st.checkbox("Show labels", True, key="labels")
107
+ min_disp = st.slider("Min displacement", 0.0, 0.1, 0.01, 0.005, key="min_d")
108
+
109
+ # โ”€โ”€ 3D viewer + stats โ”€โ”€
110
+ col_3d, col_info = st.columns([2, 1])
111
+
112
+ with col_3d:
113
+ current_mode = modes[mode_idx]
114
+ mags = np.linalg.norm(current_mode, axis=1)
115
+
116
+ if pdb_text and ca is not None and np.any(ca != 0):
117
+ render_motion_viewer(
118
+ pdb_text=pdb_text, ca_coords=ca, mode_vecs=current_mode, seq=seq,
119
+ amplitude=amplitude, arrow_scale=arrow_scale, color_scheme=color_scheme,
120
+ show_labels=show_labels, min_displacement=min_disp,
121
+ width=700, height=480,
122
+ )
123
+ else:
124
+ st.info("No PDB structure โ€” showing displacement data only.")
125
+
126
+ with col_info:
127
+ st.metric("Residues", n_res)
128
+ st.metric(f"Mode {mode_idx} mean ฮ”", f"{mags.mean():.3f} ร…")
129
+ st.metric(f"Mode {mode_idx} max ฮ”", f"{mags.max():.3f} ร…")
130
+ top5 = np.argsort(mags)[-5:][::-1]
131
+ st.markdown("**Top mobile residues:**")
132
+ for rank, idx in enumerate(top5):
133
+ aa = seq[idx] if idx < len(seq) else "?"
134
+ st.markdown(f"`#{rank+1}` **{aa}{idx+1}** โ€” {mags[idx]:.3f} ร…")
135
+ if eigenvalues is not None:
136
+ render_eigenvalue_spectrum(eigenvalues)
137
+
138
+ # โ”€โ”€ Sequence viewer โ”€โ”€
139
+ st.markdown("### ๐Ÿงฌ Sequence ร— Displacement")
140
+ render_sequence_viewer(seq, mags, coverage, mode_label=f"Mode {mode_idx}")
141
+
142
+ # โ”€โ”€ Displacement chart โ”€โ”€
143
+ st.markdown("### ๐Ÿ“ˆ Displacement Profiles")
144
+ render_displacement_chart(modes, seq, coverage)
145
+
146
+ # โ”€โ”€ Mode comparison โ”€โ”€
147
+ st.markdown("### ๐Ÿ”€ Mode Comparison")
148
+ col_corr, col_grid = st.columns([1, 2])
149
+ with col_corr:
150
+ render_mode_correlation(modes)
151
+ with col_grid:
152
+ if pdb_text and np.any(ca != 0):
153
+ render_mode_comparison(pdb_text, ca, modes, seq, amplitude, arrow_scale)
154
+
155
+ # โ”€โ”€ Deep Analysis โ”€โ”€
156
+ st.divider()
157
+ st.markdown("### ๐Ÿ”ฌ Prediction Analysis")
158
+
159
+ gt_modes = None
160
+ if gt and "eigvects" in gt:
161
+ n_gt_modes = min(4, gt["eigvects"].shape[1] if gt["eigvects"].ndim == 2 else 4)
162
+ ev = gt["eigvects"][:, :n_gt_modes] # (3N, K)
163
+ ev = ev.reshape(-1, 3, n_gt_modes).transpose(0, 2, 1) # (N, K, 3)
164
+ gt_modes = {k: ev[:, k] for k in range(n_gt_modes)}
165
+
166
+ render_prediction_analysis(
167
+ modes=modes, seq=seq, ca_coords=ca, coverage=coverage,
168
+ eigenvalues=eigenvalues, gt_modes=gt_modes, protein_name=protein_name
169
+ )
app/pages/2_๐Ÿ”ฎ_Inference.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """๐Ÿ”ฎ Inference โ€” Predict motion for a new protein."""
2
+ import streamlit as st
3
+ import os, sys, tempfile
4
+ import numpy as np
5
+
6
+ ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
7
+ PETIMOT_ROOT = os.path.dirname(ROOT)
8
+ if PETIMOT_ROOT not in sys.path:
9
+ sys.path.insert(0, PETIMOT_ROOT)
10
+
11
+ from app.utils.inference import run_inference, download_pdb
12
+ from app.components.viewer_3d import render_motion_viewer, render_mode_comparison
13
+ from app.components.sequence_viewer import render_sequence_viewer, render_displacement_chart
14
+ from app.components.mode_panel import render_mode_correlation, render_eigenvalue_spectrum
15
+
16
+ st.header("๐Ÿ”ฎ Custom Inference")
17
+ st.markdown("Predict protein motion modes for any structure. Runs on CPU (~5-30s).")
18
+
19
+ # โ”€โ”€ Input method โ”€โ”€
20
+ input_mode = st.radio("Input method", ["PDB ID (RCSB)", "Upload PDB file"], horizontal=True)
21
+
22
+ pdb_path = None
23
+
24
+ if input_mode == "PDB ID (RCSB)":
25
+ col1, col2 = st.columns([3, 1])
26
+ with col1:
27
+ pdb_id = st.text_input("PDB ID", "1akeA", placeholder="e.g. 1akeA, 4akeA",
28
+ help="4-char PDB code + optional chain letter")
29
+ with col2:
30
+ st.markdown("<br>", unsafe_allow_html=True)
31
+ fetch = st.button("๐Ÿ” Fetch", use_container_width=True)
32
+
33
+ if fetch and pdb_id:
34
+ with st.spinner(f"Downloading {pdb_id} from RCSB..."):
35
+ pdb_path = download_pdb(pdb_id)
36
+ if pdb_path:
37
+ st.success(f"Downloaded {pdb_id}")
38
+ else:
39
+ st.error(f"Could not download {pdb_id}. Check the PDB ID.")
40
+
41
+ else:
42
+ uploaded = st.file_uploader("Upload PDB", type=["pdb"], key="pdb_upload")
43
+ if uploaded:
44
+ tmp = tempfile.NamedTemporaryFile(suffix=".pdb", delete=False)
45
+ tmp.write(uploaded.read())
46
+ tmp.close()
47
+ pdb_path = tmp.name
48
+ st.success(f"Uploaded: {uploaded.name}")
49
+
50
+ # โ”€โ”€ Weights selection โ”€โ”€
51
+ weights_path = st.session_state.get("weights", None)
52
+ if not weights_path:
53
+ weights_dir = os.path.join(PETIMOT_ROOT, "weights")
54
+ pt_files = []
55
+ if os.path.isdir(weights_dir):
56
+ for root, dirs, files in os.walk(weights_dir):
57
+ for f in files:
58
+ if f.endswith(".pt"):
59
+ pt_files.append(os.path.join(root, f))
60
+ if pt_files:
61
+ weights_path = pt_files[0]
62
+
63
+ if not weights_path:
64
+ st.error("No model weights found in `weights/`. Download them from Figshare.")
65
+ st.stop()
66
+
67
+ # โ”€โ”€ Run inference โ”€โ”€
68
+ if pdb_path:
69
+ if st.button("๐Ÿš€ Run PETIMOT Inference", type="primary", use_container_width=True):
70
+ with st.spinner("Running inference... (loading embeddings + forward pass)"):
71
+ try:
72
+ result = run_inference(pdb_path, weights_path)
73
+ except Exception as e:
74
+ st.error(f"Inference failed: {e}")
75
+ st.exception(e)
76
+ st.stop()
77
+
78
+ st.session_state["inference_result"] = result
79
+ st.success(f"โœ… Predicted {len(result['modes'])} modes for {result['name']} ({result['n_res']} residues)")
80
+
81
+ # โ”€โ”€ Display results โ”€โ”€
82
+ result = st.session_state.get("inference_result", None)
83
+ if result:
84
+ modes = result["modes"]
85
+ ca = result["ca_coords"]
86
+ seq = result["seq"]
87
+ pdb_text = result["pdb_text"]
88
+ n_res = result["n_res"]
89
+
90
+ st.divider()
91
+ st.subheader(f"๐Ÿงฌ {result['name']} โ€” {n_res} residues, {len(modes)} modes")
92
+
93
+ # Controls
94
+ with st.sidebar:
95
+ st.divider()
96
+ st.markdown(f"### ๐ŸŽ›๏ธ {result['name']}")
97
+ mode_idx = st.slider("Mode", 0, max(0, len(modes) - 1), 0, key="inf_mode")
98
+ amplitude = st.slider("Amplitude", 0.5, 15.0, 3.0, 0.5, key="inf_amp")
99
+ arrow_scale = st.slider("Arrow size", 0.3, 3.0, 1.0, 0.1, key="inf_arrow")
100
+ color_scheme = st.selectbox("Colors",
101
+ ["magnitude", "rainbow", "residue_type", "bfactor"],
102
+ key="inf_color")
103
+
104
+ # 3D viewer
105
+ current_mode = modes.get(mode_idx, list(modes.values())[0])
106
+ mags = np.linalg.norm(current_mode, axis=1)
107
+
108
+ col_3d, col_stats = st.columns([2, 1])
109
+ with col_3d:
110
+ render_motion_viewer(
111
+ pdb_text=pdb_text, ca_coords=ca, mode_vecs=current_mode, seq=seq,
112
+ amplitude=amplitude, arrow_scale=arrow_scale, color_scheme=color_scheme,
113
+ width=700, height=500, key="inf_viewer",
114
+ )
115
+ with col_stats:
116
+ st.metric("Residues", n_res)
117
+ st.metric(f"Mode {mode_idx} mean", f"{mags.mean():.3f} ร…")
118
+ st.metric(f"Mode {mode_idx} max", f"{mags.max():.3f} ร…")
119
+ top3 = np.argsort(mags)[-3:][::-1]
120
+ st.markdown("**Top mobile:**")
121
+ for idx in top3:
122
+ aa = seq[idx] if idx < len(seq) else "?"
123
+ st.markdown(f"**{aa}{idx+1}** โ€” {mags[idx]:.3f} ร…")
124
+
125
+ # Sequence viewer
126
+ st.markdown("### ๐Ÿงฌ Sequence ร— Displacement")
127
+ render_sequence_viewer(seq, mags, mode_label=f"Mode {mode_idx}")
128
+
129
+ # Displacement chart
130
+ st.markdown("### ๐Ÿ“ˆ Profiles")
131
+ render_displacement_chart(modes, seq)
132
+
133
+ # Mode comparison
134
+ if len(modes) > 1:
135
+ st.markdown("### ๐Ÿ”€ Mode Comparison")
136
+ render_mode_comparison(pdb_text, ca, modes, seq, amplitude, arrow_scale)
137
+ render_mode_correlation(modes)
138
+
139
+ # Export
140
+ st.divider()
141
+ st.markdown("### ๐Ÿ’พ Export")
142
+ col_e1, col_e2 = st.columns(2)
143
+ with col_e1:
144
+ # CSV export
145
+ import pandas as pd
146
+ export_data = []
147
+ for k, v in modes.items():
148
+ m = np.linalg.norm(v, axis=1)
149
+ for i in range(len(m)):
150
+ export_data.append({
151
+ "residue": i + 1,
152
+ "aa": seq[i] if i < len(seq) else "?",
153
+ "mode": k,
154
+ "dx": v[i, 0], "dy": v[i, 1], "dz": v[i, 2],
155
+ "magnitude": m[i],
156
+ })
157
+ csv = pd.DataFrame(export_data).to_csv(index=False)
158
+ st.download_button("๐Ÿ“ฅ Download modes (CSV)", csv,
159
+ f"{result['name']}_modes.csv", "text/csv")
160
+ with col_e2:
161
+ st.download_button("๐Ÿ“ฅ Download PDB", pdb_text,
162
+ f"{result['name']}.pdb", "chemical/x-pdb")
app/pages/3_๐Ÿ“Š_Statistics.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """๐Ÿ“Š Statistics โ€” Dataset-wide analysis of PETIMOT predictions."""
2
+ import streamlit as st
3
+ import os, sys
4
+ import numpy as np
5
+ import plotly.graph_objects as go
6
+ from plotly.subplots import make_subplots
7
+
8
+ ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
9
+ PETIMOT_ROOT = os.path.dirname(ROOT)
10
+ if PETIMOT_ROOT not in sys.path:
11
+ sys.path.insert(0, PETIMOT_ROOT)
12
+
13
+ from app.utils.data_loader import find_predictions_dir, load_prediction_index
14
+
15
+ st.header("๐Ÿ“Š Dataset Statistics")
16
+
17
+ pred_dir = find_predictions_dir(PETIMOT_ROOT)
18
+ if not pred_dir:
19
+ st.warning("No predictions found. Run batch inference first.")
20
+ st.stop()
21
+
22
+ with st.spinner("Building index..."):
23
+ df = load_prediction_index(pred_dir)
24
+
25
+ if df.empty:
26
+ st.warning("Empty index")
27
+ st.stop()
28
+
29
+ st.success(f"**{len(df):,}** proteins analyzed")
30
+
31
+ # โ”€โ”€ Summary metrics โ”€โ”€
32
+ c1, c2, c3, c4, c5 = st.columns(5)
33
+ c1.metric("Proteins", f"{len(df):,}")
34
+ c2.metric("Median length", f"{df.seq_len.median():.0f}")
35
+ c3.metric("Mean ฮ” (M0)", f"{df.mean_disp_m0.mean():.3f} ร…")
36
+ c4.metric("Max ฮ” (M0)", f"{df.max_disp_m0.max():.3f} ร…")
37
+ c5.metric("Avg modes", f"{df.n_modes.mean():.1f}")
38
+
39
+ # โ”€โ”€ Distribution plots โ”€โ”€
40
+ st.markdown("### ๐Ÿ“ˆ Distributions")
41
+
42
+ fig = make_subplots(rows=2, cols=2,
43
+ subplot_titles=[
44
+ "Sequence Length Distribution",
45
+ "Mean Mode-0 Displacement",
46
+ "Max Mode-0 Displacement",
47
+ "Length vs Mean Displacement",
48
+ ])
49
+
50
+ # 1. Sequence length
51
+ fig.add_trace(go.Histogram(
52
+ x=df.seq_len, nbinsx=60, marker_color="#6366f1", name="Length",
53
+ hovertemplate="Length: %{x}<br>Count: %{y}<extra></extra>",
54
+ ), row=1, col=1)
55
+ fig.add_vline(x=df.seq_len.median(), line_dash="dash", line_color="#ef4444",
56
+ annotation_text=f"Med={df.seq_len.median():.0f}", row=1, col=1)
57
+
58
+ # 2. Mean displacement
59
+ fig.add_trace(go.Histogram(
60
+ x=df.mean_disp_m0, nbinsx=60, marker_color="#10b981", name="Mean ฮ”",
61
+ hovertemplate="Mean ฮ”: %{x:.3f}ร…<br>Count: %{y}<extra></extra>",
62
+ ), row=1, col=2)
63
+
64
+ # 3. Max displacement
65
+ fig.add_trace(go.Histogram(
66
+ x=df.max_disp_m0, nbinsx=60, marker_color="#f59e0b", name="Max ฮ”",
67
+ hovertemplate="Max ฮ”: %{x:.3f}ร…<br>Count: %{y}<extra></extra>",
68
+ ), row=2, col=1)
69
+
70
+ # 4. Scatter: length vs displacement
71
+ fig.add_trace(go.Scattergl(
72
+ x=df.seq_len, y=df.mean_disp_m0,
73
+ mode="markers",
74
+ marker=dict(size=3, color=df.max_disp_m0, colorscale="Viridis",
75
+ showscale=True, colorbar=dict(title="Max ฮ”")),
76
+ name="Proteins",
77
+ hovertemplate="%{text}<br>Length: %{x}<br>Mean ฮ”: %{y:.3f}ร…<extra></extra>",
78
+ text=df.name,
79
+ ), row=2, col=2)
80
+
81
+ fig.update_layout(
82
+ template="plotly_dark",
83
+ height=600,
84
+ showlegend=False,
85
+ paper_bgcolor="rgba(0,0,0,0)",
86
+ plot_bgcolor="rgba(30,27,75,0.5)",
87
+ margin=dict(l=40, r=40, t=40, b=30),
88
+ )
89
+ st.plotly_chart(fig, use_container_width=True)
90
+
91
+ # โ”€โ”€ Ground truth analysis (if available) โ”€โ”€
92
+ gt_dir = os.path.join(PETIMOT_ROOT, "ground_truth")
93
+ if os.path.isdir(gt_dir) and len(os.listdir(gt_dir)) > 0:
94
+ st.markdown("### ๐ŸŽฏ Ground Truth Analysis")
95
+ st.info("Loading eigenvalue statistics from ground truth...")
96
+
97
+ import torch, glob
98
+ gt_files = sorted(glob.glob(os.path.join(gt_dir, "*.pt")))[:2000]
99
+
100
+ eigendata = []
101
+ for gf in gt_files:
102
+ try:
103
+ d = torch.load(gf, map_location="cpu", weights_only=True)
104
+ ev = d["eigvals"].numpy() if "eigvals" in d else None
105
+ cov = d["coverage"].numpy() if "coverage" in d else None
106
+ if ev is not None:
107
+ eigendata.append({
108
+ "name": os.path.basename(gf).replace(".pt", ""),
109
+ "n_res": len(d["bb"]),
110
+ "eigval_0": float(ev[0]),
111
+ "eigval_ratio": float(ev[0] / (ev[1] + 1e-12)) if len(ev) > 1 else 0,
112
+ "mode1_var_frac": float(ev[0] / (ev.sum() + 1e-12)),
113
+ "mean_coverage": float(cov.mean()) if cov is not None else 1.0,
114
+ })
115
+ except Exception:
116
+ continue
117
+
118
+ if eigendata:
119
+ import pandas as pd
120
+ df_gt = pd.DataFrame(eigendata)
121
+
122
+ fig2 = make_subplots(rows=1, cols=3,
123
+ subplot_titles=["Dominant Eigenvalue (ฮปโ‚)", "Mode Dominance (ฮปโ‚/ฮปโ‚‚)", "Mode 1 Variance %"])
124
+
125
+ fig2.add_trace(go.Histogram(
126
+ x=df_gt.eigval_0, nbinsx=60, marker_color="#ec4899", name="ฮปโ‚",
127
+ ), row=1, col=1)
128
+
129
+ fig2.add_trace(go.Histogram(
130
+ x=df_gt.eigval_ratio, nbinsx=60, marker_color="#8b5cf6", name="ฮปโ‚/ฮปโ‚‚",
131
+ ), row=1, col=2)
132
+
133
+ fig2.add_trace(go.Histogram(
134
+ x=df_gt.mode1_var_frac * 100, nbinsx=50, marker_color="#06b6d4", name="Var %",
135
+ ), row=1, col=3)
136
+
137
+ fig2.update_layout(
138
+ template="plotly_dark", height=300, showlegend=False,
139
+ paper_bgcolor="rgba(0,0,0,0)",
140
+ plot_bgcolor="rgba(30,27,75,0.5)",
141
+ margin=dict(l=40, r=20, t=40, b=30),
142
+ )
143
+ st.plotly_chart(fig2, use_container_width=True)
144
+
145
+ # Summary
146
+ st.markdown(f"""
147
+ | Metric | Value |
148
+ |--------|-------|
149
+ | Samples analyzed | {len(df_gt):,} |
150
+ | ฮปโ‚ median | {df_gt.eigval_0.median():.4f} |
151
+ | ฮปโ‚/ฮปโ‚‚ median | {df_gt.eigval_ratio.median():.2f} |
152
+ | Mode 1 variance | {df_gt.mode1_var_frac.median()*100:.1f}% median |
153
+ | Coverage | {df_gt.mean_coverage.mean():.3f} ยฑ {df_gt.mean_coverage.std():.3f} |
154
+ """)
155
+
156
+ # โ”€โ”€ Data table (searchable) โ”€โ”€
157
+ st.markdown("### ๐Ÿ“‹ Full Data Table")
158
+ st.dataframe(df, use_container_width=True, height=400)
app/requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PETIMOT Streamlit Explorer
2
+ streamlit>=1.30.0
3
+ stmol>=0.0.9
4
+ py3Dmol>=2.0.0
5
+ plotly>=5.18.0
6
+ pandas>=2.0.0
7
+ numpy>=1.24.0
8
+ torch>=2.0.0
9
+ torch_geometric>=2.0.0
10
+ transformers>=4.30.0
11
+ sentencepiece>=0.1.99
12
+ scipy>=1.10.0
13
+ biopython>=1.80
14
+ requests>=2.28.0
15
+ tqdm>=4.65.0
app/utils/__init__.py ADDED
File without changes
app/utils/data_loader.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data loading utilities for pre-computed PETIMOT predictions."""
2
+ import os, json, glob, torch
3
+ import numpy as np
4
+ import pandas as pd
5
+ from pathlib import Path
6
+ from functools import lru_cache
7
+
8
+
9
+ def find_predictions_dir(root: str) -> str | None:
10
+ """Find the predictions directory (most recent model)."""
11
+ pred_root = os.path.join(root, "predictions")
12
+ if not os.path.isdir(pred_root):
13
+ return None
14
+ subdirs = [os.path.join(pred_root, d) for d in os.listdir(pred_root)
15
+ if os.path.isdir(os.path.join(pred_root, d))]
16
+ if not subdirs:
17
+ return None
18
+ return max(subdirs, key=os.path.getmtime)
19
+
20
+
21
+ @lru_cache(maxsize=1)
22
+ def load_prediction_index(pred_dir: str) -> pd.DataFrame:
23
+ """Build index of all predicted proteins with metadata."""
24
+ rows = []
25
+ mode_files = glob.glob(os.path.join(pred_dir, "*_mode_0.txt"))
26
+
27
+ for mf in mode_files:
28
+ base = os.path.basename(mf).replace("_mode_0.txt", "")
29
+ # Load mode 0 for stats
30
+ try:
31
+ vecs = np.loadtxt(mf)
32
+ n_res = len(vecs)
33
+ mag = np.linalg.norm(vecs, axis=1)
34
+
35
+ # Count available modes
36
+ n_modes = 0
37
+ for k in range(10):
38
+ if os.path.exists(os.path.join(pred_dir, f"{base}_mode_{k}.txt")):
39
+ n_modes += 1
40
+ else:
41
+ break
42
+
43
+ rows.append({
44
+ "name": base,
45
+ "seq_len": n_res,
46
+ "n_modes": n_modes,
47
+ "mean_disp_m0": float(mag.mean()),
48
+ "max_disp_m0": float(mag.max()),
49
+ "top_residue": int(np.argmax(mag)) + 1,
50
+ })
51
+ except Exception:
52
+ continue
53
+
54
+ return pd.DataFrame(rows).sort_values("name").reset_index(drop=True)
55
+
56
+
57
+ def load_modes(pred_dir: str, name: str) -> dict[int, np.ndarray]:
58
+ """Load all mode files for a protein."""
59
+ modes = {}
60
+ for k in range(10):
61
+ for pfx in [f"extracted_{name}", name]:
62
+ mf = os.path.join(pred_dir, f"{pfx}_mode_{k}.txt")
63
+ if os.path.exists(mf):
64
+ modes[k] = np.loadtxt(mf)
65
+ break
66
+ return modes
67
+
68
+
69
+ def load_ground_truth(gt_dir: str, name: str) -> dict | None:
70
+ """Load ground truth data for a protein."""
71
+ path = os.path.join(gt_dir, f"{name}.pt")
72
+ if not os.path.exists(path):
73
+ return None
74
+ try:
75
+ data = torch.load(path, map_location="cpu", weights_only=True)
76
+ return {k: v.numpy() if isinstance(v, torch.Tensor) else v
77
+ for k, v in data.items()}
78
+ except Exception:
79
+ return None
80
+
81
+
82
+ def load_pdb_text(pdb_path: str) -> str | None:
83
+ """Load PDB file as text."""
84
+ if not os.path.exists(pdb_path):
85
+ return None
86
+ with open(pdb_path) as f:
87
+ return f.read()
app/utils/download.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Auto-download PETIMOT data from Figshare on first run."""
2
+ import os, zipfile, requests
3
+ from pathlib import Path
4
+ from tqdm import tqdm
5
+
6
+
7
+ FIGSHARE_PRIVATE_KEY = "ab400d852b4669a83b64"
8
+ FIGSHARE_FILES = {
9
+ "ground_truth.zip": "52349453",
10
+ "default_2025-02-07_21-54-02_epoch_33.pt": "52349456",
11
+ "baseline_predictions.zip": "52349480",
12
+ }
13
+
14
+
15
+ def download_file(url: str, dest: str, desc: str = "") -> bool:
16
+ """Download a file with progress bar."""
17
+ try:
18
+ r = requests.get(url, stream=True, allow_redirects=True, timeout=60)
19
+ r.raise_for_status()
20
+ total = int(r.headers.get("content-length", 0))
21
+ with open(dest, "wb") as f:
22
+ with tqdm(total=total, unit="B", unit_scale=True, desc=desc) as pbar:
23
+ for chunk in r.iter_content(8192):
24
+ f.write(chunk)
25
+ pbar.update(len(chunk))
26
+ if os.path.getsize(dest) < 1000:
27
+ os.remove(dest)
28
+ return False
29
+ return True
30
+ except Exception as e:
31
+ print(f"Download failed: {e}")
32
+ return False
33
+
34
+
35
+ def ensure_weights(root: str) -> str | None:
36
+ """Ensure model weights are available. Returns path to weights or None."""
37
+ weights_dir = os.path.join(root, "weights")
38
+ os.makedirs(weights_dir, exist_ok=True)
39
+
40
+ # Check for existing weights
41
+ for f in os.listdir(weights_dir):
42
+ if f.endswith(".pt"):
43
+ return os.path.join(weights_dir, f)
44
+
45
+ # Try downloading from Figshare
46
+ wt_name = "default_2025-02-07_21-54-02_epoch_33.pt"
47
+ wt_path = os.path.join(weights_dir, wt_name)
48
+ fid = FIGSHARE_FILES[wt_name]
49
+ url = f"https://figshare.com/ndownloader/files/{fid}?private_link={FIGSHARE_PRIVATE_KEY}"
50
+
51
+ print(f"โฌ‡๏ธ Downloading model weights ({wt_name})...")
52
+ if download_file(url, wt_path, "weights"):
53
+ print(f"โœ… Weights saved to {wt_path}")
54
+ return wt_path
55
+
56
+ # Try Figshare API
57
+ try:
58
+ api_url = f"https://api.figshare.com/v2/articles/28679143/files"
59
+ r = requests.get(api_url, timeout=10)
60
+ if r.ok:
61
+ for f in r.json():
62
+ if "epoch" in f["name"] and f["name"].endswith(".pt"):
63
+ if download_file(f["download_url"], wt_path, "weights"):
64
+ return wt_path
65
+ except:
66
+ pass
67
+
68
+ return None
69
+
70
+
71
+ def ensure_ground_truth(root: str) -> bool:
72
+ """Ensure ground truth data is available."""
73
+ gt_dir = os.path.join(root, "ground_truth")
74
+ os.makedirs(gt_dir, exist_ok=True)
75
+
76
+ if len(list(Path(gt_dir).rglob("*.pt"))) > 0:
77
+ return True
78
+
79
+ # Try downloading
80
+ zip_path = os.path.join(root, "ground_truth.zip")
81
+ fid = FIGSHARE_FILES["ground_truth.zip"]
82
+ url = f"https://figshare.com/ndownloader/files/{fid}?private_link={FIGSHARE_PRIVATE_KEY}"
83
+
84
+ print(f"โฌ‡๏ธ Downloading ground truth (958 MB)...")
85
+ if download_file(url, zip_path, "ground_truth"):
86
+ print("๐Ÿ“ฆ Extracting...")
87
+ with zipfile.ZipFile(zip_path) as z:
88
+ z.extractall(root)
89
+ os.remove(zip_path)
90
+ return True
91
+ return False
92
+
93
+
94
+ def check_data_status(root: str) -> dict:
95
+ """Check what data is available."""
96
+ gt_dir = os.path.join(root, "ground_truth")
97
+ weights_dir = os.path.join(root, "weights")
98
+ pred_dir = os.path.join(root, "predictions")
99
+
100
+ n_gt = len(list(Path(gt_dir).rglob("*.pt"))) if os.path.isdir(gt_dir) else 0
101
+
102
+ n_weights = 0
103
+ if os.path.isdir(weights_dir):
104
+ n_weights = len([f for f in os.listdir(weights_dir) if f.endswith(".pt")])
105
+
106
+ n_pred = 0
107
+ if os.path.isdir(pred_dir):
108
+ for d in os.listdir(pred_dir):
109
+ dp = os.path.join(pred_dir, d)
110
+ if os.path.isdir(dp):
111
+ n_pred = len([f for f in os.listdir(dp) if f.endswith("_mode_0.txt")])
112
+ break
113
+
114
+ return {
115
+ "ground_truth": n_gt,
116
+ "weights": n_weights,
117
+ "predictions": n_pred,
118
+ "has_weights": n_weights > 0,
119
+ "has_gt": n_gt > 0,
120
+ "has_predictions": n_pred > 0,
121
+ }
app/utils/inference.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PETIMOT inference utilities for custom proteins."""
2
+ import os, sys, torch
3
+ import numpy as np
4
+ from pathlib import Path
5
+
6
+ # Ensure PETIMOT is importable
7
+ PETIMOT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8
+ if PETIMOT_ROOT not in sys.path:
9
+ sys.path.insert(0, PETIMOT_ROOT)
10
+
11
+ EMBEDDING_DIM_MAP = {"prostt5": 1024, "esmc_300m": 960, "esmc_600m": 1152}
12
+
13
+
14
+ def run_inference(pdb_path: str, weights_path: str, config_path: str = None,
15
+ output_dir: str = "/tmp/petimot_pred") -> dict:
16
+ """Run PETIMOT inference on a single PDB file.
17
+
18
+ Args:
19
+ pdb_path: Path to input PDB file
20
+ weights_path: Path to model weights .pt
21
+ config_path: Path to config YAML (default: configs/default.yaml)
22
+ output_dir: Where to save predictions
23
+
24
+ Returns:
25
+ dict with modes, ca_coords, seq, etc.
26
+ """
27
+ from petimot.infer.infer import infer
28
+ from petimot.data.pdb_utils import load_backbone_coordinates
29
+
30
+ if config_path is None:
31
+ config_path = os.path.join(PETIMOT_ROOT, "configs", "default.yaml")
32
+
33
+ os.makedirs(output_dir, exist_ok=True)
34
+
35
+ # Run inference
36
+ infer(model_path=weights_path, config_file=config_path,
37
+ input_list=[pdb_path], output_path=output_dir)
38
+
39
+ # Collect results
40
+ stem = os.path.splitext(os.path.basename(weights_path))[0]
41
+ pred_subdir = os.path.join(output_dir, stem)
42
+ basename = os.path.splitext(os.path.basename(pdb_path))[0]
43
+
44
+ # Load structure
45
+ bb_data = load_backbone_coordinates(pdb_path, allow_hetatm=True)
46
+ ca = bb_data["bb"][:, 1].numpy()
47
+ seq = bb_data.get("seq", "X" * len(ca))
48
+ if not isinstance(seq, str):
49
+ seq = "X" * len(ca)
50
+
51
+ # Load predicted modes
52
+ modes = {}
53
+ for k in range(10):
54
+ for pfx in [f"extracted_{basename}", basename]:
55
+ mf = os.path.join(pred_subdir, f"{pfx}_mode_{k}.txt")
56
+ if os.path.exists(mf):
57
+ modes[k] = np.loadtxt(mf)
58
+ break
59
+
60
+ with open(pdb_path) as f:
61
+ pdb_text = f.read()
62
+
63
+ return {
64
+ "name": basename,
65
+ "ca_coords": ca,
66
+ "seq": seq,
67
+ "modes": modes,
68
+ "pdb_text": pdb_text,
69
+ "pred_dir": pred_subdir,
70
+ "n_res": len(ca),
71
+ }
72
+
73
+
74
+ def download_pdb(pdb_id: str, output_dir: str = "/tmp/petimot_pdbs") -> str | None:
75
+ """Download PDB from RCSB."""
76
+ import requests
77
+
78
+ os.makedirs(output_dir, exist_ok=True)
79
+ code4 = pdb_id[:4].lower()
80
+ chain = pdb_id[4:].upper() if len(pdb_id) > 4 else ""
81
+ out_path = os.path.join(output_dir, f"{pdb_id}.pdb")
82
+
83
+ if os.path.exists(out_path):
84
+ return out_path
85
+
86
+ r = requests.get(f"https://files.rcsb.org/download/{code4}.pdb", timeout=30)
87
+ if not r.ok:
88
+ return None
89
+
90
+ lines = r.text.split("\n")
91
+ if chain:
92
+ lines = [l for l in lines
93
+ if (l.startswith("ATOM") and len(l) > 21 and l[21] == chain)
94
+ or not l.startswith(("ATOM", "HETATM"))]
95
+
96
+ with open(out_path, "w") as f:
97
+ f.write("\n".join(lines))
98
+ return out_path
claude.md ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PETIMOT โ€” Development State
2
+
3
+ > Last updated: 2026-03-20
4
+ > Maintainer: Valentin Lombard (valou82160@gmail.com)
5
+
6
+ ## Project Overview
7
+
8
+ **PETIMOT** = SE(3)-equivariant GNN for Protein Motion prediction from sparse data.
9
+ - Paper: [arXiv 2504.02839](https://arxiv.org/abs/2504.02839) โ€” Lombard, Grudinin & Laine
10
+ - Public repo: `PhyloSofS-Team/PETIMOT`
11
+ - Private repos: `Vlmbd/petimot_temp`, `Vlmbd/petimot` (contain training code, identical to each other)
12
+
13
+ ## Architecture
14
+
15
+ - **Model**: `ProteinMotionMPNN` in `petimot/model/neural_net.py`
16
+ - SE(3)-equivariant message passing neural network
17
+ - Input: PLM embeddings (ProstT5 1024d, or ESM 960d/1152d) + KNN graph
18
+ - Output: K motion modes as (N, 3) displacement vectors per protein
19
+ - 15 independent layers, ~4.7M params (default config)
20
+
21
+ - **Loss functions** in `petimot/model/loss.py`:
22
+ - `compute_nsse_loss`: Normalized SSE with Hungarian assignment (per-mode matching)
23
+ - `compute_rmsip_loss`: Root Mean Square Inner Product (subspace similarity)
24
+ - `compute_ortho_loss`: Independence Score (orthogonality regularizer)
25
+ - Default weights: 0.5ร—NSSE + 0.5ร—RMSIP + 0.0ร—IS
26
+
27
+ - **Data** in `petimot/data/`:
28
+ - `BaseDataset` / `InferenceDataset` in `data_set.py` (public repo)
29
+ - `TrainingDataset` in private repos โ€” extends BaseDataset with ground truth `.pt` loading
30
+ - Ground truth `.pt` format: `{eigvects: (3N,K), eigvals: (K,), bb: (N,4,3), seq: str, coverage: (N,)}`
31
+ - Embeddings: cached as `{name}_{model}.pt` files in `embeddings/` directory
32
+
33
+ ## Repository Structure
34
+
35
+ ```
36
+ PETIMOT/
37
+ โ”œโ”€โ”€ configs/default.yaml # Default hyperparameters
38
+ โ”œโ”€โ”€ petimot/
39
+ โ”‚ โ”œโ”€โ”€ __main__.py # CLI: infer, evaluate (+ train in private repo)
40
+ โ”‚ โ”œโ”€โ”€ model/
41
+ โ”‚ โ”‚ โ”œโ”€โ”€ neural_net.py # ProteinMotionMPNN
42
+ โ”‚ โ”‚ โ”œโ”€โ”€ loss.py # NSSE, RMSIP, IS losses
43
+ โ”‚ โ”‚ โ””โ”€โ”€ optimizer.py # get_optimizer factory
44
+ โ”‚ โ”œโ”€โ”€ data/
45
+ โ”‚ โ”‚ โ”œโ”€โ”€ data_set.py # BaseDataset, InferenceDataset
46
+ โ”‚ โ”‚ โ”œโ”€โ”€ embeddings.py # EmbeddingManager (ProstT5/ESM)
47
+ โ”‚ โ”‚ โ””โ”€โ”€ pdb_utils.py # load_backbone_coordinates
48
+ โ”‚ โ”œโ”€โ”€ infer/infer.py # Inference pipeline
49
+ โ”‚ โ”œโ”€โ”€ eval/eval.py # Evaluation with metrics
50
+ โ”‚ โ””โ”€โ”€ utils/
51
+ โ”‚ โ”œโ”€โ”€ seeding.py # set_seed (reproducibility)
52
+ โ”‚ โ””โ”€โ”€ rigid_utils.py # Quaternion/rotation utilities
53
+ โ”œโ”€โ”€ full_train_list.txt # ~26K training samples
54
+ โ”œโ”€โ”€ full_val_list.txt # ~5K validation samples
55
+ โ”œโ”€โ”€ eval_list.txt # Test set
56
+ โ”œโ”€โ”€ split_script.py # Family-based split generation
57
+ โ”œโ”€โ”€ PETIMOT_workflow.ipynb # โญ Main deliverable โ€” Colab notebook
58
+ โ””โ”€โ”€ weights/ # Pretrained models
59
+ ```
60
+
61
+ ## Colab Notebook (`PETIMOT_workflow.ipynb`)
62
+
63
+ ### Current State: โœ… Functional, iterating on polish
64
+
65
+ The notebook is the primary deliverable. It's built via Python patch scripts in `/tmp/` and contains 37 cells across 9 sections:
66
+
67
+ | # | Section | Status | Notes |
68
+ |---|---------|--------|-------|
69
+ | 0 | Setup | โœ… | Install (no torch reinstall), GPU check, Drive mount, WandB |
70
+ | 1 | Data | โœ… | Figshare manual download + auto-extract, rich dataset stats (6 panels) |
71
+ | 2 | Training | โœ… | Full loop with gradient norms, ETA, best/last checkpoints |
72
+ | 3 | Monitoring | โœ… | Plotly dashboard, per-sample validation analysis |
73
+ | 4 | Inference | โœ… | Single PDB, batch, upload, auto-detect weights |
74
+ | 5 | Visualization | โœ… | 5-panel analysis dashboard, 3D py3Dmol, animation, mode grid |
75
+ | 6 | Trajectory Export | โœ… | Multi-model PDB for PyMOL/ChimeraX |
76
+ | 7 | Evaluation | โœ… | Test set eval, CSV export, baseline comparison |
77
+ | 8 | Ablations | โœ… | 10 config presets, comparison plots |
78
+
79
+ ### Known Issues & Workarounds
80
+
81
+ 1. **Figshare download**: Private link URLs don't support programmatic download (wget/curl/requests all fail). Solution: manual browser download + Colab upload. Cell 1.1 auto-detects and extracts uploads.
82
+
83
+ 2. **numpy binary incompatibility**: Installing packages can break numpy. Solution: after cell 0.1, do Runtime โ†’ Restart session, then skip to 0.2.
84
+
85
+ 3. **`torch.linalg.eigh` cusolver error**: Happens with AMP (FP16) in `rigid_utils.py` quaternion computation. Solution: cell 2.4 sets `torch.backends.cuda.preferred_linalg_library("magma")`.
86
+
87
+ 4. **Split file mismatch**: Config references `val_list.txt` but repo has `full_val_list.txt`. Cell 2.2 auto-detects the correct files.
88
+
89
+ 5. **ProstT5 embedding computation**: Takes ~10min on A100, longer on T4. Embeddings are cached in `embeddings/` (symlinked to Drive if mounted). First run is slow, subsequent runs are fast.
90
+
91
+ 6. **Batch size**: Default is 16 (too small for large GPUs). On RTX PRO 6000 (96GB), use 128-256.
92
+
93
+ ### How the Notebook is Built
94
+
95
+ The notebook is modified via Python scripts that:
96
+ 1. Load the `.ipynb` JSON
97
+ 2. Find cells by title (e.g., `'5.1' in line`)
98
+ 3. Replace source with new code, ensuring proper `\n` line terminators
99
+ 4. Save back to disk
100
+
101
+ **Critical**: Each line in `source` array MUST end with `\n` (except the last line) for Colab to execute it. This was a major early bug.
102
+
103
+ Example patch script pattern:
104
+ ```python
105
+ import json
106
+ with open('PETIMOT_workflow.ipynb') as f:
107
+ nb = json.load(f)
108
+ for i, c in enumerate(nb['cells']):
109
+ if any('CELL_ID' in l for l in c['source']):
110
+ lines = new_src.split("\n")
111
+ nb['cells'][i]['source'] = [l + "\n" for l in lines[:-1]] + [lines[-1]]
112
+ break
113
+ json.dump(nb, open('PETIMOT_workflow.ipynb', 'w'), indent=1)
114
+ ```
115
+
116
+ ## Training Details (from private repo)
117
+
118
+ - **TrainingDataset** (cell 2.1):
119
+ - Loads `.pt` ground truth files listed in a text file
120
+ - Eigenvectors reshaped: `(3N, K)` โ†’ `(N, 3, K)` โ†’ `(N, K, 3)`, scaled by `โˆšN`
121
+ - Multiplicative Gaussian noise on eigvects + embeddings for augmentation
122
+ - Random embeddings option for ablation (`rand_emb=True`)
123
+
124
+ - **process_epoch** (cell 2.3):
125
+ - `set_grad_enabled(training)` per loss component
126
+ - AMP with GradScaler, gradient clipping at 10
127
+ - `optimizer.zero_grad(set_to_none=True)` for speed
128
+ - Tracks: NSSE, min_NSSE, RMSIP, ortho, success rate, gradient norms
129
+
130
+ - **train_petimot** (cell 2.3):
131
+ - AdamW + ReduceLROnPlateau (factor=0.5, patience from config)
132
+ - Saves `best.pt` + `last.pt` in `weights/{run_name}/`
133
+ - Auto-loads best model after training
134
+ - Optional WandB logging, optional resume from checkpoint
135
+
136
+ ## Data Sources
137
+
138
+ - **Figshare** (private link): https://figshare.com/s/ab400d852b4669a83b64
139
+ - `ground_truth.zip` (958 MB, ~36K `.pt` files)
140
+ - `default_2025-02-07_21-54-02_epoch_33.pt` (18 MB, pretrained weights)
141
+ - `baseline_predictions.zip` (23 MB โ€” AlphaFlow, ESMFlow, NMA)
142
+ - File IDs: ground_truth=52349453, weights=52349456, baselines=52349480
143
+
144
+ - **Local** (user's machine):
145
+ - `/Users/valentin/Documents/Petimot/` โ€” public repo clone
146
+ - `/Users/valentin/Documents/petimot_private/` โ€” private repo clone
147
+ - `/Users/valentin/Documents/petimot_temp/` โ€” private repo (has weights .pt files)
148
+
149
+ ## Target Audience
150
+
151
+ The notebook is designed for **ML/bioinformatics professors and researchers**. Key design decisions:
152
+ - Rich, publication-quality visualizations (not toy demos)
153
+ - Extensive inline comments and docstrings with tensor shapes
154
+ - Interactive Colab form controls for parameters
155
+ - Auto-detect/auto-extract for user-friendly data setup
156
+ - All 8 sections are independent after setup (can run inference without training)
157
+
158
+ ## Streamlit App (`app/`)
159
+
160
+ ### Current State: โœ… Built, needs testing
161
+
162
+ Interactive web explorer for PETIMOT predictions. Replaces notebook sections 5-8.
163
+
164
+ ```
165
+ app/
166
+ โ”œโ”€โ”€ app.py # Main entry + sidebar + dark theme
167
+ โ”œโ”€โ”€ requirements.txt # Dependencies
168
+ โ”œโ”€โ”€ pages/
169
+ โ”‚ โ”œโ”€โ”€ 1_๐Ÿ”_Explorer.py # Browse pre-computed DB (search, filter, 3D, sequence)
170
+ โ”‚ โ”œโ”€โ”€ 2_๐Ÿ”ฎ_Inference.py # Upload PDB โ†’ predict โ†’ visualize + export
171
+ โ”‚ โ””โ”€โ”€ 3_๐Ÿ“Š_Statistics.py # Dataset-wide distributions + eigenvalue analysis
172
+ โ”œโ”€โ”€ components/
173
+ โ”‚ โ”œโ”€โ”€ viewer_3d.py # py3Dmol with arrows, labels, mode comparison grid
174
+ โ”‚ โ”œโ”€โ”€ sequence_viewer.py # HTML sequence heatmap + Plotly displacement chart
175
+ โ”‚ โ””โ”€โ”€ mode_panel.py # Mode tabs, correlation matrix, eigenvalue spectrum
176
+ โ””โ”€โ”€ utils/
177
+ โ”œโ”€โ”€ data_loader.py # Load predictions index, modes, ground truth
178
+ โ””โ”€โ”€ inference.py # PETIMOT inference wrapper + RCSB PDB download
179
+ ```
180
+
181
+ **Run locally:** `cd PETIMOT && streamlit run app/app.py`
182
+
183
+ **Key features:**
184
+ - Dark theme with purple accent palette
185
+ - Real-time sliders for amplitude, arrow size, color scheme
186
+ - Searchable/sortable protein table with click-to-detail
187
+ - HTML sequence viewer with displacement-colored cells + hover tooltips
188
+ - Plotly charts (displacement profiles, eigenvalue spectrum, correlations)
189
+ - CSV + PDB export for inference results
190
+ - PDB ID fetch from RCSB or file upload
191
+
192
+ **Dependencies:** stmol (py3Dmol for Streamlit), plotly, streamlit
193
+
194
+ **Needs:** Pre-computed predictions in `predictions/` directory for Explorer page to work.
195
+
196
+ ## Dependencies
197
+
198
+ ```
199
+ torch>=2.0.0, torch_geometric, torch_scatter, torch_sparse
200
+ transformers==4.48.3, sentencepiece==0.2.0
201
+ scipy, typer==0.15.1, tqdm, numpy
202
+ wandb, plotly, py3Dmol, biopython, pandas, ipywidgets, gdown, requests
203
+ ```
requirements.txt CHANGED
@@ -1,9 +1,18 @@
 
 
 
 
 
 
 
1
  torch>=2.0.0
2
  torch_geometric>=2.0.0
3
- wandb>=0.19.0
4
- transformers==4.48.3
5
- sentencepiece==0.2.0
 
 
 
 
6
  tqdm>=4.65.0
7
- scipy>=1.13.0
8
- typer==0.15.1
9
- numpy<2
 
1
+ # PETIMOT Explorer โ€” HuggingFace Spaces
2
+ streamlit>=1.30.0
3
+ stmol>=0.0.9
4
+ py3Dmol>=2.0.0
5
+ plotly>=5.18.0
6
+ pandas>=2.0.0
7
+ numpy>=1.24.0
8
  torch>=2.0.0
9
  torch_geometric>=2.0.0
10
+ torch_scatter
11
+ torch_sparse
12
+ transformers>=4.30.0
13
+ sentencepiece>=0.1.99
14
+ scipy>=1.10.0
15
+ biopython>=1.80
16
+ requests>=2.28.0
17
  tqdm>=4.65.0
18
+ typer>=0.9.0