File size: 11,306 Bytes
74b1bac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
import json
import os
import time
import unittest
from pathlib import Path
import warnings

import numpy as np

import bm25s


# Make sure to import or define the functions/classes you're going to use,
# such as bm25s.skl_tokenize and the bm25s.BM25 class, among others.
def save_scores(scores, artifact_dir="tests/artifacts"):
    if os.getenv("ARTIFACTS_DIR"):
        artifacts_dir = Path(os.getenv("BM25_ARTIFACTS_DIR"))
    elif artifact_dir is not None:
        artifacts_dir = Path(artifact_dir)
    else:
        artifacts_dir = Path(__file__).parent / "artifacts"

    if "dataset" not in scores:
        raise ValueError("scores must contain a 'dataset' key.")
    if "model" not in scores:
        raise ValueError("scores must contain a 'model' key.")
    
    artifacts_dir = artifacts_dir / scores["model"]
    artifacts_dir.mkdir(exist_ok=True, parents=True)

    filename = f"{scores['dataset']}-{os.urandom(8).hex()}.json"
    with open(artifacts_dir / filename, "w") as f:
        json.dump(scores, f, indent=2)


class BM25TestCase(unittest.TestCase):
    def compare_with_rank_bm25(
        self,
        dataset,
        artifact_dir="tests/artifacts",
        rel_save_dir="datasets",
        corpus_subsample=None,
        queries_subsample=None,
        method="rank",
    ):
        from beir.datasets.data_loader import GenericDataLoader
        from beir.util import download_and_unzip
        import rank_bm25
        import Stemmer

        warnings.filterwarnings("ignore", category=ResourceWarning)

        if method not in ["rank", "bm25+", "bm25l"]:
            raise ValueError("method must be either 'rank' or 'bm25+'.")

        # Download and prepare dataset
        base_url = (
            "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip"
        )
        url = base_url.format(dataset)
        out_dir = Path(__file__).parent / rel_save_dir
        data_path = download_and_unzip(url, str(out_dir))

        corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(
            split="test"
        )

        # Convert corpus and queries to lists
        corpus_lst = [val["title"] + " " + val["text"] for val in corpus.values()]
        queries_lst = list(queries.values())

        if corpus_subsample is not None:
            corpus_lst = corpus_lst[:corpus_subsample]

        if queries_subsample is not None:
            queries_lst = queries_lst[:queries_subsample]

        # Tokenize using sklearn-style tokenizer + PyStemmer
        stemmer = Stemmer.Stemmer("english")

        corpus_token_strs = bm25s.tokenize(
            corpus_lst, stopwords="en", stemmer=stemmer, return_ids=False
        )
        queries_token_strs = bm25s.tokenize(
            queries_lst, stopwords="en", stemmer=stemmer, return_ids=False
        )
        print()
        print(f"Dataset:              {dataset}\n")
        # print corpus and queries size
        print(f"Corpus size:          {len(corpus_lst)}")
        print(f"Queries size:         {len(queries_lst)}")
        print()

        # Initialize and index bm25s with atire + robertson idf (to match rank-bm25)
        if method == "rank":
            bm25_sparse = bm25s.BM25(k1=1.5, b=0.75, method="atire", idf_method="robertson")
        elif method in ["bm25+", "bm25l"]:
            bm25_sparse = bm25s.BM25(k1=1.5, b=0.75, delta=0.5, method=method)
        else:
            raise ValueError("invalid method")
        
        start_time = time.monotonic()
        bm25_sparse.index(corpus_token_strs)
        bm25_sparse_index_time = time.monotonic() - start_time
        print(f"bm25s index time:     {bm25_sparse_index_time:.4f}s")

        # Scoring with bm25-sparse
        start_time = time.monotonic()
        bm25_sparse_scores = [bm25_sparse.get_scores(q) for q in queries_token_strs]
        bm25_sparse_score_time = time.monotonic() - start_time
        print(f"bm25s score time:     {bm25_sparse_score_time:.4f}s")

        # Initialize and index rank-bm25
        start_time = time.monotonic()
        if method == "rank":
            bm25_rank = rank_bm25.BM25Okapi(corpus_token_strs, k1=1.5, b=0.75, epsilon=0.0)
        elif method == "bm25+":
            bm25_rank = rank_bm25.BM25Plus(corpus_token_strs, k1=1.5, b=0.75, delta=0.5)
        elif method == "bm25l":
            bm25_rank = rank_bm25.BM25L(corpus_token_strs, k1=1.5, b=0.75, delta=0.5)
        else:
            raise ValueError("invalid method")
    
        bm25_rank_index_time = time.monotonic() - start_time
        print(f"rank-bm25 index time: {bm25_rank_index_time:.4f}s")

        # Scoring with rank-bm25
        start_time = time.monotonic()
        bm25_rank_scores = [bm25_rank.get_scores(q) for q in queries_token_strs]
        bm25_rank_score_time = time.monotonic() - start_time
        print(f"rank-bm25 score time: {bm25_rank_score_time:.4f}s")

        # print difference in time
        print(
            f"Index Time: BM25S is {bm25_rank_index_time / bm25_sparse_index_time:.2f}x faster than rank-bm25."
        )
        print(
            f"Score Time: BM25S is {bm25_rank_score_time / bm25_sparse_score_time:.2f}x faster than rank-bm25."
        )

        # Check if scores are exactly the same
        sparse_scores = np.array(bm25_sparse_scores)
        rank_scores = np.array(bm25_rank_scores)

        error_msg = f"\nScores between bm25-sparse and rank-bm25 are not exactly the same on dataset {dataset}."
        almost_equal = np.allclose(sparse_scores, rank_scores)
        self.assertTrue(almost_equal, error_msg)

        general_info = {
            "date": time.strftime("%Y-%m-%d %H:%M:%S"),
            "num_jobs": 1,
            "dataset": dataset,
            "corpus_size": len(corpus_lst),
            "queries_size": len(queries_lst),
            "corpus_subsampled": corpus_subsample is not None,
            "queries_subsampled": queries_subsample is not None,
        }
        # Save metrics
        res = {
            "model": "bm25s",
            "index_time": bm25_sparse_index_time,
            "score_time": bm25_sparse_score_time,
        }
        res.update(general_info)
        save_scores(res, artifact_dir=artifact_dir)

        res = {
            "model": "rank-bm25",
            "score_time": bm25_rank_score_time,
            "index_time": bm25_rank_index_time,
        }
        res.update(general_info)
        save_scores(res, artifact_dir=artifact_dir)

    def compare_with_bm25_pt(
        self,
        dataset,
        artifact_dir="tests/artifacts",
        rel_save_dir="datasets",
        corpus_subsample=None,
        queries_subsample=None,
    ):
        from beir.datasets.data_loader import GenericDataLoader
        from beir.util import download_and_unzip
        import bm25_pt
        import bm25s.hf

        from transformers import AutoTokenizer

        warnings.filterwarnings("ignore", category=ResourceWarning)
        warnings.filterwarnings("ignore", category=UserWarning)

        # Download and prepare dataset
        base_url = (
            "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip"
        )
        url = base_url.format(dataset)
        out_dir = Path(__file__).parent / rel_save_dir
        data_path = download_and_unzip(url, str(out_dir))

        corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(
            split="test"
        )

        # Convert corpus and queries to lists
        corpus_lst = [val["title"] + " " + val["text"] for val in corpus.values()]
        queries_lst = list(queries.values())

        if corpus_subsample is not None:
            corpus_lst = corpus_lst[:corpus_subsample]

        if queries_subsample is not None:
            queries_lst = queries_lst[:queries_subsample]

        tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        t0 = time.monotonic()
        tokenized_corpus = bm25s.hf.batch_tokenize(tokenizer, corpus_lst)
        time_corpus_tok = time.monotonic() - t0

        t0 = time.monotonic()
        queries_tokenized = bm25s.hf.batch_tokenize(tokenizer, queries_lst)
        time_query_tok = time.monotonic() - t0

        print()
        print(f"Dataset:              {dataset}\n")
        # print corpus and queries size
        print(f"Corpus size:          {len(corpus_lst)}")
        print(f"Queries size:         {len(queries_lst)}")
        print()

        # Initialize and index bm25-sparse
        bm25_sparse = bm25s.BM25(k1=1.5, b=0.75, method="atire", idf_method="lucene")
        start_time = time.monotonic()
        bm25_sparse.index(tokenized_corpus)
        bm25s_index_time = time.monotonic() - start_time
        print(f"bm25s index time:     {bm25s_index_time:.4f}s")

        # Scoring with bm25-sparse
        start_time = time.monotonic()
        bm25_sparse_scores = [bm25_sparse.get_scores(q) for q in queries_tokenized]
        bm25s_score_time = time.monotonic() - start_time
        print(f"bm25s score time:     {bm25s_score_time:.4f}s")

        # Initialize and index rank-bm25
        start_time = time.monotonic()
        model_pt = bm25_pt.BM25(tokenizer=tokenizer, device="cpu", k1=1.5, b=0.75)
        model_pt.index(corpus_lst)
        bm25_pt_index_time = time.monotonic() - start_time
        bm25_pt_index_time -= time_corpus_tok
        print(f"bm25-pt index time:   {bm25_pt_index_time:.4f}s")

        # Scoring with rank-bm25
        start_time = time.monotonic()
        bm25_pt_scores = model_pt.score_batch(queries_lst)
        bm25_pt_scores = bm25_pt_scores.cpu().numpy()
        bm25_pt_score_time = time.monotonic() - start_time
        bm25_pt_score_time -= time_query_tok
        print(f"bm25-pt score time: {bm25_pt_score_time:.4f}s")

        # print difference in time
        print(
            f"Index Time: BM25S is {bm25_pt_index_time / bm25s_index_time:.2f}x faster than bm25-pt."
        )
        print(
            f"Score Time: BM25S is {bm25_pt_score_time / bm25s_score_time:.2f}x faster than bm25-pt."
        )

        # Check if scores are exactly the same
        bm25_sparse_scores = np.array(bm25_sparse_scores)
        bm25_pt_scores = np.array(bm25_pt_scores)

        error_msg = f"\nScores between bm25-sparse and rank-bm25 are not exactly the same on dataset {dataset}."
        almost_equal = np.allclose(bm25_sparse_scores, bm25_pt_scores, atol=1e-4)
        self.assertTrue(almost_equal, error_msg)

        general_info = {
            "date": time.strftime("%Y-%m-%d %H:%M:%S"),
            "num_jobs": 1,
            "dataset": dataset,
            "corpus_size": len(corpus_lst),
            "queries_size": len(queries_lst),
            "corpus_was_subsampled": corpus_subsample is not None,
            "queries_was_subsampled": queries_subsample is not None,
        }
        # Save metrics
        res = {
            "model": "bm25s",
            "index_time": bm25s_index_time,
            "score_time": bm25s_score_time,
        }
        res.update(general_info)
        save_scores(res, artifact_dir=artifact_dir)

        res = {
            "model": "bm25-pt",
            "score_time": bm25_pt_score_time,
            "index_time": bm25_pt_index_time,
        }
        res.update(general_info)
        save_scores(res, artifact_dir=artifact_dir)