from datetime import datetime from typing import Optional, Dict, Any import json import chainlit as cl from chainlit.data.sql_alchemy import SQLAlchemyDataLayer from chainlit.step import StepDict from chainlit.user import PersistedUser from chainlit.types import Pagination, ThreadFilter, PaginatedResponse, ThreadDict class CustomPostgresDataLayer(SQLAlchemyDataLayer): """ This custom class fixes bugs in the default SQLAlchemyDataLayer for PostgreSQL. - Uses native datetime objects for database writes. - Correctly handles datetime objects read from the database, converting them to ISO strings before sending them to the Chainlit frontend. - Ensures parent threads exist before creating steps to avoid foreign key errors. """ # --- Helper to serialize datetime objects --- def _clean_datetimes(self, data): """Recursively convert datetime objects in a dict/list to ISO strings.""" if isinstance(data, dict): return {k: self._clean_datetimes(v) for k, v in data.items()} elif isinstance(data, list): return [self._clean_datetimes(i) for i in data] elif isinstance(data, datetime): return data.isoformat() return data # --- Overridden Methods --- async def get_current_timestamp(self) -> datetime: """ FIX 1: Returns a native datetime object, which asyncpg requires. """ return datetime.now() async def get_user(self, identifier: str) -> Optional[PersistedUser]: """ FIX 3: Complete re-implementation of get_user. This avoids calling the buggy parent method and correctly handles the datetime object from the database. """ if self.show_logger: self.logger.info(f"CustomSQLAlchemy: get_user, identifier={identifier}") # Manually execute the SQL query instead of calling super() query = "SELECT * FROM users WHERE identifier = :identifier" parameters = {"identifier": identifier} result = await self.execute_sql(query=query, parameters=parameters) if result and isinstance(result, list) and len(result) > 0: user_data = result[0] metadata = user_data.get("metadata", {}) if isinstance(metadata, str): metadata = json.loads(metadata) created_at_dt = user_data.get("createdAt") created_at_str = created_at_dt.isoformat() if created_at_dt else None # Construct the PersistedUser object directly. return PersistedUser( id=str(user_data["id"]), identifier=str(user_data["identifier"]), createdAt=created_at_str, metadata=metadata, ) return None def _step_dict_to_params(self, step_dict: StepDict) -> Dict[str, Any]: """ FIX: Converts a StepDict to a dictionary of parameters suitable for an SQL query, ensuring all timestamp fields are Python datetime objects. """ # Start with a copy of the step_dict params = step_dict.copy() # Convert timestamp strings to datetime objects for key in ["createdAt", "start", "end"]: if key in params and isinstance(params[key], str): try: # Handle timezone 'Z' if present ts_str = params[key].replace("Z", "+00:00") params[key] = datetime.fromisoformat(ts_str) except (ValueError, TypeError): # Fallback or log error if parsing fails params[key] = None # Ensure JSON-serializable fields are strings for key in ["metadata", "generation"]: if key in params and params[key] is not None: params[key] = json.dumps(params[key]) else: params[key] = '{}' # Use empty json object as default return params @cl.data.queue_until_user_message() async def create_step(self, step_dict: StepDict): """ FIX: Re-implements create_step to ensure foreign key constraint and to correctly handle data types for all columns, especially timestamps. """ if "threadId" in step_dict and step_dict["threadId"]: await self.update_thread(thread_id=step_dict["threadId"]) # Prepare parameters for SQL, ensuring correct types parameters = self._step_dict_to_params(step_dict) parameters['showInput'] = str(parameters.get('showInput', '')).lower() # Filter out None values to avoid sending them to the DB final_params = {k: v for k, v in parameters.items() if v is not None} columns = ", ".join(f'"{key}"' for key in final_params.keys()) values = ", ".join(f":{key}" for key in final_params.keys()) updates = ", ".join( f'"{key}" = :{key}' for key in final_params.keys() if key != "id" ) query = f""" INSERT INTO steps ({columns}) VALUES ({values}) ON CONFLICT (id) DO UPDATE SET {updates}; """ await self.execute_sql(query=query, parameters=final_params) @cl.data.queue_until_user_message() async def update_step(self, step_dict: StepDict): """ FIX: Overrides update_step to use our fixed create_step logic. """ await self.create_step(step_dict) async def get_thread(self, thread_id: str) -> Optional[ThreadDict]: """ FIX 4: Fetches a thread and cleans datetime objects before returning. """ thread = await super().get_thread(thread_id) return self._clean_datetimes(thread) async def list_threads( self, pagination: Pagination, filters: ThreadFilter ) -> PaginatedResponse[ThreadDict]: """Fetches threads and cleans datetime objects before returning.""" paginated_response = await super().list_threads(pagination, filters) paginated_response.data = self._clean_datetimes(paginated_response.data) return paginated_response async def update_thread( self, thread_id: str, name: Optional[str] = None, user_id: Optional[str] = None, metadata: Optional[Dict] = None, tags: Optional[list[str]] = None, ): """ Overrides the base method to truncate the thread name to 200 characters before saving, preventing database errors with long first messages. """ # Determine the effective name from parameters or metadata effective_name = name if not effective_name and metadata and "name" in metadata: effective_name = metadata.get("name") # Truncate the name to a safe length if effective_name: truncated_name = effective_name[:50] else: truncated_name = None # Now, call the original method from the base class with the truncated name await super().update_thread( thread_id=thread_id, name=truncated_name, user_id=user_id, metadata=metadata, tags=tags )