Spaces:
Running
Running
src/models/classifier.py — documentation
This module defines a feed-forward classifier head for multilabel protein localization on top of precomputed embeddings.
Purpose
- Consume fixed-length embedding vectors (for example from ESM-2).
- Produce raw logits per localization label.
- Keep inference helpers close to the model (
predict_proba,predict). - Keep label-name mapping inside the model object so outputs are human-readable.
Label handling
ProteinLocalizationClassifier supports dynamic label configuration:
- If
label_namesis passed explicitly, those names are used. - Else, it tries to load names from
label_columns.json(default path points to theesm2_t33_650Membeddings folder). - If that file is missing/invalid, it falls back to built-in DeepLoc label names (including
Peroxisome).
num_labels is optional:
- If
num_labels=None, it is inferred aslen(label_names). - If provided, it must match
len(label_names), otherwiseValueErroris raised.
Architecture
The model is an MLP with batch norm and dropout:
Linear(embedding_dim -> 512)BatchNorm1d(512)ReLUDropout(0.3)Linear(512 -> 256)BatchNorm1d(256)ReLUDropout(0.3)Linear(256 -> 128)BatchNorm1d(128)ReLUDropout(0.2)Linear(128 -> num_labels)
Initialization
All linear layers use He/Kaiming normal initialization (kaiming_normal_ with ReLU nonlinearity), and biases are set to zero.
Forward and inference methods
forward(x)
- Input: tensor of shape
(B, embedding_dim). - Output: raw logits
(B, num_labels). - No sigmoid is applied; this is intended for
BCEWithLogitsLoss.
predict_proba(embedding)
- Accepts either a single vector
(embedding_dim,)or a batch(B, embedding_dim). - Temporarily switches model to
eval(), runs forward pass, applies sigmoid. - Returns:
- single input ->
dict[label_name -> probability] - batch input ->
list[dict[label_name -> probability]]
- single input ->
predict(embedding, thresholds=None)
- Also accepts single or batch embeddings.
- Uses sigmoid probabilities and thresholds to produce binary outputs.
thresholdsoptions:None-> all thresholds =0.5dict[label_name -> threshold]-> missing labels default to0.5- tensor of length
num_labels
- Returns:
- single input ->
dict[label_name -> 0/1] - batch input ->
list[dict[label_name -> 0/1]]
- single input ->
Utility functions
count_parameters(model)
Prints:
- total parameter count
- trainable parameter count
load_model(path, embedding_dim, num_labels, device)
Loads model weights from checkpoint and returns an eval-mode model on target device.
Supported checkpoint formats:
{"state_dict": ...}(preferred)- plain state-dict mapping
If num_labels is None, it infers label count from:
label_namesin checkpoint (if present), else- final layer weight shape (
net.12.weight).
Notes for training/inference
- Because the model uses
BatchNorm1d, batch size 1 in training mode can raise errors; useeval()for single-sample smoke/inference checks. - Save
label_namesin checkpoints to keep label mapping portable and deterministic across runs.