KangjunNoh's picture
Upload 47 files
906e061 verified
import json
import time
import os
import sqlite3
import numpy as np
import pickle as pkl
from rank_bm25 import BM25Okapi
SPECIAL_SEPARATOR = "####SPECIAL####SEPARATOR####"
MAX_LENGTH = 256
class DocDB(object):
"""Sqlite backed document storage.
Implements get_doc_text(doc_id).
"""
def __init__(self, db_path=None, data_path=None, cache_path=None):
self.db_path = db_path
self.cache_file = cache_path
self.connection = sqlite3.connect(self.db_path, check_same_thread=False)
self.cache_dict = self.load_cache()
cursor = self.connection.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
if len(cursor.fetchall())==0:
assert data_path is not None, f"{self.db_path} is empty. Specify `data_path` in order to create a DB."
print (f"{self.db_path} is empty. start building DB from {data_path}...")
self.build_db(self.db_path, data_path)
def load_cache(self, allow_retry=True):
if os.path.exists(self.cache_file):
while True:
try:
with open(self.cache_file, "rb") as f:
cache = pkl.load(f)
break
except Exception: # if there are concurent processes, things can fail
if not allow_retry:
assert False
print ("Pickle Error: Retry in 5sec...")
time.sleep(5)
elif 's3' in self.cache_file:
from aws_utils import s3_open
s3_path = self.cache_file.removeprefix('s3://')
bucket_name = s3_path.split('/')[0]
path_to_file = '/'.join(s3_path.split('/')[1:])
with s3_open(bucket_name, path_to_file) as fp:
cache = pkl.load(fp)
else:
cache = {}
return cache
def save_cache(self):
# load the latest cache first, since if there were other processes running in parallel, cache might have been updated
for k, v in self.load_cache().items():
self.cache_dict[k] = v
with open(self.cache_file, "wb") as f:
pkl.dump(self.cache_dict, f)
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
def path(self):
"""Return the path to the file that backs this database."""
return self.path
def close(self):
"""Close the connection to the database."""
self.connection.close()
def build_db(self, db_path, data_path):
from transformers import RobertaTokenizer
tokenizer = RobertaTokenizer.from_pretrained("roberta-large")
titles = set()
output_lines = []
tot = 0
start_time = time.time()
c = self.connection.cursor()
c.execute("CREATE TABLE documents (title PRIMARY KEY, text);")
with open(data_path, "r") as f:
for line in f:
dp = json.loads(line)
title = dp["title"]
text = dp["text"]
if title in titles:
continue
titles.add(title)
if type(text)==str:
text = [text]
passages = [[]]
for sent_idx, sent in enumerate(text):
assert len(sent.strip())>0
tokens = tokenizer(sent)["input_ids"]
max_length = MAX_LENGTH - len(passages[-1])
if len(tokens) <= max_length:
passages[-1].extend(tokens)
else:
passages[-1].extend(tokens[:max_length])
offset = max_length
while offset < len(tokens):
passages.append(tokens[offset:offset+MAX_LENGTH])
offset += MAX_LENGTH
psgs = [tokenizer.decode(tokens) for tokens in passages if np.sum([t not in [0, 2] for t in tokens])>0]
text = SPECIAL_SEPARATOR.join(psgs)
output_lines.append((title, text))
tot += 1
if len(output_lines) == 1000000:
c.executemany("INSERT INTO documents VALUES (?,?)", output_lines)
output_lines = []
print ("Finish saving %dM documents (%dmin)" % (tot / 1000000, (time.time()-start_time)/60))
if len(output_lines) > 0:
c.executemany("INSERT INTO documents VALUES (?,?)", output_lines)
print ("Finish saving %dM documents (%dmin)" % (tot / 1000000, (time.time()-start_time)/60))
self.connection.commit()
self.connection.close()
def get_text_from_title(self, title):
"""Fetch the raw text of the doc for 'doc_id'."""
with open('data/wiki_corrections.txt') as fp:
all_names = fp.readlines()
all_names = [n.strip() for n in all_names]
name_converter = {names.split('=')[0]:names.split('=')[1] for names in all_names}
if title in name_converter:
title = name_converter[title]
if title in self.cache_dict:
results = self.cache_dict[title]
else:
print("I SHOULD NOT BE HERE.")
cursor = self.connection.cursor()
cursor.execute("SELECT text FROM documents WHERE title = ?", (title,))
results = cursor.fetchall()
results = [r for r in results]
cursor.close()
try:
assert results is not None and len(results)==1, f"`topic` in your data ({title}) is likely to be not a valid title in the DB."
except Exception: # if there are concurent processes, things can fail
print (f"Retrieval error for {title}: Retry in 5sec...")
# time.sleep(5)
cursor = self.connection.cursor()
cursor.execute("SELECT text FROM documents WHERE title = ?", (title,))
results = cursor.fetchall()
results = [r for r in results]
results = [['blah blah blah']]
cursor.close()
results = [{"title": title, "text": para} for para in results[0][0].split(SPECIAL_SEPARATOR)]
assert len(results)>0, f"`topic` in your data ({title}) is likely to be not a valid title in the DB."
self.cache_dict[title] = results
return results
class Retrieval(object):
def __init__(self, db, cache_path, embed_cache_path,
retrieval_type="gtr-t5-large", batch_size=None):
self.db = db
self.cache_path = cache_path
self.embed_cache_path = embed_cache_path
self.retrieval_type = retrieval_type
self.batch_size = batch_size
assert retrieval_type=="bm25" or retrieval_type.startswith("gtr-")
self.encoder = None
self.load_cache()
self.add_n = 0
self.add_n_embed = 0
def load_encoder(self):
from sentence_transformers import SentenceTransformer
encoder = SentenceTransformer("sentence-transformers/" + self.retrieval_type)
encoder = encoder.cuda()
encoder = encoder.eval()
self.encoder = encoder
assert self.batch_size is not None
def load_cache(self):
if os.path.exists(self.cache_path):
with open(self.cache_path, "r") as f:
self.cache = json.load(f)
else:
self.cache = {}
if os.path.exists(self.embed_cache_path):
with open(self.embed_cache_path, "rb") as f:
self.embed_cache = pkl.load(f)
else:
self.embed_cache = {}
def save_cache(self):
if self.add_n > 0:
if os.path.exists(self.cache_path):
with open(self.cache_path, "r") as f:
new_cache = json.load(f)
self.cache.update(new_cache)
with open(self.cache_path, "w") as f:
json.dump(self.cache, f)
if self.add_n_embed > 0:
if os.path.exists(self.embed_cache_path):
with open(self.embed_cache_path, "rb") as f:
new_cache = pkl.load(f)
self.embed_cache.update(new_cache)
with open(self.embed_cache_path, "wb") as f:
pkl.dump(self.embed_cache, f)
def get_bm25_passages(self, topic, query, passages, k):
if topic in self.embed_cache:
bm25 = self.embed_cache[topic]
else:
bm25 = BM25Okapi([psg["text"].replace("<s>", "").replace("</s>", "").split() for psg in passages])
self.embed_cache[topic] = bm25
self.add_n_embed += 1
scores = bm25.get_scores(query.split())
indices = np.argsort(-scores)[:k]
return [passages[i] for i in indices]
def get_gtr_passages(self, topic, retrieval_query, passages, k):
if self.encoder is None:
self.load_encoder()
if topic in self.embed_cache:
passage_vectors = self.embed_cache[topic]
else:
inputs = [psg["title"] + " " + psg["text"].replace("<s>", "").replace("</s>", "") for psg in passages]
passage_vectors = self.encoder.encode(inputs, batch_size=self.batch_size, device=self.encoder.device)
self.embed_cache[topic] = passage_vectors
self.add_n_embed += 1
query_vectors = self.encoder.encode([retrieval_query],
batch_size=self.batch_size,
device=self.encoder.device)[0]
scores = np.inner(query_vectors, passage_vectors)
indices = np.argsort(-scores)[:k]
return [passages[i] for i in indices]
def get_passages(self, topic, question, k):
retrieval_query = topic + " " + question.strip()
cache_key = topic + "#" + retrieval_query
if cache_key not in self.cache:
passages = self.db.get_text_from_title(topic)
if self.retrieval_type=="bm25":
self.cache[cache_key] = self.get_bm25_passages(topic, retrieval_query, passages, k)
else:
self.cache[cache_key] = self.get_gtr_passages(topic, retrieval_query, passages, k)
assert len(self.cache[cache_key]) in [k, len(passages)]
self.add_n += 1
return self.cache[cache_key]