Spaces:
Sleeping
Sleeping
File size: 933 Bytes
eef8873 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 | import logging
import torch
from src.models.fusion_model import FusionClassifier
from src.config import NUM_CLASSES
logger = logging.getLogger(__name__)
def test_fusion_model():
logger.info("Testing Fusion model architecture...")
model = FusionClassifier(
num_classes=NUM_CLASSES
)
model.eval()
eff_dummy = torch.randn(2, 3, 260, 260)
cnx_dummy = torch.randn(2, 3, 224, 224)
with torch.no_grad():
output = model(
eff_dummy,
cnx_dummy
)
assert output.shape == (2, NUM_CLASSES), \
f"Unexpected output shape: {output.shape}"
logger.info("Fusion model test passed.")
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
test_fusion_model()
print("Fusion model test completed successfully.") |