|
|
import logging |
|
|
from typing import List, TYPE_CHECKING, Optional |
|
|
from datetime import datetime |
|
|
import pytz |
|
|
|
|
|
from langchain_core.chat_history import BaseChatMessageHistory |
|
|
from langchain_core.messages import ( |
|
|
BaseMessage, |
|
|
message_to_dict, |
|
|
messages_from_dict, |
|
|
) |
|
|
from langchain_core.utils import get_from_env |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from supabase import Client |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class SupabaseChatMessageHistory(BaseChatMessageHistory): |
|
|
"""Chat message history stored in a Supabase project database.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
session_id: str, |
|
|
table_name: str = "message_store", |
|
|
session_name: str = "session", |
|
|
client: Optional['Client'] = None, |
|
|
supabase_url: Optional[str] = None, |
|
|
supabase_key: Optional[str] = None, |
|
|
): |
|
|
try: |
|
|
from supabase import create_client |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"Could not import supabase python package. " |
|
|
"Please install it with `pip install supabase`." |
|
|
) |
|
|
|
|
|
|
|
|
if not session_id: |
|
|
raise ValueError("Please ensure that the session_id parameter is provided") |
|
|
|
|
|
self.client = client |
|
|
if client is None: |
|
|
supabase_url = get_from_env("url", "SUPABASE_URL", supabase_url) |
|
|
supabase_key = get_from_env("key", "SUPABASE_KEY", supabase_key) |
|
|
|
|
|
self.client = create_client( |
|
|
supabase_url=supabase_url, |
|
|
supabase_key=supabase_key |
|
|
) |
|
|
|
|
|
self.session_id = session_id |
|
|
self.table_name = table_name |
|
|
self.session_name = session_name |
|
|
|
|
|
@property |
|
|
def messages(self) -> List[BaseMessage]: |
|
|
"""Retrieve the messages from the Supabase project database""" |
|
|
response = self.client.table(self.table_name) \ |
|
|
.select("id", "query_id", "message", "error_log") \ |
|
|
.eq(f"{self.session_name}_id", self.session_id) \ |
|
|
.order('created_at', desc=False) \ |
|
|
.execute() |
|
|
|
|
|
failed_messages = [record for record in response.data if record["message"]["data"]["content"] == "" or record["error_log"] is not None] |
|
|
|
|
|
failed_ids = [] |
|
|
for failed_message in failed_messages: |
|
|
failed_ids.extend([failed_message["id"], failed_message["query_id"]]) |
|
|
|
|
|
items = [record["message"] for record in response.data if record["id"] not in failed_ids] |
|
|
messages = messages_from_dict(items) |
|
|
|
|
|
return messages |
|
|
|
|
|
def add_message(self, message: BaseMessage, query_id: Optional[str] = None) -> None: |
|
|
"""Append the message to the record in the Supabase project database""" |
|
|
response = self.client.table(self.table_name).insert( |
|
|
{ |
|
|
f"{self.session_name}_id": self.session_id, |
|
|
"message": message_to_dict(message), |
|
|
"query_id": query_id, |
|
|
} |
|
|
).execute() |
|
|
|
|
|
return response.data[0]["id"] |
|
|
|
|
|
def update_message( |
|
|
self, |
|
|
message_id:str, |
|
|
message: Optional[BaseMessage] = None, |
|
|
error_log: Optional[dict] = None |
|
|
) -> None: |
|
|
"""Append the message to the record in the Supabase project database""" |
|
|
|
|
|
updated_dict = { |
|
|
"updated_at": datetime.now(pytz.utc).isoformat() |
|
|
} |
|
|
|
|
|
if message is not None: |
|
|
updated_dict["message"] = message_to_dict(message) |
|
|
|
|
|
if error_log is not None: |
|
|
updated_dict["error_log"] = error_log |
|
|
|
|
|
|
|
|
self.client.table(self.table_name).update(updated_dict) \ |
|
|
.eq('id', message_id) \ |
|
|
.execute() |
|
|
|
|
|
|
|
|
def clear(self) -> None: |
|
|
"""Clear session memory from the Supabase project database""" |
|
|
self.client.table(self.table_name) \ |
|
|
.delete() \ |
|
|
.eq(f"{self.session_name}_id", self.session_id) \ |
|
|
.execute() |
|
|
|