protloc-ai / scripts /test_classifier.py
Tanoj22
Initial commit: ProtLoc-AI project setup and core app
cb6f1ba
"""
Smoke-test src/models/classifier.py.
From project root:
.\\venv\\Scripts\\python.exe scripts\\test_classifier.py
"""
from __future__ import annotations
import json
import sys
import tempfile
from pathlib import Path
import torch
ROOT = Path(__file__).resolve().parent.parent
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from src.models.classifier import ( # noqa: E402
ProteinLocalizationClassifier,
count_parameters,
load_model,
)
def main() -> None:
torch.manual_seed(0)
label_columns_path = (
ROOT / "data" / "processed" / "embeddings" / "esm2_t33_650M" / "label_columns.json"
)
with label_columns_path.open("r", encoding="utf-8") as f:
label_names = json.load(f)["label_columns"]
embedding_dim = 1280
num_labels = len(label_names)
model = ProteinLocalizationClassifier(
embedding_dim=embedding_dim,
num_labels=num_labels,
label_names=label_names,
)
print("Model created.")
count_parameters(model)
print(f"label_names ({len(model.label_names)}): {model.label_names}")
x1 = torch.randn(embedding_dim)
xB = torch.randn(4, embedding_dim)
# BatchNorm requires eval mode for batch size 1 in inference-like checks.
model.eval()
with torch.no_grad():
logits1 = model(x1.unsqueeze(0))
logitsB = model(xB)
print(f"\nforward single (as batch): logits shape={tuple(logits1.shape)}")
print(f"forward batch: logits shape={tuple(logitsB.shape)}")
p1 = model.predict_proba(x1)
pB = model.predict_proba(xB)
print(f"\npredict_proba single keys={list(p1.keys())[:3]}... len={len(p1)}")
print(f"predict_proba batch len={len(pB)}; first item Membrane={pB[0]['Membrane']:.4f}")
y1 = model.predict(x1)
yB = model.predict(xB, thresholds={"Membrane": 0.4})
print(f"\npredict single example: Membrane={y1['Membrane']}, Cytoplasm={y1['Cytoplasm']}")
print(f"predict batch len={len(yB)}; first item Membrane={yB[0]['Membrane']}")
# Optional: checkpoint round-trip
with tempfile.TemporaryDirectory() as td:
ckpt_path = Path(td) / "classifier_ckpt.pt"
torch.save({"state_dict": model.state_dict(), "label_names": model.label_names}, ckpt_path)
loaded = load_model(ckpt_path, embedding_dim=embedding_dim, num_labels=num_labels, device="cpu")
with torch.no_grad():
diff = (loaded(xB) - model(xB)).abs().max().item()
print(f"\nload_model checkpoint round-trip max|diff|={diff:.6g}")
print("\nOK — classifier.py smoke test passed.")
if __name__ == "__main__":
main()