File size: 879 Bytes
63089c1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | from __future__ import annotations
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
import torch
from src.protomorph.config import ProtoMorphConfig
from src.protomorph.model import ProtoMorphHead
def main() -> None:
torch.set_num_threads(1)
cfg = ProtoMorphConfig(
num_classes=7,
embed_dim=32,
proto_count=8,
memory_tokens=4,
rbf_count=16,
num_heads=4,
)
head = ProtoMorphHead(cfg).eval()
cls = torch.randn(2, cfg.embed_dim)
patches = torch.randn(2, 8 * 8, cfg.embed_dim)
with torch.no_grad():
out = head(cls, patches)
assert out["logits"].shape == (2, cfg.num_classes)
assert out["hard_mask"].shape == (2,)
print("OK head-only smoke test", out["logits"].shape, out["hard_mask"].tolist())
if __name__ == "__main__":
main()
|