File size: 5,402 Bytes
1b8a737 4466c5e ebee819 1b8a737 ea61d54 1b8a737 ea61d54 1b8a737 4466c5e ea61d54 1b8a737 4466c5e 1b8a737 4466c5e 1b8a737 ea61d54 1b8a737 ea61d54 1b8a737 4466c5e 1b8a737 4466c5e a4c9dd7 6dbd5a0 1b8a737 a4c9dd7 ea61d54 a4c9dd7 4466c5e 5100cb5 1b8a737 5100cb5 1b8a737 5100cb5 4466c5e 1b8a737 4466c5e 1b8a737 e51c298 9749c7f ebee819 1b8a737 4466c5e 1b8a737 4466c5e 1b8a737 ea61d54 1b8a737 4466c5e 1b8a737 b7e5b63 1b8a737 4466c5e 1b8a737 4466c5e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | # Predict page (and shared): ProtBERT embedding + MLP classifier inference.
import pathlib
import numpy as np
import torch
import streamlit as st
from torch import nn
from transformers import BertModel, BertTokenizer
MODEL_INPUT_DIM = 1024 # ProtBERT pooled embedding size; MLP first layer must match.
MODEL_ARCH = "FastMLP"
PROTBERT_MODEL_NAME = "Rostlab/prot_bert" # HF id for tokenizer + encoder weights.
class FastMLP(nn.Module):
# Small classifier head on top of frozen ProtBERT embeddings at inference.
def __init__(self, input_dim=MODEL_INPUT_DIM):
super(FastMLP, self).__init__()
self.layers = nn.Sequential(
nn.Linear(input_dim, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 128),
nn.ReLU(),
nn.Linear(128, 1), # Single output logit for binary classification
)
def forward(self, x):
return self.layers(x)
def _load_checkpoint(path: pathlib.Path):
# Accept either raw state_dict (legacy) or structured checkpoint dict.
obj = torch.load(str(path), map_location="cpu")
if isinstance(obj, dict) and "state_dict" in obj:
return obj["state_dict"], obj.get("meta", {})
if isinstance(obj, dict):
return obj, {}
raise ValueError(
f"Unsupported model checkpoint format at '{path}'. "
"Expected a PyTorch state_dict or {'state_dict': ..., 'meta': ...}."
)
def _infer_first_layer_input_dim(state_dict: dict) -> int | None:
# Infer MLP input dim from Linear weight shape (out_features, in_features).
w = state_dict.get("layers.0.weight")
if w is None:
return None
if hasattr(w, "shape") and len(w.shape) == 2:
return int(w.shape[1])
return None
def _normalize_sequence(sequence: str) -> str:
# Uppercase + strip whitespace so tokenization matches training conventions.
return "".join(c for c in str(sequence).upper() if not c.isspace())
@st.cache_resource
def load_model():
# Load AMP classifier weights + ProtBERT encoder once per Streamlit process.
streamlitapp_dir = pathlib.Path(__file__).resolve().parent.parent
repo_root = streamlitapp_dir.parent
candidates = [
repo_root / "MLModels" / "ampMLModel.pt",
repo_root / "MLModels" / "fast_mlp_amp.pt",
repo_root / "models" / "ampMLModel.pt",
streamlitapp_dir / "models" / "ampMLModel.pt",
]
# Prefer first existing path so local / HF layouts both work.
model_path = next((p for p in candidates if p.exists()), candidates[0])
if not model_path.exists():
raise FileNotFoundError(
"Classifier checkpoint not found in any of:\n"
f"- {repo_root / 'MLModels' / 'ampMLModel.pt'}\n"
f"- {repo_root / 'MLModels' / 'fast_mlp_amp.pt'}\n"
f"- {repo_root / 'models' / 'ampMLModel.pt'}\n"
f"- {streamlitapp_dir / 'models' / 'ampMLModel.pt'}\n"
)
state_dict, _meta = _load_checkpoint(model_path)
inferred_input_dim = _infer_first_layer_input_dim(state_dict)
if inferred_input_dim != MODEL_INPUT_DIM:
raise ValueError(
"Model/input mismatch. Loaded classifier expects "
f"{inferred_input_dim} input features; ProtBERT pooled embeddings are {MODEL_INPUT_DIM}-dim."
)
classifier = FastMLP(input_dim=MODEL_INPUT_DIM)
classifier.load_state_dict(state_dict)
classifier.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Use an explicit slow tokenizer to avoid fast-backend conversion issues on Spaces.
tokenizer = BertTokenizer.from_pretrained(PROTBERT_MODEL_NAME, do_lower_case=False)
# Use explicit BERT class to avoid AutoModel config auto-detection issues.
encoder = BertModel.from_pretrained(PROTBERT_MODEL_NAME).to(device)
encoder.eval()
return {
"classifier": classifier,
"tokenizer": tokenizer,
"encoder": encoder,
"device": device,
"classifier_path": str(model_path),
}
def encode_sequence(seq, model_bundle):
# Convert peptide sequence to ProtBERT mean-pooled embedding (1024 dims).
clean = _normalize_sequence(seq)
spaced = " ".join(list(clean))
tokenizer = model_bundle["tokenizer"]
encoder = model_bundle["encoder"]
device = model_bundle["device"]
tokens = tokenizer(
spaced,
return_tensors="pt",
truncation=True,
padding=True,
).to(device)
with torch.no_grad():
outputs = encoder(**tokens)
emb = outputs.last_hidden_state.mean(dim=1).squeeze(0).detach().cpu().numpy()
return emb.astype(np.float32)
def get_embedding_extractor(model_bundle):
# Penultimate MLP activations for t-SNE (same depth as training-time “embedding” use).
classifier = model_bundle["classifier"]
extractor = torch.nn.Sequential(*list(classifier.layers)[:-1])
extractor.eval()
return extractor
def predict_amp(sequence, model_bundle):
# Run AMP inference and return predicted label plus AMP probability.
x = torch.tensor(encode_sequence(sequence, model_bundle), dtype=torch.float32).unsqueeze(0)
classifier = model_bundle["classifier"]
with torch.no_grad():
logits = classifier(x)
prob = torch.sigmoid(logits).item()
label = "AMP" if prob >= 0.5 else "Non-AMP"
return label, round(prob, 3) |