| import torch | |
| from utils.adaptive_classifiers import SafetyClassifier1280 | |
| def test_forward_shape(): | |
| model = SafetyClassifier1280().eval() | |
| x = torch.randn(2, 1280, 8, 8) # fake mid features | |
| with torch.no_grad(): | |
| y = model(x) | |
| assert y.shape == (2, 5), f"Expected (2,5), got {tuple(y.shape)}" | |
| if __name__ == "__main__": | |
| test_forward_shape() | |
| print("OK: classifier forward shape is (B,5)") | |