Spaces:
Runtime error
Runtime error
| import sqlite3 | |
| import warnings | |
| import zlib | |
| import numpy as np | |
| import pandas as pd | |
| documents_table = """CREATE TABLE IF NOT EXISTS documents ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| source TEXT NOT NULL, | |
| title TEXT NOT NULL, | |
| url TEXT NOT NULL, | |
| content TEXT NOT NULL, | |
| n_tokens INTEGER, | |
| embedding BLOB, | |
| current INTEGER | |
| )""" | |
| qa_table = """CREATE TABLE IF NOT EXISTS qa ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| source TEXT NOT NULL, | |
| prompt TEXT NOT NULL, | |
| answer TEXT NOT NULL, | |
| document_id_1 INTEGER, | |
| document_id_2 INTEGER, | |
| document_id_3 INTEGER, | |
| label_question INTEGER, | |
| label_answer INTEGER, | |
| testset INTEGER, | |
| FOREIGN KEY (document_id_1) REFERENCES documents (id), | |
| FOREIGN KEY (document_id_2) REFERENCES documents (id), | |
| FOREIGN KEY (document_id_3) REFERENCES documents (id) | |
| )""" | |
| class DocumentsDB: | |
| """Simple SQLite database for storing documents and questions/answers. | |
| The database is just a file on disk. It can store documents from different sources, and it can store multiple versions of the same document (e.g. if the document is updated). | |
| Questions/answers refer to the version of the document that was used at the time. | |
| Example: | |
| >>> db = DocumentsDB("/path/to/the/db.db") | |
| >>> db.write_documents("source", df) # df is a DataFrame containing the documents from a given source, obtained e.g. by using buster.docparser.generate_embeddings | |
| >>> df = db.get_documents("source") | |
| """ | |
| def __init__(self, db_path): | |
| self.db_path = db_path | |
| self.conn = sqlite3.connect(db_path) | |
| self.cursor = self.conn.cursor() | |
| self.__initialize() | |
| def __del__(self): | |
| self.conn.close() | |
| def __initialize(self): | |
| """Initialize the database.""" | |
| self.cursor.execute(documents_table) | |
| self.cursor.execute(qa_table) | |
| self.conn.commit() | |
| def write_documents(self, source: str, df: pd.DataFrame): | |
| """Write all documents from the dataframe into the db. All previous documents from that source will be set to `current = 0`.""" | |
| df = df.copy() | |
| # Prepare the rows | |
| df["source"] = source | |
| df["current"] = 1 | |
| columns = ["source", "title", "url", "content", "current"] | |
| if "embedding" in df.columns: | |
| columns.extend( | |
| [ | |
| "n_tokens", | |
| "embedding", | |
| ] | |
| ) | |
| # Check that the embeddings are float32 | |
| if not df["embedding"].iloc[0].dtype == np.float32: | |
| warnings.warn( | |
| f"Embeddings are not float32, converting them to float32 from {df['embedding'].iloc[0].dtype}.", | |
| RuntimeWarning, | |
| ) | |
| df["embedding"] = df["embedding"].apply(lambda x: x.astype(np.float32)) | |
| # ZLIB compress the embeddings | |
| df["embedding"] = df["embedding"].apply(lambda x: sqlite3.Binary(zlib.compress(x.tobytes()))) | |
| data = df[columns].values.tolist() | |
| # Set `current` to 0 for all previous documents from that source | |
| self.cursor.execute("UPDATE documents SET current = 0 WHERE source = ?", (source,)) | |
| # Insert the new documents | |
| insert_statement = f"INSERT INTO documents ({', '.join(columns)}) VALUES ({', '.join(['?']*len(columns))})" | |
| self.cursor.executemany(insert_statement, data) | |
| self.conn.commit() | |
| def get_documents(self, source: str) -> pd.DataFrame: | |
| """Get all current documents from a given source.""" | |
| # Execute the SQL statement and fetch the results | |
| results = self.cursor.execute("SELECT * FROM documents WHERE source = ? AND current = 1", (source,)) | |
| rows = results.fetchall() | |
| # Convert the results to a pandas DataFrame | |
| df = pd.DataFrame(rows, columns=[description[0] for description in results.description]) | |
| # ZLIB decompress the embeddings | |
| df["embedding"] = df["embedding"].apply(lambda x: np.frombuffer(zlib.decompress(x), dtype=np.float32).tolist()) | |
| # Drop the `current` column | |
| df.drop(columns=["current"], inplace=True) | |
| return df | |