Spaces:
Runtime error
Runtime error
| import json | |
| from typing import List | |
| from langchain_core.chat_history import BaseChatMessageHistory | |
| from langchain_core.messages import ( | |
| BaseMessage, | |
| message_to_dict, | |
| messages_from_dict, | |
| ) | |
| class XataChatMessageHistory(BaseChatMessageHistory): | |
| """Chat message history stored in a Xata database.""" | |
| def __init__( | |
| self, | |
| session_id: str, | |
| db_url: str, | |
| api_key: str, | |
| branch_name: str = "main", | |
| table_name: str = "messages", | |
| create_table: bool = True, | |
| ) -> None: | |
| """Initialize with Xata client.""" | |
| try: | |
| from xata.client import XataClient # noqa: F401 | |
| except ImportError: | |
| raise ValueError( | |
| "Could not import xata python package. " | |
| "Please install it with `pip install xata`." | |
| ) | |
| self._client = XataClient( | |
| api_key=api_key, db_url=db_url, branch_name=branch_name | |
| ) | |
| self._table_name = table_name | |
| self._session_id = session_id | |
| if create_table: | |
| self._create_table_if_not_exists() | |
| def _create_table_if_not_exists(self) -> None: | |
| r = self._client.table().get_schema(self._table_name) | |
| if r.status_code <= 299: | |
| return | |
| if r.status_code != 404: | |
| raise Exception( | |
| f"Error checking if table exists in Xata: {r.status_code} {r}" | |
| ) | |
| r = self._client.table().create(self._table_name) | |
| if r.status_code > 299: | |
| raise Exception(f"Error creating table in Xata: {r.status_code} {r}") | |
| r = self._client.table().set_schema( | |
| self._table_name, | |
| payload={ | |
| "columns": [ | |
| {"name": "sessionId", "type": "string"}, | |
| {"name": "type", "type": "string"}, | |
| {"name": "role", "type": "string"}, | |
| {"name": "content", "type": "text"}, | |
| {"name": "name", "type": "string"}, | |
| {"name": "additionalKwargs", "type": "json"}, | |
| ] | |
| }, | |
| ) | |
| if r.status_code > 299: | |
| raise Exception(f"Error setting table schema in Xata: {r.status_code} {r}") | |
| def add_message(self, message: BaseMessage) -> None: | |
| """Append the message to the Xata table""" | |
| msg = message_to_dict(message) | |
| r = self._client.records().insert( | |
| self._table_name, | |
| { | |
| "sessionId": self._session_id, | |
| "type": msg["type"], | |
| "content": message.content, | |
| "additionalKwargs": json.dumps(message.additional_kwargs), | |
| "role": msg["data"].get("role"), | |
| "name": msg["data"].get("name"), | |
| }, | |
| ) | |
| if r.status_code > 299: | |
| raise Exception(f"Error adding message to Xata: {r.status_code} {r}") | |
| def messages(self) -> List[BaseMessage]: # type: ignore | |
| r = self._client.data().query( | |
| self._table_name, | |
| payload={ | |
| "filter": { | |
| "sessionId": self._session_id, | |
| }, | |
| "sort": {"xata.createdAt": "asc"}, | |
| }, | |
| ) | |
| if r.status_code != 200: | |
| raise Exception(f"Error running query: {r.status_code} {r}") | |
| msgs = messages_from_dict( | |
| [ | |
| { | |
| "type": m["type"], | |
| "data": { | |
| "content": m["content"], | |
| "role": m.get("role"), | |
| "name": m.get("name"), | |
| "additional_kwargs": json.loads(m["additionalKwargs"]), | |
| }, | |
| } | |
| for m in r["records"] | |
| ] | |
| ) | |
| return msgs | |
| def clear(self) -> None: | |
| """Delete session from Xata table.""" | |
| while True: | |
| r = self._client.data().query( | |
| self._table_name, | |
| payload={ | |
| "columns": ["id"], | |
| "filter": { | |
| "sessionId": self._session_id, | |
| }, | |
| }, | |
| ) | |
| if r.status_code != 200: | |
| raise Exception(f"Error running query: {r.status_code} {r}") | |
| ids = [rec["id"] for rec in r["records"]] | |
| if len(ids) == 0: | |
| break | |
| operations = [ | |
| {"delete": {"table": self._table_name, "id": id}} for id in ids | |
| ] | |
| self._client.records().transaction(payload={"operations": operations}) | |