Add Streamlit explorer app with Docker deployment
Browse files- .gitignore +1 -0
- Dockerfile +28 -0
- README.md +41 -66
- app/app.py +145 -0
- app/components/__init__.py +0 -0
- app/components/mode_panel.py +139 -0
- app/components/prediction_analysis.py +320 -0
- app/components/sequence_viewer.py +268 -0
- app/components/viewer_3d.py +203 -0
- app/pages/1_๐_Explorer.py +169 -0
- app/pages/2_๐ฎ_Inference.py +162 -0
- app/pages/3_๐_Statistics.py +158 -0
- app/requirements.txt +15 -0
- app/utils/__init__.py +0 -0
- app/utils/data_loader.py +87 -0
- app/utils/download.py +121 -0
- app/utils/inference.py +98 -0
- claude.md +203 -0
- requirements.txt +15 -6
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
Dockerfile
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# System deps
|
| 6 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 7 |
+
build-essential git && \
|
| 8 |
+
rm -rf /var/lib/apt/lists/*
|
| 9 |
+
|
| 10 |
+
# Python deps
|
| 11 |
+
COPY requirements.txt .
|
| 12 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 13 |
+
|
| 14 |
+
# Copy app code
|
| 15 |
+
COPY . .
|
| 16 |
+
|
| 17 |
+
# Install PETIMOT package
|
| 18 |
+
RUN pip install --no-cache-dir -e .
|
| 19 |
+
|
| 20 |
+
# Streamlit config
|
| 21 |
+
RUN mkdir -p /root/.streamlit
|
| 22 |
+
RUN echo '[server]\nheadless = true\nport = 7860\nenableCORS = false\nenableXsrfProtection = false\n' > /root/.streamlit/config.toml
|
| 23 |
+
|
| 24 |
+
EXPOSE 7860
|
| 25 |
+
|
| 26 |
+
HEALTHCHECK CMD curl --fail http://localhost:7860/_stcore/health
|
| 27 |
+
|
| 28 |
+
ENTRYPOINT ["streamlit", "run", "app/app.py", "--server.port=7860", "--server.address=0.0.0.0"]
|
README.md
CHANGED
|
@@ -1,71 +1,46 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
## Usage
|
| 20 |
|
| 21 |
-
### Reproduce paper results
|
| 22 |
-
|
| 23 |
-
1. Download resources from [Figshare](https://figshare.com/s/ab400d852b4669a83b64):
|
| 24 |
-
- Download `default_2025-02-07_21-54-02_epoch_33.pt` into the `weights/` directory
|
| 25 |
-
- Download and extract `ground_truth.zip` into the `ground_truth/` directory
|
| 26 |
-
|
| 27 |
-
2. Run inference and evaluation:
|
| 28 |
-
```bash
|
| 29 |
-
python -m petimot infer_and_evaluate \
|
| 30 |
-
--model-path weights/default_2025-02-07_21-54-02_epoch_33.pt \
|
| 31 |
-
--list-path eval_list.txt \
|
| 32 |
-
--ground-truth-path ground_truth/ \
|
| 33 |
-
--prediction-path predictions/ \
|
| 34 |
-
--evaluation-path evaluation/
|
| 35 |
-
```
|
| 36 |
-
|
| 37 |
-
### Compare with baseline methods
|
| 38 |
-
|
| 39 |
-
1. Download baseline predictions from [Figshare](https://figshare.com/s/ab400d852b4669a83b64) :
|
| 40 |
-
- Download and extract `baseline_predictions.zip` into the `baselines/` directory
|
| 41 |
-
|
| 42 |
-
2. Run evaluation:
|
| 43 |
```bash
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
```
|
| 49 |
-
|
| 50 |
-
Available baseline predictions:
|
| 51 |
-
- AlphaFlow (distilled)
|
| 52 |
-
- ESMFlow (distilled)
|
| 53 |
-
- Normal Mode Analysis
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
### Predict motions for your own PDB files
|
| 58 |
-
|
| 59 |
-
```bash
|
| 60 |
-
# Single PDB structure
|
| 61 |
-
python -m petimot infer \
|
| 62 |
-
--model-path weights/default_2025-02-07_21-54-02_epoch_33.pt \
|
| 63 |
-
--list-path protein.pdb \
|
| 64 |
-
--output-path predictions/
|
| 65 |
-
|
| 66 |
-
# Multiple structures (provide paths in a text file)
|
| 67 |
-
python -m petimot infer \
|
| 68 |
-
--model-path weights/default_2025-02-07_21-54-02_epoch_33.pt \
|
| 69 |
-
--list-path protein_list.txt \
|
| 70 |
-
--output-path predictions/
|
| 71 |
```
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: PETIMOT Explorer
|
| 3 |
+
emoji: ๐งฌ
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: streamlit
|
| 7 |
+
sdk_version: "1.40.0"
|
| 8 |
+
app_file: app/app.py
|
| 9 |
+
pinned: true
|
| 10 |
+
license: gpl-3.0
|
| 11 |
+
tags:
|
| 12 |
+
- protein
|
| 13 |
+
- motion
|
| 14 |
+
- GNN
|
| 15 |
+
- bioinformatics
|
| 16 |
+
- structural-biology
|
| 17 |
+
short_description: "Explore SE(3)-equivariant protein motion predictions"
|
| 18 |
+
---
|
| 19 |
+
|
| 20 |
+
# ๐งฌ PETIMOT Explorer
|
| 21 |
+
|
| 22 |
+
**Protein Motion Inference from Sparse Data**
|
| 23 |
+
|
| 24 |
+
Interactive explorer for protein motion predictions using SE(3)-equivariant Graph Neural Networks.
|
| 25 |
+
|
| 26 |
+
## Features
|
| 27 |
+
|
| 28 |
+
- ๐ **Explorer** โ Browse pre-computed predictions for ~36K proteins
|
| 29 |
+
- ๐ฎ **Inference** โ Predict motion for any protein (PDB ID or upload)
|
| 30 |
+
- ๐ **Statistics** โ Dataset-wide analysis and distributions
|
| 31 |
+
- ๐จ **3D Viewer** โ Interactive motion visualization with displacement arrows
|
| 32 |
+
- ๐งฌ **Sequence View** โ Per-residue displacement heatmap with coverage overlay
|
| 33 |
+
|
| 34 |
+
## Paper
|
| 35 |
+
|
| 36 |
+
> Lombard, Grudinin & Laine โ *PETIMOT: SE(3)-Equivariant GNNs for Protein Motion Prediction*
|
| 37 |
+
> [arXiv 2504.02839](https://arxiv.org/abs/2504.02839)
|
| 38 |
|
| 39 |
## Usage
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
```bash
|
| 42 |
+
git clone https://github.com/PhyloSofS-Team/PETIMOT
|
| 43 |
+
cd PETIMOT
|
| 44 |
+
pip install -r app/requirements.txt
|
| 45 |
+
streamlit run app/app.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
```
|
app/app.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import os, sys
|
| 3 |
+
|
| 4 |
+
# โโ Page Config โโ
|
| 5 |
+
st.set_page_config(
|
| 6 |
+
page_title="PETIMOT Explorer",
|
| 7 |
+
page_icon="๐งฌ",
|
| 8 |
+
layout="wide",
|
| 9 |
+
initial_sidebar_state="expanded",
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
# โโ Ensure PETIMOT is importable โโ
|
| 13 |
+
PETIMOT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 14 |
+
if PETIMOT_ROOT not in sys.path:
|
| 15 |
+
sys.path.insert(0, PETIMOT_ROOT)
|
| 16 |
+
|
| 17 |
+
# โโ Custom CSS โโ
|
| 18 |
+
st.markdown("""
|
| 19 |
+
<style>
|
| 20 |
+
/* Dark theme overrides */
|
| 21 |
+
.stApp { background-color: #0f0d1a; }
|
| 22 |
+
.block-container { padding-top: 1rem; }
|
| 23 |
+
|
| 24 |
+
/* Sidebar styling */
|
| 25 |
+
section[data-testid="stSidebar"] {
|
| 26 |
+
background-color: #1a1730;
|
| 27 |
+
border-right: 1px solid #2d2b55;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
/* Headers */
|
| 31 |
+
h1, h2, h3 { color: #c4b5fd !important; }
|
| 32 |
+
|
| 33 |
+
/* Metric cards */
|
| 34 |
+
[data-testid="stMetric"] {
|
| 35 |
+
background-color: #1e1b4b;
|
| 36 |
+
border: 1px solid #312e81;
|
| 37 |
+
border-radius: 12px;
|
| 38 |
+
padding: 12px 16px;
|
| 39 |
+
}
|
| 40 |
+
[data-testid="stMetricLabel"] { color: #a5b4fc !important; }
|
| 41 |
+
[data-testid="stMetricValue"] { color: #e0e7ff !important; }
|
| 42 |
+
|
| 43 |
+
/* Dataframe */
|
| 44 |
+
.stDataFrame { border-radius: 8px; overflow: hidden; }
|
| 45 |
+
|
| 46 |
+
/* Tabs */
|
| 47 |
+
.stTabs [data-baseweb="tab"] {
|
| 48 |
+
background-color: #1e1b4b;
|
| 49 |
+
border-radius: 8px 8px 0 0;
|
| 50 |
+
color: #a5b4fc;
|
| 51 |
+
}
|
| 52 |
+
.stTabs [data-baseweb="tab"][aria-selected="true"] {
|
| 53 |
+
background-color: #312e81;
|
| 54 |
+
color: white;
|
| 55 |
+
}
|
| 56 |
+
</style>
|
| 57 |
+
""", unsafe_allow_html=True)
|
| 58 |
+
|
| 59 |
+
# โโ Sidebar โโ
|
| 60 |
+
with st.sidebar:
|
| 61 |
+
st.image("https://raw.githubusercontent.com/PhyloSofS-Team/PETIMOT/main/logo.png",
|
| 62 |
+
use_container_width=True)
|
| 63 |
+
st.markdown("# ๐งฌ PETIMOT")
|
| 64 |
+
st.markdown("**Protein Motion from Sparse Data**")
|
| 65 |
+
st.markdown("SE(3)-Equivariant GNNs")
|
| 66 |
+
st.divider()
|
| 67 |
+
|
| 68 |
+
# Global settings
|
| 69 |
+
st.markdown("### โ๏ธ Settings")
|
| 70 |
+
|
| 71 |
+
weights_dir = os.path.join(PETIMOT_ROOT, "weights")
|
| 72 |
+
pt_files = []
|
| 73 |
+
if os.path.isdir(weights_dir):
|
| 74 |
+
for root, dirs, files in os.walk(weights_dir):
|
| 75 |
+
for f in files:
|
| 76 |
+
if f.endswith(".pt"):
|
| 77 |
+
pt_files.append(os.path.join(root, f))
|
| 78 |
+
|
| 79 |
+
if pt_files:
|
| 80 |
+
selected_weights = st.selectbox(
|
| 81 |
+
"Model weights",
|
| 82 |
+
pt_files,
|
| 83 |
+
format_func=lambda x: os.path.basename(x),
|
| 84 |
+
key="weights"
|
| 85 |
+
)
|
| 86 |
+
else:
|
| 87 |
+
selected_weights = None
|
| 88 |
+
st.warning("No weights found in `weights/`")
|
| 89 |
+
|
| 90 |
+
st.divider()
|
| 91 |
+
st.markdown("""
|
| 92 |
+
**Links**
|
| 93 |
+
- [Paper](https://arxiv.org/abs/2504.02839)
|
| 94 |
+
- [GitHub](https://github.com/PhyloSofS-Team/PETIMOT)
|
| 95 |
+
- [Data](https://figshare.com/s/ab400d852b4669a83b64)
|
| 96 |
+
""")
|
| 97 |
+
st.caption("GPL-3.0 ยท Lombard, Grudinin & Laine")
|
| 98 |
+
|
| 99 |
+
# โโ Main Page โโ
|
| 100 |
+
st.title("๐งฌ PETIMOT Explorer")
|
| 101 |
+
st.markdown("""
|
| 102 |
+
Explore protein motion predictions from the PETIMOT framework.
|
| 103 |
+
Navigate using the sidebar pages:
|
| 104 |
+
|
| 105 |
+
| Page | Description |
|
| 106 |
+
|------|-------------|
|
| 107 |
+
| ๐ **Explorer** | Browse pre-computed predictions for ~36K proteins |
|
| 108 |
+
| ๐ฎ **Inference** | Predict motion for a new protein (PDB ID or upload) |
|
| 109 |
+
| ๐ **Statistics** | Dataset-wide analysis and distributions |
|
| 110 |
+
""")
|
| 111 |
+
|
| 112 |
+
# โโ Data Status โโ
|
| 113 |
+
from app.utils.download import check_data_status, ensure_weights
|
| 114 |
+
|
| 115 |
+
status = check_data_status(PETIMOT_ROOT)
|
| 116 |
+
|
| 117 |
+
col1, col2, col3 = st.columns(3)
|
| 118 |
+
with col1:
|
| 119 |
+
st.metric("Ground Truth", f"{status['ground_truth']:,}",
|
| 120 |
+
delta="โ
" if status['has_gt'] else "Missing")
|
| 121 |
+
with col2:
|
| 122 |
+
st.metric("Predictions", f"{status['predictions']:,}",
|
| 123 |
+
delta="โ
" if status['has_predictions'] else "Not yet computed")
|
| 124 |
+
with col3:
|
| 125 |
+
st.metric("Model Weights", "4.7M params",
|
| 126 |
+
delta="โ
" if status['has_weights'] else "Missing")
|
| 127 |
+
|
| 128 |
+
# Auto-download if missing
|
| 129 |
+
if not status['has_weights']:
|
| 130 |
+
st.divider()
|
| 131 |
+
st.warning("โ ๏ธ Model weights not found.")
|
| 132 |
+
if st.button("โฌ๏ธ Download weights from Figshare (18 MB)", type="primary"):
|
| 133 |
+
with st.spinner("Downloading..."):
|
| 134 |
+
wt = ensure_weights(PETIMOT_ROOT)
|
| 135 |
+
if wt:
|
| 136 |
+
st.success(f"โ
Weights downloaded: {os.path.basename(wt)}")
|
| 137 |
+
st.rerun()
|
| 138 |
+
else:
|
| 139 |
+
st.error("Download failed. Please manually download from "
|
| 140 |
+
"[Figshare](https://figshare.com/s/ab400d852b4669a83b64) "
|
| 141 |
+
"and place in `weights/`")
|
| 142 |
+
|
| 143 |
+
if not status['has_predictions'] and status['has_weights']:
|
| 144 |
+
st.info("๐ก No pre-computed predictions yet. Use the **Inference** page to predict "
|
| 145 |
+
"individual proteins, or run batch inference from the Colab notebook.")
|
app/components/__init__.py
ADDED
|
File without changes
|
app/components/mode_panel.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Mode selection panel with per-mode statistics."""
|
| 2 |
+
import streamlit as st
|
| 3 |
+
import numpy as np
|
| 4 |
+
import plotly.graph_objects as go
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def render_mode_panel(
|
| 8 |
+
modes: dict,
|
| 9 |
+
seq: str = "",
|
| 10 |
+
eigenvalues: np.ndarray = None,
|
| 11 |
+
) -> int:
|
| 12 |
+
"""Render mode selector with per-mode stats. Returns selected mode index.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
modes: {k: np.ndarray (N, 3)} displacement vectors per mode
|
| 16 |
+
seq: Amino acid sequence
|
| 17 |
+
eigenvalues: Ground truth eigenvalues (if available)
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
Selected mode index
|
| 21 |
+
"""
|
| 22 |
+
n_modes = len(modes)
|
| 23 |
+
if n_modes == 0:
|
| 24 |
+
st.warning("No modes available")
|
| 25 |
+
return 0
|
| 26 |
+
|
| 27 |
+
# Mode tabs
|
| 28 |
+
tabs = st.tabs([f"Mode {k}" for k in range(n_modes)])
|
| 29 |
+
|
| 30 |
+
selected = 0
|
| 31 |
+
for k in range(n_modes):
|
| 32 |
+
with tabs[k]:
|
| 33 |
+
vecs = modes[k]
|
| 34 |
+
mags = np.linalg.norm(vecs, axis=1)
|
| 35 |
+
n_res = len(mags)
|
| 36 |
+
|
| 37 |
+
# Stats columns
|
| 38 |
+
c1, c2, c3, c4 = st.columns(4)
|
| 39 |
+
c1.metric("Mean", f"{mags.mean():.3f} ร
")
|
| 40 |
+
c2.metric("Max", f"{mags.max():.3f} ร
")
|
| 41 |
+
c3.metric("Std", f"{mags.std():.3f} ร
")
|
| 42 |
+
|
| 43 |
+
if eigenvalues is not None and k < len(eigenvalues):
|
| 44 |
+
c4.metric("ฮป", f"{eigenvalues[k]:.4f}")
|
| 45 |
+
else:
|
| 46 |
+
c4.metric("Residues", f"{n_res}")
|
| 47 |
+
|
| 48 |
+
# Top mobile residues
|
| 49 |
+
top5 = np.argsort(mags)[-5:][::-1]
|
| 50 |
+
top_data = []
|
| 51 |
+
for idx in top5:
|
| 52 |
+
aa = seq[idx] if idx < len(seq) else "?"
|
| 53 |
+
top_data.append({
|
| 54 |
+
"Residue": f"{aa}{idx + 1}",
|
| 55 |
+
"Displacement": f"{mags[idx]:.3f} ร
",
|
| 56 |
+
"Rank": f"#{np.where(np.argsort(mags)[::-1] == idx)[0][0] + 1}",
|
| 57 |
+
})
|
| 58 |
+
st.markdown("**Most mobile residues:**")
|
| 59 |
+
st.dataframe(top_data, use_container_width=True, hide_index=True)
|
| 60 |
+
|
| 61 |
+
selected = st.session_state.get("_active_mode_tab", 0)
|
| 62 |
+
return selected
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def render_mode_correlation(modes: dict):
|
| 66 |
+
"""Render mode correlation matrix as Plotly heatmap."""
|
| 67 |
+
n_modes = len(modes)
|
| 68 |
+
if n_modes < 2:
|
| 69 |
+
return
|
| 70 |
+
|
| 71 |
+
# Compute displacement profile correlation
|
| 72 |
+
profiles = []
|
| 73 |
+
for k in sorted(modes.keys()):
|
| 74 |
+
mags = np.linalg.norm(modes[k], axis=1)
|
| 75 |
+
profiles.append(mags)
|
| 76 |
+
|
| 77 |
+
corr = np.corrcoef(profiles)
|
| 78 |
+
|
| 79 |
+
fig = go.Figure(go.Heatmap(
|
| 80 |
+
z=corr,
|
| 81 |
+
x=[f"M{k}" for k in range(n_modes)],
|
| 82 |
+
y=[f"M{k}" for k in range(n_modes)],
|
| 83 |
+
colorscale="RdBu_r",
|
| 84 |
+
zmin=-1, zmax=1,
|
| 85 |
+
text=np.round(corr, 2),
|
| 86 |
+
texttemplate="%{text:.2f}",
|
| 87 |
+
textfont={"size": 12},
|
| 88 |
+
))
|
| 89 |
+
|
| 90 |
+
fig.update_layout(
|
| 91 |
+
title="Mode Displacement Correlation",
|
| 92 |
+
template="plotly_dark",
|
| 93 |
+
height=300, width=300,
|
| 94 |
+
paper_bgcolor="rgba(0,0,0,0)",
|
| 95 |
+
plot_bgcolor="rgba(30,27,75,0.5)",
|
| 96 |
+
margin=dict(l=30, r=30, t=40, b=30),
|
| 97 |
+
)
|
| 98 |
+
st.plotly_chart(fig, use_container_width=False)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def render_eigenvalue_spectrum(eigenvalues: np.ndarray):
|
| 102 |
+
"""Render eigenvalue bar chart with cumulative variance line."""
|
| 103 |
+
if eigenvalues is None or len(eigenvalues) == 0:
|
| 104 |
+
return
|
| 105 |
+
|
| 106 |
+
fig = go.Figure()
|
| 107 |
+
|
| 108 |
+
# Bars
|
| 109 |
+
fig.add_trace(go.Bar(
|
| 110 |
+
x=[f"ฮป{k+1}" for k in range(len(eigenvalues))],
|
| 111 |
+
y=eigenvalues,
|
| 112 |
+
marker_color="#6366f1",
|
| 113 |
+
name="Eigenvalue",
|
| 114 |
+
))
|
| 115 |
+
|
| 116 |
+
# Cumulative variance line
|
| 117 |
+
cum = np.cumsum(eigenvalues) / eigenvalues.sum() * 100
|
| 118 |
+
fig.add_trace(go.Scatter(
|
| 119 |
+
x=[f"ฮป{k+1}" for k in range(len(eigenvalues))],
|
| 120 |
+
y=cum,
|
| 121 |
+
mode="lines+markers",
|
| 122 |
+
name="Cumul. variance %",
|
| 123 |
+
marker=dict(color="#ef4444", size=6),
|
| 124 |
+
line=dict(color="#ef4444", width=2),
|
| 125 |
+
yaxis="y2",
|
| 126 |
+
))
|
| 127 |
+
|
| 128 |
+
fig.update_layout(
|
| 129 |
+
title="Eigenvalue Spectrum",
|
| 130 |
+
template="plotly_dark",
|
| 131 |
+
height=250,
|
| 132 |
+
paper_bgcolor="rgba(0,0,0,0)",
|
| 133 |
+
plot_bgcolor="rgba(30,27,75,0.5)",
|
| 134 |
+
yaxis=dict(title="Eigenvalue"),
|
| 135 |
+
yaxis2=dict(title="Cumul. %", overlaying="y", side="right", range=[0, 105]),
|
| 136 |
+
legend=dict(orientation="h", y=1.15),
|
| 137 |
+
margin=dict(l=40, r=40, t=40, b=30),
|
| 138 |
+
)
|
| 139 |
+
st.plotly_chart(fig, use_container_width=True)
|
app/components/prediction_analysis.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Enhanced prediction analysis โ sign-invariant modes and per-residue normalization."""
|
| 2 |
+
import numpy as np
|
| 3 |
+
import streamlit as st
|
| 4 |
+
import plotly.graph_objects as go
|
| 5 |
+
from plotly.subplots import make_subplots
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def canonicalize_sign(modes: dict) -> dict:
|
| 9 |
+
"""Make eigenvectors sign-consistent.
|
| 10 |
+
|
| 11 |
+
Eigenvectors are defined up to ยฑ1 global sign. We canonicalize by choosing
|
| 12 |
+
the sign such that the component with the largest absolute value is positive.
|
| 13 |
+
This ensures consistent visualization across different runs/proteins.
|
| 14 |
+
"""
|
| 15 |
+
canonical = {}
|
| 16 |
+
for k, vecs in modes.items():
|
| 17 |
+
# Flatten to (3N,), find component with max absolute value
|
| 18 |
+
flat = vecs.flatten()
|
| 19 |
+
max_idx = np.argmax(np.abs(flat))
|
| 20 |
+
if flat[max_idx] < 0:
|
| 21 |
+
canonical[k] = -vecs # Flip sign
|
| 22 |
+
else:
|
| 23 |
+
canonical[k] = vecs.copy()
|
| 24 |
+
return canonical
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def per_residue_relative_norm(vecs: np.ndarray) -> np.ndarray:
|
| 28 |
+
"""Normalize displacement magnitudes to [0, 1] relative to max.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
vecs: (N, 3) displacement vectors
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
(N,) relative magnitudes in [0, 1]
|
| 35 |
+
"""
|
| 36 |
+
mags = np.linalg.norm(vecs, axis=1)
|
| 37 |
+
max_m = mags.max()
|
| 38 |
+
return mags / max_m if max_m > 1e-12 else mags
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def per_residue_direction(vecs: np.ndarray, ca_coords: np.ndarray) -> np.ndarray:
|
| 42 |
+
"""Compute relative direction of displacement vs protein backbone.
|
| 43 |
+
|
| 44 |
+
Projects displacement onto local backbone direction (CA_i โ CA_{i+1}).
|
| 45 |
+
Returns signed projection: positive = along backbone, negative = against.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
vecs: (N, 3) displacement vectors
|
| 49 |
+
ca_coords: (N, 3) CA coordinates
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
(N,) signed projections normalized by displacement magnitude
|
| 53 |
+
"""
|
| 54 |
+
n = len(vecs)
|
| 55 |
+
projections = np.zeros(n)
|
| 56 |
+
|
| 57 |
+
for i in range(n):
|
| 58 |
+
# Local backbone direction
|
| 59 |
+
if i < n - 1:
|
| 60 |
+
backbone = ca_coords[i + 1] - ca_coords[i]
|
| 61 |
+
else:
|
| 62 |
+
backbone = ca_coords[i] - ca_coords[i - 1]
|
| 63 |
+
|
| 64 |
+
bb_norm = np.linalg.norm(backbone)
|
| 65 |
+
if bb_norm < 1e-8:
|
| 66 |
+
continue
|
| 67 |
+
|
| 68 |
+
disp_mag = np.linalg.norm(vecs[i])
|
| 69 |
+
if disp_mag < 1e-8:
|
| 70 |
+
continue
|
| 71 |
+
|
| 72 |
+
# Cosine angle between displacement and backbone direction
|
| 73 |
+
projections[i] = np.dot(vecs[i], backbone) / (disp_mag * bb_norm)
|
| 74 |
+
|
| 75 |
+
return projections
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def render_prediction_analysis(
|
| 79 |
+
modes: dict,
|
| 80 |
+
seq: str,
|
| 81 |
+
ca_coords: np.ndarray = None,
|
| 82 |
+
coverage: np.ndarray = None,
|
| 83 |
+
eigenvalues: np.ndarray = None,
|
| 84 |
+
gt_modes: dict = None,
|
| 85 |
+
protein_name: str = "",
|
| 86 |
+
):
|
| 87 |
+
"""Comprehensive prediction analysis panel.
|
| 88 |
+
|
| 89 |
+
Shows:
|
| 90 |
+
1. Normalized displacement heatmap (all modes ร residues)
|
| 91 |
+
2. Sign-canonical direction analysis
|
| 92 |
+
3. Prediction vs ground truth comparison (if available)
|
| 93 |
+
4. Per-residue statistics table
|
| 94 |
+
"""
|
| 95 |
+
# Canonicalize signs
|
| 96 |
+
modes_c = canonicalize_sign(modes)
|
| 97 |
+
n_modes = len(modes_c)
|
| 98 |
+
n_res = len(list(modes_c.values())[0])
|
| 99 |
+
|
| 100 |
+
if coverage is None:
|
| 101 |
+
coverage = np.ones(n_res)
|
| 102 |
+
|
| 103 |
+
# โโ Tab layout โโ
|
| 104 |
+
tab_norm, tab_dir, tab_compare, tab_table = st.tabs([
|
| 105 |
+
"๐ Normalized Displacement", "๐งญ Direction Analysis",
|
| 106 |
+
"โ๏ธ Pred vs GT", "๐ Per-Residue Table"
|
| 107 |
+
])
|
| 108 |
+
|
| 109 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 110 |
+
# Tab 1: Normalized displacement heatmap
|
| 111 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 112 |
+
with tab_norm:
|
| 113 |
+
# Compute relative norms for all modes
|
| 114 |
+
rel_norms = np.zeros((n_modes, n_res))
|
| 115 |
+
abs_mags = np.zeros((n_modes, n_res))
|
| 116 |
+
for k in range(n_modes):
|
| 117 |
+
abs_mags[k] = np.linalg.norm(modes_c[k], axis=1)
|
| 118 |
+
rel_norms[k] = per_residue_relative_norm(modes_c[k])
|
| 119 |
+
|
| 120 |
+
# Hover text with sequence
|
| 121 |
+
hover = [[f"{seq[j] if j < len(seq) else '?'}{j+1}<br>"
|
| 122 |
+
f"Abs: {abs_mags[k][j]:.3f}ร
<br>"
|
| 123 |
+
f"Rel: {rel_norms[k][j]:.2%}<br>"
|
| 124 |
+
f"Cov: {coverage[j]:.2f}"
|
| 125 |
+
for j in range(n_res)] for k in range(n_modes)]
|
| 126 |
+
|
| 127 |
+
fig = make_subplots(rows=3, cols=1, row_heights=[0.4, 0.4, 0.2],
|
| 128 |
+
shared_xaxes=True, vertical_spacing=0.06,
|
| 129 |
+
subplot_titles=["Absolute Displacement (ร
)",
|
| 130 |
+
"Relative Displacement (0-1)",
|
| 131 |
+
"Coverage"])
|
| 132 |
+
|
| 133 |
+
# Absolute heatmap
|
| 134 |
+
fig.add_trace(go.Heatmap(
|
| 135 |
+
z=abs_mags, colorscale="YlOrRd",
|
| 136 |
+
y=[f"Mode {k}" for k in range(n_modes)],
|
| 137 |
+
text=hover, hovertemplate="%{text}<extra></extra>",
|
| 138 |
+
colorbar=dict(title="ร
", x=1.01, len=0.35, y=0.85),
|
| 139 |
+
), row=1, col=1)
|
| 140 |
+
|
| 141 |
+
# Relative heatmap
|
| 142 |
+
fig.add_trace(go.Heatmap(
|
| 143 |
+
z=rel_norms, colorscale="Viridis", zmin=0, zmax=1,
|
| 144 |
+
y=[f"Mode {k}" for k in range(n_modes)],
|
| 145 |
+
text=hover, hovertemplate="%{text}<extra></extra>",
|
| 146 |
+
colorbar=dict(title="Rel", x=1.08, len=0.35, y=0.5),
|
| 147 |
+
), row=2, col=1)
|
| 148 |
+
|
| 149 |
+
# Coverage bar
|
| 150 |
+
fig.add_trace(go.Bar(
|
| 151 |
+
x=list(range(n_res)), y=coverage[:n_res],
|
| 152 |
+
marker_color=["#10b981" if c > 0.5 else "#ef4444" for c in coverage[:n_res]],
|
| 153 |
+
hovertemplate="Res %{x}<br>Coverage: %{y:.3f}<extra></extra>",
|
| 154 |
+
showlegend=False,
|
| 155 |
+
), row=3, col=1)
|
| 156 |
+
|
| 157 |
+
# Sequence ticks
|
| 158 |
+
step = max(1, n_res // 50)
|
| 159 |
+
tick_vals = list(range(0, n_res, step))
|
| 160 |
+
tick_text = [f"{seq[i] if i < len(seq) else '?'}{i+1}" for i in tick_vals]
|
| 161 |
+
fig.update_xaxes(tickvals=tick_vals, ticktext=tick_text, tickangle=45,
|
| 162 |
+
tickfont=dict(size=8), row=3, col=1)
|
| 163 |
+
|
| 164 |
+
fig.update_layout(
|
| 165 |
+
template="plotly_dark", height=550,
|
| 166 |
+
paper_bgcolor="rgba(0,0,0,0)",
|
| 167 |
+
plot_bgcolor="rgba(30,27,75,0.3)",
|
| 168 |
+
margin=dict(l=60, r=80, t=30, b=50),
|
| 169 |
+
)
|
| 170 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 171 |
+
|
| 172 |
+
# Key insight
|
| 173 |
+
for k in range(min(n_modes, 4)):
|
| 174 |
+
top3 = np.argsort(abs_mags[k])[-3:][::-1]
|
| 175 |
+
top_str = ", ".join([f"**{seq[i] if i<len(seq) else '?'}{i+1}** ({abs_mags[k][i]:.2f}ร
)"
|
| 176 |
+
for i in top3])
|
| 177 |
+
st.markdown(f"Mode {k} hotspots: {top_str}")
|
| 178 |
+
|
| 179 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 180 |
+
# Tab 2: Direction analysis
|
| 181 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 182 |
+
with tab_dir:
|
| 183 |
+
if ca_coords is not None and len(ca_coords) == n_res:
|
| 184 |
+
st.markdown("""
|
| 185 |
+
**Direction Analysis**: Projects displacement onto the local backbone direction (CAโCA).
|
| 186 |
+
- ๐ต **Blue** = motion along backbone (stretching/compressing)
|
| 187 |
+
- ๐ด **Red** = motion perpendicular to backbone (lateral/hinge)
|
| 188 |
+
- Sign is arbitrary for eigenvectors โ we show absolute cosine similarity
|
| 189 |
+
""")
|
| 190 |
+
|
| 191 |
+
fig_dir = go.Figure()
|
| 192 |
+
colors = ["#6366f1", "#ef4444", "#10b981", "#f59e0b"]
|
| 193 |
+
|
| 194 |
+
for k in range(min(n_modes, 4)):
|
| 195 |
+
proj = per_residue_direction(modes_c[k], ca_coords)
|
| 196 |
+
# Show absolute cosine (sign-invariant)
|
| 197 |
+
abs_proj = np.abs(proj)
|
| 198 |
+
|
| 199 |
+
fig_dir.add_trace(go.Scatter(
|
| 200 |
+
x=list(range(1, n_res + 1)), y=abs_proj,
|
| 201 |
+
mode="lines", name=f"Mode {k}",
|
| 202 |
+
line=dict(color=colors[k], width=1.5),
|
| 203 |
+
fill="tozeroy",
|
| 204 |
+
fillcolor=colors[k].replace(")", ",0.1)").replace("#", "rgba(").replace(
|
| 205 |
+
"rgba(6366f1", "rgba(99,102,241").replace(
|
| 206 |
+
"rgba(ef4444", "rgba(239,68,68").replace(
|
| 207 |
+
"rgba(10b981", "rgba(16,185,129").replace(
|
| 208 |
+
"rgba(f59e0b", "rgba(245,158,11"),
|
| 209 |
+
hovertemplate="Res %{x}<br>|cos ฮธ|: %{y:.3f}<extra>Mode " + str(k) + "</extra>",
|
| 210 |
+
))
|
| 211 |
+
|
| 212 |
+
fig_dir.add_hline(y=0.5, line_dash="dash", line_color="#94a3b8",
|
| 213 |
+
annotation_text="isotropic threshold")
|
| 214 |
+
|
| 215 |
+
fig_dir.update_layout(
|
| 216 |
+
template="plotly_dark", height=350,
|
| 217 |
+
paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(30,27,75,0.3)",
|
| 218 |
+
xaxis_title="Residue", yaxis_title="|cos ฮธ| (backbone projection)",
|
| 219 |
+
yaxis_range=[0, 1.05],
|
| 220 |
+
margin=dict(l=50, r=20, t=30, b=50),
|
| 221 |
+
)
|
| 222 |
+
st.plotly_chart(fig_dir, use_container_width=True)
|
| 223 |
+
|
| 224 |
+
# Direction heatmap
|
| 225 |
+
st.markdown("**Per-residue ร mode direction matrix:**")
|
| 226 |
+
dir_matrix = np.zeros((n_modes, n_res))
|
| 227 |
+
for k in range(n_modes):
|
| 228 |
+
dir_matrix[k] = np.abs(per_residue_direction(modes_c[k], ca_coords))
|
| 229 |
+
|
| 230 |
+
fig_dh = go.Figure(go.Heatmap(
|
| 231 |
+
z=dir_matrix, colorscale="RdBu_r", zmin=0, zmax=1,
|
| 232 |
+
y=[f"Mode {k}" for k in range(n_modes)],
|
| 233 |
+
colorbar=dict(title="|cos ฮธ|"),
|
| 234 |
+
))
|
| 235 |
+
fig_dh.update_layout(
|
| 236 |
+
template="plotly_dark", height=200,
|
| 237 |
+
paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(30,27,75,0.3)",
|
| 238 |
+
margin=dict(l=60, r=60, t=10, b=30),
|
| 239 |
+
)
|
| 240 |
+
st.plotly_chart(fig_dh, use_container_width=True)
|
| 241 |
+
else:
|
| 242 |
+
st.info("Direction analysis requires CA coordinates (ground truth or PDB needed)")
|
| 243 |
+
|
| 244 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 245 |
+
# Tab 3: Prediction vs Ground Truth
|
| 246 |
+
# โโโโโ๏ฟฝ๏ฟฝโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 247 |
+
with tab_compare:
|
| 248 |
+
if gt_modes is not None and len(gt_modes) > 0:
|
| 249 |
+
gt_c = canonicalize_sign(gt_modes)
|
| 250 |
+
n_gt = len(gt_c)
|
| 251 |
+
|
| 252 |
+
st.markdown("**Pred vs GT displacement profiles (sign-canonicalized):**")
|
| 253 |
+
|
| 254 |
+
for k in range(min(n_modes, n_gt, 4)):
|
| 255 |
+
pred_mag = np.linalg.norm(modes_c[k], axis=1)
|
| 256 |
+
gt_mag = np.linalg.norm(gt_c[k], axis=1)
|
| 257 |
+
|
| 258 |
+
# Normalize both to [0, 1]
|
| 259 |
+
pred_rel = pred_mag / (pred_mag.max() + 1e-12)
|
| 260 |
+
gt_rel = gt_mag / (gt_mag.max() + 1e-12)
|
| 261 |
+
|
| 262 |
+
fig_cmp = go.Figure()
|
| 263 |
+
fig_cmp.add_trace(go.Scatter(
|
| 264 |
+
x=list(range(1, n_res + 1)), y=gt_rel,
|
| 265 |
+
mode="lines", name="Ground Truth",
|
| 266 |
+
line=dict(color="#10b981", width=2),
|
| 267 |
+
))
|
| 268 |
+
fig_cmp.add_trace(go.Scatter(
|
| 269 |
+
x=list(range(1, n_res + 1)), y=pred_rel,
|
| 270 |
+
mode="lines", name="Prediction",
|
| 271 |
+
line=dict(color="#6366f1", width=2, dash="dot"),
|
| 272 |
+
))
|
| 273 |
+
|
| 274 |
+
# Correlation
|
| 275 |
+
corr = np.corrcoef(pred_rel, gt_rel)[0, 1]
|
| 276 |
+
rmse = np.sqrt(np.mean((pred_rel - gt_rel) ** 2))
|
| 277 |
+
|
| 278 |
+
fig_cmp.update_layout(
|
| 279 |
+
template="plotly_dark", height=200,
|
| 280 |
+
title=f"Mode {k} โ r={corr:.3f}, RMSE={rmse:.3f}",
|
| 281 |
+
paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(30,27,75,0.3)",
|
| 282 |
+
margin=dict(l=40, r=20, t=40, b=30),
|
| 283 |
+
legend=dict(orientation="h", y=1.15),
|
| 284 |
+
)
|
| 285 |
+
st.plotly_chart(fig_cmp, use_container_width=True)
|
| 286 |
+
else:
|
| 287 |
+
st.info("No ground truth available for comparison. "
|
| 288 |
+
"Ground truth is only available for proteins in the training database.")
|
| 289 |
+
|
| 290 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 291 |
+
# Tab 4: Per-residue table
|
| 292 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 293 |
+
with tab_table:
|
| 294 |
+
import pandas as pd
|
| 295 |
+
|
| 296 |
+
rows = []
|
| 297 |
+
for i in range(n_res):
|
| 298 |
+
row = {
|
| 299 |
+
"Residue": i + 1,
|
| 300 |
+
"AA": seq[i] if i < len(seq) else "?",
|
| 301 |
+
"Coverage": f"{coverage[i]:.3f}" if i < len(coverage) else "โ",
|
| 302 |
+
}
|
| 303 |
+
for k in range(min(n_modes, 4)):
|
| 304 |
+
mag = np.linalg.norm(modes_c[k][i])
|
| 305 |
+
rel = per_residue_relative_norm(modes_c[k])[i]
|
| 306 |
+
row[f"M{k} (ร
)"] = f"{mag:.3f}"
|
| 307 |
+
row[f"M{k} rel"] = f"{rel:.2%}"
|
| 308 |
+
rows.append(row)
|
| 309 |
+
|
| 310 |
+
df = pd.DataFrame(rows)
|
| 311 |
+
st.dataframe(df, use_container_width=True, height=500,
|
| 312 |
+
column_config={
|
| 313 |
+
"Residue": st.column_config.NumberColumn(width="small"),
|
| 314 |
+
"AA": st.column_config.TextColumn(width="small"),
|
| 315 |
+
})
|
| 316 |
+
|
| 317 |
+
# Download CSV
|
| 318 |
+
csv = df.to_csv(index=False)
|
| 319 |
+
st.download_button("๐ฅ Download CSV", csv,
|
| 320 |
+
f"{protein_name}_analysis.csv", "text/csv")
|
app/components/sequence_viewer.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Interactive sequence viewer with per-residue displacement heatmap."""
|
| 2 |
+
import streamlit as st
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# Amino acid property classification
|
| 7 |
+
AA_PROPS = {
|
| 8 |
+
"A": "hydrophobic", "I": "hydrophobic", "L": "hydrophobic", "M": "hydrophobic",
|
| 9 |
+
"F": "hydrophobic", "W": "hydrophobic", "V": "hydrophobic", "P": "hydrophobic",
|
| 10 |
+
"D": "charged", "E": "charged", "K": "charged", "R": "charged", "H": "charged",
|
| 11 |
+
"S": "polar", "T": "polar", "N": "polar", "Q": "polar", "C": "polar", "Y": "polar",
|
| 12 |
+
"G": "special", "X": "unknown",
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
PROP_COLORS = {
|
| 16 |
+
"hydrophobic": "#f59e0b",
|
| 17 |
+
"charged": "#ef4444",
|
| 18 |
+
"polar": "#10b981",
|
| 19 |
+
"special": "#94a3b8",
|
| 20 |
+
"unknown": "#64748b",
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def render_sequence_viewer(
|
| 25 |
+
seq: str,
|
| 26 |
+
displacements: np.ndarray,
|
| 27 |
+
coverage: np.ndarray = None,
|
| 28 |
+
mode_label: str = "Mode 0",
|
| 29 |
+
max_chars_per_row: int = 80,
|
| 30 |
+
):
|
| 31 |
+
"""Render interactive HTML sequence viewer with displacement coloring.
|
| 32 |
+
|
| 33 |
+
Each residue is displayed as a colored cell where:
|
| 34 |
+
- Background: displacement magnitude (white โ red gradient)
|
| 35 |
+
- Border: coverage (thick = low coverage)
|
| 36 |
+
- Tooltip: residue info
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
seq: Amino acid sequence (1-letter codes)
|
| 40 |
+
displacements: Per-residue displacement magnitudes (N,)
|
| 41 |
+
coverage: Per-residue coverage (N,) in [0, 1]
|
| 42 |
+
mode_label: Label for the current mode
|
| 43 |
+
max_chars_per_row: Characters per row before wrapping
|
| 44 |
+
"""
|
| 45 |
+
n = len(seq)
|
| 46 |
+
if coverage is None:
|
| 47 |
+
coverage = np.ones(n)
|
| 48 |
+
|
| 49 |
+
max_d = displacements.max() + 1e-8
|
| 50 |
+
|
| 51 |
+
html = f"""
|
| 52 |
+
<style>
|
| 53 |
+
.seq-container {{
|
| 54 |
+
font-family: 'Consolas', 'Monaco', monospace;
|
| 55 |
+
background: #1e1b4b;
|
| 56 |
+
border-radius: 8px;
|
| 57 |
+
padding: 12px;
|
| 58 |
+
margin: 8px 0;
|
| 59 |
+
}}
|
| 60 |
+
.seq-header {{
|
| 61 |
+
color: #a5b4fc;
|
| 62 |
+
font-size: 13px;
|
| 63 |
+
margin-bottom: 8px;
|
| 64 |
+
font-weight: bold;
|
| 65 |
+
}}
|
| 66 |
+
.seq-row {{
|
| 67 |
+
display: flex;
|
| 68 |
+
flex-wrap: wrap;
|
| 69 |
+
gap: 1px;
|
| 70 |
+
margin-bottom: 2px;
|
| 71 |
+
}}
|
| 72 |
+
.res {{
|
| 73 |
+
display: inline-flex;
|
| 74 |
+
align-items: center;
|
| 75 |
+
justify-content: center;
|
| 76 |
+
width: 14px;
|
| 77 |
+
height: 22px;
|
| 78 |
+
font-size: 10px;
|
| 79 |
+
font-weight: bold;
|
| 80 |
+
border-radius: 2px;
|
| 81 |
+
cursor: pointer;
|
| 82 |
+
transition: transform 0.1s;
|
| 83 |
+
position: relative;
|
| 84 |
+
}}
|
| 85 |
+
.res:hover {{
|
| 86 |
+
transform: scale(1.8);
|
| 87 |
+
z-index: 10;
|
| 88 |
+
box-shadow: 0 0 8px rgba(99, 102, 241, 0.8);
|
| 89 |
+
}}
|
| 90 |
+
.res:hover::after {{
|
| 91 |
+
content: attr(data-tooltip);
|
| 92 |
+
position: absolute;
|
| 93 |
+
top: -38px;
|
| 94 |
+
left: 50%;
|
| 95 |
+
transform: translateX(-50%);
|
| 96 |
+
background: #312e81;
|
| 97 |
+
color: white;
|
| 98 |
+
padding: 4px 8px;
|
| 99 |
+
border-radius: 4px;
|
| 100 |
+
font-size: 10px;
|
| 101 |
+
white-space: nowrap;
|
| 102 |
+
z-index: 100;
|
| 103 |
+
border: 1px solid #6366f1;
|
| 104 |
+
}}
|
| 105 |
+
.seq-ruler {{
|
| 106 |
+
display: flex;
|
| 107 |
+
gap: 1px;
|
| 108 |
+
margin-bottom: 1px;
|
| 109 |
+
}}
|
| 110 |
+
.ruler-mark {{
|
| 111 |
+
width: 14px;
|
| 112 |
+
font-size: 7px;
|
| 113 |
+
color: #64748b;
|
| 114 |
+
text-align: center;
|
| 115 |
+
}}
|
| 116 |
+
.legend {{
|
| 117 |
+
display: flex;
|
| 118 |
+
gap: 16px;
|
| 119 |
+
margin-top: 8px;
|
| 120 |
+
font-size: 11px;
|
| 121 |
+
color: #94a3b8;
|
| 122 |
+
}}
|
| 123 |
+
.legend-item {{
|
| 124 |
+
display: flex;
|
| 125 |
+
align-items: center;
|
| 126 |
+
gap: 4px;
|
| 127 |
+
}}
|
| 128 |
+
.legend-swatch {{
|
| 129 |
+
width: 12px;
|
| 130 |
+
height: 12px;
|
| 131 |
+
border-radius: 2px;
|
| 132 |
+
}}
|
| 133 |
+
</style>
|
| 134 |
+
<div class="seq-container">
|
| 135 |
+
<div class="seq-header">{mode_label} โ Per-residue displacement ({n} residues)</div>
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
# Build rows
|
| 139 |
+
for row_start in range(0, n, max_chars_per_row):
|
| 140 |
+
row_end = min(row_start + max_chars_per_row, n)
|
| 141 |
+
|
| 142 |
+
# Ruler
|
| 143 |
+
html += '<div class="seq-ruler">'
|
| 144 |
+
for i in range(row_start, row_end):
|
| 145 |
+
if (i + 1) % 10 == 0:
|
| 146 |
+
html += f'<div class="ruler-mark">{i + 1}</div>'
|
| 147 |
+
elif (i + 1) % 5 == 0:
|
| 148 |
+
html += '<div class="ruler-mark">ยท</div>'
|
| 149 |
+
else:
|
| 150 |
+
html += '<div class="ruler-mark"></div>'
|
| 151 |
+
html += "</div>"
|
| 152 |
+
|
| 153 |
+
# Residues
|
| 154 |
+
html += '<div class="seq-row">'
|
| 155 |
+
for i in range(row_start, row_end):
|
| 156 |
+
aa = seq[i] if i < len(seq) else "X"
|
| 157 |
+
d = displacements[i]
|
| 158 |
+
c = coverage[i] if i < len(coverage) else 1.0
|
| 159 |
+
t = d / max_d # Normalized displacement
|
| 160 |
+
|
| 161 |
+
# Background: displacement heatmap (dark purple โ bright red)
|
| 162 |
+
r = int(30 + 225 * t)
|
| 163 |
+
g = int(27 + 20 * (1 - t))
|
| 164 |
+
b = int(75 - 50 * t)
|
| 165 |
+
bg = f"rgb({r},{g},{b})"
|
| 166 |
+
|
| 167 |
+
# Text color: white for high displacement, light for low
|
| 168 |
+
txt_color = "white" if t > 0.3 else "#a5b4fc"
|
| 169 |
+
|
| 170 |
+
# Border: thicker = lower coverage
|
| 171 |
+
border_w = max(0, int(3 * (1 - c)))
|
| 172 |
+
border = f"{border_w}px solid #ef4444" if border_w > 0 else "none"
|
| 173 |
+
|
| 174 |
+
prop = AA_PROPS.get(aa, "unknown")
|
| 175 |
+
tooltip = f"{aa}{i+1} | {d:.3f}ร
| cov={c:.2f} | {prop}"
|
| 176 |
+
|
| 177 |
+
html += (
|
| 178 |
+
f'<div class="res" style="background:{bg};color:{txt_color};'
|
| 179 |
+
f'border:{border}" data-tooltip="{tooltip}">{aa}</div>'
|
| 180 |
+
)
|
| 181 |
+
html += "</div>"
|
| 182 |
+
|
| 183 |
+
# Legend
|
| 184 |
+
html += """
|
| 185 |
+
<div class="legend">
|
| 186 |
+
<div class="legend-item">
|
| 187 |
+
<div class="legend-swatch" style="background:linear-gradient(90deg,#1e1b4b,#ff3030)"></div>
|
| 188 |
+
Low โ High displacement
|
| 189 |
+
</div>
|
| 190 |
+
<div class="legend-item">
|
| 191 |
+
<div class="legend-swatch" style="border:2px solid #ef4444;background:none"></div>
|
| 192 |
+
Red border = low coverage
|
| 193 |
+
</div>
|
| 194 |
+
</div>
|
| 195 |
+
</div>
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
st.markdown(html, unsafe_allow_html=True)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def render_displacement_chart(
|
| 202 |
+
displacements: dict,
|
| 203 |
+
seq: str = "",
|
| 204 |
+
coverage: np.ndarray = None,
|
| 205 |
+
):
|
| 206 |
+
"""Render interactive displacement profile chart using Plotly.
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
displacements: {mode_idx: np.ndarray of per-residue magnitudes}
|
| 210 |
+
seq: Amino acid sequence
|
| 211 |
+
coverage: Per-residue coverage
|
| 212 |
+
"""
|
| 213 |
+
import plotly.graph_objects as go
|
| 214 |
+
from plotly.subplots import make_subplots
|
| 215 |
+
|
| 216 |
+
n_modes = len(displacements)
|
| 217 |
+
n_res = len(list(displacements.values())[0])
|
| 218 |
+
residues = np.arange(1, n_res + 1)
|
| 219 |
+
|
| 220 |
+
# Hover text with AA identity
|
| 221 |
+
hover_text = [f"{seq[i] if i < len(seq) else '?'}{i+1}" for i in range(n_res)]
|
| 222 |
+
|
| 223 |
+
fig = make_subplots(
|
| 224 |
+
rows=2, cols=1, row_heights=[0.75, 0.25],
|
| 225 |
+
shared_xaxes=True, vertical_spacing=0.08,
|
| 226 |
+
subplot_titles=["Displacement by Mode", "Coverage"]
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
colors = ["#6366f1", "#ef4444", "#10b981", "#f59e0b", "#ec4899", "#8b5cf6"]
|
| 230 |
+
|
| 231 |
+
for k, d in displacements.items():
|
| 232 |
+
mags = np.linalg.norm(d, axis=1) if d.ndim == 2 else d
|
| 233 |
+
fig.add_trace(go.Scatter(
|
| 234 |
+
x=residues, y=mags,
|
| 235 |
+
mode="lines",
|
| 236 |
+
name=f"Mode {k} (ฮผ={mags.mean():.3f}ร
)",
|
| 237 |
+
line=dict(color=colors[k % len(colors)], width=1.5),
|
| 238 |
+
fill="tozeroy",
|
| 239 |
+
fillcolor=colors[k % len(colors)].replace(")", ",0.1)").replace("rgb", "rgba"),
|
| 240 |
+
text=hover_text,
|
| 241 |
+
hovertemplate="%{text}<br>Displacement: %{y:.3f}ร
<extra>Mode " + str(k) + "</extra>",
|
| 242 |
+
), row=1, col=1)
|
| 243 |
+
|
| 244 |
+
# Coverage
|
| 245 |
+
if coverage is not None:
|
| 246 |
+
fig.add_trace(go.Scatter(
|
| 247 |
+
x=residues, y=coverage[:n_res],
|
| 248 |
+
mode="lines",
|
| 249 |
+
name="Coverage",
|
| 250 |
+
line=dict(color="#94a3b8", width=1.5),
|
| 251 |
+
fill="tozeroy",
|
| 252 |
+
fillcolor="rgba(148,163,184,0.15)",
|
| 253 |
+
hovertemplate="%{x}<br>Coverage: %{y:.3f}<extra></extra>",
|
| 254 |
+
), row=2, col=1)
|
| 255 |
+
|
| 256 |
+
fig.update_layout(
|
| 257 |
+
template="plotly_dark",
|
| 258 |
+
height=400,
|
| 259 |
+
paper_bgcolor="rgba(0,0,0,0)",
|
| 260 |
+
plot_bgcolor="rgba(30,27,75,0.5)",
|
| 261 |
+
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
|
| 262 |
+
margin=dict(l=50, r=20, t=40, b=30),
|
| 263 |
+
)
|
| 264 |
+
fig.update_xaxes(title_text="Residue", row=2, col=1)
|
| 265 |
+
fig.update_yaxes(title_text="Displacement (ร
)", row=1, col=1)
|
| 266 |
+
fig.update_yaxes(title_text="Coverage", range=[0, 1.1], row=2, col=1)
|
| 267 |
+
|
| 268 |
+
st.plotly_chart(fig, use_container_width=True, key=f"disp_chart")
|
app/components/viewer_3d.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""3D protein motion visualization using py3Dmol via stmol."""
|
| 2 |
+
import streamlit as st
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
from stmol import showmol
|
| 7 |
+
import py3Dmol
|
| 8 |
+
HAS_STMOL = True
|
| 9 |
+
except ImportError:
|
| 10 |
+
HAS_STMOL = False
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def render_motion_viewer(
|
| 14 |
+
pdb_text: str,
|
| 15 |
+
ca_coords: np.ndarray,
|
| 16 |
+
mode_vecs: np.ndarray,
|
| 17 |
+
seq: str = "",
|
| 18 |
+
amplitude: float = 3.0,
|
| 19 |
+
arrow_scale: float = 1.0,
|
| 20 |
+
color_scheme: str = "magnitude",
|
| 21 |
+
show_cartoon: bool = True,
|
| 22 |
+
show_labels: bool = True,
|
| 23 |
+
min_displacement: float = 0.01,
|
| 24 |
+
width: int = 800,
|
| 25 |
+
height: int = 500,
|
| 26 |
+
key: str = "viewer",
|
| 27 |
+
):
|
| 28 |
+
"""Render interactive 3D motion viewer with displacement arrows.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
pdb_text: PDB file content as string
|
| 32 |
+
ca_coords: CA coordinates (N, 3)
|
| 33 |
+
mode_vecs: Displacement vectors for one mode (N, 3)
|
| 34 |
+
seq: Amino acid sequence (1-letter)
|
| 35 |
+
amplitude: Arrow length multiplier
|
| 36 |
+
arrow_scale: Arrow thickness multiplier
|
| 37 |
+
color_scheme: "magnitude"|"rainbow"|"bfactor"|"chain"|"residue_type"
|
| 38 |
+
show_cartoon: Show cartoon backbone
|
| 39 |
+
show_labels: Label top mobile residues
|
| 40 |
+
min_displacement: Hide arrows below this threshold
|
| 41 |
+
width, height: Viewer dimensions
|
| 42 |
+
key: Streamlit widget key
|
| 43 |
+
"""
|
| 44 |
+
if not HAS_STMOL:
|
| 45 |
+
st.error("Install `stmol`: `pip install stmol`")
|
| 46 |
+
return
|
| 47 |
+
|
| 48 |
+
n_res = len(ca_coords)
|
| 49 |
+
mags = np.linalg.norm(mode_vecs, axis=1)
|
| 50 |
+
max_mag = mags.max() + 1e-8
|
| 51 |
+
|
| 52 |
+
view = py3Dmol.view(width=width, height=height)
|
| 53 |
+
view.addModel(pdb_text, "pdb")
|
| 54 |
+
|
| 55 |
+
# Backbone style
|
| 56 |
+
if show_cartoon:
|
| 57 |
+
view.setStyle({"cartoon": {"color": "#e2e8f0", "opacity": 0.45}})
|
| 58 |
+
else:
|
| 59 |
+
view.setStyle({"stick": {"radius": 0.06, "color": "#94a3b8"}})
|
| 60 |
+
|
| 61 |
+
# Draw displacement arrows
|
| 62 |
+
for i in range(n_res):
|
| 63 |
+
if mags[i] < min_displacement:
|
| 64 |
+
continue
|
| 65 |
+
|
| 66 |
+
s = ca_coords[i]
|
| 67 |
+
d = mode_vecs[i] * amplitude
|
| 68 |
+
e = s + d
|
| 69 |
+
t = mags[i] / max_mag # Normalized intensity [0, 1]
|
| 70 |
+
|
| 71 |
+
# Color assignment
|
| 72 |
+
col = _get_color(i, t, n_res, seq, color_scheme)
|
| 73 |
+
|
| 74 |
+
# Arrow shaft โ radius proportional to displacement
|
| 75 |
+
base_r = 0.08 * arrow_scale
|
| 76 |
+
shaft_r = base_r + base_r * t
|
| 77 |
+
view.addCylinder({
|
| 78 |
+
"start": {"x": float(s[0]), "y": float(s[1]), "z": float(s[2])},
|
| 79 |
+
"end": {"x": float(e[0]), "y": float(e[1]), "z": float(e[2])},
|
| 80 |
+
"radius": shaft_r,
|
| 81 |
+
"color": col,
|
| 82 |
+
"fromCap": True,
|
| 83 |
+
})
|
| 84 |
+
|
| 85 |
+
# Arrow tip (cone-like)
|
| 86 |
+
dn = d / (np.linalg.norm(d) + 1e-8)
|
| 87 |
+
tip = e + dn * 0.25 * amplitude * arrow_scale
|
| 88 |
+
tip_r = shaft_r * 2.2
|
| 89 |
+
view.addCylinder({
|
| 90 |
+
"start": {"x": float(e[0]), "y": float(e[1]), "z": float(e[2])},
|
| 91 |
+
"end": {"x": float(tip[0]), "y": float(tip[1]), "z": float(tip[2])},
|
| 92 |
+
"radius": tip_r,
|
| 93 |
+
"color": col,
|
| 94 |
+
"toCap": True,
|
| 95 |
+
})
|
| 96 |
+
|
| 97 |
+
# Label top-5 mobile residues
|
| 98 |
+
if show_labels and n_res > 0:
|
| 99 |
+
top_n = min(5, n_res)
|
| 100 |
+
top_idx = np.argsort(mags)[-top_n:][::-1]
|
| 101 |
+
for idx in top_idx:
|
| 102 |
+
if mags[idx] < min_displacement:
|
| 103 |
+
continue
|
| 104 |
+
pos = ca_coords[idx]
|
| 105 |
+
aa = seq[idx] if idx < len(seq) else "?"
|
| 106 |
+
view.addLabel(
|
| 107 |
+
f"{aa}{idx + 1}\n{mags[idx]:.2f}ร
",
|
| 108 |
+
{
|
| 109 |
+
"position": {"x": float(pos[0]), "y": float(pos[1] + 2.5), "z": float(pos[2])},
|
| 110 |
+
"fontSize": 11,
|
| 111 |
+
"fontColor": "white",
|
| 112 |
+
"backgroundColor": "#312e81",
|
| 113 |
+
"backgroundOpacity": 0.85,
|
| 114 |
+
"borderColor": "#6366f1",
|
| 115 |
+
"borderThickness": 1,
|
| 116 |
+
},
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
view.zoomTo()
|
| 120 |
+
showmol(view, height=height, width=width)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def _get_color(idx: int, intensity: float, n_res: int, seq: str, scheme: str) -> str:
|
| 124 |
+
"""Get color for a residue based on the color scheme."""
|
| 125 |
+
if scheme == "magnitude":
|
| 126 |
+
# Blue โ Purple โ Red gradient
|
| 127 |
+
r = int(99 + 156 * intensity)
|
| 128 |
+
g = int(102 - 62 * intensity)
|
| 129 |
+
b = int(241 - 180 * intensity)
|
| 130 |
+
return f"rgb({r},{g},{b})"
|
| 131 |
+
|
| 132 |
+
elif scheme == "rainbow":
|
| 133 |
+
import colorsys
|
| 134 |
+
h = idx / max(n_res - 1, 1)
|
| 135 |
+
r, g, b = [int(255 * c) for c in colorsys.hsv_to_rgb(h, 0.85, 0.92)]
|
| 136 |
+
return f"rgb({r},{g},{b})"
|
| 137 |
+
|
| 138 |
+
elif scheme == "residue_type":
|
| 139 |
+
aa = seq[idx] if idx < len(seq) else "X"
|
| 140 |
+
hydrophobic = "AILMFWVP"
|
| 141 |
+
charged = "DEKRH"
|
| 142 |
+
polar = "STNQCY"
|
| 143 |
+
if aa in hydrophobic: return "#f59e0b"
|
| 144 |
+
if aa in charged: return "#ef4444"
|
| 145 |
+
if aa in polar: return "#10b981"
|
| 146 |
+
return "#94a3b8"
|
| 147 |
+
|
| 148 |
+
elif scheme == "bfactor":
|
| 149 |
+
r = int(255 * intensity)
|
| 150 |
+
g = int(100 * (1 - intensity))
|
| 151 |
+
b = int(50 * (1 - intensity))
|
| 152 |
+
return f"rgb({r},{g},{b})"
|
| 153 |
+
|
| 154 |
+
else:
|
| 155 |
+
return "#6366f1"
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def render_mode_comparison(
|
| 159 |
+
pdb_text: str,
|
| 160 |
+
ca_coords: np.ndarray,
|
| 161 |
+
modes: dict,
|
| 162 |
+
seq: str = "",
|
| 163 |
+
amplitude: float = 3.0,
|
| 164 |
+
arrow_scale: float = 1.0,
|
| 165 |
+
width: int = 900,
|
| 166 |
+
height: int = 350,
|
| 167 |
+
):
|
| 168 |
+
"""Render side-by-side mode comparison grid."""
|
| 169 |
+
if not HAS_STMOL:
|
| 170 |
+
st.error("Install `stmol`")
|
| 171 |
+
return
|
| 172 |
+
|
| 173 |
+
n_modes = min(4, len(modes))
|
| 174 |
+
if n_modes == 0:
|
| 175 |
+
st.warning("No modes to display")
|
| 176 |
+
return
|
| 177 |
+
|
| 178 |
+
colors = ["#6366f1", "#ef4444", "#10b981", "#f59e0b"]
|
| 179 |
+
|
| 180 |
+
cols = st.columns(n_modes)
|
| 181 |
+
for k in range(n_modes):
|
| 182 |
+
with cols[k]:
|
| 183 |
+
vecs = modes[k]
|
| 184 |
+
mags = np.linalg.norm(vecs, axis=1)
|
| 185 |
+
st.caption(f"**Mode {k}** ยท ฮผ={mags.mean():.3f}ร
ยท max={mags.max():.3f}ร
")
|
| 186 |
+
|
| 187 |
+
view = py3Dmol.view(width=width // n_modes, height=height)
|
| 188 |
+
view.addModel(pdb_text, "pdb")
|
| 189 |
+
view.setStyle({"cartoon": {"color": "#e2e8f0", "opacity": 0.35}})
|
| 190 |
+
|
| 191 |
+
max_m = mags.max() + 1e-8
|
| 192 |
+
for i in range(len(ca_coords)):
|
| 193 |
+
if mags[i] < 0.01: continue
|
| 194 |
+
s = ca_coords[i]; d = vecs[i] * amplitude; e = s + d
|
| 195 |
+
t = mags[i] / max_m
|
| 196 |
+
view.addCylinder({
|
| 197 |
+
"start": {"x": float(s[0]), "y": float(s[1]), "z": float(s[2])},
|
| 198 |
+
"end": {"x": float(e[0]), "y": float(e[1]), "z": float(e[2])},
|
| 199 |
+
"radius": 0.08 * arrow_scale + 0.05 * t * arrow_scale,
|
| 200 |
+
"color": colors[k],
|
| 201 |
+
})
|
| 202 |
+
view.zoomTo()
|
| 203 |
+
showmol(view, height=height, width=width // n_modes)
|
app/pages/1_๐_Explorer.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""๐ Database Explorer โ Browse pre-computed PETIMOT predictions."""
|
| 2 |
+
import streamlit as st
|
| 3 |
+
import os, sys
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
# Imports
|
| 7 |
+
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 8 |
+
PETIMOT_ROOT = os.path.dirname(ROOT)
|
| 9 |
+
if PETIMOT_ROOT not in sys.path:
|
| 10 |
+
sys.path.insert(0, PETIMOT_ROOT)
|
| 11 |
+
|
| 12 |
+
from app.utils.data_loader import (
|
| 13 |
+
find_predictions_dir, load_prediction_index, load_modes, load_ground_truth
|
| 14 |
+
)
|
| 15 |
+
from app.components.viewer_3d import render_motion_viewer, render_mode_comparison
|
| 16 |
+
from app.components.sequence_viewer import render_sequence_viewer, render_displacement_chart
|
| 17 |
+
from app.components.mode_panel import render_mode_correlation, render_eigenvalue_spectrum
|
| 18 |
+
from app.components.prediction_analysis import render_prediction_analysis
|
| 19 |
+
|
| 20 |
+
st.header("๐ Database Explorer")
|
| 21 |
+
|
| 22 |
+
# โโ Find predictions โโ
|
| 23 |
+
pred_dir = find_predictions_dir(PETIMOT_ROOT)
|
| 24 |
+
gt_dir = os.path.join(PETIMOT_ROOT, "ground_truth")
|
| 25 |
+
|
| 26 |
+
if not pred_dir:
|
| 27 |
+
st.error("No predictions found. Run inference first (notebook cell 4.3 or Inference page).")
|
| 28 |
+
st.stop()
|
| 29 |
+
|
| 30 |
+
# โโ Load index โโ
|
| 31 |
+
with st.spinner("Loading prediction index..."):
|
| 32 |
+
df = load_prediction_index(pred_dir)
|
| 33 |
+
|
| 34 |
+
if df.empty:
|
| 35 |
+
st.warning("No predictions in index.")
|
| 36 |
+
st.stop()
|
| 37 |
+
|
| 38 |
+
st.success(f"**{len(df):,}** proteins indexed from `{os.path.basename(pred_dir)}`")
|
| 39 |
+
|
| 40 |
+
# โโ Filters โโ
|
| 41 |
+
with st.expander("๐ง Filters", expanded=False):
|
| 42 |
+
col1, col2, col3 = st.columns(3)
|
| 43 |
+
with col1:
|
| 44 |
+
search = st.text_input("๐ Search by name", "", placeholder="e.g. 1ake")
|
| 45 |
+
with col2:
|
| 46 |
+
len_range = st.slider("Sequence length", int(df.seq_len.min()), int(df.seq_len.max()),
|
| 47 |
+
(int(df.seq_len.min()), int(df.seq_len.max())))
|
| 48 |
+
with col3:
|
| 49 |
+
sort_by = st.selectbox("Sort by", ["name", "seq_len", "mean_disp_m0", "max_disp_m0"],
|
| 50 |
+
index=2)
|
| 51 |
+
|
| 52 |
+
mask = (df.seq_len >= len_range[0]) & (df.seq_len <= len_range[1])
|
| 53 |
+
if search:
|
| 54 |
+
mask &= df.name.str.contains(search, case=False, na=False)
|
| 55 |
+
df_filtered = df[mask].sort_values(sort_by, ascending=(sort_by == "name"))
|
| 56 |
+
st.markdown(f"Showing **{len(df_filtered)}** / {len(df)} proteins")
|
| 57 |
+
|
| 58 |
+
# โโ Table โโ
|
| 59 |
+
selected_idx = st.dataframe(
|
| 60 |
+
df_filtered[["name", "seq_len", "n_modes", "mean_disp_m0", "max_disp_m0", "top_residue"]].rename(
|
| 61 |
+
columns={"name": "Protein", "seq_len": "Length", "n_modes": "Modes",
|
| 62 |
+
"mean_disp_m0": "Mean ฮ (M0)", "max_disp_m0": "Max ฮ (M0)", "top_residue": "Top Res"}
|
| 63 |
+
),
|
| 64 |
+
use_container_width=True, hide_index=True,
|
| 65 |
+
on_select="rerun", selection_mode="single-row", height=350,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# โโ Protein detail panel โโ
|
| 69 |
+
selected_rows = selected_idx.selection.rows if selected_idx.selection.rows else []
|
| 70 |
+
if not selected_rows:
|
| 71 |
+
st.info("๐ Click a row to view detailed analysis")
|
| 72 |
+
st.stop()
|
| 73 |
+
|
| 74 |
+
protein_name = df_filtered.iloc[selected_rows[0]]["name"]
|
| 75 |
+
st.divider()
|
| 76 |
+
st.subheader(f"๐งฌ {protein_name}")
|
| 77 |
+
|
| 78 |
+
# Load data
|
| 79 |
+
modes = load_modes(pred_dir, protein_name)
|
| 80 |
+
gt = load_ground_truth(gt_dir, protein_name)
|
| 81 |
+
|
| 82 |
+
if not modes:
|
| 83 |
+
st.error(f"No mode files found for {protein_name}")
|
| 84 |
+
st.stop()
|
| 85 |
+
|
| 86 |
+
n_res = len(list(modes.values())[0])
|
| 87 |
+
seq = gt.get("seq", "X" * n_res) if gt else "X" * n_res
|
| 88 |
+
ca = gt["bb"][:, 1] if gt and "bb" in gt else np.zeros((n_res, 3))
|
| 89 |
+
coverage = gt.get("coverage", np.ones(n_res)) if gt else np.ones(n_res)
|
| 90 |
+
eigenvalues = gt.get("eigvals", None) if gt else None
|
| 91 |
+
pdb_text = None
|
| 92 |
+
|
| 93 |
+
pdb_path = os.path.join(PETIMOT_ROOT, "pdbs", f"{protein_name}.pdb")
|
| 94 |
+
if os.path.exists(pdb_path):
|
| 95 |
+
with open(pdb_path) as f:
|
| 96 |
+
pdb_text = f.read()
|
| 97 |
+
|
| 98 |
+
# โโ Sidebar controls โโ
|
| 99 |
+
with st.sidebar:
|
| 100 |
+
st.divider()
|
| 101 |
+
st.markdown(f"### ๐๏ธ {protein_name}")
|
| 102 |
+
mode_idx = st.slider("Mode", 0, len(modes) - 1, 0, key="mode_sel")
|
| 103 |
+
amplitude = st.slider("Arrow amplitude", 0.5, 15.0, 3.0, 0.5, key="amp")
|
| 104 |
+
arrow_scale = st.slider("Arrow thickness", 0.3, 3.0, 1.0, 0.1, key="arrow_s")
|
| 105 |
+
color_scheme = st.selectbox("Color", ["magnitude", "rainbow", "residue_type", "bfactor"], key="col_s")
|
| 106 |
+
show_labels = st.checkbox("Show labels", True, key="labels")
|
| 107 |
+
min_disp = st.slider("Min displacement", 0.0, 0.1, 0.01, 0.005, key="min_d")
|
| 108 |
+
|
| 109 |
+
# โโ 3D viewer + stats โโ
|
| 110 |
+
col_3d, col_info = st.columns([2, 1])
|
| 111 |
+
|
| 112 |
+
with col_3d:
|
| 113 |
+
current_mode = modes[mode_idx]
|
| 114 |
+
mags = np.linalg.norm(current_mode, axis=1)
|
| 115 |
+
|
| 116 |
+
if pdb_text and ca is not None and np.any(ca != 0):
|
| 117 |
+
render_motion_viewer(
|
| 118 |
+
pdb_text=pdb_text, ca_coords=ca, mode_vecs=current_mode, seq=seq,
|
| 119 |
+
amplitude=amplitude, arrow_scale=arrow_scale, color_scheme=color_scheme,
|
| 120 |
+
show_labels=show_labels, min_displacement=min_disp,
|
| 121 |
+
width=700, height=480,
|
| 122 |
+
)
|
| 123 |
+
else:
|
| 124 |
+
st.info("No PDB structure โ showing displacement data only.")
|
| 125 |
+
|
| 126 |
+
with col_info:
|
| 127 |
+
st.metric("Residues", n_res)
|
| 128 |
+
st.metric(f"Mode {mode_idx} mean ฮ", f"{mags.mean():.3f} ร
")
|
| 129 |
+
st.metric(f"Mode {mode_idx} max ฮ", f"{mags.max():.3f} ร
")
|
| 130 |
+
top5 = np.argsort(mags)[-5:][::-1]
|
| 131 |
+
st.markdown("**Top mobile residues:**")
|
| 132 |
+
for rank, idx in enumerate(top5):
|
| 133 |
+
aa = seq[idx] if idx < len(seq) else "?"
|
| 134 |
+
st.markdown(f"`#{rank+1}` **{aa}{idx+1}** โ {mags[idx]:.3f} ร
")
|
| 135 |
+
if eigenvalues is not None:
|
| 136 |
+
render_eigenvalue_spectrum(eigenvalues)
|
| 137 |
+
|
| 138 |
+
# โโ Sequence viewer โโ
|
| 139 |
+
st.markdown("### ๐งฌ Sequence ร Displacement")
|
| 140 |
+
render_sequence_viewer(seq, mags, coverage, mode_label=f"Mode {mode_idx}")
|
| 141 |
+
|
| 142 |
+
# โโ Displacement chart โโ
|
| 143 |
+
st.markdown("### ๐ Displacement Profiles")
|
| 144 |
+
render_displacement_chart(modes, seq, coverage)
|
| 145 |
+
|
| 146 |
+
# โโ Mode comparison โโ
|
| 147 |
+
st.markdown("### ๐ Mode Comparison")
|
| 148 |
+
col_corr, col_grid = st.columns([1, 2])
|
| 149 |
+
with col_corr:
|
| 150 |
+
render_mode_correlation(modes)
|
| 151 |
+
with col_grid:
|
| 152 |
+
if pdb_text and np.any(ca != 0):
|
| 153 |
+
render_mode_comparison(pdb_text, ca, modes, seq, amplitude, arrow_scale)
|
| 154 |
+
|
| 155 |
+
# โโ Deep Analysis โโ
|
| 156 |
+
st.divider()
|
| 157 |
+
st.markdown("### ๐ฌ Prediction Analysis")
|
| 158 |
+
|
| 159 |
+
gt_modes = None
|
| 160 |
+
if gt and "eigvects" in gt:
|
| 161 |
+
n_gt_modes = min(4, gt["eigvects"].shape[1] if gt["eigvects"].ndim == 2 else 4)
|
| 162 |
+
ev = gt["eigvects"][:, :n_gt_modes] # (3N, K)
|
| 163 |
+
ev = ev.reshape(-1, 3, n_gt_modes).transpose(0, 2, 1) # (N, K, 3)
|
| 164 |
+
gt_modes = {k: ev[:, k] for k in range(n_gt_modes)}
|
| 165 |
+
|
| 166 |
+
render_prediction_analysis(
|
| 167 |
+
modes=modes, seq=seq, ca_coords=ca, coverage=coverage,
|
| 168 |
+
eigenvalues=eigenvalues, gt_modes=gt_modes, protein_name=protein_name
|
| 169 |
+
)
|
app/pages/2_๐ฎ_Inference.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""๐ฎ Inference โ Predict motion for a new protein."""
|
| 2 |
+
import streamlit as st
|
| 3 |
+
import os, sys, tempfile
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 7 |
+
PETIMOT_ROOT = os.path.dirname(ROOT)
|
| 8 |
+
if PETIMOT_ROOT not in sys.path:
|
| 9 |
+
sys.path.insert(0, PETIMOT_ROOT)
|
| 10 |
+
|
| 11 |
+
from app.utils.inference import run_inference, download_pdb
|
| 12 |
+
from app.components.viewer_3d import render_motion_viewer, render_mode_comparison
|
| 13 |
+
from app.components.sequence_viewer import render_sequence_viewer, render_displacement_chart
|
| 14 |
+
from app.components.mode_panel import render_mode_correlation, render_eigenvalue_spectrum
|
| 15 |
+
|
| 16 |
+
st.header("๐ฎ Custom Inference")
|
| 17 |
+
st.markdown("Predict protein motion modes for any structure. Runs on CPU (~5-30s).")
|
| 18 |
+
|
| 19 |
+
# โโ Input method โโ
|
| 20 |
+
input_mode = st.radio("Input method", ["PDB ID (RCSB)", "Upload PDB file"], horizontal=True)
|
| 21 |
+
|
| 22 |
+
pdb_path = None
|
| 23 |
+
|
| 24 |
+
if input_mode == "PDB ID (RCSB)":
|
| 25 |
+
col1, col2 = st.columns([3, 1])
|
| 26 |
+
with col1:
|
| 27 |
+
pdb_id = st.text_input("PDB ID", "1akeA", placeholder="e.g. 1akeA, 4akeA",
|
| 28 |
+
help="4-char PDB code + optional chain letter")
|
| 29 |
+
with col2:
|
| 30 |
+
st.markdown("<br>", unsafe_allow_html=True)
|
| 31 |
+
fetch = st.button("๐ Fetch", use_container_width=True)
|
| 32 |
+
|
| 33 |
+
if fetch and pdb_id:
|
| 34 |
+
with st.spinner(f"Downloading {pdb_id} from RCSB..."):
|
| 35 |
+
pdb_path = download_pdb(pdb_id)
|
| 36 |
+
if pdb_path:
|
| 37 |
+
st.success(f"Downloaded {pdb_id}")
|
| 38 |
+
else:
|
| 39 |
+
st.error(f"Could not download {pdb_id}. Check the PDB ID.")
|
| 40 |
+
|
| 41 |
+
else:
|
| 42 |
+
uploaded = st.file_uploader("Upload PDB", type=["pdb"], key="pdb_upload")
|
| 43 |
+
if uploaded:
|
| 44 |
+
tmp = tempfile.NamedTemporaryFile(suffix=".pdb", delete=False)
|
| 45 |
+
tmp.write(uploaded.read())
|
| 46 |
+
tmp.close()
|
| 47 |
+
pdb_path = tmp.name
|
| 48 |
+
st.success(f"Uploaded: {uploaded.name}")
|
| 49 |
+
|
| 50 |
+
# โโ Weights selection โโ
|
| 51 |
+
weights_path = st.session_state.get("weights", None)
|
| 52 |
+
if not weights_path:
|
| 53 |
+
weights_dir = os.path.join(PETIMOT_ROOT, "weights")
|
| 54 |
+
pt_files = []
|
| 55 |
+
if os.path.isdir(weights_dir):
|
| 56 |
+
for root, dirs, files in os.walk(weights_dir):
|
| 57 |
+
for f in files:
|
| 58 |
+
if f.endswith(".pt"):
|
| 59 |
+
pt_files.append(os.path.join(root, f))
|
| 60 |
+
if pt_files:
|
| 61 |
+
weights_path = pt_files[0]
|
| 62 |
+
|
| 63 |
+
if not weights_path:
|
| 64 |
+
st.error("No model weights found in `weights/`. Download them from Figshare.")
|
| 65 |
+
st.stop()
|
| 66 |
+
|
| 67 |
+
# โโ Run inference โโ
|
| 68 |
+
if pdb_path:
|
| 69 |
+
if st.button("๐ Run PETIMOT Inference", type="primary", use_container_width=True):
|
| 70 |
+
with st.spinner("Running inference... (loading embeddings + forward pass)"):
|
| 71 |
+
try:
|
| 72 |
+
result = run_inference(pdb_path, weights_path)
|
| 73 |
+
except Exception as e:
|
| 74 |
+
st.error(f"Inference failed: {e}")
|
| 75 |
+
st.exception(e)
|
| 76 |
+
st.stop()
|
| 77 |
+
|
| 78 |
+
st.session_state["inference_result"] = result
|
| 79 |
+
st.success(f"โ
Predicted {len(result['modes'])} modes for {result['name']} ({result['n_res']} residues)")
|
| 80 |
+
|
| 81 |
+
# โโ Display results โโ
|
| 82 |
+
result = st.session_state.get("inference_result", None)
|
| 83 |
+
if result:
|
| 84 |
+
modes = result["modes"]
|
| 85 |
+
ca = result["ca_coords"]
|
| 86 |
+
seq = result["seq"]
|
| 87 |
+
pdb_text = result["pdb_text"]
|
| 88 |
+
n_res = result["n_res"]
|
| 89 |
+
|
| 90 |
+
st.divider()
|
| 91 |
+
st.subheader(f"๐งฌ {result['name']} โ {n_res} residues, {len(modes)} modes")
|
| 92 |
+
|
| 93 |
+
# Controls
|
| 94 |
+
with st.sidebar:
|
| 95 |
+
st.divider()
|
| 96 |
+
st.markdown(f"### ๐๏ธ {result['name']}")
|
| 97 |
+
mode_idx = st.slider("Mode", 0, max(0, len(modes) - 1), 0, key="inf_mode")
|
| 98 |
+
amplitude = st.slider("Amplitude", 0.5, 15.0, 3.0, 0.5, key="inf_amp")
|
| 99 |
+
arrow_scale = st.slider("Arrow size", 0.3, 3.0, 1.0, 0.1, key="inf_arrow")
|
| 100 |
+
color_scheme = st.selectbox("Colors",
|
| 101 |
+
["magnitude", "rainbow", "residue_type", "bfactor"],
|
| 102 |
+
key="inf_color")
|
| 103 |
+
|
| 104 |
+
# 3D viewer
|
| 105 |
+
current_mode = modes.get(mode_idx, list(modes.values())[0])
|
| 106 |
+
mags = np.linalg.norm(current_mode, axis=1)
|
| 107 |
+
|
| 108 |
+
col_3d, col_stats = st.columns([2, 1])
|
| 109 |
+
with col_3d:
|
| 110 |
+
render_motion_viewer(
|
| 111 |
+
pdb_text=pdb_text, ca_coords=ca, mode_vecs=current_mode, seq=seq,
|
| 112 |
+
amplitude=amplitude, arrow_scale=arrow_scale, color_scheme=color_scheme,
|
| 113 |
+
width=700, height=500, key="inf_viewer",
|
| 114 |
+
)
|
| 115 |
+
with col_stats:
|
| 116 |
+
st.metric("Residues", n_res)
|
| 117 |
+
st.metric(f"Mode {mode_idx} mean", f"{mags.mean():.3f} ร
")
|
| 118 |
+
st.metric(f"Mode {mode_idx} max", f"{mags.max():.3f} ร
")
|
| 119 |
+
top3 = np.argsort(mags)[-3:][::-1]
|
| 120 |
+
st.markdown("**Top mobile:**")
|
| 121 |
+
for idx in top3:
|
| 122 |
+
aa = seq[idx] if idx < len(seq) else "?"
|
| 123 |
+
st.markdown(f"**{aa}{idx+1}** โ {mags[idx]:.3f} ร
")
|
| 124 |
+
|
| 125 |
+
# Sequence viewer
|
| 126 |
+
st.markdown("### ๐งฌ Sequence ร Displacement")
|
| 127 |
+
render_sequence_viewer(seq, mags, mode_label=f"Mode {mode_idx}")
|
| 128 |
+
|
| 129 |
+
# Displacement chart
|
| 130 |
+
st.markdown("### ๐ Profiles")
|
| 131 |
+
render_displacement_chart(modes, seq)
|
| 132 |
+
|
| 133 |
+
# Mode comparison
|
| 134 |
+
if len(modes) > 1:
|
| 135 |
+
st.markdown("### ๐ Mode Comparison")
|
| 136 |
+
render_mode_comparison(pdb_text, ca, modes, seq, amplitude, arrow_scale)
|
| 137 |
+
render_mode_correlation(modes)
|
| 138 |
+
|
| 139 |
+
# Export
|
| 140 |
+
st.divider()
|
| 141 |
+
st.markdown("### ๐พ Export")
|
| 142 |
+
col_e1, col_e2 = st.columns(2)
|
| 143 |
+
with col_e1:
|
| 144 |
+
# CSV export
|
| 145 |
+
import pandas as pd
|
| 146 |
+
export_data = []
|
| 147 |
+
for k, v in modes.items():
|
| 148 |
+
m = np.linalg.norm(v, axis=1)
|
| 149 |
+
for i in range(len(m)):
|
| 150 |
+
export_data.append({
|
| 151 |
+
"residue": i + 1,
|
| 152 |
+
"aa": seq[i] if i < len(seq) else "?",
|
| 153 |
+
"mode": k,
|
| 154 |
+
"dx": v[i, 0], "dy": v[i, 1], "dz": v[i, 2],
|
| 155 |
+
"magnitude": m[i],
|
| 156 |
+
})
|
| 157 |
+
csv = pd.DataFrame(export_data).to_csv(index=False)
|
| 158 |
+
st.download_button("๐ฅ Download modes (CSV)", csv,
|
| 159 |
+
f"{result['name']}_modes.csv", "text/csv")
|
| 160 |
+
with col_e2:
|
| 161 |
+
st.download_button("๐ฅ Download PDB", pdb_text,
|
| 162 |
+
f"{result['name']}.pdb", "chemical/x-pdb")
|
app/pages/3_๐_Statistics.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""๐ Statistics โ Dataset-wide analysis of PETIMOT predictions."""
|
| 2 |
+
import streamlit as st
|
| 3 |
+
import os, sys
|
| 4 |
+
import numpy as np
|
| 5 |
+
import plotly.graph_objects as go
|
| 6 |
+
from plotly.subplots import make_subplots
|
| 7 |
+
|
| 8 |
+
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 9 |
+
PETIMOT_ROOT = os.path.dirname(ROOT)
|
| 10 |
+
if PETIMOT_ROOT not in sys.path:
|
| 11 |
+
sys.path.insert(0, PETIMOT_ROOT)
|
| 12 |
+
|
| 13 |
+
from app.utils.data_loader import find_predictions_dir, load_prediction_index
|
| 14 |
+
|
| 15 |
+
st.header("๐ Dataset Statistics")
|
| 16 |
+
|
| 17 |
+
pred_dir = find_predictions_dir(PETIMOT_ROOT)
|
| 18 |
+
if not pred_dir:
|
| 19 |
+
st.warning("No predictions found. Run batch inference first.")
|
| 20 |
+
st.stop()
|
| 21 |
+
|
| 22 |
+
with st.spinner("Building index..."):
|
| 23 |
+
df = load_prediction_index(pred_dir)
|
| 24 |
+
|
| 25 |
+
if df.empty:
|
| 26 |
+
st.warning("Empty index")
|
| 27 |
+
st.stop()
|
| 28 |
+
|
| 29 |
+
st.success(f"**{len(df):,}** proteins analyzed")
|
| 30 |
+
|
| 31 |
+
# โโ Summary metrics โโ
|
| 32 |
+
c1, c2, c3, c4, c5 = st.columns(5)
|
| 33 |
+
c1.metric("Proteins", f"{len(df):,}")
|
| 34 |
+
c2.metric("Median length", f"{df.seq_len.median():.0f}")
|
| 35 |
+
c3.metric("Mean ฮ (M0)", f"{df.mean_disp_m0.mean():.3f} ร
")
|
| 36 |
+
c4.metric("Max ฮ (M0)", f"{df.max_disp_m0.max():.3f} ร
")
|
| 37 |
+
c5.metric("Avg modes", f"{df.n_modes.mean():.1f}")
|
| 38 |
+
|
| 39 |
+
# โโ Distribution plots โโ
|
| 40 |
+
st.markdown("### ๐ Distributions")
|
| 41 |
+
|
| 42 |
+
fig = make_subplots(rows=2, cols=2,
|
| 43 |
+
subplot_titles=[
|
| 44 |
+
"Sequence Length Distribution",
|
| 45 |
+
"Mean Mode-0 Displacement",
|
| 46 |
+
"Max Mode-0 Displacement",
|
| 47 |
+
"Length vs Mean Displacement",
|
| 48 |
+
])
|
| 49 |
+
|
| 50 |
+
# 1. Sequence length
|
| 51 |
+
fig.add_trace(go.Histogram(
|
| 52 |
+
x=df.seq_len, nbinsx=60, marker_color="#6366f1", name="Length",
|
| 53 |
+
hovertemplate="Length: %{x}<br>Count: %{y}<extra></extra>",
|
| 54 |
+
), row=1, col=1)
|
| 55 |
+
fig.add_vline(x=df.seq_len.median(), line_dash="dash", line_color="#ef4444",
|
| 56 |
+
annotation_text=f"Med={df.seq_len.median():.0f}", row=1, col=1)
|
| 57 |
+
|
| 58 |
+
# 2. Mean displacement
|
| 59 |
+
fig.add_trace(go.Histogram(
|
| 60 |
+
x=df.mean_disp_m0, nbinsx=60, marker_color="#10b981", name="Mean ฮ",
|
| 61 |
+
hovertemplate="Mean ฮ: %{x:.3f}ร
<br>Count: %{y}<extra></extra>",
|
| 62 |
+
), row=1, col=2)
|
| 63 |
+
|
| 64 |
+
# 3. Max displacement
|
| 65 |
+
fig.add_trace(go.Histogram(
|
| 66 |
+
x=df.max_disp_m0, nbinsx=60, marker_color="#f59e0b", name="Max ฮ",
|
| 67 |
+
hovertemplate="Max ฮ: %{x:.3f}ร
<br>Count: %{y}<extra></extra>",
|
| 68 |
+
), row=2, col=1)
|
| 69 |
+
|
| 70 |
+
# 4. Scatter: length vs displacement
|
| 71 |
+
fig.add_trace(go.Scattergl(
|
| 72 |
+
x=df.seq_len, y=df.mean_disp_m0,
|
| 73 |
+
mode="markers",
|
| 74 |
+
marker=dict(size=3, color=df.max_disp_m0, colorscale="Viridis",
|
| 75 |
+
showscale=True, colorbar=dict(title="Max ฮ")),
|
| 76 |
+
name="Proteins",
|
| 77 |
+
hovertemplate="%{text}<br>Length: %{x}<br>Mean ฮ: %{y:.3f}ร
<extra></extra>",
|
| 78 |
+
text=df.name,
|
| 79 |
+
), row=2, col=2)
|
| 80 |
+
|
| 81 |
+
fig.update_layout(
|
| 82 |
+
template="plotly_dark",
|
| 83 |
+
height=600,
|
| 84 |
+
showlegend=False,
|
| 85 |
+
paper_bgcolor="rgba(0,0,0,0)",
|
| 86 |
+
plot_bgcolor="rgba(30,27,75,0.5)",
|
| 87 |
+
margin=dict(l=40, r=40, t=40, b=30),
|
| 88 |
+
)
|
| 89 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 90 |
+
|
| 91 |
+
# โโ Ground truth analysis (if available) โโ
|
| 92 |
+
gt_dir = os.path.join(PETIMOT_ROOT, "ground_truth")
|
| 93 |
+
if os.path.isdir(gt_dir) and len(os.listdir(gt_dir)) > 0:
|
| 94 |
+
st.markdown("### ๐ฏ Ground Truth Analysis")
|
| 95 |
+
st.info("Loading eigenvalue statistics from ground truth...")
|
| 96 |
+
|
| 97 |
+
import torch, glob
|
| 98 |
+
gt_files = sorted(glob.glob(os.path.join(gt_dir, "*.pt")))[:2000]
|
| 99 |
+
|
| 100 |
+
eigendata = []
|
| 101 |
+
for gf in gt_files:
|
| 102 |
+
try:
|
| 103 |
+
d = torch.load(gf, map_location="cpu", weights_only=True)
|
| 104 |
+
ev = d["eigvals"].numpy() if "eigvals" in d else None
|
| 105 |
+
cov = d["coverage"].numpy() if "coverage" in d else None
|
| 106 |
+
if ev is not None:
|
| 107 |
+
eigendata.append({
|
| 108 |
+
"name": os.path.basename(gf).replace(".pt", ""),
|
| 109 |
+
"n_res": len(d["bb"]),
|
| 110 |
+
"eigval_0": float(ev[0]),
|
| 111 |
+
"eigval_ratio": float(ev[0] / (ev[1] + 1e-12)) if len(ev) > 1 else 0,
|
| 112 |
+
"mode1_var_frac": float(ev[0] / (ev.sum() + 1e-12)),
|
| 113 |
+
"mean_coverage": float(cov.mean()) if cov is not None else 1.0,
|
| 114 |
+
})
|
| 115 |
+
except Exception:
|
| 116 |
+
continue
|
| 117 |
+
|
| 118 |
+
if eigendata:
|
| 119 |
+
import pandas as pd
|
| 120 |
+
df_gt = pd.DataFrame(eigendata)
|
| 121 |
+
|
| 122 |
+
fig2 = make_subplots(rows=1, cols=3,
|
| 123 |
+
subplot_titles=["Dominant Eigenvalue (ฮปโ)", "Mode Dominance (ฮปโ/ฮปโ)", "Mode 1 Variance %"])
|
| 124 |
+
|
| 125 |
+
fig2.add_trace(go.Histogram(
|
| 126 |
+
x=df_gt.eigval_0, nbinsx=60, marker_color="#ec4899", name="ฮปโ",
|
| 127 |
+
), row=1, col=1)
|
| 128 |
+
|
| 129 |
+
fig2.add_trace(go.Histogram(
|
| 130 |
+
x=df_gt.eigval_ratio, nbinsx=60, marker_color="#8b5cf6", name="ฮปโ/ฮปโ",
|
| 131 |
+
), row=1, col=2)
|
| 132 |
+
|
| 133 |
+
fig2.add_trace(go.Histogram(
|
| 134 |
+
x=df_gt.mode1_var_frac * 100, nbinsx=50, marker_color="#06b6d4", name="Var %",
|
| 135 |
+
), row=1, col=3)
|
| 136 |
+
|
| 137 |
+
fig2.update_layout(
|
| 138 |
+
template="plotly_dark", height=300, showlegend=False,
|
| 139 |
+
paper_bgcolor="rgba(0,0,0,0)",
|
| 140 |
+
plot_bgcolor="rgba(30,27,75,0.5)",
|
| 141 |
+
margin=dict(l=40, r=20, t=40, b=30),
|
| 142 |
+
)
|
| 143 |
+
st.plotly_chart(fig2, use_container_width=True)
|
| 144 |
+
|
| 145 |
+
# Summary
|
| 146 |
+
st.markdown(f"""
|
| 147 |
+
| Metric | Value |
|
| 148 |
+
|--------|-------|
|
| 149 |
+
| Samples analyzed | {len(df_gt):,} |
|
| 150 |
+
| ฮปโ median | {df_gt.eigval_0.median():.4f} |
|
| 151 |
+
| ฮปโ/ฮปโ median | {df_gt.eigval_ratio.median():.2f} |
|
| 152 |
+
| Mode 1 variance | {df_gt.mode1_var_frac.median()*100:.1f}% median |
|
| 153 |
+
| Coverage | {df_gt.mean_coverage.mean():.3f} ยฑ {df_gt.mean_coverage.std():.3f} |
|
| 154 |
+
""")
|
| 155 |
+
|
| 156 |
+
# โโ Data table (searchable) โโ
|
| 157 |
+
st.markdown("### ๐ Full Data Table")
|
| 158 |
+
st.dataframe(df, use_container_width=True, height=400)
|
app/requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PETIMOT Streamlit Explorer
|
| 2 |
+
streamlit>=1.30.0
|
| 3 |
+
stmol>=0.0.9
|
| 4 |
+
py3Dmol>=2.0.0
|
| 5 |
+
plotly>=5.18.0
|
| 6 |
+
pandas>=2.0.0
|
| 7 |
+
numpy>=1.24.0
|
| 8 |
+
torch>=2.0.0
|
| 9 |
+
torch_geometric>=2.0.0
|
| 10 |
+
transformers>=4.30.0
|
| 11 |
+
sentencepiece>=0.1.99
|
| 12 |
+
scipy>=1.10.0
|
| 13 |
+
biopython>=1.80
|
| 14 |
+
requests>=2.28.0
|
| 15 |
+
tqdm>=4.65.0
|
app/utils/__init__.py
ADDED
|
File without changes
|
app/utils/data_loader.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data loading utilities for pre-computed PETIMOT predictions."""
|
| 2 |
+
import os, json, glob, torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from functools import lru_cache
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def find_predictions_dir(root: str) -> str | None:
|
| 10 |
+
"""Find the predictions directory (most recent model)."""
|
| 11 |
+
pred_root = os.path.join(root, "predictions")
|
| 12 |
+
if not os.path.isdir(pred_root):
|
| 13 |
+
return None
|
| 14 |
+
subdirs = [os.path.join(pred_root, d) for d in os.listdir(pred_root)
|
| 15 |
+
if os.path.isdir(os.path.join(pred_root, d))]
|
| 16 |
+
if not subdirs:
|
| 17 |
+
return None
|
| 18 |
+
return max(subdirs, key=os.path.getmtime)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@lru_cache(maxsize=1)
|
| 22 |
+
def load_prediction_index(pred_dir: str) -> pd.DataFrame:
|
| 23 |
+
"""Build index of all predicted proteins with metadata."""
|
| 24 |
+
rows = []
|
| 25 |
+
mode_files = glob.glob(os.path.join(pred_dir, "*_mode_0.txt"))
|
| 26 |
+
|
| 27 |
+
for mf in mode_files:
|
| 28 |
+
base = os.path.basename(mf).replace("_mode_0.txt", "")
|
| 29 |
+
# Load mode 0 for stats
|
| 30 |
+
try:
|
| 31 |
+
vecs = np.loadtxt(mf)
|
| 32 |
+
n_res = len(vecs)
|
| 33 |
+
mag = np.linalg.norm(vecs, axis=1)
|
| 34 |
+
|
| 35 |
+
# Count available modes
|
| 36 |
+
n_modes = 0
|
| 37 |
+
for k in range(10):
|
| 38 |
+
if os.path.exists(os.path.join(pred_dir, f"{base}_mode_{k}.txt")):
|
| 39 |
+
n_modes += 1
|
| 40 |
+
else:
|
| 41 |
+
break
|
| 42 |
+
|
| 43 |
+
rows.append({
|
| 44 |
+
"name": base,
|
| 45 |
+
"seq_len": n_res,
|
| 46 |
+
"n_modes": n_modes,
|
| 47 |
+
"mean_disp_m0": float(mag.mean()),
|
| 48 |
+
"max_disp_m0": float(mag.max()),
|
| 49 |
+
"top_residue": int(np.argmax(mag)) + 1,
|
| 50 |
+
})
|
| 51 |
+
except Exception:
|
| 52 |
+
continue
|
| 53 |
+
|
| 54 |
+
return pd.DataFrame(rows).sort_values("name").reset_index(drop=True)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def load_modes(pred_dir: str, name: str) -> dict[int, np.ndarray]:
|
| 58 |
+
"""Load all mode files for a protein."""
|
| 59 |
+
modes = {}
|
| 60 |
+
for k in range(10):
|
| 61 |
+
for pfx in [f"extracted_{name}", name]:
|
| 62 |
+
mf = os.path.join(pred_dir, f"{pfx}_mode_{k}.txt")
|
| 63 |
+
if os.path.exists(mf):
|
| 64 |
+
modes[k] = np.loadtxt(mf)
|
| 65 |
+
break
|
| 66 |
+
return modes
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def load_ground_truth(gt_dir: str, name: str) -> dict | None:
|
| 70 |
+
"""Load ground truth data for a protein."""
|
| 71 |
+
path = os.path.join(gt_dir, f"{name}.pt")
|
| 72 |
+
if not os.path.exists(path):
|
| 73 |
+
return None
|
| 74 |
+
try:
|
| 75 |
+
data = torch.load(path, map_location="cpu", weights_only=True)
|
| 76 |
+
return {k: v.numpy() if isinstance(v, torch.Tensor) else v
|
| 77 |
+
for k, v in data.items()}
|
| 78 |
+
except Exception:
|
| 79 |
+
return None
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def load_pdb_text(pdb_path: str) -> str | None:
|
| 83 |
+
"""Load PDB file as text."""
|
| 84 |
+
if not os.path.exists(pdb_path):
|
| 85 |
+
return None
|
| 86 |
+
with open(pdb_path) as f:
|
| 87 |
+
return f.read()
|
app/utils/download.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Auto-download PETIMOT data from Figshare on first run."""
|
| 2 |
+
import os, zipfile, requests
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
FIGSHARE_PRIVATE_KEY = "ab400d852b4669a83b64"
|
| 8 |
+
FIGSHARE_FILES = {
|
| 9 |
+
"ground_truth.zip": "52349453",
|
| 10 |
+
"default_2025-02-07_21-54-02_epoch_33.pt": "52349456",
|
| 11 |
+
"baseline_predictions.zip": "52349480",
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def download_file(url: str, dest: str, desc: str = "") -> bool:
|
| 16 |
+
"""Download a file with progress bar."""
|
| 17 |
+
try:
|
| 18 |
+
r = requests.get(url, stream=True, allow_redirects=True, timeout=60)
|
| 19 |
+
r.raise_for_status()
|
| 20 |
+
total = int(r.headers.get("content-length", 0))
|
| 21 |
+
with open(dest, "wb") as f:
|
| 22 |
+
with tqdm(total=total, unit="B", unit_scale=True, desc=desc) as pbar:
|
| 23 |
+
for chunk in r.iter_content(8192):
|
| 24 |
+
f.write(chunk)
|
| 25 |
+
pbar.update(len(chunk))
|
| 26 |
+
if os.path.getsize(dest) < 1000:
|
| 27 |
+
os.remove(dest)
|
| 28 |
+
return False
|
| 29 |
+
return True
|
| 30 |
+
except Exception as e:
|
| 31 |
+
print(f"Download failed: {e}")
|
| 32 |
+
return False
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def ensure_weights(root: str) -> str | None:
|
| 36 |
+
"""Ensure model weights are available. Returns path to weights or None."""
|
| 37 |
+
weights_dir = os.path.join(root, "weights")
|
| 38 |
+
os.makedirs(weights_dir, exist_ok=True)
|
| 39 |
+
|
| 40 |
+
# Check for existing weights
|
| 41 |
+
for f in os.listdir(weights_dir):
|
| 42 |
+
if f.endswith(".pt"):
|
| 43 |
+
return os.path.join(weights_dir, f)
|
| 44 |
+
|
| 45 |
+
# Try downloading from Figshare
|
| 46 |
+
wt_name = "default_2025-02-07_21-54-02_epoch_33.pt"
|
| 47 |
+
wt_path = os.path.join(weights_dir, wt_name)
|
| 48 |
+
fid = FIGSHARE_FILES[wt_name]
|
| 49 |
+
url = f"https://figshare.com/ndownloader/files/{fid}?private_link={FIGSHARE_PRIVATE_KEY}"
|
| 50 |
+
|
| 51 |
+
print(f"โฌ๏ธ Downloading model weights ({wt_name})...")
|
| 52 |
+
if download_file(url, wt_path, "weights"):
|
| 53 |
+
print(f"โ
Weights saved to {wt_path}")
|
| 54 |
+
return wt_path
|
| 55 |
+
|
| 56 |
+
# Try Figshare API
|
| 57 |
+
try:
|
| 58 |
+
api_url = f"https://api.figshare.com/v2/articles/28679143/files"
|
| 59 |
+
r = requests.get(api_url, timeout=10)
|
| 60 |
+
if r.ok:
|
| 61 |
+
for f in r.json():
|
| 62 |
+
if "epoch" in f["name"] and f["name"].endswith(".pt"):
|
| 63 |
+
if download_file(f["download_url"], wt_path, "weights"):
|
| 64 |
+
return wt_path
|
| 65 |
+
except:
|
| 66 |
+
pass
|
| 67 |
+
|
| 68 |
+
return None
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def ensure_ground_truth(root: str) -> bool:
|
| 72 |
+
"""Ensure ground truth data is available."""
|
| 73 |
+
gt_dir = os.path.join(root, "ground_truth")
|
| 74 |
+
os.makedirs(gt_dir, exist_ok=True)
|
| 75 |
+
|
| 76 |
+
if len(list(Path(gt_dir).rglob("*.pt"))) > 0:
|
| 77 |
+
return True
|
| 78 |
+
|
| 79 |
+
# Try downloading
|
| 80 |
+
zip_path = os.path.join(root, "ground_truth.zip")
|
| 81 |
+
fid = FIGSHARE_FILES["ground_truth.zip"]
|
| 82 |
+
url = f"https://figshare.com/ndownloader/files/{fid}?private_link={FIGSHARE_PRIVATE_KEY}"
|
| 83 |
+
|
| 84 |
+
print(f"โฌ๏ธ Downloading ground truth (958 MB)...")
|
| 85 |
+
if download_file(url, zip_path, "ground_truth"):
|
| 86 |
+
print("๐ฆ Extracting...")
|
| 87 |
+
with zipfile.ZipFile(zip_path) as z:
|
| 88 |
+
z.extractall(root)
|
| 89 |
+
os.remove(zip_path)
|
| 90 |
+
return True
|
| 91 |
+
return False
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def check_data_status(root: str) -> dict:
|
| 95 |
+
"""Check what data is available."""
|
| 96 |
+
gt_dir = os.path.join(root, "ground_truth")
|
| 97 |
+
weights_dir = os.path.join(root, "weights")
|
| 98 |
+
pred_dir = os.path.join(root, "predictions")
|
| 99 |
+
|
| 100 |
+
n_gt = len(list(Path(gt_dir).rglob("*.pt"))) if os.path.isdir(gt_dir) else 0
|
| 101 |
+
|
| 102 |
+
n_weights = 0
|
| 103 |
+
if os.path.isdir(weights_dir):
|
| 104 |
+
n_weights = len([f for f in os.listdir(weights_dir) if f.endswith(".pt")])
|
| 105 |
+
|
| 106 |
+
n_pred = 0
|
| 107 |
+
if os.path.isdir(pred_dir):
|
| 108 |
+
for d in os.listdir(pred_dir):
|
| 109 |
+
dp = os.path.join(pred_dir, d)
|
| 110 |
+
if os.path.isdir(dp):
|
| 111 |
+
n_pred = len([f for f in os.listdir(dp) if f.endswith("_mode_0.txt")])
|
| 112 |
+
break
|
| 113 |
+
|
| 114 |
+
return {
|
| 115 |
+
"ground_truth": n_gt,
|
| 116 |
+
"weights": n_weights,
|
| 117 |
+
"predictions": n_pred,
|
| 118 |
+
"has_weights": n_weights > 0,
|
| 119 |
+
"has_gt": n_gt > 0,
|
| 120 |
+
"has_predictions": n_pred > 0,
|
| 121 |
+
}
|
app/utils/inference.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PETIMOT inference utilities for custom proteins."""
|
| 2 |
+
import os, sys, torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
# Ensure PETIMOT is importable
|
| 7 |
+
PETIMOT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 8 |
+
if PETIMOT_ROOT not in sys.path:
|
| 9 |
+
sys.path.insert(0, PETIMOT_ROOT)
|
| 10 |
+
|
| 11 |
+
EMBEDDING_DIM_MAP = {"prostt5": 1024, "esmc_300m": 960, "esmc_600m": 1152}
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def run_inference(pdb_path: str, weights_path: str, config_path: str = None,
|
| 15 |
+
output_dir: str = "/tmp/petimot_pred") -> dict:
|
| 16 |
+
"""Run PETIMOT inference on a single PDB file.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
pdb_path: Path to input PDB file
|
| 20 |
+
weights_path: Path to model weights .pt
|
| 21 |
+
config_path: Path to config YAML (default: configs/default.yaml)
|
| 22 |
+
output_dir: Where to save predictions
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
dict with modes, ca_coords, seq, etc.
|
| 26 |
+
"""
|
| 27 |
+
from petimot.infer.infer import infer
|
| 28 |
+
from petimot.data.pdb_utils import load_backbone_coordinates
|
| 29 |
+
|
| 30 |
+
if config_path is None:
|
| 31 |
+
config_path = os.path.join(PETIMOT_ROOT, "configs", "default.yaml")
|
| 32 |
+
|
| 33 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 34 |
+
|
| 35 |
+
# Run inference
|
| 36 |
+
infer(model_path=weights_path, config_file=config_path,
|
| 37 |
+
input_list=[pdb_path], output_path=output_dir)
|
| 38 |
+
|
| 39 |
+
# Collect results
|
| 40 |
+
stem = os.path.splitext(os.path.basename(weights_path))[0]
|
| 41 |
+
pred_subdir = os.path.join(output_dir, stem)
|
| 42 |
+
basename = os.path.splitext(os.path.basename(pdb_path))[0]
|
| 43 |
+
|
| 44 |
+
# Load structure
|
| 45 |
+
bb_data = load_backbone_coordinates(pdb_path, allow_hetatm=True)
|
| 46 |
+
ca = bb_data["bb"][:, 1].numpy()
|
| 47 |
+
seq = bb_data.get("seq", "X" * len(ca))
|
| 48 |
+
if not isinstance(seq, str):
|
| 49 |
+
seq = "X" * len(ca)
|
| 50 |
+
|
| 51 |
+
# Load predicted modes
|
| 52 |
+
modes = {}
|
| 53 |
+
for k in range(10):
|
| 54 |
+
for pfx in [f"extracted_{basename}", basename]:
|
| 55 |
+
mf = os.path.join(pred_subdir, f"{pfx}_mode_{k}.txt")
|
| 56 |
+
if os.path.exists(mf):
|
| 57 |
+
modes[k] = np.loadtxt(mf)
|
| 58 |
+
break
|
| 59 |
+
|
| 60 |
+
with open(pdb_path) as f:
|
| 61 |
+
pdb_text = f.read()
|
| 62 |
+
|
| 63 |
+
return {
|
| 64 |
+
"name": basename,
|
| 65 |
+
"ca_coords": ca,
|
| 66 |
+
"seq": seq,
|
| 67 |
+
"modes": modes,
|
| 68 |
+
"pdb_text": pdb_text,
|
| 69 |
+
"pred_dir": pred_subdir,
|
| 70 |
+
"n_res": len(ca),
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def download_pdb(pdb_id: str, output_dir: str = "/tmp/petimot_pdbs") -> str | None:
|
| 75 |
+
"""Download PDB from RCSB."""
|
| 76 |
+
import requests
|
| 77 |
+
|
| 78 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 79 |
+
code4 = pdb_id[:4].lower()
|
| 80 |
+
chain = pdb_id[4:].upper() if len(pdb_id) > 4 else ""
|
| 81 |
+
out_path = os.path.join(output_dir, f"{pdb_id}.pdb")
|
| 82 |
+
|
| 83 |
+
if os.path.exists(out_path):
|
| 84 |
+
return out_path
|
| 85 |
+
|
| 86 |
+
r = requests.get(f"https://files.rcsb.org/download/{code4}.pdb", timeout=30)
|
| 87 |
+
if not r.ok:
|
| 88 |
+
return None
|
| 89 |
+
|
| 90 |
+
lines = r.text.split("\n")
|
| 91 |
+
if chain:
|
| 92 |
+
lines = [l for l in lines
|
| 93 |
+
if (l.startswith("ATOM") and len(l) > 21 and l[21] == chain)
|
| 94 |
+
or not l.startswith(("ATOM", "HETATM"))]
|
| 95 |
+
|
| 96 |
+
with open(out_path, "w") as f:
|
| 97 |
+
f.write("\n".join(lines))
|
| 98 |
+
return out_path
|
claude.md
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PETIMOT โ Development State
|
| 2 |
+
|
| 3 |
+
> Last updated: 2026-03-20
|
| 4 |
+
> Maintainer: Valentin Lombard (valou82160@gmail.com)
|
| 5 |
+
|
| 6 |
+
## Project Overview
|
| 7 |
+
|
| 8 |
+
**PETIMOT** = SE(3)-equivariant GNN for Protein Motion prediction from sparse data.
|
| 9 |
+
- Paper: [arXiv 2504.02839](https://arxiv.org/abs/2504.02839) โ Lombard, Grudinin & Laine
|
| 10 |
+
- Public repo: `PhyloSofS-Team/PETIMOT`
|
| 11 |
+
- Private repos: `Vlmbd/petimot_temp`, `Vlmbd/petimot` (contain training code, identical to each other)
|
| 12 |
+
|
| 13 |
+
## Architecture
|
| 14 |
+
|
| 15 |
+
- **Model**: `ProteinMotionMPNN` in `petimot/model/neural_net.py`
|
| 16 |
+
- SE(3)-equivariant message passing neural network
|
| 17 |
+
- Input: PLM embeddings (ProstT5 1024d, or ESM 960d/1152d) + KNN graph
|
| 18 |
+
- Output: K motion modes as (N, 3) displacement vectors per protein
|
| 19 |
+
- 15 independent layers, ~4.7M params (default config)
|
| 20 |
+
|
| 21 |
+
- **Loss functions** in `petimot/model/loss.py`:
|
| 22 |
+
- `compute_nsse_loss`: Normalized SSE with Hungarian assignment (per-mode matching)
|
| 23 |
+
- `compute_rmsip_loss`: Root Mean Square Inner Product (subspace similarity)
|
| 24 |
+
- `compute_ortho_loss`: Independence Score (orthogonality regularizer)
|
| 25 |
+
- Default weights: 0.5รNSSE + 0.5รRMSIP + 0.0รIS
|
| 26 |
+
|
| 27 |
+
- **Data** in `petimot/data/`:
|
| 28 |
+
- `BaseDataset` / `InferenceDataset` in `data_set.py` (public repo)
|
| 29 |
+
- `TrainingDataset` in private repos โ extends BaseDataset with ground truth `.pt` loading
|
| 30 |
+
- Ground truth `.pt` format: `{eigvects: (3N,K), eigvals: (K,), bb: (N,4,3), seq: str, coverage: (N,)}`
|
| 31 |
+
- Embeddings: cached as `{name}_{model}.pt` files in `embeddings/` directory
|
| 32 |
+
|
| 33 |
+
## Repository Structure
|
| 34 |
+
|
| 35 |
+
```
|
| 36 |
+
PETIMOT/
|
| 37 |
+
โโโ configs/default.yaml # Default hyperparameters
|
| 38 |
+
โโโ petimot/
|
| 39 |
+
โ โโโ __main__.py # CLI: infer, evaluate (+ train in private repo)
|
| 40 |
+
โ โโโ model/
|
| 41 |
+
โ โ โโโ neural_net.py # ProteinMotionMPNN
|
| 42 |
+
โ โ โโโ loss.py # NSSE, RMSIP, IS losses
|
| 43 |
+
โ โ โโโ optimizer.py # get_optimizer factory
|
| 44 |
+
โ โโโ data/
|
| 45 |
+
โ โ โโโ data_set.py # BaseDataset, InferenceDataset
|
| 46 |
+
โ โ โโโ embeddings.py # EmbeddingManager (ProstT5/ESM)
|
| 47 |
+
โ โ โโโ pdb_utils.py # load_backbone_coordinates
|
| 48 |
+
โ โโโ infer/infer.py # Inference pipeline
|
| 49 |
+
โ โโโ eval/eval.py # Evaluation with metrics
|
| 50 |
+
โ โโโ utils/
|
| 51 |
+
โ โโโ seeding.py # set_seed (reproducibility)
|
| 52 |
+
โ โโโ rigid_utils.py # Quaternion/rotation utilities
|
| 53 |
+
โโโ full_train_list.txt # ~26K training samples
|
| 54 |
+
โโโ full_val_list.txt # ~5K validation samples
|
| 55 |
+
โโโ eval_list.txt # Test set
|
| 56 |
+
โโโ split_script.py # Family-based split generation
|
| 57 |
+
โโโ PETIMOT_workflow.ipynb # โญ Main deliverable โ Colab notebook
|
| 58 |
+
โโโ weights/ # Pretrained models
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
## Colab Notebook (`PETIMOT_workflow.ipynb`)
|
| 62 |
+
|
| 63 |
+
### Current State: โ
Functional, iterating on polish
|
| 64 |
+
|
| 65 |
+
The notebook is the primary deliverable. It's built via Python patch scripts in `/tmp/` and contains 37 cells across 9 sections:
|
| 66 |
+
|
| 67 |
+
| # | Section | Status | Notes |
|
| 68 |
+
|---|---------|--------|-------|
|
| 69 |
+
| 0 | Setup | โ
| Install (no torch reinstall), GPU check, Drive mount, WandB |
|
| 70 |
+
| 1 | Data | โ
| Figshare manual download + auto-extract, rich dataset stats (6 panels) |
|
| 71 |
+
| 2 | Training | โ
| Full loop with gradient norms, ETA, best/last checkpoints |
|
| 72 |
+
| 3 | Monitoring | โ
| Plotly dashboard, per-sample validation analysis |
|
| 73 |
+
| 4 | Inference | โ
| Single PDB, batch, upload, auto-detect weights |
|
| 74 |
+
| 5 | Visualization | โ
| 5-panel analysis dashboard, 3D py3Dmol, animation, mode grid |
|
| 75 |
+
| 6 | Trajectory Export | โ
| Multi-model PDB for PyMOL/ChimeraX |
|
| 76 |
+
| 7 | Evaluation | โ
| Test set eval, CSV export, baseline comparison |
|
| 77 |
+
| 8 | Ablations | โ
| 10 config presets, comparison plots |
|
| 78 |
+
|
| 79 |
+
### Known Issues & Workarounds
|
| 80 |
+
|
| 81 |
+
1. **Figshare download**: Private link URLs don't support programmatic download (wget/curl/requests all fail). Solution: manual browser download + Colab upload. Cell 1.1 auto-detects and extracts uploads.
|
| 82 |
+
|
| 83 |
+
2. **numpy binary incompatibility**: Installing packages can break numpy. Solution: after cell 0.1, do Runtime โ Restart session, then skip to 0.2.
|
| 84 |
+
|
| 85 |
+
3. **`torch.linalg.eigh` cusolver error**: Happens with AMP (FP16) in `rigid_utils.py` quaternion computation. Solution: cell 2.4 sets `torch.backends.cuda.preferred_linalg_library("magma")`.
|
| 86 |
+
|
| 87 |
+
4. **Split file mismatch**: Config references `val_list.txt` but repo has `full_val_list.txt`. Cell 2.2 auto-detects the correct files.
|
| 88 |
+
|
| 89 |
+
5. **ProstT5 embedding computation**: Takes ~10min on A100, longer on T4. Embeddings are cached in `embeddings/` (symlinked to Drive if mounted). First run is slow, subsequent runs are fast.
|
| 90 |
+
|
| 91 |
+
6. **Batch size**: Default is 16 (too small for large GPUs). On RTX PRO 6000 (96GB), use 128-256.
|
| 92 |
+
|
| 93 |
+
### How the Notebook is Built
|
| 94 |
+
|
| 95 |
+
The notebook is modified via Python scripts that:
|
| 96 |
+
1. Load the `.ipynb` JSON
|
| 97 |
+
2. Find cells by title (e.g., `'5.1' in line`)
|
| 98 |
+
3. Replace source with new code, ensuring proper `\n` line terminators
|
| 99 |
+
4. Save back to disk
|
| 100 |
+
|
| 101 |
+
**Critical**: Each line in `source` array MUST end with `\n` (except the last line) for Colab to execute it. This was a major early bug.
|
| 102 |
+
|
| 103 |
+
Example patch script pattern:
|
| 104 |
+
```python
|
| 105 |
+
import json
|
| 106 |
+
with open('PETIMOT_workflow.ipynb') as f:
|
| 107 |
+
nb = json.load(f)
|
| 108 |
+
for i, c in enumerate(nb['cells']):
|
| 109 |
+
if any('CELL_ID' in l for l in c['source']):
|
| 110 |
+
lines = new_src.split("\n")
|
| 111 |
+
nb['cells'][i]['source'] = [l + "\n" for l in lines[:-1]] + [lines[-1]]
|
| 112 |
+
break
|
| 113 |
+
json.dump(nb, open('PETIMOT_workflow.ipynb', 'w'), indent=1)
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
## Training Details (from private repo)
|
| 117 |
+
|
| 118 |
+
- **TrainingDataset** (cell 2.1):
|
| 119 |
+
- Loads `.pt` ground truth files listed in a text file
|
| 120 |
+
- Eigenvectors reshaped: `(3N, K)` โ `(N, 3, K)` โ `(N, K, 3)`, scaled by `โN`
|
| 121 |
+
- Multiplicative Gaussian noise on eigvects + embeddings for augmentation
|
| 122 |
+
- Random embeddings option for ablation (`rand_emb=True`)
|
| 123 |
+
|
| 124 |
+
- **process_epoch** (cell 2.3):
|
| 125 |
+
- `set_grad_enabled(training)` per loss component
|
| 126 |
+
- AMP with GradScaler, gradient clipping at 10
|
| 127 |
+
- `optimizer.zero_grad(set_to_none=True)` for speed
|
| 128 |
+
- Tracks: NSSE, min_NSSE, RMSIP, ortho, success rate, gradient norms
|
| 129 |
+
|
| 130 |
+
- **train_petimot** (cell 2.3):
|
| 131 |
+
- AdamW + ReduceLROnPlateau (factor=0.5, patience from config)
|
| 132 |
+
- Saves `best.pt` + `last.pt` in `weights/{run_name}/`
|
| 133 |
+
- Auto-loads best model after training
|
| 134 |
+
- Optional WandB logging, optional resume from checkpoint
|
| 135 |
+
|
| 136 |
+
## Data Sources
|
| 137 |
+
|
| 138 |
+
- **Figshare** (private link): https://figshare.com/s/ab400d852b4669a83b64
|
| 139 |
+
- `ground_truth.zip` (958 MB, ~36K `.pt` files)
|
| 140 |
+
- `default_2025-02-07_21-54-02_epoch_33.pt` (18 MB, pretrained weights)
|
| 141 |
+
- `baseline_predictions.zip` (23 MB โ AlphaFlow, ESMFlow, NMA)
|
| 142 |
+
- File IDs: ground_truth=52349453, weights=52349456, baselines=52349480
|
| 143 |
+
|
| 144 |
+
- **Local** (user's machine):
|
| 145 |
+
- `/Users/valentin/Documents/Petimot/` โ public repo clone
|
| 146 |
+
- `/Users/valentin/Documents/petimot_private/` โ private repo clone
|
| 147 |
+
- `/Users/valentin/Documents/petimot_temp/` โ private repo (has weights .pt files)
|
| 148 |
+
|
| 149 |
+
## Target Audience
|
| 150 |
+
|
| 151 |
+
The notebook is designed for **ML/bioinformatics professors and researchers**. Key design decisions:
|
| 152 |
+
- Rich, publication-quality visualizations (not toy demos)
|
| 153 |
+
- Extensive inline comments and docstrings with tensor shapes
|
| 154 |
+
- Interactive Colab form controls for parameters
|
| 155 |
+
- Auto-detect/auto-extract for user-friendly data setup
|
| 156 |
+
- All 8 sections are independent after setup (can run inference without training)
|
| 157 |
+
|
| 158 |
+
## Streamlit App (`app/`)
|
| 159 |
+
|
| 160 |
+
### Current State: โ
Built, needs testing
|
| 161 |
+
|
| 162 |
+
Interactive web explorer for PETIMOT predictions. Replaces notebook sections 5-8.
|
| 163 |
+
|
| 164 |
+
```
|
| 165 |
+
app/
|
| 166 |
+
โโโ app.py # Main entry + sidebar + dark theme
|
| 167 |
+
โโโ requirements.txt # Dependencies
|
| 168 |
+
โโโ pages/
|
| 169 |
+
โ โโโ 1_๐_Explorer.py # Browse pre-computed DB (search, filter, 3D, sequence)
|
| 170 |
+
โ โโโ 2_๐ฎ_Inference.py # Upload PDB โ predict โ visualize + export
|
| 171 |
+
โ โโโ 3_๐_Statistics.py # Dataset-wide distributions + eigenvalue analysis
|
| 172 |
+
โโโ components/
|
| 173 |
+
โ โโโ viewer_3d.py # py3Dmol with arrows, labels, mode comparison grid
|
| 174 |
+
โ โโโ sequence_viewer.py # HTML sequence heatmap + Plotly displacement chart
|
| 175 |
+
โ โโโ mode_panel.py # Mode tabs, correlation matrix, eigenvalue spectrum
|
| 176 |
+
โโโ utils/
|
| 177 |
+
โโโ data_loader.py # Load predictions index, modes, ground truth
|
| 178 |
+
โโโ inference.py # PETIMOT inference wrapper + RCSB PDB download
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
**Run locally:** `cd PETIMOT && streamlit run app/app.py`
|
| 182 |
+
|
| 183 |
+
**Key features:**
|
| 184 |
+
- Dark theme with purple accent palette
|
| 185 |
+
- Real-time sliders for amplitude, arrow size, color scheme
|
| 186 |
+
- Searchable/sortable protein table with click-to-detail
|
| 187 |
+
- HTML sequence viewer with displacement-colored cells + hover tooltips
|
| 188 |
+
- Plotly charts (displacement profiles, eigenvalue spectrum, correlations)
|
| 189 |
+
- CSV + PDB export for inference results
|
| 190 |
+
- PDB ID fetch from RCSB or file upload
|
| 191 |
+
|
| 192 |
+
**Dependencies:** stmol (py3Dmol for Streamlit), plotly, streamlit
|
| 193 |
+
|
| 194 |
+
**Needs:** Pre-computed predictions in `predictions/` directory for Explorer page to work.
|
| 195 |
+
|
| 196 |
+
## Dependencies
|
| 197 |
+
|
| 198 |
+
```
|
| 199 |
+
torch>=2.0.0, torch_geometric, torch_scatter, torch_sparse
|
| 200 |
+
transformers==4.48.3, sentencepiece==0.2.0
|
| 201 |
+
scipy, typer==0.15.1, tqdm, numpy
|
| 202 |
+
wandb, plotly, py3Dmol, biopython, pandas, ipywidgets, gdown, requests
|
| 203 |
+
```
|
requirements.txt
CHANGED
|
@@ -1,9 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
torch>=2.0.0
|
| 2 |
torch_geometric>=2.0.0
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
tqdm>=4.65.0
|
| 7 |
-
|
| 8 |
-
typer==0.15.1
|
| 9 |
-
numpy<2
|
|
|
|
| 1 |
+
# PETIMOT Explorer โ HuggingFace Spaces
|
| 2 |
+
streamlit>=1.30.0
|
| 3 |
+
stmol>=0.0.9
|
| 4 |
+
py3Dmol>=2.0.0
|
| 5 |
+
plotly>=5.18.0
|
| 6 |
+
pandas>=2.0.0
|
| 7 |
+
numpy>=1.24.0
|
| 8 |
torch>=2.0.0
|
| 9 |
torch_geometric>=2.0.0
|
| 10 |
+
torch_scatter
|
| 11 |
+
torch_sparse
|
| 12 |
+
transformers>=4.30.0
|
| 13 |
+
sentencepiece>=0.1.99
|
| 14 |
+
scipy>=1.10.0
|
| 15 |
+
biopython>=1.80
|
| 16 |
+
requests>=2.28.0
|
| 17 |
tqdm>=4.65.0
|
| 18 |
+
typer>=0.9.0
|
|
|
|
|
|