File size: 930 Bytes
eef8873
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import logging

from src.training.train_resnet import run_resnet_training
from src.config import CHECKPOINT_DIR

logger = logging.getLogger(__name__)


def test_train_resnet():
    logger.info("Testing ResNet training pipeline...")

    checkpoint_path = CHECKPOINT_DIR / "best_resnet_model.pt"

    if checkpoint_path.exists():
        checkpoint_path.unlink()

    preds, labels = run_resnet_training()

    assert checkpoint_path.exists(), \
        "ResNet checkpoint was not created"

    assert len(preds) > 0, \
        "No predictions returned"

    assert len(labels) > 0, \
        "No labels returned"

    logger.info("ResNet training test passed.")


if __name__ == "__main__":
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s"
    )

    test_train_resnet()

    print("ResNet training test completed successfully.")