File size: 1,648 Bytes
d992912 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 | import numpy as np
from backend.app.engine.bm25 import SimpleBM25
class TestSimpleBM25:
def test_fit_and_score(self):
bm25 = SimpleBM25()
docs = [
"black leather jacket mens",
"red floral dress womens",
"blue denim jeans casual",
]
bm25.fit(docs)
assert bm25.n_docs == 3
scores = bm25.score_candidates("black leather", [0, 1, 2])
assert scores[0] > scores[1]
assert scores[0] > scores[2]
def test_empty_query(self):
bm25 = SimpleBM25()
bm25.fit(["black dress", "red shoes"])
scores = bm25.score_candidates("", [0, 1])
assert np.all(scores == 0.0)
def test_unknown_terms(self):
bm25 = SimpleBM25()
bm25.fit(["black dress"])
scores = bm25.score_candidates("xyznotaword", [0])
assert scores[0] == 0.0
def test_out_of_range_index(self):
bm25 = SimpleBM25()
bm25.fit(["black dress"])
scores = bm25.score_candidates("black", [0, 999])
assert scores[0] > 0
assert scores[1] == 0.0
def test_exact_match_scores_higher(self):
bm25 = SimpleBM25()
bm25.fit(["black leather jacket", "red silk dress", "blue cotton shirt"])
scores = bm25.score_candidates("black leather jacket", [0, 1, 2])
assert scores[0] > scores[1]
assert scores[0] > scores[2]
def test_doc_frequency_computed(self):
bm25 = SimpleBM25()
bm25.fit(["black dress", "black shoes", "red dress"])
assert bm25.df["black"] == 2
assert bm25.df["dress"] == 2
assert bm25.df["red"] == 1
|