Spaces:
Sleeping
Sleeping
| from datetime import datetime | |
| from typing import Optional, Dict, Any | |
| import json | |
| import chainlit as cl | |
| from chainlit.data.sql_alchemy import SQLAlchemyDataLayer | |
| from chainlit.step import StepDict | |
| from chainlit.user import PersistedUser | |
| from chainlit.types import Pagination, ThreadFilter, PaginatedResponse, ThreadDict | |
| class CustomPostgresDataLayer(SQLAlchemyDataLayer): | |
| """ | |
| This custom class fixes bugs in the default SQLAlchemyDataLayer for PostgreSQL. | |
| - Uses native datetime objects for database writes. | |
| - Correctly handles datetime objects read from the database, converting them | |
| to ISO strings before sending them to the Chainlit frontend. | |
| - Ensures parent threads exist before creating steps to avoid foreign key errors. | |
| """ | |
| # --- Helper to serialize datetime objects --- | |
| def _clean_datetimes(self, data): | |
| """Recursively convert datetime objects in a dict/list to ISO strings.""" | |
| if isinstance(data, dict): | |
| return {k: self._clean_datetimes(v) for k, v in data.items()} | |
| elif isinstance(data, list): | |
| return [self._clean_datetimes(i) for i in data] | |
| elif isinstance(data, datetime): | |
| return data.isoformat() | |
| return data | |
| # --- Overridden Methods --- | |
| async def get_current_timestamp(self) -> datetime: | |
| """ | |
| FIX 1: Returns a native datetime object, which asyncpg requires. | |
| """ | |
| return datetime.now() | |
| async def get_user(self, identifier: str) -> Optional[PersistedUser]: | |
| """ | |
| FIX 3: Complete re-implementation of get_user. This avoids calling the buggy | |
| parent method and correctly handles the datetime object from the database. | |
| """ | |
| if self.show_logger: | |
| self.logger.info(f"CustomSQLAlchemy: get_user, identifier={identifier}") | |
| # Manually execute the SQL query instead of calling super() | |
| query = "SELECT * FROM users WHERE identifier = :identifier" | |
| parameters = {"identifier": identifier} | |
| result = await self.execute_sql(query=query, parameters=parameters) | |
| if result and isinstance(result, list) and len(result) > 0: | |
| user_data = result[0] | |
| metadata = user_data.get("metadata", {}) | |
| if isinstance(metadata, str): | |
| metadata = json.loads(metadata) | |
| created_at_dt = user_data.get("createdAt") | |
| created_at_str = created_at_dt.isoformat() if created_at_dt else None | |
| # Construct the PersistedUser object directly. | |
| return PersistedUser( | |
| id=str(user_data["id"]), | |
| identifier=str(user_data["identifier"]), | |
| createdAt=created_at_str, | |
| metadata=metadata, | |
| ) | |
| return None | |
| def _step_dict_to_params(self, step_dict: StepDict) -> Dict[str, Any]: | |
| """ | |
| FIX: Converts a StepDict to a dictionary of parameters suitable for an SQL query, | |
| ensuring all timestamp fields are Python datetime objects. | |
| """ | |
| # Start with a copy of the step_dict | |
| params = step_dict.copy() | |
| # Convert timestamp strings to datetime objects | |
| for key in ["createdAt", "start", "end"]: | |
| if key in params and isinstance(params[key], str): | |
| try: | |
| # Handle timezone 'Z' if present | |
| ts_str = params[key].replace("Z", "+00:00") | |
| params[key] = datetime.fromisoformat(ts_str) | |
| except (ValueError, TypeError): | |
| # Fallback or log error if parsing fails | |
| params[key] = None | |
| # Ensure JSON-serializable fields are strings | |
| for key in ["metadata", "generation"]: | |
| if key in params and params[key] is not None: | |
| params[key] = json.dumps(params[key]) | |
| else: | |
| params[key] = '{}' # Use empty json object as default | |
| return params | |
| async def create_step(self, step_dict: StepDict): | |
| """ | |
| FIX: Re-implements create_step to ensure foreign key constraint and to | |
| correctly handle data types for all columns, especially timestamps. | |
| """ | |
| if "threadId" in step_dict and step_dict["threadId"]: | |
| await self.update_thread(thread_id=step_dict["threadId"]) | |
| # Prepare parameters for SQL, ensuring correct types | |
| parameters = self._step_dict_to_params(step_dict) | |
| parameters['showInput'] = str(parameters.get('showInput', '')).lower() | |
| # Filter out None values to avoid sending them to the DB | |
| final_params = {k: v for k, v in parameters.items() if v is not None} | |
| columns = ", ".join(f'"{key}"' for key in final_params.keys()) | |
| values = ", ".join(f":{key}" for key in final_params.keys()) | |
| updates = ", ".join( | |
| f'"{key}" = :{key}' for key in final_params.keys() if key != "id" | |
| ) | |
| query = f""" | |
| INSERT INTO steps ({columns}) | |
| VALUES ({values}) | |
| ON CONFLICT (id) DO UPDATE | |
| SET {updates}; | |
| """ | |
| await self.execute_sql(query=query, parameters=final_params) | |
| async def update_step(self, step_dict: StepDict): | |
| """ | |
| FIX: Overrides update_step to use our fixed create_step logic. | |
| """ | |
| await self.create_step(step_dict) | |
| async def get_thread(self, thread_id: str) -> Optional[ThreadDict]: | |
| """ | |
| FIX 4: Fetches a thread and cleans datetime objects before returning. | |
| """ | |
| thread = await super().get_thread(thread_id) | |
| return self._clean_datetimes(thread) | |
| async def list_threads( | |
| self, pagination: Pagination, filters: ThreadFilter | |
| ) -> PaginatedResponse[ThreadDict]: | |
| """Fetches threads and cleans datetime objects before returning.""" | |
| paginated_response = await super().list_threads(pagination, filters) | |
| paginated_response.data = self._clean_datetimes(paginated_response.data) | |
| return paginated_response | |
| async def update_thread( | |
| self, | |
| thread_id: str, | |
| name: Optional[str] = None, | |
| user_id: Optional[str] = None, | |
| metadata: Optional[Dict] = None, | |
| tags: Optional[list[str]] = None, | |
| ): | |
| """ | |
| Overrides the base method to truncate the thread name to 200 characters | |
| before saving, preventing database errors with long first messages. | |
| """ | |
| # Determine the effective name from parameters or metadata | |
| effective_name = name | |
| if not effective_name and metadata and "name" in metadata: | |
| effective_name = metadata.get("name") | |
| # Truncate the name to a safe length | |
| if effective_name: | |
| truncated_name = effective_name[:50] | |
| else: | |
| truncated_name = None | |
| # Now, call the original method from the base class with the truncated name | |
| await super().update_thread( | |
| thread_id=thread_id, | |
| name=truncated_name, | |
| user_id=user_id, | |
| metadata=metadata, | |
| tags=tags | |
| ) | |