Spaces:
Sleeping
Sleeping
File size: 930 Bytes
eef8873 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 | import logging
from src.training.train_resnet import run_resnet_training
from src.config import CHECKPOINT_DIR
logger = logging.getLogger(__name__)
def test_train_resnet():
logger.info("Testing ResNet training pipeline...")
checkpoint_path = CHECKPOINT_DIR / "best_resnet_model.pt"
if checkpoint_path.exists():
checkpoint_path.unlink()
preds, labels = run_resnet_training()
assert checkpoint_path.exists(), \
"ResNet checkpoint was not created"
assert len(preds) > 0, \
"No predictions returned"
assert len(labels) > 0, \
"No labels returned"
logger.info("ResNet training test passed.")
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
test_train_resnet()
print("ResNet training test completed successfully.") |