|
|
from __future__ import annotations |
|
|
|
|
|
"""HW1 (more instructed).ipynb adlı not defterinin kopyası |
|
|
|
|
|
Automatically generated by Colab. |
|
|
|
|
|
Original file is located at |
|
|
https://colab.research.google.com/drive/18CpMm-9nCuo64vywjq-qJhJF_DWrGavX |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
"""## Pre-requisite code |
|
|
|
|
|
The code within this section will be used in the tasks. Please do not change these code lines. |
|
|
|
|
|
### SciQ loading and counting |
|
|
""" |
|
|
|
|
|
from dataclasses import dataclass |
|
|
import pickle |
|
|
import os |
|
|
from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar |
|
|
from nlp4web_codebase.ir.data_loaders.dm import Document |
|
|
from collections import Counter |
|
|
import tqdm |
|
|
import re |
|
|
import nltk |
|
|
nltk.download("stopwords", quiet=True) |
|
|
from nltk.corpus import stopwords as nltk_stopwords |
|
|
|
|
|
LANGUAGE = "english" |
|
|
word_splitter = re.compile(r"(?u)\b\w\w+\b").findall |
|
|
stopwords = set(nltk_stopwords.words(LANGUAGE)) |
|
|
|
|
|
|
|
|
def word_splitting(text: str) -> List[str]: |
|
|
return word_splitter(text.lower()) |
|
|
|
|
|
def lemmatization(words: List[str]) -> List[str]: |
|
|
return words |
|
|
|
|
|
def simple_tokenize(text: str) -> List[str]: |
|
|
words = word_splitting(text) |
|
|
tokenized = list(filter(lambda w: w not in stopwords, words)) |
|
|
tokenized = lemmatization(tokenized) |
|
|
return tokenized |
|
|
|
|
|
T = TypeVar("T", bound="InvertedIndex") |
|
|
|
|
|
@dataclass |
|
|
class PostingList: |
|
|
term: str |
|
|
docid_postings: List[int] |
|
|
tweight_postings: List[float] |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class InvertedIndex: |
|
|
posting_lists: List[PostingList] |
|
|
vocab: Dict[str, int] |
|
|
cid2docid: Dict[str, int] |
|
|
collection_ids: List[str] |
|
|
doc_texts: Optional[List[str]] = None |
|
|
|
|
|
def save(self, output_dir: str) -> None: |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
with open(os.path.join(output_dir, "index.pkl"), "wb") as f: |
|
|
pickle.dump(self, f) |
|
|
|
|
|
@classmethod |
|
|
def from_saved(cls: Type[T], saved_dir: str) -> T: |
|
|
index = cls( |
|
|
posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None |
|
|
) |
|
|
with open(os.path.join(saved_dir, "index.pkl"), "rb") as f: |
|
|
index = pickle.load(f) |
|
|
return index |
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Counting: |
|
|
posting_lists: List[PostingList] |
|
|
vocab: Dict[str, int] |
|
|
cid2docid: Dict[str, int] |
|
|
collection_ids: List[str] |
|
|
dfs: List[int] |
|
|
dls: List[int] |
|
|
avgdl: float |
|
|
nterms: int |
|
|
doc_texts: Optional[List[str]] = None |
|
|
|
|
|
def run_counting( |
|
|
documents: Iterable[Document], |
|
|
tokenize_fn: Callable[[str], List[str]] = simple_tokenize, |
|
|
store_raw: bool = True, |
|
|
ndocs: Optional[int] = None, |
|
|
show_progress_bar: bool = True, |
|
|
) -> Counting: |
|
|
"""Counting TFs, DFs, doc_lengths, etc.""" |
|
|
posting_lists: List[PostingList] = [] |
|
|
vocab: Dict[str, int] = {} |
|
|
cid2docid: Dict[str, int] = {} |
|
|
collection_ids: List[str] = [] |
|
|
dfs: List[int] = [] |
|
|
dls: List[int] = [] |
|
|
nterms: int = 0 |
|
|
doc_texts: Optional[List[str]] = [] |
|
|
for doc in tqdm.tqdm( |
|
|
documents, |
|
|
desc="Counting", |
|
|
total=ndocs, |
|
|
disable=not show_progress_bar, |
|
|
): |
|
|
if doc.collection_id in cid2docid: |
|
|
continue |
|
|
collection_ids.append(doc.collection_id) |
|
|
docid = cid2docid.setdefault(doc.collection_id, len(cid2docid)) |
|
|
toks = tokenize_fn(doc.text) |
|
|
tok2tf = Counter(toks) |
|
|
dls.append(sum(tok2tf.values())) |
|
|
for tok, tf in tok2tf.items(): |
|
|
nterms += tf |
|
|
tid = vocab.get(tok, None) |
|
|
if tid is None: |
|
|
posting_lists.append( |
|
|
PostingList(term=tok, docid_postings=[], tweight_postings=[]) |
|
|
) |
|
|
tid = vocab.setdefault(tok, len(vocab)) |
|
|
posting_lists[tid].docid_postings.append(docid) |
|
|
posting_lists[tid].tweight_postings.append(tf) |
|
|
if tid < len(dfs): |
|
|
dfs[tid] += 1 |
|
|
else: |
|
|
dfs.append(0) |
|
|
if store_raw: |
|
|
doc_texts.append(doc.text) |
|
|
else: |
|
|
doc_texts = None |
|
|
return Counting( |
|
|
posting_lists=posting_lists, |
|
|
vocab=vocab, |
|
|
cid2docid=cid2docid, |
|
|
collection_ids=collection_ids, |
|
|
dfs=dfs, |
|
|
dls=dls, |
|
|
avgdl=sum(dls) / len(dls), |
|
|
nterms=nterms, |
|
|
doc_texts=doc_texts, |
|
|
) |
|
|
|
|
|
from nlp4web_codebase.ir.data_loaders.sciq import load_sciq |
|
|
sciq = load_sciq() |
|
|
counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus)) |
|
|
|
|
|
"""### BM25 Index""" |
|
|
|
|
|
|
|
|
from dataclasses import asdict, dataclass |
|
|
import math |
|
|
import os |
|
|
from typing import Iterable, List, Optional, Type |
|
|
import tqdm |
|
|
from nlp4web_codebase.ir.data_loaders.dm import Document |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class BM25Index(InvertedIndex): |
|
|
|
|
|
@staticmethod |
|
|
def tokenize(text: str) -> List[str]: |
|
|
return simple_tokenize(text) |
|
|
|
|
|
@staticmethod |
|
|
def cache_term_weights( |
|
|
posting_lists: List[PostingList], |
|
|
total_docs: int, |
|
|
avgdl: float, |
|
|
dfs: List[int], |
|
|
dls: List[int], |
|
|
k1: float, |
|
|
b: float, |
|
|
) -> None: |
|
|
"""Compute term weights and caching""" |
|
|
|
|
|
N = total_docs |
|
|
for tid, posting_list in enumerate( |
|
|
tqdm.tqdm(posting_lists, desc="Regularizing TFs") |
|
|
): |
|
|
idf = BM25Index.calc_idf(df=dfs[tid], N=N) |
|
|
for i in range(len(posting_list.docid_postings)): |
|
|
docid = posting_list.docid_postings[i] |
|
|
tf = posting_list.tweight_postings[i] |
|
|
dl = dls[docid] |
|
|
regularized_tf = BM25Index.calc_regularized_tf( |
|
|
tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b |
|
|
) |
|
|
posting_list.tweight_postings[i] = regularized_tf * idf |
|
|
|
|
|
@staticmethod |
|
|
def calc_regularized_tf( |
|
|
tf: int, dl: float, avgdl: float, k1: float, b: float |
|
|
) -> float: |
|
|
return tf / (tf + k1 * (1 - b + b * dl / avgdl)) |
|
|
|
|
|
@staticmethod |
|
|
def calc_idf(df: int, N: int): |
|
|
return math.log(1 + (N - df + 0.5) / (df + 0.5)) |
|
|
|
|
|
@classmethod |
|
|
def build_from_documents( |
|
|
cls: Type[BM25Index], |
|
|
documents: Iterable[Document], |
|
|
store_raw: bool = True, |
|
|
output_dir: Optional[str] = None, |
|
|
ndocs: Optional[int] = None, |
|
|
show_progress_bar: bool = True, |
|
|
k1: float = 0.9, |
|
|
b: float = 0.4, |
|
|
) -> BM25Index: |
|
|
|
|
|
counting = run_counting( |
|
|
documents=documents, |
|
|
tokenize_fn=BM25Index.tokenize, |
|
|
store_raw=store_raw, |
|
|
ndocs=ndocs, |
|
|
show_progress_bar=show_progress_bar, |
|
|
) |
|
|
|
|
|
|
|
|
posting_lists = counting.posting_lists |
|
|
total_docs = len(counting.cid2docid) |
|
|
BM25Index.cache_term_weights( |
|
|
posting_lists=posting_lists, |
|
|
total_docs=total_docs, |
|
|
avgdl=counting.avgdl, |
|
|
dfs=counting.dfs, |
|
|
dls=counting.dls, |
|
|
k1=k1, |
|
|
b=b, |
|
|
) |
|
|
|
|
|
|
|
|
index = BM25Index( |
|
|
posting_lists=posting_lists, |
|
|
vocab=counting.vocab, |
|
|
cid2docid=counting.cid2docid, |
|
|
collection_ids=counting.collection_ids, |
|
|
doc_texts=counting.doc_texts, |
|
|
) |
|
|
return index |
|
|
|
|
|
bm25_index = BM25Index.build_from_documents( |
|
|
documents=iter(sciq.corpus), |
|
|
ndocs=12160, |
|
|
show_progress_bar=True, |
|
|
) |
|
|
bm25_index.save("output/bm25_index") |
|
|
|
|
|
|
|
|
"""### BM25 Retriever""" |
|
|
|
|
|
from nlp4web_codebase.ir.models import BaseRetriever |
|
|
from typing import Type |
|
|
from abc import abstractmethod |
|
|
|
|
|
|
|
|
class BaseInvertedIndexRetriever(BaseRetriever): |
|
|
|
|
|
@property |
|
|
@abstractmethod |
|
|
def index_class(self) -> Type[InvertedIndex]: |
|
|
pass |
|
|
|
|
|
def __init__(self, index_dir: str) -> None: |
|
|
self.index = self.index_class.from_saved(index_dir) |
|
|
|
|
|
def get_term_weights(self, query: str, cid: str) -> Dict[str, float]: |
|
|
toks = self.index.tokenize(query) |
|
|
target_docid = self.index.cid2docid[cid] |
|
|
term_weights = {} |
|
|
for tok in toks: |
|
|
if tok not in self.index.vocab: |
|
|
continue |
|
|
tid = self.index.vocab[tok] |
|
|
posting_list = self.index.posting_lists[tid] |
|
|
for docid, tweight in zip( |
|
|
posting_list.docid_postings, posting_list.tweight_postings |
|
|
): |
|
|
if docid == target_docid: |
|
|
term_weights[tok] = tweight |
|
|
break |
|
|
return term_weights |
|
|
|
|
|
def score(self, query: str, cid: str) -> float: |
|
|
return sum(self.get_term_weights(query=query, cid=cid).values()) |
|
|
|
|
|
def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]: |
|
|
toks = self.index.tokenize(query) |
|
|
docid2score: Dict[int, float] = {} |
|
|
for tok in toks: |
|
|
if tok not in self.index.vocab: |
|
|
continue |
|
|
tid = self.index.vocab[tok] |
|
|
posting_list = self.index.posting_lists[tid] |
|
|
for docid, tweight in zip( |
|
|
posting_list.docid_postings, posting_list.tweight_postings |
|
|
): |
|
|
docid2score.setdefault(docid, 0) |
|
|
docid2score[docid] += tweight |
|
|
docid2score = dict( |
|
|
sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk] |
|
|
) |
|
|
return { |
|
|
self.index.collection_ids[docid]: score |
|
|
for docid, score in docid2score.items() |
|
|
} |
|
|
|
|
|
|
|
|
class BM25Retriever(BaseInvertedIndexRetriever): |
|
|
|
|
|
@property |
|
|
def index_class(self) -> Type[BM25Index]: |
|
|
return BM25Index |
|
|
|
|
|
bm25_retriever = BM25Retriever(index_dir="output/bm25_index") |
|
|
bm25_retriever.retrieve("What type of diseases occur when the immune system attacks normal body cells?") |
|
|
|
|
|
"""# TASK1: tune b and k1 (4 points) |
|
|
|
|
|
Tune b and k1 on the **dev** split of SciQ using the metric MAP@10. The evaluation function (`evalaute_map`) is provided. Record the values in `plots_k1` and `plots_b`. Do it in a greedy manner: as the influence from b is larger, please first tune b (with k1 fixed to the default value 0.9) and use the best value of b to further tune k1. |
|
|
|
|
|
$${\displaystyle {\text{score}}(D,Q)=\sum _{i=1}^{n}{\text{IDF}}(q_{i})\cdot {\frac {f(q_{i},D)\cdot (k_{1}+1)}{f(q_{i},D)+k_{1}\cdot \left(1-b+b\cdot {\frac {|D|}{\text{avgdl}}}\right)}}}$$ |
|
|
""" |
|
|
|
|
|
from nlp4web_codebase.ir.data_loaders import Split |
|
|
import pytrec_eval |
|
|
|
|
|
|
|
|
def evaluate_map(rankings: Dict[str, Dict[str, float]], split=Split.dev) -> float: |
|
|
metric = "map_cut_10" |
|
|
qrels = sciq.get_qrels_dict(split) |
|
|
evaluator = pytrec_eval.RelevanceEvaluator(sciq.get_qrels_dict(split), (metric,)) |
|
|
qps = evaluator.evaluate(rankings) |
|
|
return float(np.mean([qp[metric] for qp in qps.values()])) |
|
|
|
|
|
"""Example of using the pre-requisite code:""" |
|
|
|
|
|
|
|
|
from nlp4web_codebase.ir.data_loaders.sciq import load_sciq |
|
|
sciq = load_sciq() |
|
|
counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus)) |
|
|
|
|
|
|
|
|
bm25_index = BM25Index.build_from_documents( |
|
|
documents=iter(sciq.corpus), |
|
|
ndocs=12160, |
|
|
show_progress_bar=True |
|
|
) |
|
|
bm25_index.save("output/bm25_index") |
|
|
|
|
|
|
|
|
bm25_retriever = BM25Retriever(index_dir="output/bm25_index") |
|
|
print(bm25_retriever.retrieve("What type of diseases occur when the immune system attacks normal body cells?")) |
|
|
|
|
|
import tqdm |
|
|
import numpy as np |
|
|
|
|
|
plots_b: Dict[str, List[float]] = { |
|
|
"X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], |
|
|
"Y": [] |
|
|
} |
|
|
plots_k1: Dict[str, List[float]] = { |
|
|
"X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], |
|
|
"Y": [] |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result = {} |
|
|
best_b = 0 |
|
|
best_k1 = 0 |
|
|
for b in plots_b["X"]: |
|
|
bm25_index = BM25Index.build_from_documents( |
|
|
documents=iter(sciq.corpus), |
|
|
ndocs=12160, |
|
|
show_progress_bar=True, |
|
|
k1=0.9, |
|
|
b=b |
|
|
) |
|
|
bm25_index.save("output/bm25_index") |
|
|
bm25_retriever = BM25Retriever(index_dir="output/bm25_index") |
|
|
for query in sciq.get_split_queries(Split.dev): |
|
|
result[query.query_id]=bm25_retriever.retrieve(query.text) |
|
|
|
|
|
if best_b < evaluate_map(result): |
|
|
best_b = evaluate_map(result) |
|
|
plots_b["Y"].append(evaluate_map(result)) |
|
|
|
|
|
for k1 in plots_k1["X"]: |
|
|
bm25_index = BM25Index.build_from_documents( |
|
|
documents=iter(sciq.corpus), |
|
|
ndocs=12160, |
|
|
show_progress_bar=True, |
|
|
k1=k1, |
|
|
b=best_b |
|
|
) |
|
|
bm25_index.save("output/bm25_index") |
|
|
bm25_retriever = BM25Retriever(index_dir="output/bm25_index") |
|
|
for query in sciq.get_split_queries(Split.dev): |
|
|
result[query.query_id]=bm25_retriever.retrieve(query.text) |
|
|
|
|
|
if best_k1 < evaluate_map(result): |
|
|
best_k1 = evaluate_map(result) |
|
|
plots_k1["Y"].append(evaluate_map(result)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(plots_k1["Y"][9]) |
|
|
print(plots_b["Y"][1]) |
|
|
|
|
|
|
|
|
print(plots_k1) |
|
|
print(plots_b) |
|
|
|
|
|
from matplotlib import pyplot as plt |
|
|
plt.plot(plots_b["X"], plots_b["Y"], label="b") |
|
|
plt.plot(plots_k1["X"], plots_k1["Y"], label="k1") |
|
|
plt.ylabel("MAP") |
|
|
plt.legend() |
|
|
plt.grid() |
|
|
plt.show() |
|
|
|
|
|
"""Let's check the effectiveness gain on test after this tuning on dev""" |
|
|
|
|
|
default_map = 0.7849 |
|
|
best_b = plots_b["X"][np.argmax(plots_b["Y"])] |
|
|
best_k1 = plots_k1["X"][np.argmax(plots_k1["Y"])] |
|
|
bm25_index = BM25Index.build_from_documents( |
|
|
documents=iter(sciq.corpus), |
|
|
ndocs=12160, |
|
|
show_progress_bar=True, |
|
|
k1=best_k1, |
|
|
b=best_b |
|
|
) |
|
|
bm25_index.save("output/bm25_index") |
|
|
bm25_retriever = BM25Retriever(index_dir="output/bm25_index") |
|
|
rankings = {} |
|
|
for query in sciq.get_split_queries(Split.test): |
|
|
ranking = bm25_retriever.retrieve(query=query.text) |
|
|
rankings[query.query_id] = ranking |
|
|
optimized_map = evaluate_map(rankings, split=Split.test) |
|
|
print(default_map, optimized_map) |
|
|
|
|
|
"""# TASK2: CSC matrix and `CSCBM25Index` (12 points) |
|
|
|
|
|
Recall that we use Python lists to implement posting lists, mapping term IDs to the documents in which they appear. This is inefficient due to its naive design. Actually [Compressed Sparse Column matrix](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csc_matrix.html) is very suitable for storing the posting lists and can boost the efficiency. |
|
|
|
|
|
## TASK2.1: learn about `scipy.sparse.csc_matrix` (2 point) |
|
|
|
|
|
Convert the matrix \begin{bmatrix} |
|
|
0 & 1 & 0 & 3 \\ |
|
|
10 & 2 & 1 & 0 \\ |
|
|
0 & 0 & 0 & 9 |
|
|
\end{bmatrix} to a `csc_matrix` by specifying `data`, `indices`, `indptr` and `shape`. |
|
|
""" |
|
|
|
|
|
from scipy.sparse._csc import csc_matrix |
|
|
input_matrix = [[0, 1, 0, 3], [10, 2, 1, 0], [0, 0, 0, 9]] |
|
|
data = None |
|
|
indices = None |
|
|
indptr = None |
|
|
shape = None |
|
|
|
|
|
|
|
|
|
|
|
data = [10, 1, 2, 1, 3, 9] |
|
|
indices = [1, 0, 1, 2, 0, 2] |
|
|
indptr = [0, 1, 3, 4, 6] |
|
|
shape = (3, 4) |
|
|
|
|
|
output_matrix = csc_matrix((data, indices, indptr), shape=shape) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print((output_matrix.indices + output_matrix.data).tolist()[2]) |
|
|
print((output_matrix.indices + output_matrix.data).tolist()[-1]) |
|
|
|
|
|
|
|
|
print((output_matrix.indices + output_matrix.data).tolist()) |
|
|
|
|
|
"""## TASK2.2: implement `CSCBM25Index` (4 points) |
|
|
|
|
|
Implement `CSCBM25Index` by completing the missing code. Note that `CSCInvertedIndex` is similar to `InvertedIndex` which we talked about during the class. The main difference is posting lists are represented by a CSC sparse matrix. |
|
|
""" |
|
|
|
|
|
@dataclass |
|
|
class CSCInvertedIndex: |
|
|
posting_lists_matrix: csc_matrix |
|
|
vocab: Dict[str, int] |
|
|
cid2docid: Dict[str, int] |
|
|
collection_ids: List[str] |
|
|
doc_texts: Optional[List[str]] = None |
|
|
|
|
|
def save(self, output_dir: str) -> None: |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
with open(os.path.join(output_dir, "index.pkl"), "wb") as f: |
|
|
pickle.dump(self, f) |
|
|
|
|
|
@classmethod |
|
|
def from_saved(cls: Type[T], saved_dir: str) -> T: |
|
|
index = cls( |
|
|
posting_lists_matrix=None, vocab={}, cid2docid={}, collection_ids=[], doc_texts=None |
|
|
) |
|
|
with open(os.path.join(saved_dir, "index.pkl"), "rb") as f: |
|
|
index = pickle.load(f) |
|
|
return index |
|
|
|
|
|
def convert_to_csc(posting_lists, num_docs): |
|
|
data = [] |
|
|
|
|
|
row = [] |
|
|
column=[] |
|
|
for col_idx, posting in enumerate(posting_lists): |
|
|
|
|
|
for i in range(len(posting.docid_postings)): |
|
|
data.append(posting.tweight_postings[i]) |
|
|
column.append(posting.docid_postings[i]) |
|
|
row.append(col_idx) |
|
|
|
|
|
return csc_matrix((data, (row, column)), shape=(len(posting_lists),num_docs ), dtype=np.float32) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class CSCBM25Index(CSCInvertedIndex): |
|
|
|
|
|
@staticmethod |
|
|
def tokenize(text: str) -> List[str]: |
|
|
return simple_tokenize(text) |
|
|
|
|
|
@staticmethod |
|
|
def cache_term_weights( |
|
|
posting_lists: List[PostingList], |
|
|
total_docs: int, |
|
|
avgdl: float, |
|
|
dfs: List[int], |
|
|
dls: List[int], |
|
|
k1: float, |
|
|
b: float, |
|
|
) -> csc_matrix: |
|
|
"""Compute term weights and caching""" |
|
|
N = total_docs |
|
|
|
|
|
for tid, posting_list in enumerate( |
|
|
tqdm.tqdm(posting_lists, desc="Regularizing TFs") |
|
|
): |
|
|
idf = CSCBM25Index.calc_idf(df=dfs[tid], N=N) |
|
|
for i in range(len(posting_list.docid_postings)): |
|
|
docid = posting_list.docid_postings[i] |
|
|
tf = posting_list.tweight_postings[i] |
|
|
dl = dls[docid] |
|
|
regularized_tf = CSCBM25Index.calc_regularized_tf( |
|
|
tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b |
|
|
) |
|
|
posting_list.tweight_postings[i] = regularized_tf * idf |
|
|
|
|
|
return convert_to_csc(posting_lists, N) |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def calc_regularized_tf( |
|
|
tf: int, dl: float, avgdl: float, k1: float, b: float |
|
|
) -> float: |
|
|
return tf / (tf + k1 * (1 - b + b * dl / avgdl)) |
|
|
|
|
|
@staticmethod |
|
|
def calc_idf(df: int, N: int): |
|
|
return math.log(1 + (N - df + 0.5) / (df + 0.5)) |
|
|
|
|
|
@classmethod |
|
|
def build_from_documents( |
|
|
cls: Type[CSCBM25Index], |
|
|
documents: Iterable[Document], |
|
|
store_raw: bool = True, |
|
|
output_dir: Optional[str] = None, |
|
|
ndocs: Optional[int] = None, |
|
|
show_progress_bar: bool = True, |
|
|
k1: float = 0.9, |
|
|
b: float = 0.4, |
|
|
) -> CSCBM25Index: |
|
|
|
|
|
counting = run_counting( |
|
|
documents=documents, |
|
|
tokenize_fn=CSCBM25Index.tokenize, |
|
|
store_raw=store_raw, |
|
|
ndocs=ndocs, |
|
|
show_progress_bar=show_progress_bar, |
|
|
) |
|
|
|
|
|
|
|
|
posting_lists = counting.posting_lists |
|
|
total_docs = len(counting.cid2docid) |
|
|
posting_lists_matrix = CSCBM25Index.cache_term_weights( |
|
|
posting_lists=posting_lists, |
|
|
total_docs=total_docs, |
|
|
avgdl=counting.avgdl, |
|
|
dfs=counting.dfs, |
|
|
dls=counting.dls, |
|
|
k1=k1, |
|
|
b=b, |
|
|
) |
|
|
|
|
|
|
|
|
index = CSCBM25Index( |
|
|
posting_lists_matrix=posting_lists_matrix, |
|
|
vocab=counting.vocab, |
|
|
cid2docid=counting.cid2docid, |
|
|
collection_ids=counting.collection_ids, |
|
|
doc_texts=counting.doc_texts, |
|
|
) |
|
|
return index |
|
|
|
|
|
csc_bm25_index = CSCBM25Index.build_from_documents( |
|
|
documents=iter(sciq.corpus), |
|
|
ndocs=12160, |
|
|
show_progress_bar=True, |
|
|
k1=0.9, |
|
|
b=0.8 |
|
|
) |
|
|
csc_bm25_index.save("output/csc_bm25_index") |
|
|
|
|
|
print(len(str(os.path.getsize("output/csc_bm25_index/index.pkl")))) |
|
|
print(os.path.getsize("output/csc_bm25_index/index.pkl") // int(1e5)) |
|
|
|
|
|
|
|
|
print(len(str(os.path.getsize("output/csc_bm25_index/index.pkl")))) |
|
|
print(os.path.getsize("output/csc_bm25_index/index.pkl") // int(1e5)) |
|
|
|
|
|
|
|
|
print(os.path.getsize("output/csc_bm25_index/index.pkl")) |
|
|
|
|
|
"""We can compare the size of the CSC-based index with the Python-list-based index:""" |
|
|
|
|
|
print(os.path.getsize("output/bm25_index/index.pkl")) |
|
|
|
|
|
"""## TASK2.3: implement `CSCInvertedIndexRetriever` (6 points) |
|
|
|
|
|
Implement `CSCInvertedIndexRetriever` by completing the missing code. |
|
|
""" |
|
|
|
|
|
from nlp4web_codebase.ir.models import BaseRetriever |
|
|
from typing import Type |
|
|
from abc import abstractmethod |
|
|
|
|
|
|
|
|
class BaseInvertedIndexRetriever(BaseRetriever): |
|
|
|
|
|
@property |
|
|
@abstractmethod |
|
|
def index_class(self) -> Type[InvertedIndex]: |
|
|
pass |
|
|
|
|
|
def __init__(self, index_dir: str) -> None: |
|
|
self.index = self.index_class.from_saved(index_dir) |
|
|
|
|
|
def get_term_weights(self, query: str, cid: str) -> Dict[str, float]: |
|
|
toks = self.index.tokenize(query) |
|
|
target_docid = self.index.cid2docid[cid] |
|
|
term_weights = {} |
|
|
for tok in toks: |
|
|
if tok not in self.index.vocab: |
|
|
continue |
|
|
tid = self.index.vocab[tok] |
|
|
posting_list = self.index.posting_lists[tid] |
|
|
for docid, tweight in zip( |
|
|
posting_list.docid_postings, posting_list.tweight_postings |
|
|
): |
|
|
if docid == target_docid: |
|
|
term_weights[tok] = tweight |
|
|
break |
|
|
return term_weights |
|
|
|
|
|
def score(self, query: str, cid: str) -> float: |
|
|
return sum(self.get_term_weights(query=query, cid=cid).values()) |
|
|
|
|
|
def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]: |
|
|
toks = self.index.tokenize(query) |
|
|
docid2score: Dict[int, float] = {} |
|
|
for tok in toks: |
|
|
if tok not in self.index.vocab: |
|
|
continue |
|
|
tid = self.index.vocab[tok] |
|
|
posting_list = self.index.posting_lists[tid] |
|
|
for docid, tweight in zip( |
|
|
posting_list.docid_postings, posting_list.tweight_postings |
|
|
): |
|
|
docid2score.setdefault(docid, 0) |
|
|
docid2score[docid] += tweight |
|
|
docid2score = dict( |
|
|
sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk] |
|
|
) |
|
|
return { |
|
|
self.index.collection_ids[docid]: score |
|
|
for docid, score in docid2score.items() |
|
|
} |
|
|
|
|
|
|
|
|
class BM25Retriever(BaseInvertedIndexRetriever): |
|
|
|
|
|
@property |
|
|
def index_class(self) -> Type[BM25Index]: |
|
|
return BM25Index |
|
|
|
|
|
bm25_retriever = BM25Retriever(index_dir="output/bm25_index") |
|
|
bm25_retriever.retrieve("What type of diseases occur when the immune system attacks normal body cells?") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bm25_retriever = BM25Retriever(index_dir="output/bm25_index") |
|
|
query = "What type of diseases occur when the immune system attacks normal body cells?" |
|
|
print(bm25_retriever.get_term_weights(query=query, cid="train-2006")) |
|
|
print(bm25_retriever.retrieve(query)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseCSCInvertedIndexRetriever(BaseRetriever): |
|
|
|
|
|
@property |
|
|
@abstractmethod |
|
|
def index_class(self) -> Type[CSCInvertedIndex]: |
|
|
pass |
|
|
|
|
|
def __init__(self, index_dir: str) -> None: |
|
|
self.index = self.index_class.from_saved(index_dir) |
|
|
|
|
|
def get_term_weights(self, query: str, cid: str) -> Dict[str, float]: |
|
|
""" |
|
|
Retrieve term weights for a specific query and collection ID using CSC matrix. |
|
|
""" |
|
|
toks = self.index.tokenize(query) |
|
|
target_docid = self.index.cid2docid.get(cid, None) |
|
|
if target_docid is None: |
|
|
return {} |
|
|
term_weights = {} |
|
|
for tok in toks: |
|
|
if tok not in self.index.vocab: |
|
|
continue |
|
|
tid = self.index.vocab[tok] |
|
|
|
|
|
tweight = self.index.posting_lists_matrix[tid, target_docid] |
|
|
if tweight != 0: |
|
|
term_weights[tok] = tweight |
|
|
return term_weights |
|
|
|
|
|
def score(self, query: str, cid: str) -> float: |
|
|
return sum(self.get_term_weights(query=query, cid=cid).values()) |
|
|
|
|
|
def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]: |
|
|
""" |
|
|
Retrieve top-k documents for a given query using the CSC-based BM25 index. |
|
|
""" |
|
|
toks = self.index.tokenize(query) |
|
|
docid2score: Dict[int, float] = {} |
|
|
for tok in toks: |
|
|
if tok not in self.index.vocab: |
|
|
continue |
|
|
tid = self.index.vocab[tok] |
|
|
|
|
|
term_vector = self.index.posting_lists_matrix.getrow(tid) |
|
|
|
|
|
for docid, tweight in zip(term_vector.indices, term_vector.data): |
|
|
docid2score.setdefault(docid, 0) |
|
|
docid2score[docid] += tweight |
|
|
|
|
|
top_docs = sorted(docid2score.items(), key=lambda x: x[1], reverse=True)[:topk] |
|
|
return { |
|
|
self.index.collection_ids[docid]: score |
|
|
for docid, score in top_docs |
|
|
} |
|
|
|
|
|
class CSCBM25Retriever(BaseCSCInvertedIndexRetriever): |
|
|
|
|
|
@property |
|
|
def index_class(self) -> Type[CSCBM25Index]: |
|
|
return CSCBM25Index |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
csc_bm25_retriever = CSCBM25Retriever(index_dir="output/csc_bm25_index") |
|
|
query = "Who proposed the theory of evolution by natural selection?" |
|
|
print(csc_bm25_retriever.get_term_weights(query=query, cid="train-2006")) |
|
|
print(csc_bm25_retriever.retrieve(query)) |
|
|
|
|
|
|
|
|
csc_bm25_retriever = CSCBM25Retriever(index_dir="output/csc_bm25_index") |
|
|
query = "What are the differences between immunodeficiency and autoimmune diseases?" |
|
|
print(csc_bm25_retriever.get_term_weights(query=query, cid="train-1691")) |
|
|
print(csc_bm25_retriever.retrieve("What are the differences between immunodeficiency and autoimmune diseases?")) |
|
|
|
|
|
"""# TASK3: a search-engine demo based on Huggingface space (4 points) |
|
|
|
|
|
## TASK3.1: create the gradio app (2 point) |
|
|
|
|
|
Create a gradio app to demo the BM25 search engine index on SciQ. The app should have a single input variable for the query (of type `str`) and a single output variable for the returned ranking (of type `List[Hit]` in the code below). Please use the BM25 system with default k1 and b values. |
|
|
|
|
|
Hint: it should use a "search" function of signature: |
|
|
|
|
|
```python |
|
|
def search(query: str) -> List[Hit]: |
|
|
... |
|
|
``` |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
from typing import TypedDict |
|
|
|
|
|
class Hit(TypedDict): |
|
|
cid: str |
|
|
score: float |
|
|
text: str |
|
|
|
|
|
demo: Optional[gr.Interface] = None |
|
|
return_type = List[Hit] |
|
|
|
|
|
|
|
|
def search(query: str) -> List[Hit]: |
|
|
|
|
|
rankings = bm25_retriever.retrieve(query) |
|
|
|
|
|
hits = [] |
|
|
for cid, score in rankings.items(): |
|
|
|
|
|
docid = bm25_retriever.index.cid2docid.get(cid, None) |
|
|
if docid is not None: |
|
|
|
|
|
text = bm25_retriever.index.doc_texts[docid] if bm25_retriever.index.doc_texts else "Text not available." |
|
|
hits.append({ |
|
|
"cid": cid, |
|
|
"score": score, |
|
|
"text": text |
|
|
}) |
|
|
return hits |
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=search, |
|
|
inputs=gr.Textbox(lines=2, placeholder="Enter your query here...", label="Search Query"), |
|
|
outputs=gr.Textbox(label="Search Results", lines=10), |
|
|
title="BM25 Search Engine", |
|
|
description=""" |
|
|
BM25 |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
demo.launch() |
|
|
|
|
|
|
|
|
import requests |
|
|
import json |
|
|
|
|
|
headers = {"Content-Type": "application/json"} |
|
|
data = {"data": ["What type of organism is commonly used in preparation of foods such as cheese and yogurt?"]} |
|
|
response = requests.post(f"{demo.local_api_url.strip('/')}/call/predict", headers=headers, data=json.dumps(data)) |
|
|
event_id = response.json()["event_id"] |
|
|
response = requests.get(f"{demo.local_api_url.strip('/')}/call/predict/{event_id}", stream=True) |
|
|
lines = list(response.iter_lines()) |
|
|
print(eval(json.loads(lines[1].decode("UTF-8").replace("data:", ""))[0])) |
|
|
|
|
|
|
|
|
import requests |
|
|
import json |
|
|
|
|
|
headers = {"Content-Type": "application/json"} |
|
|
data = {"data": ["What are the differences between immunodeficiency and autoimmune diseases?"]} |
|
|
response = requests.post(f"{demo.local_api_url.strip('/')}/call/predict", headers=headers, data=json.dumps(data)) |
|
|
event_id = response.json()["event_id"] |
|
|
response = requests.get(f"{demo.local_api_url.strip('/')}/call/predict/{event_id}", stream=True) |
|
|
lines = list(response.iter_lines()) |
|
|
print(eval(json.loads(lines[1].decode("UTF-8").replace("data:", ""))[0])) |
|
|
|
|
|
"""## TASK3.2: upload it to Huggingface Space (2 point) |
|
|
|
|
|
Upload your gradio app to Huggingface Space. Put your URL to the Space app in the variable `hf_space_url`. |
|
|
|
|
|
IMPORTANT!!! You can get this URL from: |
|
|
|
|
|
*Your Space page* -> *"three dots" on the top right* -> "embedd this space" -> "Direct URL" |
|
|
|
|
|
An example URL (not for our task) is: https://stabilityai-stable-diffusion-3-5-large.hf.space (from https://huggingface.co/spaces/stabilityai/stable-diffusion-3.5-large) |
|
|
""" |
|
|
|
|
|
hf_space_url: Optional[str] = None |
|
|
|
|
|
hf_space_url: Optional[str] = "https://intelava-nlp4web.hf.space" |
|
|
|
|
|
|
|
|
|
|
|
import requests |
|
|
import json |
|
|
|
|
|
print(hf_space_url) |
|
|
headers = {"Content-Type": "application/json"} |
|
|
data = {"data": ["What are the differences between immunodeficiency and autoimmune diseases?"]} |
|
|
response = requests.post(f"{hf_space_url.strip('/')}/call/predict", headers=headers, data=json.dumps(data)) |
|
|
event_id = response.json()["event_id"] |
|
|
response = requests.get(f"{hf_space_url.strip('/')}/call/predict/{event_id}", stream=True) |
|
|
lines = list(response.iter_lines()) |
|
|
print(eval(json.loads(lines[1].decode("UTF-8").replace("data:", ""))[0])) |
|
|
|
|
|
|
|
|
import requests |
|
|
import json |
|
|
|
|
|
headers = {"Content-Type": "application/json"} |
|
|
data = {"data": ["Changes from a less-ordered state to a more-ordered state (such as a liquid to a solid) are always what?"]} |
|
|
response = requests.post(f"{hf_space_url.strip('/')}/call/predict", headers=headers, data=json.dumps(data)) |
|
|
event_id = response.json()["event_id"] |
|
|
response = requests.get(f"{hf_space_url.strip('/')}/call/predict/{event_id}", stream=True) |
|
|
lines = list(response.iter_lines()) |
|
|
print(eval(json.loads(lines[1].decode("UTF-8").replace("data:", ""))[0])) |
|
|
|
|
|
|