safe-diffusion-guidance / tests /test_classifier.py
basimazam's picture
Upload folder using huggingface_hub
b8877ca verified
import torch
from utils.adaptive_classifiers import SafetyClassifier1280
def test_forward_shape():
model = SafetyClassifier1280().eval()
x = torch.randn(2, 1280, 8, 8) # fake mid features
with torch.no_grad():
y = model(x)
assert y.shape == (2, 5), f"Expected (2,5), got {tuple(y.shape)}"
if __name__ == "__main__":
test_forward_shape()
print("OK: classifier forward shape is (B,5)")