| | import pytest |
| | import logging |
| | import numpy as np |
| | from bertopic._utils import check_documents_type, check_embeddings_shape, MyLogger |
| |
|
| |
|
| | def test_logger(): |
| | logger = MyLogger("DEBUG") |
| | assert isinstance(logger.logger, logging.Logger) |
| | assert logger.logger.level == 10 |
| |
|
| | logger = MyLogger("WARNING") |
| | assert isinstance(logger.logger, logging.Logger) |
| | assert logger.logger.level == 30 |
| |
|
| |
|
| | @pytest.mark.parametrize( |
| | "docs", |
| | [ |
| | "A document not in an iterable", |
| | [None], |
| | 5 |
| | ], |
| | ) |
| | def test_check_documents_type(docs): |
| | with pytest.raises(TypeError): |
| | check_documents_type(docs) |
| |
|
| |
|
| | def test_check_embeddings_shape(): |
| | docs = ["doc_one", "doc_two"] |
| | embeddings = np.array([[1, 2, 3], |
| | [2, 3, 4]]) |
| | check_embeddings_shape(embeddings, docs) |
| |
|