|
|
import pytest |
|
|
import logging |
|
|
import numpy as np |
|
|
from bertopic._utils import check_documents_type, check_embeddings_shape, MyLogger |
|
|
|
|
|
|
|
|
def test_logger(): |
|
|
logger = MyLogger("DEBUG") |
|
|
assert isinstance(logger.logger, logging.Logger) |
|
|
assert logger.logger.level == 10 |
|
|
|
|
|
logger = MyLogger("WARNING") |
|
|
assert isinstance(logger.logger, logging.Logger) |
|
|
assert logger.logger.level == 30 |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
"docs", |
|
|
[ |
|
|
"A document not in an iterable", |
|
|
[None], |
|
|
5 |
|
|
], |
|
|
) |
|
|
def test_check_documents_type(docs): |
|
|
with pytest.raises(TypeError): |
|
|
check_documents_type(docs) |
|
|
|
|
|
|
|
|
def test_check_embeddings_shape(): |
|
|
docs = ["doc_one", "doc_two"] |
|
|
embeddings = np.array([[1, 2, 3], |
|
|
[2, 3, 4]]) |
|
|
check_embeddings_shape(embeddings, docs) |
|
|
|