File size: 879 Bytes
63089c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from __future__ import annotations

import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))

import torch

from src.protomorph.config import ProtoMorphConfig
from src.protomorph.model import ProtoMorphHead


def main() -> None:
    torch.set_num_threads(1)
    cfg = ProtoMorphConfig(
        num_classes=7,
        embed_dim=32,
        proto_count=8,
        memory_tokens=4,
        rbf_count=16,
        num_heads=4,
    )
    head = ProtoMorphHead(cfg).eval()
    cls = torch.randn(2, cfg.embed_dim)
    patches = torch.randn(2, 8 * 8, cfg.embed_dim)
    with torch.no_grad():
        out = head(cls, patches)
    assert out["logits"].shape == (2, cfg.num_classes)
    assert out["hard_mask"].shape == (2,)
    print("OK head-only smoke test", out["logits"].shape, out["hard_mask"].tolist())


if __name__ == "__main__":
    main()