| from __future__ import annotations | |
| import sys | |
| from pathlib import Path | |
| sys.path.insert(0, str(Path(__file__).resolve().parents[1])) | |
| import torch | |
| from src.protomorph.config import ProtoMorphConfig | |
| from src.protomorph.model import ProtoMorphHead | |
| def main() -> None: | |
| torch.set_num_threads(1) | |
| cfg = ProtoMorphConfig( | |
| num_classes=7, | |
| embed_dim=32, | |
| proto_count=8, | |
| memory_tokens=4, | |
| rbf_count=16, | |
| num_heads=4, | |
| ) | |
| head = ProtoMorphHead(cfg).eval() | |
| cls = torch.randn(2, cfg.embed_dim) | |
| patches = torch.randn(2, 8 * 8, cfg.embed_dim) | |
| with torch.no_grad(): | |
| out = head(cls, patches) | |
| assert out["logits"].shape == (2, cfg.num_classes) | |
| assert out["hard_mask"].shape == (2,) | |
| print("OK head-only smoke test", out["logits"].shape, out["hard_mask"].tolist()) | |
| if __name__ == "__main__": | |
| main() | |