from __future__ import annotations import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parents[1])) import torch from src.protomorph.config import ProtoMorphConfig from src.protomorph.model import ProtoMorphHead def main() -> None: torch.set_num_threads(1) cfg = ProtoMorphConfig( num_classes=7, embed_dim=32, proto_count=8, memory_tokens=4, rbf_count=16, num_heads=4, ) head = ProtoMorphHead(cfg).eval() cls = torch.randn(2, cfg.embed_dim) patches = torch.randn(2, 8 * 8, cfg.embed_dim) with torch.no_grad(): out = head(cls, patches) assert out["logits"].shape == (2, cfg.num_classes) assert out["hard_mask"].shape == (2,) print("OK head-only smoke test", out["logits"].shape, out["hard_mask"].tolist()) if __name__ == "__main__": main()