Spaces:
Sleeping
Sleeping
| import json | |
| import sqlite3 | |
| import time | |
| import numpy as np | |
| import pickle as pkl | |
| from transformers import RobertaTokenizer | |
| #we copied this code from https://github.com/shmsw25/FActScore and we declare this is only used for research purpose | |
| MAX_LENGTH = 256 | |
| SPECIAL_SEPARATOR = "####SPECIAL####SEPARATOR####" | |
| class DocDB(object): | |
| """Sqlite backed document storage. | |
| Implements get_doc_text(doc_id). | |
| """ | |
| def __init__(self, db_path=None, data_path=None): | |
| self.db_path = db_path | |
| self.connection = sqlite3.connect(self.db_path, check_same_thread=False) | |
| cursor = self.connection.cursor() | |
| cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") | |
| if len(cursor.fetchall())==0: | |
| assert data_path is not None, f"{self.db_path} is empty. Specify `data_path` in order to create a DB." | |
| print (f"{self.db_path} is empty. start building DB from {data_path}...") | |
| self.build_db(self.db_path, data_path) | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, *args): | |
| self.close() | |
| def path(self): | |
| """Return the path to the file that backs this database.""" | |
| return self.path | |
| def close(self): | |
| """Close the connection to the database.""" | |
| self.connection.close() | |
| def build_db(self, db_path, data_path): | |
| tokenizer = RobertaTokenizer.from_pretrained("roberta-large") | |
| titles = set() | |
| output_lines = [] | |
| tot = 0 | |
| start_time = time.time() | |
| c = self.connection.cursor() | |
| c.execute("CREATE TABLE documents (title PRIMARY KEY, text);") | |
| with open(data_path, "r") as f: | |
| for line in f: | |
| dp = json.loads(line) | |
| title = dp["title"] | |
| text = dp["text"] | |
| if title in titles: | |
| continue | |
| titles.add(title) | |
| if type(text)==str: | |
| text = [text] | |
| passages = [[]] | |
| for sent_idx, sent in enumerate(text): | |
| assert len(sent.strip())>0 | |
| tokens = tokenizer(sent)["input_ids"] | |
| max_length = MAX_LENGTH - len(passages[-1]) | |
| if len(tokens) <= max_length: | |
| passages[-1].extend(tokens) | |
| else: | |
| passages[-1].extend(tokens[:max_length]) | |
| offset = max_length | |
| while offset < len(tokens): | |
| passages.append(tokens[offset:offset+MAX_LENGTH]) | |
| offset += MAX_LENGTH | |
| psgs = [tokenizer.decode(tokens) for tokens in passages if np.sum([t not in [0, 2] for t in tokens])>0] | |
| text = SPECIAL_SEPARATOR.join(psgs) | |
| output_lines.append((title, text)) | |
| tot += 1 | |
| if len(output_lines) == 1000000: | |
| c.executemany("INSERT INTO documents VALUES (?,?)", output_lines) | |
| output_lines = [] | |
| print ("Finish saving %dM documents (%dmin)" % (tot / 1000000, (time.time()-start_time)/60)) | |
| if len(output_lines) > 0: | |
| c.executemany("INSERT INTO documents VALUES (?,?)", output_lines) | |
| print ("Finish saving %dM documents (%dmin)" % (tot / 1000000, (time.time()-start_time)/60)) | |
| self.connection.commit() | |
| self.connection.close() | |
| def get_text_from_title(self, title): | |
| """Fetch the raw text of the doc for 'doc_id'.""" | |
| cursor = self.connection.cursor() | |
| cursor.execute("SELECT text FROM documents WHERE title = ?", (title,)) | |
| results = cursor.fetchall() | |
| results = [r for r in results] | |
| cursor.close() | |
| assert results is not None and len(results)==1, f"`topic` in your data ({title}) is likely to be not a valid title in the DB." | |
| results = [{"title": title, "text": para} for para in results[0][0].split(SPECIAL_SEPARATOR)] | |
| assert len(results)>0, f"`topic` in your data ({title}) is likely to be not a valid title in the DB." | |
| return results |