JanWerth commited on
Commit
51016f6
·
verified ·
1 Parent(s): 3f34044

Delete app.py

Browse files

wrong retriever

Files changed (1) hide show
  1. app.py +0 -615
app.py DELETED
@@ -1,615 +0,0 @@
1
- from dataclasses import dataclass
2
- import pickle
3
- import os
4
- from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar
5
- from nlp4web_codebase.ir.data_loaders.dm import Document
6
- from collections import Counter
7
- import tqdm
8
- import re
9
- import nltk
10
- nltk.download("stopwords", quiet=True)
11
- from nltk.corpus import stopwords as nltk_stopwords
12
-
13
- LANGUAGE = "english"
14
- word_splitter = re.compile(r"(?u)\b\w\w+\b").findall
15
- stopwords = set(nltk_stopwords.words(LANGUAGE))
16
-
17
-
18
- def word_splitting(text: str) -> List[str]:
19
- return word_splitter(text.lower())
20
-
21
- def lemmatization(words: List[str]) -> List[str]:
22
- return words # We ignore lemmatization here for simplicity
23
-
24
- def simple_tokenize(text: str) -> List[str]:
25
- words = word_splitting(text)
26
- tokenized = list(filter(lambda w: w not in stopwords, words))
27
- tokenized = lemmatization(tokenized)
28
- return tokenized
29
-
30
- T = TypeVar("T", bound="InvertedIndex")
31
-
32
- @dataclass
33
- class PostingList:
34
- term: str # The term
35
- docid_postings: List[int] # docid_postings[i] means the docid (int) of the i-th associated posting
36
- tweight_postings: List[float] # tweight_postings[i] means the term weight (float) of the i-th associated posting
37
-
38
-
39
- @dataclass
40
- class InvertedIndex:
41
- posting_lists: List[PostingList] # docid -> posting_list
42
- vocab: Dict[str, int]
43
- cid2docid: Dict[str, int] # collection_id -> docid
44
- collection_ids: List[str] # docid -> collection_id
45
- doc_texts: Optional[List[str]] = None # docid -> document text
46
-
47
- def save(self, output_dir: str) -> None:
48
- os.makedirs(output_dir, exist_ok=True)
49
- with open(os.path.join(output_dir, "index.pkl"), "wb") as f:
50
- pickle.dump(self, f)
51
-
52
- @classmethod
53
- def from_saved(cls: Type[T], saved_dir: str) -> T:
54
- index = cls(
55
- posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None
56
- )
57
- with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
58
- index = pickle.load(f)
59
- return index
60
-
61
-
62
- # The output of the counting function:
63
- @dataclass
64
- class Counting:
65
- posting_lists: List[PostingList]
66
- vocab: Dict[str, int]
67
- cid2docid: Dict[str, int]
68
- collection_ids: List[str]
69
- dfs: List[int] # tid -> df
70
- dls: List[int] # docid -> doc length
71
- avgdl: float
72
- nterms: int
73
- doc_texts: Optional[List[str]] = None
74
-
75
- def run_counting(
76
- documents: Iterable[Document],
77
- tokenize_fn: Callable[[str], List[str]] = simple_tokenize,
78
- store_raw: bool = True, # store the document text in doc_texts
79
- ndocs: Optional[int] = None,
80
- show_progress_bar: bool = True,
81
- ) -> Counting:
82
- """Counting TFs, DFs, doc_lengths, etc."""
83
- posting_lists: List[PostingList] = []
84
- vocab: Dict[str, int] = {}
85
- cid2docid: Dict[str, int] = {}
86
- collection_ids: List[str] = []
87
- dfs: List[int] = [] # tid -> df
88
- dls: List[int] = [] # docid -> doc length
89
- nterms: int = 0
90
- doc_texts: Optional[List[str]] = []
91
- for doc in tqdm.tqdm(
92
- documents,
93
- desc="Counting",
94
- total=ndocs,
95
- disable=not show_progress_bar,
96
- ):
97
- if doc.collection_id in cid2docid:
98
- continue
99
- collection_ids.append(doc.collection_id)
100
- docid = cid2docid.setdefault(doc.collection_id, len(cid2docid))
101
- toks = tokenize_fn(doc.text)
102
- tok2tf = Counter(toks)
103
- dls.append(sum(tok2tf.values()))
104
- for tok, tf in tok2tf.items():
105
- nterms += tf
106
- tid = vocab.get(tok, None)
107
- if tid is None:
108
- posting_lists.append(
109
- PostingList(term=tok, docid_postings=[], tweight_postings=[])
110
- )
111
- tid = vocab.setdefault(tok, len(vocab))
112
- posting_lists[tid].docid_postings.append(docid)
113
- posting_lists[tid].tweight_postings.append(tf)
114
- if tid < len(dfs):
115
- dfs[tid] += 1
116
- else:
117
- dfs.append(1) # Fixed according to moodle discussion https://moodle.informatik.tu-darmstadt.de/mod/moodleoverflow/discussion.php?d=2097
118
- if store_raw:
119
- doc_texts.append(doc.text)
120
- else:
121
- doc_texts = None
122
- return Counting(
123
- posting_lists=posting_lists,
124
- vocab=vocab,
125
- cid2docid=cid2docid,
126
- collection_ids=collection_ids,
127
- dfs=dfs,
128
- dls=dls,
129
- avgdl=sum(dls) / len(dls),
130
- nterms=nterms,
131
- doc_texts=doc_texts,
132
- )
133
-
134
- from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
135
- sciq = load_sciq()
136
- counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
137
-
138
- from __future__ import annotations
139
- from dataclasses import asdict, dataclass
140
- import math
141
- import os
142
- from typing import Iterable, List, Optional, Type
143
- import tqdm
144
- from nlp4web_codebase.ir.data_loaders.dm import Document
145
-
146
-
147
- @dataclass
148
- class BM25Index(InvertedIndex):
149
-
150
- @staticmethod
151
- def tokenize(text: str) -> List[str]:
152
- return simple_tokenize(text)
153
-
154
- @staticmethod
155
- def cache_term_weights(
156
- posting_lists: List[PostingList],
157
- total_docs: int,
158
- avgdl: float,
159
- dfs: List[int],
160
- dls: List[int],
161
- k1: float,
162
- b: float,
163
- ) -> None:
164
- """Compute term weights and caching"""
165
-
166
- N = total_docs
167
- for tid, posting_list in enumerate(
168
- tqdm.tqdm(posting_lists, desc="Regularizing TFs")
169
- ):
170
- idf = BM25Index.calc_idf(df=dfs[tid], N=N)
171
- for i in range(len(posting_list.docid_postings)):
172
- docid = posting_list.docid_postings[i]
173
- tf = posting_list.tweight_postings[i]
174
- dl = dls[docid]
175
- regularized_tf = BM25Index.calc_regularized_tf(
176
- tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b
177
- )
178
- posting_list.tweight_postings[i] = regularized_tf * idf
179
-
180
- @staticmethod
181
- def calc_regularized_tf(
182
- tf: int, dl: float, avgdl: float, k1: float, b: float
183
- ) -> float:
184
- return tf / (tf + k1 * (1 - b + b * dl / avgdl))
185
-
186
- @staticmethod
187
- def calc_idf(df: int, N: int):
188
- return math.log(1 + (N - df + 0.5) / (df + 0.5))
189
-
190
- @classmethod
191
- def build_from_documents(
192
- cls: Type[BM25Index],
193
- documents: Iterable[Document],
194
- store_raw: bool = True,
195
- output_dir: Optional[str] = None,
196
- ndocs: Optional[int] = None,
197
- show_progress_bar: bool = True,
198
- k1: float = 0.9,
199
- b: float = 0.4,
200
- ) -> BM25Index:
201
- # Counting TFs, DFs, doc_lengths, etc.:
202
- counting = run_counting(
203
- documents=documents,
204
- tokenize_fn=BM25Index.tokenize,
205
- store_raw=store_raw,
206
- ndocs=ndocs,
207
- show_progress_bar=show_progress_bar,
208
- )
209
-
210
- # Compute term weights and caching:
211
- posting_lists = counting.posting_lists
212
- total_docs = len(counting.cid2docid)
213
- BM25Index.cache_term_weights(
214
- posting_lists=posting_lists,
215
- total_docs=total_docs,
216
- avgdl=counting.avgdl,
217
- dfs=counting.dfs,
218
- dls=counting.dls,
219
- k1=k1,
220
- b=b,
221
- )
222
-
223
- # Assembly and save:
224
- index = BM25Index(
225
- posting_lists=posting_lists,
226
- vocab=counting.vocab,
227
- cid2docid=counting.cid2docid,
228
- collection_ids=counting.collection_ids,
229
- doc_texts=counting.doc_texts,
230
- )
231
- return index
232
-
233
- bm25_index = BM25Index.build_from_documents(
234
- documents=iter(sciq.corpus),
235
- ndocs=12160,
236
- show_progress_bar=True,
237
- )
238
- bm25_index.save("output/bm25_index")
239
-
240
- plots_b: Dict[str, List[float]] = {
241
- "X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
242
- "Y": []
243
- }
244
- plots_k1: Dict[str, List[float]] = {
245
- "X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
246
- "Y": []
247
- }
248
-
249
- ## YOU_CODE_STARTS_HERE
250
- # Tune b
251
- k1_default = 0.4
252
- best_b = 0
253
- best_score = 0
254
- for b in plots_b["X"]:
255
- bm25_index = BM25Index.build_from_documents(
256
- documents=iter(sciq.corpus),
257
- #ndocs=12160,
258
- show_progress_bar=True,
259
- k1=k1_default,
260
- b=b
261
- )
262
- bm25_index.save("output/bm25_index")
263
- bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
264
- rankings = {}
265
- for query in sciq.get_split_queries(Split.dev):
266
- ranking = bm25_retriever.retrieve(query=query.text)
267
- rankings[query.query_id] = ranking
268
- score = evaluate_map(rankings, split=Split.dev)
269
- plots_b["Y"].append(score)
270
- if score > best_score:
271
- best_score = score
272
- best_b = b
273
-
274
- print(best_b, best_score)
275
-
276
- # Tune k1
277
- b_default = best_b
278
- best_k1 = 0
279
- best_score = 0
280
- for k1 in plots_k1["X"]:
281
- bm25_index = BM25Index.build_from_documents(
282
- documents=iter(sciq.corpus),
283
- #ndocs=12160,
284
- show_progress_bar=True,
285
- k1=k1,
286
- b=b_default
287
- )
288
- bm25_index.save("output/bm25_index")
289
- bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
290
- rankings = {}
291
- for query in sciq.get_split_queries(Split.dev):
292
- ranking = bm25_retriever.retrieve(query=query.text)
293
- rankings[query.query_id] = ranking
294
- score = evaluate_map(rankings, split=Split.dev)
295
- plots_k1["Y"].append(score)
296
- if score > best_score:
297
- best_score = score
298
- best_k1 = k1
299
-
300
- print(best_k1, best_score)
301
- ## YOU_CODE_ENDS_HERE
302
-
303
- from nlp4web_codebase.ir.models import BaseRetriever
304
- from typing import Type
305
- from abc import abstractmethod
306
-
307
-
308
- class BaseInvertedIndexRetriever(BaseRetriever):
309
-
310
- @property
311
- @abstractmethod
312
- def index_class(self) -> Type[InvertedIndex]:
313
- pass
314
-
315
- def __init__(self, index_dir: str) -> None:
316
- self.index = self.index_class.from_saved(index_dir)
317
-
318
- def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
319
- toks = self.index.tokenize(query)
320
- target_docid = self.index.cid2docid[cid]
321
- term_weights = {}
322
- for tok in toks:
323
- if tok not in self.index.vocab:
324
- continue
325
- tid = self.index.vocab[tok]
326
- posting_list = self.index.posting_lists[tid]
327
- for docid, tweight in zip(
328
- posting_list.docid_postings, posting_list.tweight_postings
329
- ):
330
- if docid == target_docid:
331
- term_weights[tok] = tweight
332
- break
333
- return term_weights
334
-
335
- def score(self, query: str, cid: str) -> float:
336
- return sum(self.get_term_weights(query=query, cid=cid).values())
337
-
338
- def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
339
- toks = self.index.tokenize(query)
340
- docid2score: Dict[int, float] = {}
341
- for tok in toks:
342
- if tok not in self.index.vocab:
343
- continue
344
- tid = self.index.vocab[tok]
345
- posting_list = self.index.posting_lists[tid]
346
- for docid, tweight in zip(
347
- posting_list.docid_postings, posting_list.tweight_postings
348
- ):
349
- docid2score.setdefault(docid, 0)
350
- docid2score[docid] += tweight
351
- docid2score = dict(
352
- sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk]
353
- )
354
- return {
355
- self.index.collection_ids[docid]: score
356
- for docid, score in docid2score.items()
357
- }
358
-
359
-
360
- class BM25Retriever(BaseInvertedIndexRetriever):
361
-
362
- @property
363
- def index_class(self) -> Type[BM25Index]:
364
- return BM25Index
365
-
366
- bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
367
- bm25_retriever.retrieve("What type of diseases occur when the immune system attacks normal body cells?")
368
-
369
- @dataclass
370
- class CSCInvertedIndex:
371
- posting_lists_matrix: csc_matrix # docid -> posting_list
372
- vocab: Dict[str, int]
373
- cid2docid: Dict[str, int] # collection_id -> docid
374
- collection_ids: List[str] # docid -> collection_id
375
- doc_texts: Optional[List[str]] = None # docid -> document text
376
-
377
- def save(self, output_dir: str) -> None:
378
- os.makedirs(output_dir, exist_ok=True)
379
- with open(os.path.join(output_dir, "index.pkl"), "wb") as f:
380
- pickle.dump(self, f)
381
-
382
- @classmethod
383
- def from_saved(cls: Type[T], saved_dir: str) -> T:
384
- index = cls(
385
- posting_lists_matrix=None, vocab={}, cid2docid={}, collection_ids=[], doc_texts=None
386
- )
387
- with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
388
- index = pickle.load(f)
389
- return index
390
-
391
- import scipy
392
-
393
- @dataclass
394
- class CSCBM25Index(CSCInvertedIndex):
395
-
396
- @staticmethod
397
- def tokenize(text: str) -> List[str]:
398
- return simple_tokenize(text)
399
-
400
- @staticmethod
401
- def cache_term_weights(
402
- posting_lists: List[PostingList],
403
- total_docs: int,
404
- avgdl: float,
405
- dfs: List[int],
406
- dls: List[int],
407
- k1: float,
408
- b: float,
409
- ) -> csc_matrix:
410
- """Compute term weights and caching"""
411
-
412
- ## YOUR_CODE_STARTS_HERE
413
- data = []
414
- row_indices = []
415
- col_indices = []
416
-
417
- # Loop over each term
418
- for term_id, posting_list in enumerate(posting_lists):
419
- df = dfs[term_id]
420
- idf = BM25Index.calc_idf(df, total_docs)
421
-
422
- docid_postings = posting_list.docid_postings
423
- tweight_postings = posting_list.tweight_postings
424
-
425
- # Loop over each document in the posting list where the term appears
426
- for docid, tf in zip(docid_postings, tweight_postings):
427
- dl = dls[docid]
428
- regularized_tf = BM25Index.calc_regularized_tf(tf, dl, avgdl, k1, b)
429
-
430
- data.append(regularized_tf * idf)
431
- row_indices.append(term_id)
432
- col_indices.append(docid)
433
-
434
- data = np.array(data)
435
- row_indices = np.array(row_indices)
436
- col_indices = np.array(col_indices)
437
-
438
- num_terms = len(posting_lists)
439
-
440
- # Create a coo matrix and then convert to csc
441
- coo_matrix = scipy.sparse.coo_matrix((data, (row_indices, col_indices)), shape=(num_terms, total_docs))
442
- csc_matrix = coo_matrix.tocsc()
443
-
444
- return csc_matrix
445
- ## YOUR_CODE_ENDS_HERE
446
-
447
- @staticmethod
448
- def calc_regularized_tf(
449
- tf: int, dl: float, avgdl: float, k1: float, b: float
450
- ) -> float:
451
- return tf / (tf + k1 * (1 - b + b * dl / avgdl))
452
-
453
- @staticmethod
454
- def calc_idf(df: int, N: int):
455
- return math.log(1 + (N - df + 0.5) / (df + 0.5))
456
-
457
- @classmethod
458
- def build_from_documents(
459
- cls: Type[CSCBM25Index],
460
- documents: Iterable[Document],
461
- store_raw: bool = True,
462
- output_dir: Optional[str] = None,
463
- ndocs: Optional[int] = None,
464
- show_progress_bar: bool = True,
465
- k1: float = 0.9,
466
- b: float = 0.4,
467
- ) -> CSCBM25Index:
468
- # Counting TFs, DFs, doc_lengths, etc.:
469
- counting = run_counting(
470
- documents=documents,
471
- tokenize_fn=CSCBM25Index.tokenize,
472
- store_raw=store_raw,
473
- ndocs=ndocs,
474
- show_progress_bar=show_progress_bar,
475
- )
476
-
477
- # Compute term weights and caching:
478
- posting_lists = counting.posting_lists
479
- total_docs = len(counting.cid2docid)
480
- posting_lists_matrix = CSCBM25Index.cache_term_weights(
481
- posting_lists=posting_lists,
482
- total_docs=total_docs,
483
- avgdl=counting.avgdl,
484
- dfs=counting.dfs,
485
- dls=counting.dls,
486
- k1=k1,
487
- b=b,
488
- )
489
-
490
- # Assembly and save:
491
- index = CSCBM25Index(
492
- posting_lists_matrix=posting_lists_matrix,
493
- vocab=counting.vocab,
494
- cid2docid=counting.cid2docid,
495
- collection_ids=counting.collection_ids,
496
- doc_texts=counting.doc_texts,
497
- )
498
- return index
499
-
500
- csc_bm25_index = CSCBM25Index.build_from_documents(
501
- documents=iter(sciq.corpus),
502
- ndocs=12160,
503
- show_progress_bar=True,
504
- k1=best_k1,
505
- b=best_b
506
- )
507
- csc_bm25_index.save("output/csc_bm25_index")
508
-
509
- class BaseCSCInvertedIndexRetriever(BaseRetriever):
510
-
511
- @property
512
- @abstractmethod
513
- def index_class(self) -> Type[CSCInvertedIndex]:
514
- pass
515
-
516
- def __init__(self, index_dir: str) -> None:
517
- self.index = self.index_class.from_saved(index_dir)
518
-
519
- def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
520
- ## YOUR_CODE_STARTS_HERE
521
- terms = self.index.tokenize(query)
522
- docid = self.index.cid2docid[cid]
523
- term_weights = {}
524
-
525
- # In case the document ID ist not found, return the empty dictionary
526
- if docid is None:
527
- return term_weights
528
-
529
- for term in terms:
530
- term_index = self.index.vocab[term]
531
- weight = self.index.posting_lists_matrix[term_index, docid]
532
-
533
- if weight > 0:
534
- term_weights[term] = weight
535
-
536
- return term_weights
537
-
538
- ## YOUR_CODE_ENDS_HERE
539
-
540
- def score(self, query: str, cid: str) -> float:
541
- return sum(self.get_term_weights(query=query, cid=cid).values())
542
-
543
- def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
544
- ## YOUR_CODE_STARTS_HERE
545
- terms = self.index.tokenize(query)
546
-
547
- term_indices = [self.index.vocab[term] for term in terms if term in self.index.vocab]
548
-
549
- if not term_indices:
550
- return {}
551
-
552
- term_matrix = self.index.posting_lists_matrix[term_indices, :]
553
-
554
- scores = term_matrix.sum(axis=0)
555
-
556
- scores = np.asarray(scores).flatten()
557
-
558
- non_zero_doc_indices = np.nonzero(scores)[0]
559
-
560
- if non_zero_doc_indices.size == 0:
561
- return {}
562
-
563
- non_zero_scores = scores[non_zero_doc_indices]
564
- sorted_indices = np.argsort(-non_zero_scores)
565
-
566
- topk_indices = non_zero_doc_indices[sorted_indices[:topk]]
567
- topk_scores = non_zero_scores[sorted_indices[:topk]]
568
-
569
- topk_cids = [self.index.collection_ids[docid] for docid in topk_indices]
570
-
571
- return dict(zip(topk_cids, topk_scores))
572
-
573
- ## YOUR_CODE_ENDS_HERE
574
-
575
- class CSCBM25Retriever(BaseCSCInvertedIndexRetriever):
576
-
577
- @property
578
- def index_class(self) -> Type[CSCBM25Index]:
579
- return CSCBM25Index
580
-
581
- import gradio as gr
582
- from typing import TypedDict
583
-
584
- class Hit(TypedDict):
585
- cid: str
586
- score: float
587
- text: str
588
-
589
- demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable
590
- return_type = List[Hit]
591
-
592
- ## YOUR_CODE_STARTS_HERE
593
- retriever = CSCBM25Retriever(index_dir='output/csc_bm25_index')
594
-
595
- def search(query: str) -> List[Hit]:
596
- results = retriever.retrieve(query)
597
- hits = []
598
- print(results)
599
- for cid, score in results.items():
600
- docid = retriever.index.cid2docid[cid]
601
- text = retriever.index.doc_texts[docid]
602
- hit = Hit(cid=cid, score=score, text=text)
603
- hits.append(hit)
604
- return hits
605
-
606
- # Create the Gradio interface
607
- demo = gr.Interface(
608
- fn=search,
609
- inputs = gr.Textbox(lines=1, placeholder="Enter your query here..."),
610
- outputs=gr.Textbox(),
611
- title="BM25 Search Engine",
612
- description="Enter a query to search the SciQ dataset using BM25.",
613
- )
614
- ## YOUR_CODE_ENDS_HERE
615
- demo.launch()