import pytest import logging import numpy as np from bertopic._utils import check_documents_type, check_embeddings_shape, MyLogger def test_logger(): logger = MyLogger("DEBUG") assert isinstance(logger.logger, logging.Logger) assert logger.logger.level == 10 logger = MyLogger("WARNING") assert isinstance(logger.logger, logging.Logger) assert logger.logger.level == 30 @pytest.mark.parametrize( "docs", [ "A document not in an iterable", [None], 5 ], ) def test_check_documents_type(docs): with pytest.raises(TypeError): check_documents_type(docs) def test_check_embeddings_shape(): docs = ["doc_one", "doc_two"] embeddings = np.array([[1, 2, 3], [2, 3, 4]]) check_embeddings_shape(embeddings, docs)