DINO-Protomorph / scripts /smoke_test_head_only.py
shiowo's picture
Upload ProtoMorph-DINO scaffold and random head checkpoint
63089c1 verified
from __future__ import annotations
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
import torch
from src.protomorph.config import ProtoMorphConfig
from src.protomorph.model import ProtoMorphHead
def main() -> None:
torch.set_num_threads(1)
cfg = ProtoMorphConfig(
num_classes=7,
embed_dim=32,
proto_count=8,
memory_tokens=4,
rbf_count=16,
num_heads=4,
)
head = ProtoMorphHead(cfg).eval()
cls = torch.randn(2, cfg.embed_dim)
patches = torch.randn(2, 8 * 8, cfg.embed_dim)
with torch.no_grad():
out = head(cls, patches)
assert out["logits"].shape == (2, cfg.num_classes)
assert out["hard_mask"].shape == (2,)
print("OK head-only smoke test", out["logits"].shape, out["hard_mask"].tolist())
if __name__ == "__main__":
main()