File size: 432 Bytes
b8877ca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 | import torch
from utils.adaptive_classifiers import SafetyClassifier1280
def test_forward_shape():
model = SafetyClassifier1280().eval()
x = torch.randn(2, 1280, 8, 8) # fake mid features
with torch.no_grad():
y = model(x)
assert y.shape == (2, 5), f"Expected (2,5), got {tuple(y.shape)}"
if __name__ == "__main__":
test_forward_shape()
print("OK: classifier forward shape is (B,5)")
|