protloc-ai / docs /classifier_module.md
Tanoj22
Initial commit: ProtLoc-AI project setup and core app
cb6f1ba

src/models/classifier.py — documentation

This module defines a feed-forward classifier head for multilabel protein localization on top of precomputed embeddings.


Purpose

  • Consume fixed-length embedding vectors (for example from ESM-2).
  • Produce raw logits per localization label.
  • Keep inference helpers close to the model (predict_proba, predict).
  • Keep label-name mapping inside the model object so outputs are human-readable.

Label handling

ProteinLocalizationClassifier supports dynamic label configuration:

  1. If label_names is passed explicitly, those names are used.
  2. Else, it tries to load names from label_columns.json (default path points to the esm2_t33_650M embeddings folder).
  3. If that file is missing/invalid, it falls back to built-in DeepLoc label names (including Peroxisome).

num_labels is optional:

  • If num_labels=None, it is inferred as len(label_names).
  • If provided, it must match len(label_names), otherwise ValueError is raised.

Architecture

The model is an MLP with batch norm and dropout:

  1. Linear(embedding_dim -> 512)
  2. BatchNorm1d(512)
  3. ReLU
  4. Dropout(0.3)
  5. Linear(512 -> 256)
  6. BatchNorm1d(256)
  7. ReLU
  8. Dropout(0.3)
  9. Linear(256 -> 128)
  10. BatchNorm1d(128)
  11. ReLU
  12. Dropout(0.2)
  13. Linear(128 -> num_labels)

Initialization

All linear layers use He/Kaiming normal initialization (kaiming_normal_ with ReLU nonlinearity), and biases are set to zero.


Forward and inference methods

forward(x)

  • Input: tensor of shape (B, embedding_dim).
  • Output: raw logits (B, num_labels).
  • No sigmoid is applied; this is intended for BCEWithLogitsLoss.

predict_proba(embedding)

  • Accepts either a single vector (embedding_dim,) or a batch (B, embedding_dim).
  • Temporarily switches model to eval(), runs forward pass, applies sigmoid.
  • Returns:
    • single input -> dict[label_name -> probability]
    • batch input -> list[dict[label_name -> probability]]

predict(embedding, thresholds=None)

  • Also accepts single or batch embeddings.
  • Uses sigmoid probabilities and thresholds to produce binary outputs.
  • thresholds options:
    • None -> all thresholds = 0.5
    • dict[label_name -> threshold] -> missing labels default to 0.5
    • tensor of length num_labels
  • Returns:
    • single input -> dict[label_name -> 0/1]
    • batch input -> list[dict[label_name -> 0/1]]

Utility functions

count_parameters(model)

Prints:

  • total parameter count
  • trainable parameter count

load_model(path, embedding_dim, num_labels, device)

Loads model weights from checkpoint and returns an eval-mode model on target device.

Supported checkpoint formats:

  • {"state_dict": ...} (preferred)
  • plain state-dict mapping

If num_labels is None, it infers label count from:

  1. label_names in checkpoint (if present), else
  2. final layer weight shape (net.12.weight).

Notes for training/inference

  • Because the model uses BatchNorm1d, batch size 1 in training mode can raise errors; use eval() for single-sample smoke/inference checks.
  • Save label_names in checkpoints to keep label mapping portable and deterministic across runs.