| from __future__ import annotations
|
| from dataclasses import asdict, dataclass
|
| import math
|
| import os
|
| from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar, TypedDict
|
| import tqdm
|
| from nlp4web_codebase.ir.data_loaders.dm import Document
|
| from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
|
| from nlp4web_codebase.ir.models import BaseRetriever
|
| from abc import abstractmethod
|
|
|
| import pickle
|
| from collections import Counter
|
| import re
|
| import gradio as gr
|
|
|
| import nltk
|
|
|
| nltk.download("stopwords", quiet=True)
|
| from nltk.corpus import stopwords as nltk_stopwords
|
|
|
| LANGUAGE = "english"
|
| word_splitter = re.compile(r"(?u)\b\w\w+\b").findall
|
| stopwords = set(nltk_stopwords.words(LANGUAGE))
|
|
|
| sciq = load_sciq()
|
|
|
|
|
| def word_splitting(text: str) -> List[str]:
|
| return word_splitter(text.lower())
|
|
|
|
|
| def lemmatization(words: List[str]) -> List[str]:
|
| return words
|
|
|
|
|
| def simple_tokenize(text: str) -> List[str]:
|
| words = word_splitting(text)
|
| tokenized = list(filter(lambda w: w not in stopwords, words))
|
| tokenized = lemmatization(tokenized)
|
| return tokenized
|
|
|
|
|
| T = TypeVar("T", bound="InvertedIndex")
|
|
|
|
|
| @dataclass
|
| class PostingList:
|
| term: str
|
| docid_postings: List[int]
|
| tweight_postings: List[float]
|
|
|
|
|
| @dataclass
|
| class InvertedIndex:
|
| posting_lists: List[PostingList]
|
| vocab: Dict[str, int]
|
| cid2docid: Dict[str, int]
|
| collection_ids: List[str]
|
| doc_texts: Optional[List[str]] = None
|
|
|
| def save(self, output_dir: str) -> None:
|
| os.makedirs(output_dir, exist_ok=True)
|
| with open(os.path.join(output_dir, "index.pkl"), "wb") as f:
|
| pickle.dump(self, f)
|
|
|
| @classmethod
|
| def from_saved(cls: Type[T], saved_dir: str) -> T:
|
| index = cls(
|
| posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None
|
| )
|
| with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
|
| index = pickle.load(f)
|
| return index
|
|
|
|
|
|
|
| @dataclass
|
| class Counting:
|
| posting_lists: List[PostingList]
|
| vocab: Dict[str, int]
|
| cid2docid: Dict[str, int]
|
| collection_ids: List[str]
|
| dfs: List[int]
|
| dls: List[int]
|
| avgdl: float
|
| nterms: int
|
| doc_texts: Optional[List[str]] = None
|
|
|
|
|
| def run_counting(
|
| documents: Iterable[Document],
|
| tokenize_fn: Callable[[str], List[str]] = simple_tokenize,
|
| store_raw: bool = True,
|
| ndocs: Optional[int] = None,
|
| show_progress_bar: bool = True,
|
| ) -> Counting:
|
| """Counting TFs, DFs, doc_lengths, etc."""
|
| posting_lists: List[PostingList] = []
|
| vocab: Dict[str, int] = {}
|
| cid2docid: Dict[str, int] = {}
|
| collection_ids: List[str] = []
|
| dfs: List[int] = []
|
| dls: List[int] = []
|
| nterms: int = 0
|
| doc_texts: Optional[List[str]] = []
|
| for doc in tqdm.tqdm(
|
| documents,
|
| desc="Counting",
|
| total=ndocs,
|
| disable=not show_progress_bar,
|
| ):
|
| if doc.collection_id in cid2docid:
|
| continue
|
| collection_ids.append(doc.collection_id)
|
| docid = cid2docid.setdefault(doc.collection_id, len(cid2docid))
|
| toks = tokenize_fn(doc.text)
|
| tok2tf = Counter(toks)
|
| dls.append(sum(tok2tf.values()))
|
| for tok, tf in tok2tf.items():
|
| nterms += tf
|
| tid = vocab.get(tok, None)
|
| if tid is None:
|
| posting_lists.append(
|
| PostingList(term=tok, docid_postings=[], tweight_postings=[])
|
| )
|
| tid = vocab.setdefault(tok, len(vocab))
|
| posting_lists[tid].docid_postings.append(docid)
|
| posting_lists[tid].tweight_postings.append(tf)
|
| if tid < len(dfs):
|
| dfs[tid] += 1
|
| else:
|
| dfs.append(
|
| 1)
|
| if store_raw:
|
| doc_texts.append(doc.text)
|
| else:
|
| doc_texts = None
|
| return Counting(
|
| posting_lists=posting_lists,
|
| vocab=vocab,
|
| cid2docid=cid2docid,
|
| collection_ids=collection_ids,
|
| dfs=dfs,
|
| dls=dls,
|
| avgdl=sum(dls) / len(dls),
|
| nterms=nterms,
|
| doc_texts=doc_texts,
|
| )
|
|
|
|
|
| counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
|
|
|
|
|
| @dataclass
|
| class BM25Index(InvertedIndex):
|
|
|
| @staticmethod
|
| def tokenize(text: str) -> List[str]:
|
| return simple_tokenize(text)
|
|
|
| @staticmethod
|
| def cache_term_weights(
|
| posting_lists: List[PostingList],
|
| total_docs: int,
|
| avgdl: float,
|
| dfs: List[int],
|
| dls: List[int],
|
| k1: float,
|
| b: float,
|
| ) -> None:
|
| """Compute term weights and caching"""
|
|
|
| N = total_docs
|
| for tid, posting_list in enumerate(
|
| tqdm.tqdm(posting_lists, desc="Regularizing TFs")
|
| ):
|
| idf = BM25Index.calc_idf(df=dfs[tid], N=N)
|
| for i in range(len(posting_list.docid_postings)):
|
| docid = posting_list.docid_postings[i]
|
| tf = posting_list.tweight_postings[i]
|
| dl = dls[docid]
|
| regularized_tf = BM25Index.calc_regularized_tf(
|
| tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b
|
| )
|
| posting_list.tweight_postings[i] = regularized_tf * idf
|
|
|
| @staticmethod
|
| def calc_regularized_tf(
|
| tf: int, dl: float, avgdl: float, k1: float, b: float
|
| ) -> float:
|
| return tf / (tf + k1 * (1 - b + b * dl / avgdl))
|
|
|
| @staticmethod
|
| def calc_idf(df: int, N: int):
|
| return math.log(1 + (N - df + 0.5) / (df + 0.5))
|
|
|
| @classmethod
|
| def build_from_documents(
|
| cls: Type[BM25Index],
|
| documents: Iterable[Document],
|
| store_raw: bool = True,
|
| output_dir: Optional[str] = None,
|
| ndocs: Optional[int] = None,
|
| show_progress_bar: bool = True,
|
| k1: float = 0.9,
|
| b: float = 0.4,
|
| ) -> BM25Index:
|
|
|
| counting = run_counting(
|
| documents=documents,
|
| tokenize_fn=BM25Index.tokenize,
|
| store_raw=store_raw,
|
| ndocs=ndocs,
|
| show_progress_bar=show_progress_bar,
|
| )
|
|
|
|
|
| posting_lists = counting.posting_lists
|
| total_docs = len(counting.cid2docid)
|
| BM25Index.cache_term_weights(
|
| posting_lists=posting_lists,
|
| total_docs=total_docs,
|
| avgdl=counting.avgdl,
|
| dfs=counting.dfs,
|
| dls=counting.dls,
|
| k1=k1,
|
| b=b,
|
| )
|
|
|
|
|
| index = BM25Index(
|
| posting_lists=posting_lists,
|
| vocab=counting.vocab,
|
| cid2docid=counting.cid2docid,
|
| collection_ids=counting.collection_ids,
|
| doc_texts=counting.doc_texts,
|
| )
|
| return index
|
|
|
|
|
| bm25_index = BM25Index.build_from_documents(
|
| documents=iter(sciq.corpus),
|
| ndocs=12160,
|
| show_progress_bar=True,
|
| )
|
|
|
| bm25_index.save("output/bm25_index")
|
|
|
|
|
| class BaseInvertedIndexRetriever(BaseRetriever):
|
|
|
| @property
|
| @abstractmethod
|
| def index_class(self) -> Type[InvertedIndex]:
|
| pass
|
|
|
| def __init__(self, index_dir: str) -> None:
|
| self.index = self.index_class.from_saved(index_dir)
|
|
|
| def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
|
| toks = self.index.tokenize(query)
|
| target_docid = self.index.cid2docid[cid]
|
| term_weights = {}
|
| for tok in toks:
|
| if tok not in self.index.vocab:
|
| continue
|
| tid = self.index.vocab[tok]
|
| posting_list = self.index.posting_lists[tid]
|
| for docid, tweight in zip(
|
| posting_list.docid_postings, posting_list.tweight_postings
|
| ):
|
| if docid == target_docid:
|
| term_weights[tok] = tweight
|
| break
|
| return term_weights
|
|
|
| def score(self, query: str, cid: str) -> float:
|
| return sum(self.get_term_weights(query=query, cid=cid).values())
|
|
|
| def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
|
| toks = self.index.tokenize(query)
|
| docid2score: Dict[int, float] = {}
|
| for tok in toks:
|
| if tok not in self.index.vocab:
|
| continue
|
| tid = self.index.vocab[tok]
|
| posting_list = self.index.posting_lists[tid]
|
| for docid, tweight in zip(
|
| posting_list.docid_postings, posting_list.tweight_postings
|
| ):
|
| docid2score.setdefault(docid, 0)
|
| docid2score[docid] += tweight
|
| docid2score = dict(
|
| sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk]
|
| )
|
| return {
|
| self.index.collection_ids[docid]: score
|
| for docid, score in docid2score.items()
|
| }
|
|
|
|
|
| class BM25Retriever(BaseInvertedIndexRetriever):
|
|
|
| @property
|
| def index_class(self) -> Type[BM25Index]:
|
| return BM25Index
|
|
|
|
|
| bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
|
|
|
|
|
| class Hit(TypedDict):
|
| cid: str
|
| score: float
|
| text: str
|
|
|
|
|
| demo: Optional[gr.Interface] = None
|
| return_type = List[Hit]
|
|
|
|
|
| retriever = BM25Retriever(index_dir="output/bm25_index")
|
|
|
|
|
|
|
|
|
| def search(query: str) -> List[Hit]:
|
| results = retriever.retrieve(query)
|
| hits = []
|
| print(results)
|
| for cid, score in results.items():
|
| docid = retriever.index.cid2docid[cid]
|
| text = retriever.index.doc_texts[docid]
|
| hit = Hit(cid=cid, score=score, text=text)
|
| hits.append(hit)
|
| return hits
|
|
|
|
|
|
|
| demo = gr.Interface(
|
| fn=search,
|
| inputs=gr.Textbox(lines=1, placeholder="Enter your query here..."),
|
| outputs=gr.Textbox(),
|
| title="BM25 Search Engine",
|
| description="Enter a query to search the SciQ dataset using BM25.",
|
| )
|
|
|
| demo.launch()
|
|
|