Spaces:
Sleeping
Sleeping
File size: 850 Bytes
eef8873 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 | import logging
import torch
from src.models.resnet_model import CarClassifierResNet
from src.config import NUM_CLASSES
logger = logging.getLogger(__name__)
def test_resnet_model():
logger.info("Testing ResNet model architecture...")
model = CarClassifierResNet(
num_classes=NUM_CLASSES
)
model.eval()
dummy_input = torch.randn(2, 3, 128, 128)
with torch.no_grad():
output = model(dummy_input)
assert output.shape == (2, NUM_CLASSES), \
f"Unexpected output shape: {output.shape}"
logger.info("ResNet model test passed.")
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
test_resnet_model()
print("ResNet model test completed successfully.") |