File size: 933 Bytes
eef8873
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import logging
import torch

from src.models.fusion_model import FusionClassifier
from src.config import NUM_CLASSES

logger = logging.getLogger(__name__)


def test_fusion_model():
    logger.info("Testing Fusion model architecture...")

    model = FusionClassifier(
        num_classes=NUM_CLASSES
    )

    model.eval()

    eff_dummy = torch.randn(2, 3, 260, 260)
    cnx_dummy = torch.randn(2, 3, 224, 224)

    with torch.no_grad():
        output = model(
            eff_dummy,
            cnx_dummy
        )

    assert output.shape == (2, NUM_CLASSES), \
        f"Unexpected output shape: {output.shape}"

    logger.info("Fusion model test passed.")


if __name__ == "__main__":
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s"
    )

    test_fusion_model()

    print("Fusion model test completed successfully.")