Agents_SDK_Expert / data_layer.py
Aasher's picture
fix: truncate thread name to 50 characters in update_thread method to ensure safe length
933e440
from datetime import datetime
from typing import Optional, Dict, Any
import json
import chainlit as cl
from chainlit.data.sql_alchemy import SQLAlchemyDataLayer
from chainlit.step import StepDict
from chainlit.user import PersistedUser
from chainlit.types import Pagination, ThreadFilter, PaginatedResponse, ThreadDict
class CustomPostgresDataLayer(SQLAlchemyDataLayer):
"""
This custom class fixes bugs in the default SQLAlchemyDataLayer for PostgreSQL.
- Uses native datetime objects for database writes.
- Correctly handles datetime objects read from the database, converting them
to ISO strings before sending them to the Chainlit frontend.
- Ensures parent threads exist before creating steps to avoid foreign key errors.
"""
# --- Helper to serialize datetime objects ---
def _clean_datetimes(self, data):
"""Recursively convert datetime objects in a dict/list to ISO strings."""
if isinstance(data, dict):
return {k: self._clean_datetimes(v) for k, v in data.items()}
elif isinstance(data, list):
return [self._clean_datetimes(i) for i in data]
elif isinstance(data, datetime):
return data.isoformat()
return data
# --- Overridden Methods ---
async def get_current_timestamp(self) -> datetime:
"""
FIX 1: Returns a native datetime object, which asyncpg requires.
"""
return datetime.now()
async def get_user(self, identifier: str) -> Optional[PersistedUser]:
"""
FIX 3: Complete re-implementation of get_user. This avoids calling the buggy
parent method and correctly handles the datetime object from the database.
"""
if self.show_logger:
self.logger.info(f"CustomSQLAlchemy: get_user, identifier={identifier}")
# Manually execute the SQL query instead of calling super()
query = "SELECT * FROM users WHERE identifier = :identifier"
parameters = {"identifier": identifier}
result = await self.execute_sql(query=query, parameters=parameters)
if result and isinstance(result, list) and len(result) > 0:
user_data = result[0]
metadata = user_data.get("metadata", {})
if isinstance(metadata, str):
metadata = json.loads(metadata)
created_at_dt = user_data.get("createdAt")
created_at_str = created_at_dt.isoformat() if created_at_dt else None
# Construct the PersistedUser object directly.
return PersistedUser(
id=str(user_data["id"]),
identifier=str(user_data["identifier"]),
createdAt=created_at_str,
metadata=metadata,
)
return None
def _step_dict_to_params(self, step_dict: StepDict) -> Dict[str, Any]:
"""
FIX: Converts a StepDict to a dictionary of parameters suitable for an SQL query,
ensuring all timestamp fields are Python datetime objects.
"""
# Start with a copy of the step_dict
params = step_dict.copy()
# Convert timestamp strings to datetime objects
for key in ["createdAt", "start", "end"]:
if key in params and isinstance(params[key], str):
try:
# Handle timezone 'Z' if present
ts_str = params[key].replace("Z", "+00:00")
params[key] = datetime.fromisoformat(ts_str)
except (ValueError, TypeError):
# Fallback or log error if parsing fails
params[key] = None
# Ensure JSON-serializable fields are strings
for key in ["metadata", "generation"]:
if key in params and params[key] is not None:
params[key] = json.dumps(params[key])
else:
params[key] = '{}' # Use empty json object as default
return params
@cl.data.queue_until_user_message()
async def create_step(self, step_dict: StepDict):
"""
FIX: Re-implements create_step to ensure foreign key constraint and to
correctly handle data types for all columns, especially timestamps.
"""
if "threadId" in step_dict and step_dict["threadId"]:
await self.update_thread(thread_id=step_dict["threadId"])
# Prepare parameters for SQL, ensuring correct types
parameters = self._step_dict_to_params(step_dict)
parameters['showInput'] = str(parameters.get('showInput', '')).lower()
# Filter out None values to avoid sending them to the DB
final_params = {k: v for k, v in parameters.items() if v is not None}
columns = ", ".join(f'"{key}"' for key in final_params.keys())
values = ", ".join(f":{key}" for key in final_params.keys())
updates = ", ".join(
f'"{key}" = :{key}' for key in final_params.keys() if key != "id"
)
query = f"""
INSERT INTO steps ({columns})
VALUES ({values})
ON CONFLICT (id) DO UPDATE
SET {updates};
"""
await self.execute_sql(query=query, parameters=final_params)
@cl.data.queue_until_user_message()
async def update_step(self, step_dict: StepDict):
"""
FIX: Overrides update_step to use our fixed create_step logic.
"""
await self.create_step(step_dict)
async def get_thread(self, thread_id: str) -> Optional[ThreadDict]:
"""
FIX 4: Fetches a thread and cleans datetime objects before returning.
"""
thread = await super().get_thread(thread_id)
return self._clean_datetimes(thread)
async def list_threads(
self, pagination: Pagination, filters: ThreadFilter
) -> PaginatedResponse[ThreadDict]:
"""Fetches threads and cleans datetime objects before returning."""
paginated_response = await super().list_threads(pagination, filters)
paginated_response.data = self._clean_datetimes(paginated_response.data)
return paginated_response
async def update_thread(
self,
thread_id: str,
name: Optional[str] = None,
user_id: Optional[str] = None,
metadata: Optional[Dict] = None,
tags: Optional[list[str]] = None,
):
"""
Overrides the base method to truncate the thread name to 200 characters
before saving, preventing database errors with long first messages.
"""
# Determine the effective name from parameters or metadata
effective_name = name
if not effective_name and metadata and "name" in metadata:
effective_name = metadata.get("name")
# Truncate the name to a safe length
if effective_name:
truncated_name = effective_name[:50]
else:
truncated_name = None
# Now, call the original method from the base class with the truncated name
await super().update_thread(
thread_id=thread_id,
name=truncated_name,
user_id=user_id,
metadata=metadata,
tags=tags
)