import json import time import os import sqlite3 import numpy as np import pickle as pkl from rank_bm25 import BM25Okapi SPECIAL_SEPARATOR = "####SPECIAL####SEPARATOR####" MAX_LENGTH = 256 class DocDB(object): """Sqlite backed document storage. Implements get_doc_text(doc_id). """ def __init__(self, db_path=None, data_path=None, cache_path=None): self.db_path = db_path self.cache_file = cache_path self.connection = sqlite3.connect(self.db_path, check_same_thread=False) self.cache_dict = self.load_cache() cursor = self.connection.cursor() cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") if len(cursor.fetchall())==0: assert data_path is not None, f"{self.db_path} is empty. Specify `data_path` in order to create a DB." print (f"{self.db_path} is empty. start building DB from {data_path}...") self.build_db(self.db_path, data_path) def load_cache(self, allow_retry=True): if os.path.exists(self.cache_file): while True: try: with open(self.cache_file, "rb") as f: cache = pkl.load(f) break except Exception: # if there are concurent processes, things can fail if not allow_retry: assert False print ("Pickle Error: Retry in 5sec...") time.sleep(5) elif 's3' in self.cache_file: from aws_utils import s3_open s3_path = self.cache_file.removeprefix('s3://') bucket_name = s3_path.split('/')[0] path_to_file = '/'.join(s3_path.split('/')[1:]) with s3_open(bucket_name, path_to_file) as fp: cache = pkl.load(fp) else: cache = {} return cache def save_cache(self): # load the latest cache first, since if there were other processes running in parallel, cache might have been updated for k, v in self.load_cache().items(): self.cache_dict[k] = v with open(self.cache_file, "wb") as f: pkl.dump(self.cache_dict, f) def __enter__(self): return self def __exit__(self, *args): self.close() def path(self): """Return the path to the file that backs this database.""" return self.path def close(self): """Close the connection to the database.""" self.connection.close() def build_db(self, db_path, data_path): from transformers import RobertaTokenizer tokenizer = RobertaTokenizer.from_pretrained("roberta-large") titles = set() output_lines = [] tot = 0 start_time = time.time() c = self.connection.cursor() c.execute("CREATE TABLE documents (title PRIMARY KEY, text);") with open(data_path, "r") as f: for line in f: dp = json.loads(line) title = dp["title"] text = dp["text"] if title in titles: continue titles.add(title) if type(text)==str: text = [text] passages = [[]] for sent_idx, sent in enumerate(text): assert len(sent.strip())>0 tokens = tokenizer(sent)["input_ids"] max_length = MAX_LENGTH - len(passages[-1]) if len(tokens) <= max_length: passages[-1].extend(tokens) else: passages[-1].extend(tokens[:max_length]) offset = max_length while offset < len(tokens): passages.append(tokens[offset:offset+MAX_LENGTH]) offset += MAX_LENGTH psgs = [tokenizer.decode(tokens) for tokens in passages if np.sum([t not in [0, 2] for t in tokens])>0] text = SPECIAL_SEPARATOR.join(psgs) output_lines.append((title, text)) tot += 1 if len(output_lines) == 1000000: c.executemany("INSERT INTO documents VALUES (?,?)", output_lines) output_lines = [] print ("Finish saving %dM documents (%dmin)" % (tot / 1000000, (time.time()-start_time)/60)) if len(output_lines) > 0: c.executemany("INSERT INTO documents VALUES (?,?)", output_lines) print ("Finish saving %dM documents (%dmin)" % (tot / 1000000, (time.time()-start_time)/60)) self.connection.commit() self.connection.close() def get_text_from_title(self, title): """Fetch the raw text of the doc for 'doc_id'.""" with open('data/wiki_corrections.txt') as fp: all_names = fp.readlines() all_names = [n.strip() for n in all_names] name_converter = {names.split('=')[0]:names.split('=')[1] for names in all_names} if title in name_converter: title = name_converter[title] if title in self.cache_dict: results = self.cache_dict[title] else: print("I SHOULD NOT BE HERE.") cursor = self.connection.cursor() cursor.execute("SELECT text FROM documents WHERE title = ?", (title,)) results = cursor.fetchall() results = [r for r in results] cursor.close() try: assert results is not None and len(results)==1, f"`topic` in your data ({title}) is likely to be not a valid title in the DB." except Exception: # if there are concurent processes, things can fail print (f"Retrieval error for {title}: Retry in 5sec...") # time.sleep(5) cursor = self.connection.cursor() cursor.execute("SELECT text FROM documents WHERE title = ?", (title,)) results = cursor.fetchall() results = [r for r in results] results = [['blah blah blah']] cursor.close() results = [{"title": title, "text": para} for para in results[0][0].split(SPECIAL_SEPARATOR)] assert len(results)>0, f"`topic` in your data ({title}) is likely to be not a valid title in the DB." self.cache_dict[title] = results return results class Retrieval(object): def __init__(self, db, cache_path, embed_cache_path, retrieval_type="gtr-t5-large", batch_size=None): self.db = db self.cache_path = cache_path self.embed_cache_path = embed_cache_path self.retrieval_type = retrieval_type self.batch_size = batch_size assert retrieval_type=="bm25" or retrieval_type.startswith("gtr-") self.encoder = None self.load_cache() self.add_n = 0 self.add_n_embed = 0 def load_encoder(self): from sentence_transformers import SentenceTransformer encoder = SentenceTransformer("sentence-transformers/" + self.retrieval_type) encoder = encoder.cuda() encoder = encoder.eval() self.encoder = encoder assert self.batch_size is not None def load_cache(self): if os.path.exists(self.cache_path): with open(self.cache_path, "r") as f: self.cache = json.load(f) else: self.cache = {} if os.path.exists(self.embed_cache_path): with open(self.embed_cache_path, "rb") as f: self.embed_cache = pkl.load(f) else: self.embed_cache = {} def save_cache(self): if self.add_n > 0: if os.path.exists(self.cache_path): with open(self.cache_path, "r") as f: new_cache = json.load(f) self.cache.update(new_cache) with open(self.cache_path, "w") as f: json.dump(self.cache, f) if self.add_n_embed > 0: if os.path.exists(self.embed_cache_path): with open(self.embed_cache_path, "rb") as f: new_cache = pkl.load(f) self.embed_cache.update(new_cache) with open(self.embed_cache_path, "wb") as f: pkl.dump(self.embed_cache, f) def get_bm25_passages(self, topic, query, passages, k): if topic in self.embed_cache: bm25 = self.embed_cache[topic] else: bm25 = BM25Okapi([psg["text"].replace("", "").replace("", "").split() for psg in passages]) self.embed_cache[topic] = bm25 self.add_n_embed += 1 scores = bm25.get_scores(query.split()) indices = np.argsort(-scores)[:k] return [passages[i] for i in indices] def get_gtr_passages(self, topic, retrieval_query, passages, k): if self.encoder is None: self.load_encoder() if topic in self.embed_cache: passage_vectors = self.embed_cache[topic] else: inputs = [psg["title"] + " " + psg["text"].replace("", "").replace("", "") for psg in passages] passage_vectors = self.encoder.encode(inputs, batch_size=self.batch_size, device=self.encoder.device) self.embed_cache[topic] = passage_vectors self.add_n_embed += 1 query_vectors = self.encoder.encode([retrieval_query], batch_size=self.batch_size, device=self.encoder.device)[0] scores = np.inner(query_vectors, passage_vectors) indices = np.argsort(-scores)[:k] return [passages[i] for i in indices] def get_passages(self, topic, question, k): retrieval_query = topic + " " + question.strip() cache_key = topic + "#" + retrieval_query if cache_key not in self.cache: passages = self.db.get_text_from_title(topic) if self.retrieval_type=="bm25": self.cache[cache_key] = self.get_bm25_passages(topic, retrieval_query, passages, k) else: self.cache[cache_key] = self.get_gtr_passages(topic, retrieval_query, passages, k) assert len(self.cache[cache_key]) in [k, len(passages)] self.add_n += 1 return self.cache[cache_key]