import torch from utils.adaptive_classifiers import SafetyClassifier1280 def test_forward_shape(): model = SafetyClassifier1280().eval() x = torch.randn(2, 1280, 8, 8) # fake mid features with torch.no_grad(): y = model(x) assert y.shape == (2, 5), f"Expected (2,5), got {tuple(y.shape)}" if __name__ == "__main__": test_forward_shape() print("OK: classifier forward shape is (B,5)")