File size: 2,626 Bytes
cb6f1ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""
Smoke-test src/models/classifier.py.

From project root:
    .\\venv\\Scripts\\python.exe scripts\\test_classifier.py
"""

from __future__ import annotations

import json
import sys
import tempfile
from pathlib import Path

import torch

ROOT = Path(__file__).resolve().parent.parent
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from src.models.classifier import (  # noqa: E402
    ProteinLocalizationClassifier,
    count_parameters,
    load_model,
)


def main() -> None:
    torch.manual_seed(0)

    label_columns_path = (
        ROOT / "data" / "processed" / "embeddings" / "esm2_t33_650M" / "label_columns.json"
    )
    with label_columns_path.open("r", encoding="utf-8") as f:
        label_names = json.load(f)["label_columns"]

    embedding_dim = 1280
    num_labels = len(label_names)
    model = ProteinLocalizationClassifier(
        embedding_dim=embedding_dim,
        num_labels=num_labels,
        label_names=label_names,
    )

    print("Model created.")
    count_parameters(model)
    print(f"label_names ({len(model.label_names)}): {model.label_names}")

    x1 = torch.randn(embedding_dim)
    xB = torch.randn(4, embedding_dim)

    # BatchNorm requires eval mode for batch size 1 in inference-like checks.
    model.eval()
    with torch.no_grad():
        logits1 = model(x1.unsqueeze(0))
        logitsB = model(xB)
    print(f"\nforward single (as batch): logits shape={tuple(logits1.shape)}")
    print(f"forward batch: logits shape={tuple(logitsB.shape)}")

    p1 = model.predict_proba(x1)
    pB = model.predict_proba(xB)
    print(f"\npredict_proba single keys={list(p1.keys())[:3]}... len={len(p1)}")
    print(f"predict_proba batch len={len(pB)}; first item Membrane={pB[0]['Membrane']:.4f}")

    y1 = model.predict(x1)
    yB = model.predict(xB, thresholds={"Membrane": 0.4})
    print(f"\npredict single example: Membrane={y1['Membrane']}, Cytoplasm={y1['Cytoplasm']}")
    print(f"predict batch len={len(yB)}; first item Membrane={yB[0]['Membrane']}")

    # Optional: checkpoint round-trip
    with tempfile.TemporaryDirectory() as td:
        ckpt_path = Path(td) / "classifier_ckpt.pt"
        torch.save({"state_dict": model.state_dict(), "label_names": model.label_names}, ckpt_path)
        loaded = load_model(ckpt_path, embedding_dim=embedding_dim, num_labels=num_labels, device="cpu")
        with torch.no_grad():
            diff = (loaded(xB) - model(xB)).abs().max().item()
        print(f"\nload_model checkpoint round-trip max|diff|={diff:.6g}")

    print("\nOK — classifier.py smoke test passed.")


if __name__ == "__main__":
    main()