Spaces:
Runtime error
Runtime error
| import json | |
| import logging | |
| from typing import List | |
| from langchain_core.chat_history import BaseChatMessageHistory | |
| from langchain_core.messages import ( | |
| BaseMessage, | |
| message_to_dict, | |
| messages_from_dict, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| DEFAULT_CONNECTION_STRING = "postgresql://postgres:mypassword@localhost/chat_history" | |
| class PostgresChatMessageHistory(BaseChatMessageHistory): | |
| """Chat message history stored in a Postgres database.""" | |
| def __init__( | |
| self, | |
| session_id: str, | |
| connection_string: str = DEFAULT_CONNECTION_STRING, | |
| table_name: str = "message_store", | |
| ): | |
| import psycopg | |
| from psycopg.rows import dict_row | |
| try: | |
| self.connection = psycopg.connect(connection_string) | |
| self.cursor = self.connection.cursor(row_factory=dict_row) | |
| except psycopg.OperationalError as error: | |
| logger.error(error) | |
| self.session_id = session_id | |
| self.table_name = table_name | |
| self._create_table_if_not_exists() | |
| def _create_table_if_not_exists(self) -> None: | |
| create_table_query = f"""CREATE TABLE IF NOT EXISTS {self.table_name} ( | |
| id SERIAL PRIMARY KEY, | |
| session_id TEXT NOT NULL, | |
| message JSONB NOT NULL | |
| );""" | |
| self.cursor.execute(create_table_query) | |
| self.connection.commit() | |
| def messages(self) -> List[BaseMessage]: # type: ignore | |
| """Retrieve the messages from PostgreSQL""" | |
| query = ( | |
| f"SELECT message FROM {self.table_name} WHERE session_id = %s ORDER BY id;" | |
| ) | |
| self.cursor.execute(query, (self.session_id,)) | |
| items = [record["message"] for record in self.cursor.fetchall()] | |
| messages = messages_from_dict(items) | |
| return messages | |
| def add_message(self, message: BaseMessage) -> None: | |
| """Append the message to the record in PostgreSQL""" | |
| from psycopg import sql | |
| query = sql.SQL("INSERT INTO {} (session_id, message) VALUES (%s, %s);").format( | |
| sql.Identifier(self.table_name) | |
| ) | |
| self.cursor.execute( | |
| query, (self.session_id, json.dumps(message_to_dict(message))) | |
| ) | |
| self.connection.commit() | |
| def clear(self) -> None: | |
| """Clear session memory from PostgreSQL""" | |
| query = f"DELETE FROM {self.table_name} WHERE session_id = %s;" | |
| self.cursor.execute(query, (self.session_id,)) | |
| self.connection.commit() | |
| def __del__(self) -> None: | |
| if self.cursor: | |
| self.cursor.close() | |
| if self.connection: | |
| self.connection.close() | |