JanWerth commited on
Commit
c40b4ab
·
verified ·
1 Parent(s): 51016f6

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +349 -0
app.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from dataclasses import asdict, dataclass
3
+ import math
4
+ import os
5
+ from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar, TypedDict
6
+ import tqdm
7
+ from nlp4web_codebase.ir.data_loaders.dm import Document
8
+ from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
9
+ from nlp4web_codebase.ir.models import BaseRetriever
10
+ from abc import abstractmethod
11
+
12
+ import pickle
13
+ from collections import Counter
14
+ import re
15
+ import gradio as gr
16
+
17
+ import nltk
18
+
19
+ nltk.download("stopwords", quiet=True)
20
+ from nltk.corpus import stopwords as nltk_stopwords
21
+
22
+ LANGUAGE = "english"
23
+ word_splitter = re.compile(r"(?u)\b\w\w+\b").findall
24
+ stopwords = set(nltk_stopwords.words(LANGUAGE))
25
+
26
+ sciq = load_sciq()
27
+
28
+
29
+ def word_splitting(text: str) -> List[str]:
30
+ return word_splitter(text.lower())
31
+
32
+
33
+ def lemmatization(words: List[str]) -> List[str]:
34
+ return words # We ignore lemmatization here for simplicity
35
+
36
+
37
+ def simple_tokenize(text: str) -> List[str]:
38
+ words = word_splitting(text)
39
+ tokenized = list(filter(lambda w: w not in stopwords, words))
40
+ tokenized = lemmatization(tokenized)
41
+ return tokenized
42
+
43
+
44
+ T = TypeVar("T", bound="InvertedIndex")
45
+
46
+
47
+ @dataclass
48
+ class PostingList:
49
+ term: str # The term
50
+ docid_postings: List[int] # docid_postings[i] means the docid (int) of the i-th associated posting
51
+ tweight_postings: List[float] # tweight_postings[i] means the term weight (float) of the i-th associated posting
52
+
53
+
54
+ @dataclass
55
+ class InvertedIndex:
56
+ posting_lists: List[PostingList] # docid -> posting_list
57
+ vocab: Dict[str, int]
58
+ cid2docid: Dict[str, int] # collection_id -> docid
59
+ collection_ids: List[str] # docid -> collection_id
60
+ doc_texts: Optional[List[str]] = None # docid -> document text
61
+
62
+ def save(self, output_dir: str) -> None:
63
+ os.makedirs(output_dir, exist_ok=True)
64
+ with open(os.path.join(output_dir, "index.pkl"), "wb") as f:
65
+ pickle.dump(self, f)
66
+
67
+ @classmethod
68
+ def from_saved(cls: Type[T], saved_dir: str) -> T:
69
+ index = cls(
70
+ posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None
71
+ )
72
+ with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
73
+ index = pickle.load(f)
74
+ return index
75
+
76
+
77
+ # The output of the counting function:
78
+ @dataclass
79
+ class Counting:
80
+ posting_lists: List[PostingList]
81
+ vocab: Dict[str, int]
82
+ cid2docid: Dict[str, int]
83
+ collection_ids: List[str]
84
+ dfs: List[int] # tid -> df
85
+ dls: List[int] # docid -> doc length
86
+ avgdl: float
87
+ nterms: int
88
+ doc_texts: Optional[List[str]] = None
89
+
90
+
91
+ def run_counting(
92
+ documents: Iterable[Document],
93
+ tokenize_fn: Callable[[str], List[str]] = simple_tokenize,
94
+ store_raw: bool = True, # store the document text in doc_texts
95
+ ndocs: Optional[int] = None,
96
+ show_progress_bar: bool = True,
97
+ ) -> Counting:
98
+ """Counting TFs, DFs, doc_lengths, etc."""
99
+ posting_lists: List[PostingList] = []
100
+ vocab: Dict[str, int] = {}
101
+ cid2docid: Dict[str, int] = {}
102
+ collection_ids: List[str] = []
103
+ dfs: List[int] = [] # tid -> df
104
+ dls: List[int] = [] # docid -> doc length
105
+ nterms: int = 0
106
+ doc_texts: Optional[List[str]] = []
107
+ for doc in tqdm.tqdm(
108
+ documents,
109
+ desc="Counting",
110
+ total=ndocs,
111
+ disable=not show_progress_bar,
112
+ ):
113
+ if doc.collection_id in cid2docid:
114
+ continue
115
+ collection_ids.append(doc.collection_id)
116
+ docid = cid2docid.setdefault(doc.collection_id, len(cid2docid))
117
+ toks = tokenize_fn(doc.text)
118
+ tok2tf = Counter(toks)
119
+ dls.append(sum(tok2tf.values()))
120
+ for tok, tf in tok2tf.items():
121
+ nterms += tf
122
+ tid = vocab.get(tok, None)
123
+ if tid is None:
124
+ posting_lists.append(
125
+ PostingList(term=tok, docid_postings=[], tweight_postings=[])
126
+ )
127
+ tid = vocab.setdefault(tok, len(vocab))
128
+ posting_lists[tid].docid_postings.append(docid)
129
+ posting_lists[tid].tweight_postings.append(tf)
130
+ if tid < len(dfs):
131
+ dfs[tid] += 1
132
+ else:
133
+ dfs.append(
134
+ 1) # Fixed according to moodle discussion https://moodle.informatik.tu-darmstadt.de/mod/moodleoverflow/discussion.php?d=2097
135
+ if store_raw:
136
+ doc_texts.append(doc.text)
137
+ else:
138
+ doc_texts = None
139
+ return Counting(
140
+ posting_lists=posting_lists,
141
+ vocab=vocab,
142
+ cid2docid=cid2docid,
143
+ collection_ids=collection_ids,
144
+ dfs=dfs,
145
+ dls=dls,
146
+ avgdl=sum(dls) / len(dls),
147
+ nterms=nterms,
148
+ doc_texts=doc_texts,
149
+ )
150
+
151
+
152
+ counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
153
+
154
+
155
+ @dataclass
156
+ class BM25Index(InvertedIndex):
157
+
158
+ @staticmethod
159
+ def tokenize(text: str) -> List[str]:
160
+ return simple_tokenize(text)
161
+
162
+ @staticmethod
163
+ def cache_term_weights(
164
+ posting_lists: List[PostingList],
165
+ total_docs: int,
166
+ avgdl: float,
167
+ dfs: List[int],
168
+ dls: List[int],
169
+ k1: float,
170
+ b: float,
171
+ ) -> None:
172
+ """Compute term weights and caching"""
173
+
174
+ N = total_docs
175
+ for tid, posting_list in enumerate(
176
+ tqdm.tqdm(posting_lists, desc="Regularizing TFs")
177
+ ):
178
+ idf = BM25Index.calc_idf(df=dfs[tid], N=N)
179
+ for i in range(len(posting_list.docid_postings)):
180
+ docid = posting_list.docid_postings[i]
181
+ tf = posting_list.tweight_postings[i]
182
+ dl = dls[docid]
183
+ regularized_tf = BM25Index.calc_regularized_tf(
184
+ tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b
185
+ )
186
+ posting_list.tweight_postings[i] = regularized_tf * idf
187
+
188
+ @staticmethod
189
+ def calc_regularized_tf(
190
+ tf: int, dl: float, avgdl: float, k1: float, b: float
191
+ ) -> float:
192
+ return tf / (tf + k1 * (1 - b + b * dl / avgdl))
193
+
194
+ @staticmethod
195
+ def calc_idf(df: int, N: int):
196
+ return math.log(1 + (N - df + 0.5) / (df + 0.5))
197
+
198
+ @classmethod
199
+ def build_from_documents(
200
+ cls: Type[BM25Index],
201
+ documents: Iterable[Document],
202
+ store_raw: bool = True,
203
+ output_dir: Optional[str] = None,
204
+ ndocs: Optional[int] = None,
205
+ show_progress_bar: bool = True,
206
+ k1: float = 0.9,
207
+ b: float = 0.4,
208
+ ) -> BM25Index:
209
+ # Counting TFs, DFs, doc_lengths, etc.:
210
+ counting = run_counting(
211
+ documents=documents,
212
+ tokenize_fn=BM25Index.tokenize,
213
+ store_raw=store_raw,
214
+ ndocs=ndocs,
215
+ show_progress_bar=show_progress_bar,
216
+ )
217
+
218
+ # Compute term weights and caching:
219
+ posting_lists = counting.posting_lists
220
+ total_docs = len(counting.cid2docid)
221
+ BM25Index.cache_term_weights(
222
+ posting_lists=posting_lists,
223
+ total_docs=total_docs,
224
+ avgdl=counting.avgdl,
225
+ dfs=counting.dfs,
226
+ dls=counting.dls,
227
+ k1=k1,
228
+ b=b,
229
+ )
230
+
231
+ # Assembly and save:
232
+ index = BM25Index(
233
+ posting_lists=posting_lists,
234
+ vocab=counting.vocab,
235
+ cid2docid=counting.cid2docid,
236
+ collection_ids=counting.collection_ids,
237
+ doc_texts=counting.doc_texts,
238
+ )
239
+ return index
240
+
241
+
242
+ bm25_index = BM25Index.build_from_documents(
243
+ documents=iter(sciq.corpus),
244
+ ndocs=12160,
245
+ show_progress_bar=True,
246
+ )
247
+
248
+ bm25_index.save("output/bm25_index")
249
+
250
+
251
+ class BaseInvertedIndexRetriever(BaseRetriever):
252
+
253
+ @property
254
+ @abstractmethod
255
+ def index_class(self) -> Type[InvertedIndex]:
256
+ pass
257
+
258
+ def __init__(self, index_dir: str) -> None:
259
+ self.index = self.index_class.from_saved(index_dir)
260
+
261
+ def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
262
+ toks = self.index.tokenize(query)
263
+ target_docid = self.index.cid2docid[cid]
264
+ term_weights = {}
265
+ for tok in toks:
266
+ if tok not in self.index.vocab:
267
+ continue
268
+ tid = self.index.vocab[tok]
269
+ posting_list = self.index.posting_lists[tid]
270
+ for docid, tweight in zip(
271
+ posting_list.docid_postings, posting_list.tweight_postings
272
+ ):
273
+ if docid == target_docid:
274
+ term_weights[tok] = tweight
275
+ break
276
+ return term_weights
277
+
278
+ def score(self, query: str, cid: str) -> float:
279
+ return sum(self.get_term_weights(query=query, cid=cid).values())
280
+
281
+ def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
282
+ toks = self.index.tokenize(query)
283
+ docid2score: Dict[int, float] = {}
284
+ for tok in toks:
285
+ if tok not in self.index.vocab:
286
+ continue
287
+ tid = self.index.vocab[tok]
288
+ posting_list = self.index.posting_lists[tid]
289
+ for docid, tweight in zip(
290
+ posting_list.docid_postings, posting_list.tweight_postings
291
+ ):
292
+ docid2score.setdefault(docid, 0)
293
+ docid2score[docid] += tweight
294
+ docid2score = dict(
295
+ sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk]
296
+ )
297
+ return {
298
+ self.index.collection_ids[docid]: score
299
+ for docid, score in docid2score.items()
300
+ }
301
+
302
+
303
+ class BM25Retriever(BaseInvertedIndexRetriever):
304
+
305
+ @property
306
+ def index_class(self) -> Type[BM25Index]:
307
+ return BM25Index
308
+
309
+
310
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
311
+
312
+
313
+ class Hit(TypedDict):
314
+ cid: str
315
+ score: float
316
+ text: str
317
+
318
+
319
+ demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable
320
+ return_type = List[Hit]
321
+
322
+ ## YOUR_CODE_STARTS_HERE
323
+ retriever = BM25Retriever(index_dir="output/bm25_index")
324
+
325
+
326
+ # retriever = CSCBM25Retriever(index_dir='output/csc_bm25_index')
327
+
328
+ def search(query: str) -> List[Hit]:
329
+ results = retriever.retrieve(query)
330
+ hits = []
331
+ print(results)
332
+ for cid, score in results.items():
333
+ docid = retriever.index.cid2docid[cid]
334
+ text = retriever.index.doc_texts[docid]
335
+ hit = Hit(cid=cid, score=score, text=text)
336
+ hits.append(hit)
337
+ return hits
338
+
339
+
340
+ # Create the Gradio interface
341
+ demo = gr.Interface(
342
+ fn=search,
343
+ inputs=gr.Textbox(lines=1, placeholder="Enter your query here..."),
344
+ outputs=gr.Textbox(),
345
+ title="BM25 Search Engine",
346
+ description="Enter a query to search the SciQ dataset using BM25.",
347
+ )
348
+ ## YOUR_CODE_ENDS_HERE
349
+ demo.launch()