m0ksh commited on
Commit
1b8a737
·
verified ·
1 Parent(s): 893f11c

Sync from GitHub (preserve manual model files)

Browse files
StreamlitApp/StreamlitApp.py CHANGED
@@ -10,7 +10,7 @@ import html as _html
10
  from sklearn.manifold import TSNE
11
 
12
  # Utils map to sidebar pages: predict / analyze / optimize / visualize / tsne, plus shared_ui.
13
- from utils.predict import load_model, predict_amp, encode_sequence
14
  from utils.analyze import aa_composition, compute_properties
15
  from utils.optimize import optimize_sequence
16
  from utils.shared_ui import (
@@ -684,11 +684,11 @@ elif page == "t-SNE":
684
  embeddings_list, labels, confs, lengths, hydros, charges = [], [], [], [], [], []
685
 
686
  # Use penultimate model representation as embedding features.
687
- embedding_extractor = torch.nn.Sequential(*list(model.layers)[:-1])
688
 
689
  # Build embeddings, then predict label/conf for each sequence (for hover + coloring).
690
  for i, s in enumerate(sequences):
691
- x = torch.tensor(encode_sequence(s), dtype=torch.float32).unsqueeze(0)
692
  with torch.no_grad():
693
  emb = embedding_extractor(x).squeeze().numpy()
694
  embeddings_list.append(emb)
 
10
  from sklearn.manifold import TSNE
11
 
12
  # Utils map to sidebar pages: predict / analyze / optimize / visualize / tsne, plus shared_ui.
13
+ from utils.predict import load_model, predict_amp, encode_sequence, get_embedding_extractor
14
  from utils.analyze import aa_composition, compute_properties
15
  from utils.optimize import optimize_sequence
16
  from utils.shared_ui import (
 
684
  embeddings_list, labels, confs, lengths, hydros, charges = [], [], [], [], [], []
685
 
686
  # Use penultimate model representation as embedding features.
687
+ embedding_extractor = get_embedding_extractor(model)
688
 
689
  # Build embeddings, then predict label/conf for each sequence (for hover + coloring).
690
  for i, s in enumerate(sequences):
691
+ x = torch.tensor(encode_sequence(s, model), dtype=torch.float32).unsqueeze(0)
692
  with torch.no_grad():
693
  emb = embedding_extractor(x).squeeze().numpy()
694
  embeddings_list.append(emb)
StreamlitApp/utils/predict.py CHANGED
@@ -1,13 +1,19 @@
1
- # Predict page (and shared): load AMP model, one-hot encode, run predict_amp.
2
  import pathlib
3
  import numpy as np
4
  import torch
5
  import streamlit as st
6
  from torch import nn
 
 
 
 
 
 
 
7
 
8
- # Lightweight MLP used for AMP binary classification.
9
  class FastMLP(nn.Module):
10
- def __init__(self, input_dim=1024):
11
  super(FastMLP, self).__init__()
12
  self.layers = nn.Sequential(
13
  nn.Linear(input_dim, 512),
@@ -15,21 +21,48 @@ class FastMLP(nn.Module):
15
  nn.Dropout(0.3),
16
  nn.Linear(512, 128),
17
  nn.ReLU(),
18
- nn.Linear(128, 1) # Single output for binary classification
19
  )
20
 
21
  def forward(self, x):
22
  return self.layers(x)
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  @st.cache_resource
25
  def load_model():
26
- # Load model weights once per Streamlit process.
27
- # Always resolve relative to the StreamlitApp folder, not the process CWD.
28
  streamlitapp_dir = pathlib.Path(__file__).resolve().parent.parent
29
  repo_root = streamlitapp_dir.parent
30
 
31
  candidates = [
32
  repo_root / "MLModels" / "ampMLModel.pt",
 
33
  repo_root / "models" / "ampMLModel.pt",
34
  streamlitapp_dir / "models" / "ampMLModel.pt",
35
  ]
@@ -37,43 +70,73 @@ def load_model():
37
 
38
  if not model_path.exists():
39
  raise FileNotFoundError(
40
- "Model file 'ampMLModel.pt' not found in any of:\n"
41
  f"- {repo_root / 'MLModels' / 'ampMLModel.pt'}\n"
 
42
  f"- {repo_root / 'models' / 'ampMLModel.pt'}\n"
43
  f"- {streamlitapp_dir / 'models' / 'ampMLModel.pt'}\n"
44
  )
45
 
46
- # Instantiate architecture and hydrate weights from disk.
47
- model = FastMLP(input_dim=1024)
48
- model.load_state_dict(torch.load(str(model_path), map_location="cpu"))
49
- model.eval()
50
- return model
 
 
 
 
 
 
51
 
52
- def encode_sequence(seq, max_len=51):
53
- # Convert sequence to a padded/truncated flattened one-hot vector (1024 dims).
54
- amino_acids = "ACDEFGHIKLMNPQRSTVWY"
55
- aa_to_idx = {aa: i for i, aa in enumerate(amino_acids)}
56
 
57
- # Encode each residue as a one-hot row, then flatten to vector features.
58
- one_hot = np.zeros((max_len, len(amino_acids)))
59
- for i, aa in enumerate(seq[:max_len]):
60
- if aa in aa_to_idx:
61
- one_hot[i, aa_to_idx[aa]] = 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- flat = one_hot.flatten()
64
 
65
- if len(flat) < 1024:
66
- flat = np.pad(flat, (0, 1024 - len(flat)))
 
 
 
67
 
68
- return flat
69
 
70
- def predict_amp(sequence, model):
71
  # Run AMP inference and return predicted label plus AMP probability.
72
- x = torch.tensor(encode_sequence(sequence), dtype=torch.float32).unsqueeze(0)
 
73
 
74
- # Sigmoid(logit) gives AMP probability in [0, 1].
75
  with torch.no_grad():
76
- logits = model(x)
77
  prob = torch.sigmoid(logits).item()
78
 
79
  label = "AMP" if prob >= 0.5 else "Non-AMP"
 
1
+ # Predict page (and shared): ProtBERT embedding + MLP classifier inference.
2
  import pathlib
3
  import numpy as np
4
  import torch
5
  import streamlit as st
6
  from torch import nn
7
+ from transformers import AutoModel, AutoTokenizer
8
+
9
+
10
+ MODEL_INPUT_DIM = 1024
11
+ MODEL_ARCH = "FastMLP"
12
+ PROTBERT_MODEL_NAME = "Rostlab/prot_bert"
13
+
14
 
 
15
  class FastMLP(nn.Module):
16
+ def __init__(self, input_dim=MODEL_INPUT_DIM):
17
  super(FastMLP, self).__init__()
18
  self.layers = nn.Sequential(
19
  nn.Linear(input_dim, 512),
 
21
  nn.Dropout(0.3),
22
  nn.Linear(512, 128),
23
  nn.ReLU(),
24
+ nn.Linear(128, 1), # Single output logit for binary classification
25
  )
26
 
27
  def forward(self, x):
28
  return self.layers(x)
29
 
30
+
31
+ def _load_checkpoint(path: pathlib.Path):
32
+ # Accept either raw state_dict (legacy) or structured checkpoint dict.
33
+ obj = torch.load(str(path), map_location="cpu")
34
+ if isinstance(obj, dict) and "state_dict" in obj:
35
+ return obj["state_dict"], obj.get("meta", {})
36
+ if isinstance(obj, dict):
37
+ return obj, {}
38
+ raise ValueError(
39
+ f"Unsupported model checkpoint format at '{path}'. "
40
+ "Expected a PyTorch state_dict or {'state_dict': ..., 'meta': ...}."
41
+ )
42
+
43
+
44
+ def _infer_first_layer_input_dim(state_dict: dict) -> int | None:
45
+ w = state_dict.get("layers.0.weight")
46
+ if w is None:
47
+ return None
48
+ if hasattr(w, "shape") and len(w.shape) == 2:
49
+ return int(w.shape[1])
50
+ return None
51
+
52
+
53
+ def _normalize_sequence(sequence: str) -> str:
54
+ return "".join(c for c in str(sequence).upper() if not c.isspace())
55
+
56
+
57
  @st.cache_resource
58
  def load_model():
59
+ # Load AMP classifier weights + ProtBERT encoder once per Streamlit process.
 
60
  streamlitapp_dir = pathlib.Path(__file__).resolve().parent.parent
61
  repo_root = streamlitapp_dir.parent
62
 
63
  candidates = [
64
  repo_root / "MLModels" / "ampMLModel.pt",
65
+ repo_root / "MLModels" / "fast_mlp_amp.pt",
66
  repo_root / "models" / "ampMLModel.pt",
67
  streamlitapp_dir / "models" / "ampMLModel.pt",
68
  ]
 
70
 
71
  if not model_path.exists():
72
  raise FileNotFoundError(
73
+ "Classifier checkpoint not found in any of:\n"
74
  f"- {repo_root / 'MLModels' / 'ampMLModel.pt'}\n"
75
+ f"- {repo_root / 'MLModels' / 'fast_mlp_amp.pt'}\n"
76
  f"- {repo_root / 'models' / 'ampMLModel.pt'}\n"
77
  f"- {streamlitapp_dir / 'models' / 'ampMLModel.pt'}\n"
78
  )
79
 
80
+ state_dict, _meta = _load_checkpoint(model_path)
81
+ inferred_input_dim = _infer_first_layer_input_dim(state_dict)
82
+ if inferred_input_dim != MODEL_INPUT_DIM:
83
+ raise ValueError(
84
+ "Model/input mismatch. Loaded classifier expects "
85
+ f"{inferred_input_dim} input features; ProtBERT pooled embeddings are {MODEL_INPUT_DIM}-dim."
86
+ )
87
+
88
+ classifier = FastMLP(input_dim=MODEL_INPUT_DIM)
89
+ classifier.load_state_dict(state_dict)
90
+ classifier.eval()
91
 
92
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
93
+ tokenizer = AutoTokenizer.from_pretrained(PROTBERT_MODEL_NAME)
94
+ encoder = AutoModel.from_pretrained(PROTBERT_MODEL_NAME).to(device)
95
+ encoder.eval()
96
 
97
+ return {
98
+ "classifier": classifier,
99
+ "tokenizer": tokenizer,
100
+ "encoder": encoder,
101
+ "device": device,
102
+ "classifier_path": str(model_path),
103
+ }
104
+
105
+
106
+ def encode_sequence(seq, model_bundle):
107
+ # Convert peptide sequence to ProtBERT mean-pooled embedding (1024 dims).
108
+ clean = _normalize_sequence(seq)
109
+ spaced = " ".join(list(clean))
110
+ tokenizer = model_bundle["tokenizer"]
111
+ encoder = model_bundle["encoder"]
112
+ device = model_bundle["device"]
113
+
114
+ tokens = tokenizer(
115
+ spaced,
116
+ return_tensors="pt",
117
+ truncation=True,
118
+ padding=True,
119
+ ).to(device)
120
+ with torch.no_grad():
121
+ outputs = encoder(**tokens)
122
+ emb = outputs.last_hidden_state.mean(dim=1).squeeze(0).detach().cpu().numpy()
123
+ return emb.astype(np.float32)
124
 
 
125
 
126
+ def get_embedding_extractor(model_bundle):
127
+ classifier = model_bundle["classifier"]
128
+ extractor = torch.nn.Sequential(*list(classifier.layers)[:-1])
129
+ extractor.eval()
130
+ return extractor
131
 
 
132
 
133
+ def predict_amp(sequence, model_bundle):
134
  # Run AMP inference and return predicted label plus AMP probability.
135
+ x = torch.tensor(encode_sequence(sequence, model_bundle), dtype=torch.float32).unsqueeze(0)
136
+ classifier = model_bundle["classifier"]
137
 
 
138
  with torch.no_grad():
139
+ logits = classifier(x)
140
  prob = torch.sigmoid(logits).item()
141
 
142
  label = "AMP" if prob >= 0.5 else "Non-AMP"
StreamlitApp/utils/tsne.py CHANGED
@@ -5,17 +5,18 @@ from sklearn.manifold import TSNE
5
  import streamlit as st
6
  import torch
7
  import numpy as np
8
- from utils.predict import encode_sequence
9
 
10
  def tsne_visualization(sequences, model):
11
  # Project model embeddings into 2D and render a quick scatter plot.
12
  st.info("Generating embeddings... this may take a moment.")
13
  embeddings = []
 
14
  for seq in sequences:
15
- x = torch.tensor(encode_sequence(seq), dtype=torch.float32).unsqueeze(0)
16
  with torch.no_grad():
17
  # Use an early hidden layer as a compact learned representation.
18
- emb = model.layers[0](x)
19
  embeddings.append(emb.numpy().flatten())
20
 
21
  embeddings = np.vstack(embeddings)
 
5
  import streamlit as st
6
  import torch
7
  import numpy as np
8
+ from utils.predict import encode_sequence, get_embedding_extractor
9
 
10
  def tsne_visualization(sequences, model):
11
  # Project model embeddings into 2D and render a quick scatter plot.
12
  st.info("Generating embeddings... this may take a moment.")
13
  embeddings = []
14
+ embedding_extractor = get_embedding_extractor(model)
15
  for seq in sequences:
16
+ x = torch.tensor(encode_sequence(seq, model), dtype=torch.float32).unsqueeze(0)
17
  with torch.no_grad():
18
  # Use an early hidden layer as a compact learned representation.
19
+ emb = embedding_extractor(x)
20
  embeddings.append(emb.numpy().flatten())
21
 
22
  embeddings = np.vstack(embeddings)
requirements.txt CHANGED
@@ -7,3 +7,4 @@ matplotlib>=3.7.0
7
  plotly>=5.14.0
8
  requests>=2.28.0
9
  py3dmol>=2.0.0
 
 
7
  plotly>=5.14.0
8
  requests>=2.28.0
9
  py3dmol>=2.0.0
10
+ transformers>=4.40.0