File size: 432 Bytes
b8877ca
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch
from utils.adaptive_classifiers import SafetyClassifier1280

def test_forward_shape():
    model = SafetyClassifier1280().eval()
    x = torch.randn(2, 1280, 8, 8)  # fake mid features
    with torch.no_grad():
        y = model(x)
    assert y.shape == (2, 5), f"Expected (2,5), got {tuple(y.shape)}"

if __name__ == "__main__":
    test_forward_shape()
    print("OK: classifier forward shape is (B,5)")