JanWerth commited on
Commit
75908fc
·
verified ·
1 Parent(s): 4d4b118

Upload 2 files

Browse files
Files changed (2) hide show
  1. hw1_app.py +615 -0
  2. requirements.txt +5 -0
hw1_app.py ADDED
@@ -0,0 +1,615 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import pickle
3
+ import os
4
+ from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar
5
+ from nlp4web_codebase.ir.data_loaders.dm import Document
6
+ from collections import Counter
7
+ import tqdm
8
+ import re
9
+ import nltk
10
+ nltk.download("stopwords", quiet=True)
11
+ from nltk.corpus import stopwords as nltk_stopwords
12
+
13
+ LANGUAGE = "english"
14
+ word_splitter = re.compile(r"(?u)\b\w\w+\b").findall
15
+ stopwords = set(nltk_stopwords.words(LANGUAGE))
16
+
17
+
18
+ def word_splitting(text: str) -> List[str]:
19
+ return word_splitter(text.lower())
20
+
21
+ def lemmatization(words: List[str]) -> List[str]:
22
+ return words # We ignore lemmatization here for simplicity
23
+
24
+ def simple_tokenize(text: str) -> List[str]:
25
+ words = word_splitting(text)
26
+ tokenized = list(filter(lambda w: w not in stopwords, words))
27
+ tokenized = lemmatization(tokenized)
28
+ return tokenized
29
+
30
+ T = TypeVar("T", bound="InvertedIndex")
31
+
32
+ @dataclass
33
+ class PostingList:
34
+ term: str # The term
35
+ docid_postings: List[int] # docid_postings[i] means the docid (int) of the i-th associated posting
36
+ tweight_postings: List[float] # tweight_postings[i] means the term weight (float) of the i-th associated posting
37
+
38
+
39
+ @dataclass
40
+ class InvertedIndex:
41
+ posting_lists: List[PostingList] # docid -> posting_list
42
+ vocab: Dict[str, int]
43
+ cid2docid: Dict[str, int] # collection_id -> docid
44
+ collection_ids: List[str] # docid -> collection_id
45
+ doc_texts: Optional[List[str]] = None # docid -> document text
46
+
47
+ def save(self, output_dir: str) -> None:
48
+ os.makedirs(output_dir, exist_ok=True)
49
+ with open(os.path.join(output_dir, "index.pkl"), "wb") as f:
50
+ pickle.dump(self, f)
51
+
52
+ @classmethod
53
+ def from_saved(cls: Type[T], saved_dir: str) -> T:
54
+ index = cls(
55
+ posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None
56
+ )
57
+ with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
58
+ index = pickle.load(f)
59
+ return index
60
+
61
+
62
+ # The output of the counting function:
63
+ @dataclass
64
+ class Counting:
65
+ posting_lists: List[PostingList]
66
+ vocab: Dict[str, int]
67
+ cid2docid: Dict[str, int]
68
+ collection_ids: List[str]
69
+ dfs: List[int] # tid -> df
70
+ dls: List[int] # docid -> doc length
71
+ avgdl: float
72
+ nterms: int
73
+ doc_texts: Optional[List[str]] = None
74
+
75
+ def run_counting(
76
+ documents: Iterable[Document],
77
+ tokenize_fn: Callable[[str], List[str]] = simple_tokenize,
78
+ store_raw: bool = True, # store the document text in doc_texts
79
+ ndocs: Optional[int] = None,
80
+ show_progress_bar: bool = True,
81
+ ) -> Counting:
82
+ """Counting TFs, DFs, doc_lengths, etc."""
83
+ posting_lists: List[PostingList] = []
84
+ vocab: Dict[str, int] = {}
85
+ cid2docid: Dict[str, int] = {}
86
+ collection_ids: List[str] = []
87
+ dfs: List[int] = [] # tid -> df
88
+ dls: List[int] = [] # docid -> doc length
89
+ nterms: int = 0
90
+ doc_texts: Optional[List[str]] = []
91
+ for doc in tqdm.tqdm(
92
+ documents,
93
+ desc="Counting",
94
+ total=ndocs,
95
+ disable=not show_progress_bar,
96
+ ):
97
+ if doc.collection_id in cid2docid:
98
+ continue
99
+ collection_ids.append(doc.collection_id)
100
+ docid = cid2docid.setdefault(doc.collection_id, len(cid2docid))
101
+ toks = tokenize_fn(doc.text)
102
+ tok2tf = Counter(toks)
103
+ dls.append(sum(tok2tf.values()))
104
+ for tok, tf in tok2tf.items():
105
+ nterms += tf
106
+ tid = vocab.get(tok, None)
107
+ if tid is None:
108
+ posting_lists.append(
109
+ PostingList(term=tok, docid_postings=[], tweight_postings=[])
110
+ )
111
+ tid = vocab.setdefault(tok, len(vocab))
112
+ posting_lists[tid].docid_postings.append(docid)
113
+ posting_lists[tid].tweight_postings.append(tf)
114
+ if tid < len(dfs):
115
+ dfs[tid] += 1
116
+ else:
117
+ dfs.append(1) # Fixed according to moodle discussion https://moodle.informatik.tu-darmstadt.de/mod/moodleoverflow/discussion.php?d=2097
118
+ if store_raw:
119
+ doc_texts.append(doc.text)
120
+ else:
121
+ doc_texts = None
122
+ return Counting(
123
+ posting_lists=posting_lists,
124
+ vocab=vocab,
125
+ cid2docid=cid2docid,
126
+ collection_ids=collection_ids,
127
+ dfs=dfs,
128
+ dls=dls,
129
+ avgdl=sum(dls) / len(dls),
130
+ nterms=nterms,
131
+ doc_texts=doc_texts,
132
+ )
133
+
134
+ from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
135
+ sciq = load_sciq()
136
+ counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
137
+
138
+ from __future__ import annotations
139
+ from dataclasses import asdict, dataclass
140
+ import math
141
+ import os
142
+ from typing import Iterable, List, Optional, Type
143
+ import tqdm
144
+ from nlp4web_codebase.ir.data_loaders.dm import Document
145
+
146
+
147
+ @dataclass
148
+ class BM25Index(InvertedIndex):
149
+
150
+ @staticmethod
151
+ def tokenize(text: str) -> List[str]:
152
+ return simple_tokenize(text)
153
+
154
+ @staticmethod
155
+ def cache_term_weights(
156
+ posting_lists: List[PostingList],
157
+ total_docs: int,
158
+ avgdl: float,
159
+ dfs: List[int],
160
+ dls: List[int],
161
+ k1: float,
162
+ b: float,
163
+ ) -> None:
164
+ """Compute term weights and caching"""
165
+
166
+ N = total_docs
167
+ for tid, posting_list in enumerate(
168
+ tqdm.tqdm(posting_lists, desc="Regularizing TFs")
169
+ ):
170
+ idf = BM25Index.calc_idf(df=dfs[tid], N=N)
171
+ for i in range(len(posting_list.docid_postings)):
172
+ docid = posting_list.docid_postings[i]
173
+ tf = posting_list.tweight_postings[i]
174
+ dl = dls[docid]
175
+ regularized_tf = BM25Index.calc_regularized_tf(
176
+ tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b
177
+ )
178
+ posting_list.tweight_postings[i] = regularized_tf * idf
179
+
180
+ @staticmethod
181
+ def calc_regularized_tf(
182
+ tf: int, dl: float, avgdl: float, k1: float, b: float
183
+ ) -> float:
184
+ return tf / (tf + k1 * (1 - b + b * dl / avgdl))
185
+
186
+ @staticmethod
187
+ def calc_idf(df: int, N: int):
188
+ return math.log(1 + (N - df + 0.5) / (df + 0.5))
189
+
190
+ @classmethod
191
+ def build_from_documents(
192
+ cls: Type[BM25Index],
193
+ documents: Iterable[Document],
194
+ store_raw: bool = True,
195
+ output_dir: Optional[str] = None,
196
+ ndocs: Optional[int] = None,
197
+ show_progress_bar: bool = True,
198
+ k1: float = 0.9,
199
+ b: float = 0.4,
200
+ ) -> BM25Index:
201
+ # Counting TFs, DFs, doc_lengths, etc.:
202
+ counting = run_counting(
203
+ documents=documents,
204
+ tokenize_fn=BM25Index.tokenize,
205
+ store_raw=store_raw,
206
+ ndocs=ndocs,
207
+ show_progress_bar=show_progress_bar,
208
+ )
209
+
210
+ # Compute term weights and caching:
211
+ posting_lists = counting.posting_lists
212
+ total_docs = len(counting.cid2docid)
213
+ BM25Index.cache_term_weights(
214
+ posting_lists=posting_lists,
215
+ total_docs=total_docs,
216
+ avgdl=counting.avgdl,
217
+ dfs=counting.dfs,
218
+ dls=counting.dls,
219
+ k1=k1,
220
+ b=b,
221
+ )
222
+
223
+ # Assembly and save:
224
+ index = BM25Index(
225
+ posting_lists=posting_lists,
226
+ vocab=counting.vocab,
227
+ cid2docid=counting.cid2docid,
228
+ collection_ids=counting.collection_ids,
229
+ doc_texts=counting.doc_texts,
230
+ )
231
+ return index
232
+
233
+ bm25_index = BM25Index.build_from_documents(
234
+ documents=iter(sciq.corpus),
235
+ ndocs=12160,
236
+ show_progress_bar=True,
237
+ )
238
+ bm25_index.save("output/bm25_index")
239
+
240
+ plots_b: Dict[str, List[float]] = {
241
+ "X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
242
+ "Y": []
243
+ }
244
+ plots_k1: Dict[str, List[float]] = {
245
+ "X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
246
+ "Y": []
247
+ }
248
+
249
+ ## YOU_CODE_STARTS_HERE
250
+ # Tune b
251
+ k1_default = 0.4
252
+ best_b = 0
253
+ best_score = 0
254
+ for b in plots_b["X"]:
255
+ bm25_index = BM25Index.build_from_documents(
256
+ documents=iter(sciq.corpus),
257
+ #ndocs=12160,
258
+ show_progress_bar=True,
259
+ k1=k1_default,
260
+ b=b
261
+ )
262
+ bm25_index.save("output/bm25_index")
263
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
264
+ rankings = {}
265
+ for query in sciq.get_split_queries(Split.dev):
266
+ ranking = bm25_retriever.retrieve(query=query.text)
267
+ rankings[query.query_id] = ranking
268
+ score = evaluate_map(rankings, split=Split.dev)
269
+ plots_b["Y"].append(score)
270
+ if score > best_score:
271
+ best_score = score
272
+ best_b = b
273
+
274
+ print(best_b, best_score)
275
+
276
+ # Tune k1
277
+ b_default = best_b
278
+ best_k1 = 0
279
+ best_score = 0
280
+ for k1 in plots_k1["X"]:
281
+ bm25_index = BM25Index.build_from_documents(
282
+ documents=iter(sciq.corpus),
283
+ #ndocs=12160,
284
+ show_progress_bar=True,
285
+ k1=k1,
286
+ b=b_default
287
+ )
288
+ bm25_index.save("output/bm25_index")
289
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
290
+ rankings = {}
291
+ for query in sciq.get_split_queries(Split.dev):
292
+ ranking = bm25_retriever.retrieve(query=query.text)
293
+ rankings[query.query_id] = ranking
294
+ score = evaluate_map(rankings, split=Split.dev)
295
+ plots_k1["Y"].append(score)
296
+ if score > best_score:
297
+ best_score = score
298
+ best_k1 = k1
299
+
300
+ print(best_k1, best_score)
301
+ ## YOU_CODE_ENDS_HERE
302
+
303
+ from nlp4web_codebase.ir.models import BaseRetriever
304
+ from typing import Type
305
+ from abc import abstractmethod
306
+
307
+
308
+ class BaseInvertedIndexRetriever(BaseRetriever):
309
+
310
+ @property
311
+ @abstractmethod
312
+ def index_class(self) -> Type[InvertedIndex]:
313
+ pass
314
+
315
+ def __init__(self, index_dir: str) -> None:
316
+ self.index = self.index_class.from_saved(index_dir)
317
+
318
+ def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
319
+ toks = self.index.tokenize(query)
320
+ target_docid = self.index.cid2docid[cid]
321
+ term_weights = {}
322
+ for tok in toks:
323
+ if tok not in self.index.vocab:
324
+ continue
325
+ tid = self.index.vocab[tok]
326
+ posting_list = self.index.posting_lists[tid]
327
+ for docid, tweight in zip(
328
+ posting_list.docid_postings, posting_list.tweight_postings
329
+ ):
330
+ if docid == target_docid:
331
+ term_weights[tok] = tweight
332
+ break
333
+ return term_weights
334
+
335
+ def score(self, query: str, cid: str) -> float:
336
+ return sum(self.get_term_weights(query=query, cid=cid).values())
337
+
338
+ def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
339
+ toks = self.index.tokenize(query)
340
+ docid2score: Dict[int, float] = {}
341
+ for tok in toks:
342
+ if tok not in self.index.vocab:
343
+ continue
344
+ tid = self.index.vocab[tok]
345
+ posting_list = self.index.posting_lists[tid]
346
+ for docid, tweight in zip(
347
+ posting_list.docid_postings, posting_list.tweight_postings
348
+ ):
349
+ docid2score.setdefault(docid, 0)
350
+ docid2score[docid] += tweight
351
+ docid2score = dict(
352
+ sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk]
353
+ )
354
+ return {
355
+ self.index.collection_ids[docid]: score
356
+ for docid, score in docid2score.items()
357
+ }
358
+
359
+
360
+ class BM25Retriever(BaseInvertedIndexRetriever):
361
+
362
+ @property
363
+ def index_class(self) -> Type[BM25Index]:
364
+ return BM25Index
365
+
366
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
367
+ bm25_retriever.retrieve("What type of diseases occur when the immune system attacks normal body cells?")
368
+
369
+ @dataclass
370
+ class CSCInvertedIndex:
371
+ posting_lists_matrix: csc_matrix # docid -> posting_list
372
+ vocab: Dict[str, int]
373
+ cid2docid: Dict[str, int] # collection_id -> docid
374
+ collection_ids: List[str] # docid -> collection_id
375
+ doc_texts: Optional[List[str]] = None # docid -> document text
376
+
377
+ def save(self, output_dir: str) -> None:
378
+ os.makedirs(output_dir, exist_ok=True)
379
+ with open(os.path.join(output_dir, "index.pkl"), "wb") as f:
380
+ pickle.dump(self, f)
381
+
382
+ @classmethod
383
+ def from_saved(cls: Type[T], saved_dir: str) -> T:
384
+ index = cls(
385
+ posting_lists_matrix=None, vocab={}, cid2docid={}, collection_ids=[], doc_texts=None
386
+ )
387
+ with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
388
+ index = pickle.load(f)
389
+ return index
390
+
391
+ import scipy
392
+
393
+ @dataclass
394
+ class CSCBM25Index(CSCInvertedIndex):
395
+
396
+ @staticmethod
397
+ def tokenize(text: str) -> List[str]:
398
+ return simple_tokenize(text)
399
+
400
+ @staticmethod
401
+ def cache_term_weights(
402
+ posting_lists: List[PostingList],
403
+ total_docs: int,
404
+ avgdl: float,
405
+ dfs: List[int],
406
+ dls: List[int],
407
+ k1: float,
408
+ b: float,
409
+ ) -> csc_matrix:
410
+ """Compute term weights and caching"""
411
+
412
+ ## YOUR_CODE_STARTS_HERE
413
+ data = []
414
+ row_indices = []
415
+ col_indices = []
416
+
417
+ # Loop over each term
418
+ for term_id, posting_list in enumerate(posting_lists):
419
+ df = dfs[term_id]
420
+ idf = BM25Index.calc_idf(df, total_docs)
421
+
422
+ docid_postings = posting_list.docid_postings
423
+ tweight_postings = posting_list.tweight_postings
424
+
425
+ # Loop over each document in the posting list where the term appears
426
+ for docid, tf in zip(docid_postings, tweight_postings):
427
+ dl = dls[docid]
428
+ regularized_tf = BM25Index.calc_regularized_tf(tf, dl, avgdl, k1, b)
429
+
430
+ data.append(regularized_tf * idf)
431
+ row_indices.append(term_id)
432
+ col_indices.append(docid)
433
+
434
+ data = np.array(data)
435
+ row_indices = np.array(row_indices)
436
+ col_indices = np.array(col_indices)
437
+
438
+ num_terms = len(posting_lists)
439
+
440
+ # Create a coo matrix and then convert to csc
441
+ coo_matrix = scipy.sparse.coo_matrix((data, (row_indices, col_indices)), shape=(num_terms, total_docs))
442
+ csc_matrix = coo_matrix.tocsc()
443
+
444
+ return csc_matrix
445
+ ## YOUR_CODE_ENDS_HERE
446
+
447
+ @staticmethod
448
+ def calc_regularized_tf(
449
+ tf: int, dl: float, avgdl: float, k1: float, b: float
450
+ ) -> float:
451
+ return tf / (tf + k1 * (1 - b + b * dl / avgdl))
452
+
453
+ @staticmethod
454
+ def calc_idf(df: int, N: int):
455
+ return math.log(1 + (N - df + 0.5) / (df + 0.5))
456
+
457
+ @classmethod
458
+ def build_from_documents(
459
+ cls: Type[CSCBM25Index],
460
+ documents: Iterable[Document],
461
+ store_raw: bool = True,
462
+ output_dir: Optional[str] = None,
463
+ ndocs: Optional[int] = None,
464
+ show_progress_bar: bool = True,
465
+ k1: float = 0.9,
466
+ b: float = 0.4,
467
+ ) -> CSCBM25Index:
468
+ # Counting TFs, DFs, doc_lengths, etc.:
469
+ counting = run_counting(
470
+ documents=documents,
471
+ tokenize_fn=CSCBM25Index.tokenize,
472
+ store_raw=store_raw,
473
+ ndocs=ndocs,
474
+ show_progress_bar=show_progress_bar,
475
+ )
476
+
477
+ # Compute term weights and caching:
478
+ posting_lists = counting.posting_lists
479
+ total_docs = len(counting.cid2docid)
480
+ posting_lists_matrix = CSCBM25Index.cache_term_weights(
481
+ posting_lists=posting_lists,
482
+ total_docs=total_docs,
483
+ avgdl=counting.avgdl,
484
+ dfs=counting.dfs,
485
+ dls=counting.dls,
486
+ k1=k1,
487
+ b=b,
488
+ )
489
+
490
+ # Assembly and save:
491
+ index = CSCBM25Index(
492
+ posting_lists_matrix=posting_lists_matrix,
493
+ vocab=counting.vocab,
494
+ cid2docid=counting.cid2docid,
495
+ collection_ids=counting.collection_ids,
496
+ doc_texts=counting.doc_texts,
497
+ )
498
+ return index
499
+
500
+ csc_bm25_index = CSCBM25Index.build_from_documents(
501
+ documents=iter(sciq.corpus),
502
+ ndocs=12160,
503
+ show_progress_bar=True,
504
+ k1=best_k1,
505
+ b=best_b
506
+ )
507
+ csc_bm25_index.save("output/csc_bm25_index")
508
+
509
+ class BaseCSCInvertedIndexRetriever(BaseRetriever):
510
+
511
+ @property
512
+ @abstractmethod
513
+ def index_class(self) -> Type[CSCInvertedIndex]:
514
+ pass
515
+
516
+ def __init__(self, index_dir: str) -> None:
517
+ self.index = self.index_class.from_saved(index_dir)
518
+
519
+ def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
520
+ ## YOUR_CODE_STARTS_HERE
521
+ terms = self.index.tokenize(query)
522
+ docid = self.index.cid2docid[cid]
523
+ term_weights = {}
524
+
525
+ # In case the document ID ist not found, return the empty dictionary
526
+ if docid is None:
527
+ return term_weights
528
+
529
+ for term in terms:
530
+ term_index = self.index.vocab[term]
531
+ weight = self.index.posting_lists_matrix[term_index, docid]
532
+
533
+ if weight > 0:
534
+ term_weights[term] = weight
535
+
536
+ return term_weights
537
+
538
+ ## YOUR_CODE_ENDS_HERE
539
+
540
+ def score(self, query: str, cid: str) -> float:
541
+ return sum(self.get_term_weights(query=query, cid=cid).values())
542
+
543
+ def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
544
+ ## YOUR_CODE_STARTS_HERE
545
+ terms = self.index.tokenize(query)
546
+
547
+ term_indices = [self.index.vocab[term] for term in terms if term in self.index.vocab]
548
+
549
+ if not term_indices:
550
+ return {}
551
+
552
+ term_matrix = self.index.posting_lists_matrix[term_indices, :]
553
+
554
+ scores = term_matrix.sum(axis=0)
555
+
556
+ scores = np.asarray(scores).flatten()
557
+
558
+ non_zero_doc_indices = np.nonzero(scores)[0]
559
+
560
+ if non_zero_doc_indices.size == 0:
561
+ return {}
562
+
563
+ non_zero_scores = scores[non_zero_doc_indices]
564
+ sorted_indices = np.argsort(-non_zero_scores)
565
+
566
+ topk_indices = non_zero_doc_indices[sorted_indices[:topk]]
567
+ topk_scores = non_zero_scores[sorted_indices[:topk]]
568
+
569
+ topk_cids = [self.index.collection_ids[docid] for docid in topk_indices]
570
+
571
+ return dict(zip(topk_cids, topk_scores))
572
+
573
+ ## YOUR_CODE_ENDS_HERE
574
+
575
+ class CSCBM25Retriever(BaseCSCInvertedIndexRetriever):
576
+
577
+ @property
578
+ def index_class(self) -> Type[CSCBM25Index]:
579
+ return CSCBM25Index
580
+
581
+ import gradio as gr
582
+ from typing import TypedDict
583
+
584
+ class Hit(TypedDict):
585
+ cid: str
586
+ score: float
587
+ text: str
588
+
589
+ demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable
590
+ return_type = List[Hit]
591
+
592
+ ## YOUR_CODE_STARTS_HERE
593
+ retriever = CSCBM25Retriever(index_dir='output/csc_bm25_index')
594
+
595
+ def search(query: str) -> List[Hit]:
596
+ results = retriever.retrieve(query)
597
+ hits = []
598
+ print(results)
599
+ for cid, score in results.items():
600
+ docid = retriever.index.cid2docid[cid]
601
+ text = retriever.index.doc_texts[docid]
602
+ hit = Hit(cid=cid, score=score, text=text)
603
+ hits.append(hit)
604
+ return hits
605
+
606
+ # Create the Gradio interface
607
+ demo = gr.Interface(
608
+ fn=search,
609
+ inputs = gr.Textbox(lines=1, placeholder="Enter your query here..."),
610
+ outputs=gr.Textbox(),
611
+ title="BM25 Search Engine",
612
+ description="Enter a query to search the SciQ dataset using BM25.",
613
+ )
614
+ ## YOUR_CODE_ENDS_HERE
615
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ tqdm
3
+ nltk
4
+
5
+ git+https://github.com/kwang2049/nlp4web-codebase.git