| from __future__ import annotations |
| import gradio as gr |
| from typing import TypedDict |
| from typing import Dict, List |
| from datasets import load_dataset |
| import joblib |
| from dataclasses import dataclass |
| from enum import Enum |
| from typing import Dict, List, Type |
| from dataclasses import dataclass |
| from typing import Optional |
| from dataclasses import asdict, dataclass |
| import math |
| import os |
| from typing import Iterable, List, Optional, Type |
| import tqdm |
| from dataclasses import dataclass |
| import pickle |
| import os |
| from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar |
| from collections import Counter |
| import tqdm |
| import re |
| import nltk |
| from abc import ABC, abstractmethod |
| from typing import Any, Dict, Type |
| nltk.download("stopwords", quiet=True) |
| from nltk.corpus import stopwords as nltk_stopwords |
|
|
| class BaseRetriever(ABC): |
|
|
| @property |
| @abstractmethod |
| def index_class(self) -> Type[Any]: |
| pass |
|
|
| def get_term_weights(self, query: str, cid: str) -> Dict[str, float]: |
| raise NotImplementedError |
|
|
| @abstractmethod |
| def score(self, query: str, cid: str) -> float: |
| pass |
|
|
| @abstractmethod |
| def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]: |
| pass |
|
|
|
|
| @dataclass |
| class Document: |
| collection_id: str |
| text: str |
|
|
|
|
| @dataclass |
| class Query: |
| query_id: str |
| text: str |
|
|
|
|
| @dataclass |
| class QRel: |
| query_id: str |
| collection_id: str |
| relevance: int |
| answer: Optional[str] = None |
|
|
|
|
|
|
| class Split(str, Enum): |
| train = "train" |
| dev = "dev" |
| test = "test" |
|
|
|
|
| @dataclass |
| class IRDataset: |
| corpus: List[Document] |
| queries: List[Query] |
| split2qrels: Dict[Split, List[QRel]] |
|
|
| def get_stats(self) -> Dict[str, int]: |
| stats = {"|corpus|": len(self.corpus), "|queries|": len(self.queries)} |
| for split, qrels in self.split2qrels.items(): |
| stats[f"|qrels-{split}|"] = len(qrels) |
| return stats |
|
|
| def get_qrels_dict(self, split: Split) -> Dict[str, Dict[str, int]]: |
| qrels_dict = {} |
| for qrel in self.split2qrels[split]: |
| qrels_dict.setdefault(qrel.query_id, {}) |
| qrels_dict[qrel.query_id][qrel.collection_id] = qrel.relevance |
| return qrels_dict |
|
|
| def get_split_queries(self, split: Split) -> List[Query]: |
| qrels = self.split2qrels[split] |
| qids = {qrel.query_id for qrel in qrels} |
| return list(filter(lambda query: query.query_id in qids, self.queries)) |
|
|
|
|
| @(joblib.Memory(".cache").cache) |
| def load_sciq(verbose: bool = False) -> IRDataset: |
| train = load_dataset("allenai/sciq", split="train") |
| validation = load_dataset("allenai/sciq", split="validation") |
| test = load_dataset("allenai/sciq", split="test") |
| data = {Split.train: train, Split.dev: validation, Split.test: test} |
|
|
| |
| df = train.to_pandas() + validation.to_pandas() + test.to_pandas() |
| for question, group in df.groupby("question"): |
| assert len(set(group["support"].tolist())) == len(group) |
| assert len(set(group["correct_answer"].tolist())) == len(group) |
|
|
| |
| corpus = [] |
| queries = [] |
| split2qrels: Dict[str, List[dict]] = {} |
| question2id = {} |
| support2id = {} |
| for split, rows in data.items(): |
| if verbose: |
| print(f"|raw_{split}|", len(rows)) |
| split2qrels[split] = [] |
| for i, row in enumerate(rows): |
| example_id = f"{split}-{i}" |
| support: str = row["support"] |
| if len(support.strip()) == 0: |
| continue |
| question = row["question"] |
| if len(support.strip()) == 0: |
| continue |
| if support in support2id: |
| continue |
| else: |
| support2id[support] = example_id |
| if question in question2id: |
| continue |
| else: |
| question2id[question] = example_id |
| doc = {"collection_id": example_id, "text": support} |
| query = {"query_id": example_id, "text": row["question"]} |
| qrel = { |
| "query_id": example_id, |
| "collection_id": example_id, |
| "relevance": 1, |
| "answer": row["correct_answer"], |
| } |
| corpus.append(Document(**doc)) |
| queries.append(Query(**query)) |
| split2qrels[split].append(QRel(**qrel)) |
|
|
| |
| return IRDataset(corpus=corpus, queries=queries, split2qrels=split2qrels) |
| |
| class BaseInvertedIndexRetriever(BaseRetriever): |
|
|
| @property |
| @abstractmethod |
| def index_class(self) -> Type[InvertedIndex]: |
| pass |
|
|
| def __init__(self, index_dir: str) -> None: |
| self.index = self.index_class.from_saved(index_dir) |
|
|
| def get_term_weights(self, query: str, cid: str) -> Dict[str, float]: |
| toks = self.index.tokenize(query) |
| target_docid = self.index.cid2docid[cid] |
| term_weights = {} |
| for tok in toks: |
| if tok not in self.index.vocab: |
| continue |
| tid = self.index.vocab[tok] |
| posting_list = self.index.posting_lists[tid] |
| for docid, tweight in zip( |
| posting_list.docid_postings, posting_list.tweight_postings |
| ): |
| if docid == target_docid: |
| term_weights[tok] = tweight |
| break |
| return term_weights |
|
|
| def score(self, query: str, cid: str) -> float: |
| return sum(self.get_term_weights(query=query, cid=cid).values()) |
|
|
| def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]: |
| toks = self.index.tokenize(query) |
| docid2score: Dict[int, float] = {} |
| for tok in toks: |
| if tok not in self.index.vocab: |
| continue |
| tid = self.index.vocab[tok] |
| posting_list = self.index.posting_lists[tid] |
| for docid, tweight in zip( |
| posting_list.docid_postings, posting_list.tweight_postings |
| ): |
| docid2score.setdefault(docid, 0) |
| docid2score[docid] += tweight |
| docid2score = dict( |
| sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk] |
| ) |
| return { |
| self.index.collection_ids[docid]: score |
| for docid, score in docid2score.items() |
| } |
|
|
|
|
| class BM25Retriever(BaseInvertedIndexRetriever): |
|
|
| @property |
| def index_class(self) -> Type[BM25Index]: |
| return BM25Index |
|
|
|
|
| if __name__ == "__main__": |
| |
| import ujson |
| import time |
|
|
| start = time.time() |
| dataset = load_sciq(verbose=True) |
| print(f"Loading costs: {time.time() - start}s") |
| print(ujson.dumps(dataset.get_stats(), indent=4)) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| LANGUAGE = "english" |
| word_splitter = re.compile(r"(?u)\b\w\w+\b").findall |
| stopwords = set(nltk_stopwords.words(LANGUAGE)) |
|
|
|
|
| def word_splitting(text: str) -> List[str]: |
| return word_splitter(text.lower()) |
|
|
| def lemmatization(words: List[str]) -> List[str]: |
| return words |
|
|
| def simple_tokenize(text: str) -> List[str]: |
| words = word_splitting(text) |
| tokenized = list(filter(lambda w: w not in stopwords, words)) |
| tokenized = lemmatization(tokenized) |
| return tokenized |
|
|
| T = TypeVar("T", bound="InvertedIndex") |
|
|
| @dataclass |
| class PostingList: |
| term: str |
| docid_postings: List[int] |
| tweight_postings: List[float] |
|
|
|
|
| @dataclass |
| class InvertedIndex: |
| posting_lists: List[PostingList] |
| vocab: Dict[str, int] |
| cid2docid: Dict[str, int] |
| collection_ids: List[str] |
| doc_texts: Optional[List[str]] = None |
|
|
| def save(self, output_dir: str) -> None: |
| os.makedirs(output_dir, exist_ok=True) |
| with open(os.path.join(output_dir, "index.pkl"), "wb") as f: |
| pickle.dump(self, f) |
|
|
| @classmethod |
| def from_saved(cls: Type[T], saved_dir: str) -> T: |
| index = cls( |
| posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None |
| ) |
| with open(os.path.join(saved_dir, "index.pkl"), "rb") as f: |
| index = pickle.load(f) |
| return index |
|
|
|
|
| |
| @dataclass |
| class Counting: |
| posting_lists: List[PostingList] |
| vocab: Dict[str, int] |
| cid2docid: Dict[str, int] |
| collection_ids: List[str] |
| dfs: List[int] |
| dls: List[int] |
| avgdl: float |
| nterms: int |
| doc_texts: Optional[List[str]] = None |
|
|
| def run_counting( |
| documents: Iterable[Document], |
| tokenize_fn: Callable[[str], List[str]] = simple_tokenize, |
| store_raw: bool = True, |
| ndocs: Optional[int] = None, |
| show_progress_bar: bool = True, |
| ) -> Counting: |
| """Counting TFs, DFs, doc_lengths, etc.""" |
| posting_lists: List[PostingList] = [] |
| vocab: Dict[str, int] = {} |
| cid2docid: Dict[str, int] = {} |
| collection_ids: List[str] = [] |
| dfs: List[int] = [] |
| dls: List[int] = [] |
| nterms: int = 0 |
| doc_texts: Optional[List[str]] = [] |
| for doc in tqdm.tqdm( |
| documents, |
| desc="Counting", |
| total=ndocs, |
| disable=not show_progress_bar, |
| ): |
| if doc.collection_id in cid2docid: |
| continue |
| collection_ids.append(doc.collection_id) |
| docid = cid2docid.setdefault(doc.collection_id, len(cid2docid)) |
| toks = tokenize_fn(doc.text) |
| tok2tf = Counter(toks) |
| dls.append(sum(tok2tf.values())) |
| for tok, tf in tok2tf.items(): |
| nterms += tf |
| tid = vocab.get(tok, None) |
| if tid is None: |
| posting_lists.append( |
| PostingList(term=tok, docid_postings=[], tweight_postings=[]) |
| ) |
| tid = vocab.setdefault(tok, len(vocab)) |
| posting_lists[tid].docid_postings.append(docid) |
| posting_lists[tid].tweight_postings.append(tf) |
| if tid < len(dfs): |
| dfs[tid] += 1 |
| else: |
| dfs.append(0) |
| if store_raw: |
| doc_texts.append(doc.text) |
| else: |
| doc_texts = None |
| return Counting( |
| posting_lists=posting_lists, |
| vocab=vocab, |
| cid2docid=cid2docid, |
| collection_ids=collection_ids, |
| dfs=dfs, |
| dls=dls, |
| avgdl=sum(dls) / len(dls), |
| nterms=nterms, |
| doc_texts=doc_texts, |
| ) |
|
|
|
|
| @dataclass |
| class BM25Index(InvertedIndex): |
|
|
| @staticmethod |
| def tokenize(text: str) -> List[str]: |
| return simple_tokenize(text) |
|
|
| @staticmethod |
| def cache_term_weights( |
| posting_lists: List[PostingList], |
| total_docs: int, |
| avgdl: float, |
| dfs: List[int], |
| dls: List[int], |
| k1: float, |
| b: float, |
| ) -> None: |
| """Compute term weights and caching""" |
|
|
| N = total_docs |
| for tid, posting_list in enumerate( |
| tqdm.tqdm(posting_lists, desc="Regularizing TFs") |
| ): |
| idf = BM25Index.calc_idf(df=dfs[tid], N=N) |
| for i in range(len(posting_list.docid_postings)): |
| docid = posting_list.docid_postings[i] |
| tf = posting_list.tweight_postings[i] |
| dl = dls[docid] |
| regularized_tf = BM25Index.calc_regularized_tf( |
| tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b |
| ) |
| posting_list.tweight_postings[i] = regularized_tf * idf |
|
|
| @staticmethod |
| def calc_regularized_tf( |
| tf: int, dl: float, avgdl: float, k1: float, b: float |
| ) -> float: |
| return tf / (tf + k1 * (1 - b + b * dl / avgdl)) |
|
|
| @staticmethod |
| def calc_idf(df: int, N: int): |
| return math.log(1 + (N - df + 0.5) / (df + 0.5)) |
|
|
| @classmethod |
| def build_from_documents( |
| cls: Type[BM25Index], |
| documents: Iterable[Document], |
| store_raw: bool = True, |
| output_dir: Optional[str] = None, |
| ndocs: Optional[int] = None, |
| show_progress_bar: bool = True, |
| k1: float = 0.9, |
| b: float = 0.4, |
| ) -> BM25Index: |
| |
| counting = run_counting( |
| documents=documents, |
| tokenize_fn=BM25Index.tokenize, |
| store_raw=store_raw, |
| ndocs=ndocs, |
| show_progress_bar=show_progress_bar, |
| ) |
|
|
| |
| posting_lists = counting.posting_lists |
| total_docs = len(counting.cid2docid) |
| BM25Index.cache_term_weights( |
| posting_lists=posting_lists, |
| total_docs=total_docs, |
| avgdl=counting.avgdl, |
| dfs=counting.dfs, |
| dls=counting.dls, |
| k1=k1, |
| b=b, |
| ) |
|
|
| |
| index = BM25Index( |
| posting_lists=posting_lists, |
| vocab=counting.vocab, |
| cid2docid=counting.cid2docid, |
| collection_ids=counting.collection_ids, |
| doc_texts=counting.doc_texts, |
| ) |
| return index |
|
|
|
|
| class Hit(TypedDict): |
| cid: str |
| score: float |
| text: str |
|
|
| |
| def search(query: str) -> List[Hit]: |
|
|
| sciq = load_sciq() |
| counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus)) |
|
|
| bm25_index = BM25Index.build_from_documents( |
| documents=iter(sciq.corpus), |
| ndocs=12160, |
| show_progress_bar=True |
| ) |
| bm25_index.save("output/bm25_index") |
|
|
| bm25_retriever = BM25Retriever(index_dir="output/bm25_index") |
|
|
| results = bm25_retriever.retrieve(query=query) |
|
|
| hits: List[Hit] = [] |
| for cid, score in results.items(): |
| docid = bm25_retriever.index.cid2docid[cid] |
| text = bm25_retriever.index.doc_texts[docid] |
| hits.append({"cid": cid, "score": score, "text": text}) |
|
|
| return hits |
| |
|
|
| demo: Optional[gr.Interface] = gr.Interface( |
| fn=search, |
| inputs=gr.Textbox(label="Query"), |
| outputs=gr.JSON(label="Results") |
| ) |
| return_type = List[Hit] |
| demo.launch() |