Sync from GitHub (preserve manual model files)
Browse files- StreamlitApp/StreamlitApp.py +3 -3
- StreamlitApp/utils/predict.py +92 -29
- StreamlitApp/utils/tsne.py +4 -3
- requirements.txt +1 -0
StreamlitApp/StreamlitApp.py
CHANGED
|
@@ -10,7 +10,7 @@ import html as _html
|
|
| 10 |
from sklearn.manifold import TSNE
|
| 11 |
|
| 12 |
# Utils map to sidebar pages: predict / analyze / optimize / visualize / tsne, plus shared_ui.
|
| 13 |
-
from utils.predict import load_model, predict_amp, encode_sequence
|
| 14 |
from utils.analyze import aa_composition, compute_properties
|
| 15 |
from utils.optimize import optimize_sequence
|
| 16 |
from utils.shared_ui import (
|
|
@@ -684,11 +684,11 @@ elif page == "t-SNE":
|
|
| 684 |
embeddings_list, labels, confs, lengths, hydros, charges = [], [], [], [], [], []
|
| 685 |
|
| 686 |
# Use penultimate model representation as embedding features.
|
| 687 |
-
embedding_extractor =
|
| 688 |
|
| 689 |
# Build embeddings, then predict label/conf for each sequence (for hover + coloring).
|
| 690 |
for i, s in enumerate(sequences):
|
| 691 |
-
x = torch.tensor(encode_sequence(s), dtype=torch.float32).unsqueeze(0)
|
| 692 |
with torch.no_grad():
|
| 693 |
emb = embedding_extractor(x).squeeze().numpy()
|
| 694 |
embeddings_list.append(emb)
|
|
|
|
| 10 |
from sklearn.manifold import TSNE
|
| 11 |
|
| 12 |
# Utils map to sidebar pages: predict / analyze / optimize / visualize / tsne, plus shared_ui.
|
| 13 |
+
from utils.predict import load_model, predict_amp, encode_sequence, get_embedding_extractor
|
| 14 |
from utils.analyze import aa_composition, compute_properties
|
| 15 |
from utils.optimize import optimize_sequence
|
| 16 |
from utils.shared_ui import (
|
|
|
|
| 684 |
embeddings_list, labels, confs, lengths, hydros, charges = [], [], [], [], [], []
|
| 685 |
|
| 686 |
# Use penultimate model representation as embedding features.
|
| 687 |
+
embedding_extractor = get_embedding_extractor(model)
|
| 688 |
|
| 689 |
# Build embeddings, then predict label/conf for each sequence (for hover + coloring).
|
| 690 |
for i, s in enumerate(sequences):
|
| 691 |
+
x = torch.tensor(encode_sequence(s, model), dtype=torch.float32).unsqueeze(0)
|
| 692 |
with torch.no_grad():
|
| 693 |
emb = embedding_extractor(x).squeeze().numpy()
|
| 694 |
embeddings_list.append(emb)
|
StreamlitApp/utils/predict.py
CHANGED
|
@@ -1,13 +1,19 @@
|
|
| 1 |
-
# Predict page (and shared):
|
| 2 |
import pathlib
|
| 3 |
import numpy as np
|
| 4 |
import torch
|
| 5 |
import streamlit as st
|
| 6 |
from torch import nn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
# Lightweight MLP used for AMP binary classification.
|
| 9 |
class FastMLP(nn.Module):
|
| 10 |
-
def __init__(self, input_dim=
|
| 11 |
super(FastMLP, self).__init__()
|
| 12 |
self.layers = nn.Sequential(
|
| 13 |
nn.Linear(input_dim, 512),
|
|
@@ -15,21 +21,48 @@ class FastMLP(nn.Module):
|
|
| 15 |
nn.Dropout(0.3),
|
| 16 |
nn.Linear(512, 128),
|
| 17 |
nn.ReLU(),
|
| 18 |
-
nn.Linear(128, 1) # Single output for binary classification
|
| 19 |
)
|
| 20 |
|
| 21 |
def forward(self, x):
|
| 22 |
return self.layers(x)
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
@st.cache_resource
|
| 25 |
def load_model():
|
| 26 |
-
# Load
|
| 27 |
-
# Always resolve relative to the StreamlitApp folder, not the process CWD.
|
| 28 |
streamlitapp_dir = pathlib.Path(__file__).resolve().parent.parent
|
| 29 |
repo_root = streamlitapp_dir.parent
|
| 30 |
|
| 31 |
candidates = [
|
| 32 |
repo_root / "MLModels" / "ampMLModel.pt",
|
|
|
|
| 33 |
repo_root / "models" / "ampMLModel.pt",
|
| 34 |
streamlitapp_dir / "models" / "ampMLModel.pt",
|
| 35 |
]
|
|
@@ -37,43 +70,73 @@ def load_model():
|
|
| 37 |
|
| 38 |
if not model_path.exists():
|
| 39 |
raise FileNotFoundError(
|
| 40 |
-
"
|
| 41 |
f"- {repo_root / 'MLModels' / 'ampMLModel.pt'}\n"
|
|
|
|
| 42 |
f"- {repo_root / 'models' / 'ampMLModel.pt'}\n"
|
| 43 |
f"- {streamlitapp_dir / 'models' / 'ampMLModel.pt'}\n"
|
| 44 |
)
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
flat = one_hot.flatten()
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
-
return flat
|
| 69 |
|
| 70 |
-
def predict_amp(sequence,
|
| 71 |
# Run AMP inference and return predicted label plus AMP probability.
|
| 72 |
-
x = torch.tensor(encode_sequence(sequence), dtype=torch.float32).unsqueeze(0)
|
|
|
|
| 73 |
|
| 74 |
-
# Sigmoid(logit) gives AMP probability in [0, 1].
|
| 75 |
with torch.no_grad():
|
| 76 |
-
logits =
|
| 77 |
prob = torch.sigmoid(logits).item()
|
| 78 |
|
| 79 |
label = "AMP" if prob >= 0.5 else "Non-AMP"
|
|
|
|
| 1 |
+
# Predict page (and shared): ProtBERT embedding + MLP classifier inference.
|
| 2 |
import pathlib
|
| 3 |
import numpy as np
|
| 4 |
import torch
|
| 5 |
import streamlit as st
|
| 6 |
from torch import nn
|
| 7 |
+
from transformers import AutoModel, AutoTokenizer
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
MODEL_INPUT_DIM = 1024
|
| 11 |
+
MODEL_ARCH = "FastMLP"
|
| 12 |
+
PROTBERT_MODEL_NAME = "Rostlab/prot_bert"
|
| 13 |
+
|
| 14 |
|
|
|
|
| 15 |
class FastMLP(nn.Module):
|
| 16 |
+
def __init__(self, input_dim=MODEL_INPUT_DIM):
|
| 17 |
super(FastMLP, self).__init__()
|
| 18 |
self.layers = nn.Sequential(
|
| 19 |
nn.Linear(input_dim, 512),
|
|
|
|
| 21 |
nn.Dropout(0.3),
|
| 22 |
nn.Linear(512, 128),
|
| 23 |
nn.ReLU(),
|
| 24 |
+
nn.Linear(128, 1), # Single output logit for binary classification
|
| 25 |
)
|
| 26 |
|
| 27 |
def forward(self, x):
|
| 28 |
return self.layers(x)
|
| 29 |
|
| 30 |
+
|
| 31 |
+
def _load_checkpoint(path: pathlib.Path):
|
| 32 |
+
# Accept either raw state_dict (legacy) or structured checkpoint dict.
|
| 33 |
+
obj = torch.load(str(path), map_location="cpu")
|
| 34 |
+
if isinstance(obj, dict) and "state_dict" in obj:
|
| 35 |
+
return obj["state_dict"], obj.get("meta", {})
|
| 36 |
+
if isinstance(obj, dict):
|
| 37 |
+
return obj, {}
|
| 38 |
+
raise ValueError(
|
| 39 |
+
f"Unsupported model checkpoint format at '{path}'. "
|
| 40 |
+
"Expected a PyTorch state_dict or {'state_dict': ..., 'meta': ...}."
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _infer_first_layer_input_dim(state_dict: dict) -> int | None:
|
| 45 |
+
w = state_dict.get("layers.0.weight")
|
| 46 |
+
if w is None:
|
| 47 |
+
return None
|
| 48 |
+
if hasattr(w, "shape") and len(w.shape) == 2:
|
| 49 |
+
return int(w.shape[1])
|
| 50 |
+
return None
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _normalize_sequence(sequence: str) -> str:
|
| 54 |
+
return "".join(c for c in str(sequence).upper() if not c.isspace())
|
| 55 |
+
|
| 56 |
+
|
| 57 |
@st.cache_resource
|
| 58 |
def load_model():
|
| 59 |
+
# Load AMP classifier weights + ProtBERT encoder once per Streamlit process.
|
|
|
|
| 60 |
streamlitapp_dir = pathlib.Path(__file__).resolve().parent.parent
|
| 61 |
repo_root = streamlitapp_dir.parent
|
| 62 |
|
| 63 |
candidates = [
|
| 64 |
repo_root / "MLModels" / "ampMLModel.pt",
|
| 65 |
+
repo_root / "MLModels" / "fast_mlp_amp.pt",
|
| 66 |
repo_root / "models" / "ampMLModel.pt",
|
| 67 |
streamlitapp_dir / "models" / "ampMLModel.pt",
|
| 68 |
]
|
|
|
|
| 70 |
|
| 71 |
if not model_path.exists():
|
| 72 |
raise FileNotFoundError(
|
| 73 |
+
"Classifier checkpoint not found in any of:\n"
|
| 74 |
f"- {repo_root / 'MLModels' / 'ampMLModel.pt'}\n"
|
| 75 |
+
f"- {repo_root / 'MLModels' / 'fast_mlp_amp.pt'}\n"
|
| 76 |
f"- {repo_root / 'models' / 'ampMLModel.pt'}\n"
|
| 77 |
f"- {streamlitapp_dir / 'models' / 'ampMLModel.pt'}\n"
|
| 78 |
)
|
| 79 |
|
| 80 |
+
state_dict, _meta = _load_checkpoint(model_path)
|
| 81 |
+
inferred_input_dim = _infer_first_layer_input_dim(state_dict)
|
| 82 |
+
if inferred_input_dim != MODEL_INPUT_DIM:
|
| 83 |
+
raise ValueError(
|
| 84 |
+
"Model/input mismatch. Loaded classifier expects "
|
| 85 |
+
f"{inferred_input_dim} input features; ProtBERT pooled embeddings are {MODEL_INPUT_DIM}-dim."
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
classifier = FastMLP(input_dim=MODEL_INPUT_DIM)
|
| 89 |
+
classifier.load_state_dict(state_dict)
|
| 90 |
+
classifier.eval()
|
| 91 |
|
| 92 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 93 |
+
tokenizer = AutoTokenizer.from_pretrained(PROTBERT_MODEL_NAME)
|
| 94 |
+
encoder = AutoModel.from_pretrained(PROTBERT_MODEL_NAME).to(device)
|
| 95 |
+
encoder.eval()
|
| 96 |
|
| 97 |
+
return {
|
| 98 |
+
"classifier": classifier,
|
| 99 |
+
"tokenizer": tokenizer,
|
| 100 |
+
"encoder": encoder,
|
| 101 |
+
"device": device,
|
| 102 |
+
"classifier_path": str(model_path),
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def encode_sequence(seq, model_bundle):
|
| 107 |
+
# Convert peptide sequence to ProtBERT mean-pooled embedding (1024 dims).
|
| 108 |
+
clean = _normalize_sequence(seq)
|
| 109 |
+
spaced = " ".join(list(clean))
|
| 110 |
+
tokenizer = model_bundle["tokenizer"]
|
| 111 |
+
encoder = model_bundle["encoder"]
|
| 112 |
+
device = model_bundle["device"]
|
| 113 |
+
|
| 114 |
+
tokens = tokenizer(
|
| 115 |
+
spaced,
|
| 116 |
+
return_tensors="pt",
|
| 117 |
+
truncation=True,
|
| 118 |
+
padding=True,
|
| 119 |
+
).to(device)
|
| 120 |
+
with torch.no_grad():
|
| 121 |
+
outputs = encoder(**tokens)
|
| 122 |
+
emb = outputs.last_hidden_state.mean(dim=1).squeeze(0).detach().cpu().numpy()
|
| 123 |
+
return emb.astype(np.float32)
|
| 124 |
|
|
|
|
| 125 |
|
| 126 |
+
def get_embedding_extractor(model_bundle):
|
| 127 |
+
classifier = model_bundle["classifier"]
|
| 128 |
+
extractor = torch.nn.Sequential(*list(classifier.layers)[:-1])
|
| 129 |
+
extractor.eval()
|
| 130 |
+
return extractor
|
| 131 |
|
|
|
|
| 132 |
|
| 133 |
+
def predict_amp(sequence, model_bundle):
|
| 134 |
# Run AMP inference and return predicted label plus AMP probability.
|
| 135 |
+
x = torch.tensor(encode_sequence(sequence, model_bundle), dtype=torch.float32).unsqueeze(0)
|
| 136 |
+
classifier = model_bundle["classifier"]
|
| 137 |
|
|
|
|
| 138 |
with torch.no_grad():
|
| 139 |
+
logits = classifier(x)
|
| 140 |
prob = torch.sigmoid(logits).item()
|
| 141 |
|
| 142 |
label = "AMP" if prob >= 0.5 else "Non-AMP"
|
StreamlitApp/utils/tsne.py
CHANGED
|
@@ -5,17 +5,18 @@ from sklearn.manifold import TSNE
|
|
| 5 |
import streamlit as st
|
| 6 |
import torch
|
| 7 |
import numpy as np
|
| 8 |
-
from utils.predict import encode_sequence
|
| 9 |
|
| 10 |
def tsne_visualization(sequences, model):
|
| 11 |
# Project model embeddings into 2D and render a quick scatter plot.
|
| 12 |
st.info("Generating embeddings... this may take a moment.")
|
| 13 |
embeddings = []
|
|
|
|
| 14 |
for seq in sequences:
|
| 15 |
-
x = torch.tensor(encode_sequence(seq), dtype=torch.float32).unsqueeze(0)
|
| 16 |
with torch.no_grad():
|
| 17 |
# Use an early hidden layer as a compact learned representation.
|
| 18 |
-
emb =
|
| 19 |
embeddings.append(emb.numpy().flatten())
|
| 20 |
|
| 21 |
embeddings = np.vstack(embeddings)
|
|
|
|
| 5 |
import streamlit as st
|
| 6 |
import torch
|
| 7 |
import numpy as np
|
| 8 |
+
from utils.predict import encode_sequence, get_embedding_extractor
|
| 9 |
|
| 10 |
def tsne_visualization(sequences, model):
|
| 11 |
# Project model embeddings into 2D and render a quick scatter plot.
|
| 12 |
st.info("Generating embeddings... this may take a moment.")
|
| 13 |
embeddings = []
|
| 14 |
+
embedding_extractor = get_embedding_extractor(model)
|
| 15 |
for seq in sequences:
|
| 16 |
+
x = torch.tensor(encode_sequence(seq, model), dtype=torch.float32).unsqueeze(0)
|
| 17 |
with torch.no_grad():
|
| 18 |
# Use an early hidden layer as a compact learned representation.
|
| 19 |
+
emb = embedding_extractor(x)
|
| 20 |
embeddings.append(emb.numpy().flatten())
|
| 21 |
|
| 22 |
embeddings = np.vstack(embeddings)
|
requirements.txt
CHANGED
|
@@ -7,3 +7,4 @@ matplotlib>=3.7.0
|
|
| 7 |
plotly>=5.14.0
|
| 8 |
requests>=2.28.0
|
| 9 |
py3dmol>=2.0.0
|
|
|
|
|
|
| 7 |
plotly>=5.14.0
|
| 8 |
requests>=2.28.0
|
| 9 |
py3dmol>=2.0.0
|
| 10 |
+
transformers>=4.40.0
|