Ryoya Awano
deploy: fix MedLFQA Marginal mode sample matching
19fc84f
import json
import sqlite3
import time
import numpy as np
import pickle as pkl
from transformers import RobertaTokenizer
#we copied this code from https://github.com/shmsw25/FActScore and we declare this is only used for research purpose
MAX_LENGTH = 256
SPECIAL_SEPARATOR = "####SPECIAL####SEPARATOR####"
class DocDB(object):
"""Sqlite backed document storage.
Implements get_doc_text(doc_id).
"""
def __init__(self, db_path=None, data_path=None):
self.db_path = db_path
self.connection = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = self.connection.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
if len(cursor.fetchall())==0:
assert data_path is not None, f"{self.db_path} is empty. Specify `data_path` in order to create a DB."
print (f"{self.db_path} is empty. start building DB from {data_path}...")
self.build_db(self.db_path, data_path)
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
def path(self):
"""Return the path to the file that backs this database."""
return self.path
def close(self):
"""Close the connection to the database."""
self.connection.close()
def build_db(self, db_path, data_path):
tokenizer = RobertaTokenizer.from_pretrained("roberta-large")
titles = set()
output_lines = []
tot = 0
start_time = time.time()
c = self.connection.cursor()
c.execute("CREATE TABLE documents (title PRIMARY KEY, text);")
with open(data_path, "r") as f:
for line in f:
dp = json.loads(line)
title = dp["title"]
text = dp["text"]
if title in titles:
continue
titles.add(title)
if type(text)==str:
text = [text]
passages = [[]]
for sent_idx, sent in enumerate(text):
assert len(sent.strip())>0
tokens = tokenizer(sent)["input_ids"]
max_length = MAX_LENGTH - len(passages[-1])
if len(tokens) <= max_length:
passages[-1].extend(tokens)
else:
passages[-1].extend(tokens[:max_length])
offset = max_length
while offset < len(tokens):
passages.append(tokens[offset:offset+MAX_LENGTH])
offset += MAX_LENGTH
psgs = [tokenizer.decode(tokens) for tokens in passages if np.sum([t not in [0, 2] for t in tokens])>0]
text = SPECIAL_SEPARATOR.join(psgs)
output_lines.append((title, text))
tot += 1
if len(output_lines) == 1000000:
c.executemany("INSERT INTO documents VALUES (?,?)", output_lines)
output_lines = []
print ("Finish saving %dM documents (%dmin)" % (tot / 1000000, (time.time()-start_time)/60))
if len(output_lines) > 0:
c.executemany("INSERT INTO documents VALUES (?,?)", output_lines)
print ("Finish saving %dM documents (%dmin)" % (tot / 1000000, (time.time()-start_time)/60))
self.connection.commit()
self.connection.close()
def get_text_from_title(self, title):
"""Fetch the raw text of the doc for 'doc_id'."""
cursor = self.connection.cursor()
cursor.execute("SELECT text FROM documents WHERE title = ?", (title,))
results = cursor.fetchall()
results = [r for r in results]
cursor.close()
assert results is not None and len(results)==1, f"`topic` in your data ({title}) is likely to be not a valid title in the DB."
results = [{"title": title, "text": para} for para in results[0][0].split(SPECIAL_SEPARATOR)]
assert len(results)>0, f"`topic` in your data ({title}) is likely to be not a valid title in the DB."
return results