|
|
import json |
|
|
import time |
|
|
import os |
|
|
|
|
|
import sqlite3 |
|
|
import numpy as np |
|
|
import pickle as pkl |
|
|
|
|
|
from rank_bm25 import BM25Okapi |
|
|
|
|
|
SPECIAL_SEPARATOR = "####SPECIAL####SEPARATOR####" |
|
|
MAX_LENGTH = 256 |
|
|
|
|
|
class DocDB(object): |
|
|
"""Sqlite backed document storage. |
|
|
|
|
|
Implements get_doc_text(doc_id). |
|
|
""" |
|
|
|
|
|
def __init__(self, db_path=None, data_path=None, cache_path=None): |
|
|
self.db_path = db_path |
|
|
self.cache_file = cache_path |
|
|
self.connection = sqlite3.connect(self.db_path, check_same_thread=False) |
|
|
|
|
|
self.cache_dict = self.load_cache() |
|
|
|
|
|
cursor = self.connection.cursor() |
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") |
|
|
|
|
|
if len(cursor.fetchall())==0: |
|
|
assert data_path is not None, f"{self.db_path} is empty. Specify `data_path` in order to create a DB." |
|
|
print (f"{self.db_path} is empty. start building DB from {data_path}...") |
|
|
self.build_db(self.db_path, data_path) |
|
|
|
|
|
def load_cache(self, allow_retry=True): |
|
|
if os.path.exists(self.cache_file): |
|
|
while True: |
|
|
try: |
|
|
with open(self.cache_file, "rb") as f: |
|
|
cache = pkl.load(f) |
|
|
break |
|
|
except Exception: |
|
|
if not allow_retry: |
|
|
assert False |
|
|
print ("Pickle Error: Retry in 5sec...") |
|
|
time.sleep(5) |
|
|
elif 's3' in self.cache_file: |
|
|
from aws_utils import s3_open |
|
|
s3_path = self.cache_file.removeprefix('s3://') |
|
|
bucket_name = s3_path.split('/')[0] |
|
|
path_to_file = '/'.join(s3_path.split('/')[1:]) |
|
|
with s3_open(bucket_name, path_to_file) as fp: |
|
|
cache = pkl.load(fp) |
|
|
else: |
|
|
cache = {} |
|
|
return cache |
|
|
|
|
|
def save_cache(self): |
|
|
|
|
|
for k, v in self.load_cache().items(): |
|
|
self.cache_dict[k] = v |
|
|
|
|
|
with open(self.cache_file, "wb") as f: |
|
|
pkl.dump(self.cache_dict, f) |
|
|
|
|
|
def __enter__(self): |
|
|
return self |
|
|
|
|
|
def __exit__(self, *args): |
|
|
self.close() |
|
|
|
|
|
def path(self): |
|
|
"""Return the path to the file that backs this database.""" |
|
|
return self.path |
|
|
|
|
|
def close(self): |
|
|
"""Close the connection to the database.""" |
|
|
self.connection.close() |
|
|
|
|
|
def build_db(self, db_path, data_path): |
|
|
from transformers import RobertaTokenizer |
|
|
tokenizer = RobertaTokenizer.from_pretrained("roberta-large") |
|
|
|
|
|
titles = set() |
|
|
output_lines = [] |
|
|
tot = 0 |
|
|
start_time = time.time() |
|
|
c = self.connection.cursor() |
|
|
c.execute("CREATE TABLE documents (title PRIMARY KEY, text);") |
|
|
|
|
|
with open(data_path, "r") as f: |
|
|
for line in f: |
|
|
dp = json.loads(line) |
|
|
title = dp["title"] |
|
|
text = dp["text"] |
|
|
if title in titles: |
|
|
continue |
|
|
titles.add(title) |
|
|
if type(text)==str: |
|
|
text = [text] |
|
|
passages = [[]] |
|
|
for sent_idx, sent in enumerate(text): |
|
|
assert len(sent.strip())>0 |
|
|
tokens = tokenizer(sent)["input_ids"] |
|
|
max_length = MAX_LENGTH - len(passages[-1]) |
|
|
if len(tokens) <= max_length: |
|
|
passages[-1].extend(tokens) |
|
|
else: |
|
|
passages[-1].extend(tokens[:max_length]) |
|
|
offset = max_length |
|
|
while offset < len(tokens): |
|
|
passages.append(tokens[offset:offset+MAX_LENGTH]) |
|
|
offset += MAX_LENGTH |
|
|
|
|
|
psgs = [tokenizer.decode(tokens) for tokens in passages if np.sum([t not in [0, 2] for t in tokens])>0] |
|
|
text = SPECIAL_SEPARATOR.join(psgs) |
|
|
output_lines.append((title, text)) |
|
|
tot += 1 |
|
|
|
|
|
if len(output_lines) == 1000000: |
|
|
c.executemany("INSERT INTO documents VALUES (?,?)", output_lines) |
|
|
output_lines = [] |
|
|
print ("Finish saving %dM documents (%dmin)" % (tot / 1000000, (time.time()-start_time)/60)) |
|
|
|
|
|
if len(output_lines) > 0: |
|
|
c.executemany("INSERT INTO documents VALUES (?,?)", output_lines) |
|
|
print ("Finish saving %dM documents (%dmin)" % (tot / 1000000, (time.time()-start_time)/60)) |
|
|
|
|
|
self.connection.commit() |
|
|
self.connection.close() |
|
|
|
|
|
def get_text_from_title(self, title): |
|
|
"""Fetch the raw text of the doc for 'doc_id'.""" |
|
|
with open('data/wiki_corrections.txt') as fp: |
|
|
all_names = fp.readlines() |
|
|
all_names = [n.strip() for n in all_names] |
|
|
name_converter = {names.split('=')[0]:names.split('=')[1] for names in all_names} |
|
|
if title in name_converter: |
|
|
title = name_converter[title] |
|
|
|
|
|
if title in self.cache_dict: |
|
|
results = self.cache_dict[title] |
|
|
else: |
|
|
print("I SHOULD NOT BE HERE.") |
|
|
cursor = self.connection.cursor() |
|
|
cursor.execute("SELECT text FROM documents WHERE title = ?", (title,)) |
|
|
results = cursor.fetchall() |
|
|
results = [r for r in results] |
|
|
cursor.close() |
|
|
try: |
|
|
assert results is not None and len(results)==1, f"`topic` in your data ({title}) is likely to be not a valid title in the DB." |
|
|
except Exception: |
|
|
print (f"Retrieval error for {title}: Retry in 5sec...") |
|
|
|
|
|
cursor = self.connection.cursor() |
|
|
cursor.execute("SELECT text FROM documents WHERE title = ?", (title,)) |
|
|
results = cursor.fetchall() |
|
|
results = [r for r in results] |
|
|
results = [['blah blah blah']] |
|
|
cursor.close() |
|
|
results = [{"title": title, "text": para} for para in results[0][0].split(SPECIAL_SEPARATOR)] |
|
|
assert len(results)>0, f"`topic` in your data ({title}) is likely to be not a valid title in the DB." |
|
|
self.cache_dict[title] = results |
|
|
return results |
|
|
|
|
|
class Retrieval(object): |
|
|
|
|
|
def __init__(self, db, cache_path, embed_cache_path, |
|
|
retrieval_type="gtr-t5-large", batch_size=None): |
|
|
self.db = db |
|
|
self.cache_path = cache_path |
|
|
self.embed_cache_path = embed_cache_path |
|
|
self.retrieval_type = retrieval_type |
|
|
self.batch_size = batch_size |
|
|
assert retrieval_type=="bm25" or retrieval_type.startswith("gtr-") |
|
|
|
|
|
self.encoder = None |
|
|
self.load_cache() |
|
|
self.add_n = 0 |
|
|
self.add_n_embed = 0 |
|
|
|
|
|
def load_encoder(self): |
|
|
from sentence_transformers import SentenceTransformer |
|
|
encoder = SentenceTransformer("sentence-transformers/" + self.retrieval_type) |
|
|
encoder = encoder.cuda() |
|
|
encoder = encoder.eval() |
|
|
self.encoder = encoder |
|
|
assert self.batch_size is not None |
|
|
|
|
|
def load_cache(self): |
|
|
if os.path.exists(self.cache_path): |
|
|
with open(self.cache_path, "r") as f: |
|
|
self.cache = json.load(f) |
|
|
else: |
|
|
self.cache = {} |
|
|
if os.path.exists(self.embed_cache_path): |
|
|
with open(self.embed_cache_path, "rb") as f: |
|
|
self.embed_cache = pkl.load(f) |
|
|
else: |
|
|
self.embed_cache = {} |
|
|
|
|
|
def save_cache(self): |
|
|
if self.add_n > 0: |
|
|
if os.path.exists(self.cache_path): |
|
|
with open(self.cache_path, "r") as f: |
|
|
new_cache = json.load(f) |
|
|
self.cache.update(new_cache) |
|
|
|
|
|
with open(self.cache_path, "w") as f: |
|
|
json.dump(self.cache, f) |
|
|
|
|
|
if self.add_n_embed > 0: |
|
|
if os.path.exists(self.embed_cache_path): |
|
|
with open(self.embed_cache_path, "rb") as f: |
|
|
new_cache = pkl.load(f) |
|
|
self.embed_cache.update(new_cache) |
|
|
|
|
|
with open(self.embed_cache_path, "wb") as f: |
|
|
pkl.dump(self.embed_cache, f) |
|
|
|
|
|
def get_bm25_passages(self, topic, query, passages, k): |
|
|
if topic in self.embed_cache: |
|
|
bm25 = self.embed_cache[topic] |
|
|
else: |
|
|
bm25 = BM25Okapi([psg["text"].replace("<s>", "").replace("</s>", "").split() for psg in passages]) |
|
|
self.embed_cache[topic] = bm25 |
|
|
self.add_n_embed += 1 |
|
|
scores = bm25.get_scores(query.split()) |
|
|
indices = np.argsort(-scores)[:k] |
|
|
return [passages[i] for i in indices] |
|
|
|
|
|
def get_gtr_passages(self, topic, retrieval_query, passages, k): |
|
|
if self.encoder is None: |
|
|
self.load_encoder() |
|
|
if topic in self.embed_cache: |
|
|
passage_vectors = self.embed_cache[topic] |
|
|
else: |
|
|
inputs = [psg["title"] + " " + psg["text"].replace("<s>", "").replace("</s>", "") for psg in passages] |
|
|
passage_vectors = self.encoder.encode(inputs, batch_size=self.batch_size, device=self.encoder.device) |
|
|
self.embed_cache[topic] = passage_vectors |
|
|
self.add_n_embed += 1 |
|
|
query_vectors = self.encoder.encode([retrieval_query], |
|
|
batch_size=self.batch_size, |
|
|
device=self.encoder.device)[0] |
|
|
scores = np.inner(query_vectors, passage_vectors) |
|
|
indices = np.argsort(-scores)[:k] |
|
|
return [passages[i] for i in indices] |
|
|
|
|
|
def get_passages(self, topic, question, k): |
|
|
retrieval_query = topic + " " + question.strip() |
|
|
cache_key = topic + "#" + retrieval_query |
|
|
|
|
|
if cache_key not in self.cache: |
|
|
passages = self.db.get_text_from_title(topic) |
|
|
if self.retrieval_type=="bm25": |
|
|
self.cache[cache_key] = self.get_bm25_passages(topic, retrieval_query, passages, k) |
|
|
else: |
|
|
self.cache[cache_key] = self.get_gtr_passages(topic, retrieval_query, passages, k) |
|
|
assert len(self.cache[cache_key]) in [k, len(passages)] |
|
|
self.add_n += 1 |
|
|
|
|
|
|
|
|
return self.cache[cache_key] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|