| import unittest | |
| import os | |
| from image_classification_model.predict import predict | |
| from image_classification_model.utils import ROOT_DIR | |
| DATA_DIR = os.path.join(ROOT_DIR, "tests/data") | |
| class TestPrediction(unittest.TestCase): | |
| def test_prediction_label_3(self): | |
| test_image_path = os.path.join(DATA_DIR, "number_3.jpg") | |
| predicted_label = predict(test_image_path) | |
| self.assertEqual( | |
| predicted_label, 3, f"Expected label 3, but got {predicted_label}" | |
| ) | |
| if __name__ == "__main__": | |
| unittest.main() | |