Spaces:
Running
feat: add ChatKit migration with SSE streaming
Browse filesThis commit implements the ChatKit migration for the Hugging Face
Spaces deployment, replacing WebSocket with Server-Sent Events (SSE).
New Features:
- ChatKit server (chatkit_server.py) with SSE streaming
- PostgreSQL ChatKit store (services/chatkit_store.py)
- Thread model for ChatKit conversations (models/thread.py)
- MCP tool wrappers for Agents SDK (ai_agent/tool_wrappers.py)
- Database migration for threads table (migrations/migrate_threads.sql)
Backend Changes:
- Updated /api/chatkit endpoint with SSE streaming
- Fixed SSE JSON serialization using json.dumps()
- Added tool_calls JSONB column support
- Enhanced authentication with httpOnly cookie support
Updated Files:
- api/chat.py: Added ChatKit SSE endpoint
- core/config.py: Added get_gemini_client() function
- core/security.py: Enhanced cookie-based authentication
- ai_agent/__init__.py: Updated agent imports
- .env.example: Added GEMINI_BASE_URL configuration
This enables the same ChatKit functionality on Hugging Face Spaces
as in the main SDDRI-Hackathon-2 repository.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- .env.example +2 -0
- __init__.py +0 -0
- ai_agent/__init__.py +3 -7
- ai_agent/tool_wrappers.py +366 -0
- api/chat.py +269 -65
- chat.py +0 -478
- chatkit_server.py +358 -0
- config.py +0 -54
- core/config.py +42 -3
- core/security.py +37 -0
- migrations/migrate_threads.sql +76 -0
- models/thread.py +72 -0
- security.py +0 -276
- services/chatkit_store.py +416 -0
|
@@ -14,3 +14,5 @@ ENVIRONMENT=development
|
|
| 14 |
# Get your API key from https://aistudio.google.com
|
| 15 |
GEMINI_API_KEY=your-gemini-api-key-here
|
| 16 |
GEMINI_MODEL=gemini-2.0-flash-exp
|
|
|
|
|
|
|
|
|
| 14 |
# Get your API key from https://aistudio.google.com
|
| 15 |
GEMINI_API_KEY=your-gemini-api-key-here
|
| 16 |
GEMINI_MODEL=gemini-2.0-flash-exp
|
| 17 |
+
# Gemini OpenAI-compatible endpoint for ChatKit integration
|
| 18 |
+
GEMINI_BASE_URL=https://generativelanguage.googleapis.com/v1beta/openai/
|
|
File without changes
|
|
@@ -2,26 +2,22 @@
|
|
| 2 |
|
| 3 |
[Task]: T014, T072
|
| 4 |
[From]: specs/004-ai-chatbot/tasks.md
|
|
|
|
| 5 |
|
| 6 |
This module provides the AI agent that powers the chatbot functionality.
|
| 7 |
It uses OpenAI SDK with function calling and Gemini via AsyncOpenAI adapter.
|
| 8 |
|
| 9 |
-
|
|
|
|
| 10 |
"""
|
| 11 |
from ai_agent.agent_simple import (
|
| 12 |
get_gemini_client,
|
| 13 |
run_agent,
|
| 14 |
is_gemini_configured
|
| 15 |
)
|
| 16 |
-
from ai_agent.agent_streaming import (
|
| 17 |
-
run_agent_with_streaming,
|
| 18 |
-
execute_tool_with_progress,
|
| 19 |
-
)
|
| 20 |
|
| 21 |
__all__ = [
|
| 22 |
"get_gemini_client",
|
| 23 |
"run_agent",
|
| 24 |
-
"run_agent_with_streaming",
|
| 25 |
-
"execute_tool_with_progress",
|
| 26 |
"is_gemini_configured"
|
| 27 |
]
|
|
|
|
| 2 |
|
| 3 |
[Task]: T014, T072
|
| 4 |
[From]: specs/004-ai-chatbot/tasks.md
|
| 5 |
+
[From]: T045 - Delete agent_streaming.py (ChatKit migration replaces WebSocket streaming)
|
| 6 |
|
| 7 |
This module provides the AI agent that powers the chatbot functionality.
|
| 8 |
It uses OpenAI SDK with function calling and Gemini via AsyncOpenAI adapter.
|
| 9 |
|
| 10 |
+
NOTE: The streaming agent functionality has been migrated to ChatKit SSE endpoint.
|
| 11 |
+
See backend/chatkit_server.py for the new ChatKit-based implementation.
|
| 12 |
"""
|
| 13 |
from ai_agent.agent_simple import (
|
| 14 |
get_gemini_client,
|
| 15 |
run_agent,
|
| 16 |
is_gemini_configured
|
| 17 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
__all__ = [
|
| 20 |
"get_gemini_client",
|
| 21 |
"run_agent",
|
|
|
|
|
|
|
| 22 |
"is_gemini_configured"
|
| 23 |
]
|
|
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Agents SDK function wrappers for MCP task management tools.
|
| 2 |
+
|
| 3 |
+
[Task]: T012-T018
|
| 4 |
+
[From]: specs/010-chatkit-migration/tasks.md - Phase 3 Backend Implementation
|
| 5 |
+
[From]: specs/010-chatkit-migration/contracts/backend.md - Tool Contracts
|
| 6 |
+
|
| 7 |
+
This module wraps existing MCP tools as Agents SDK functions using the
|
| 8 |
+
@function_tool decorator. Each wrapper calls the underlying MCP tool function
|
| 9 |
+
and returns the result in a format compatible with the Agents SDK.
|
| 10 |
+
|
| 11 |
+
Tools wrapped:
|
| 12 |
+
1. create_task (T012) - from mcp_server/tools/add_task.py
|
| 13 |
+
2. list_tasks (T013) - from mcp_server/tools/list_tasks.py
|
| 14 |
+
3. update_task (T014) - from mcp_server/tools/update_task.py
|
| 15 |
+
4. delete_task (T015) - from mcp_server/tools/delete_task.py
|
| 16 |
+
5. complete_task (T016) - from mcp_server/tools/complete_task.py
|
| 17 |
+
6. complete_all_tasks (T017) - from mcp_server/tools/complete_all_tasks.py
|
| 18 |
+
7. delete_all_tasks (T018) - from mcp_server/tools/delete_all_tasks.py
|
| 19 |
+
|
| 20 |
+
[From]: specs/010-chatkit-migration/research.md - Section 7 (Tool Visualization Support)
|
| 21 |
+
"""
|
| 22 |
+
import json
|
| 23 |
+
import logging
|
| 24 |
+
from typing import Optional
|
| 25 |
+
|
| 26 |
+
from agents import function_tool, RunContextWrapper
|
| 27 |
+
|
| 28 |
+
# Import MCP tools
|
| 29 |
+
# Note: We import the actual async functions from MCP tools
|
| 30 |
+
import sys
|
| 31 |
+
import os
|
| 32 |
+
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
| 33 |
+
|
| 34 |
+
from mcp_server.tools.add_task import add_task as mcp_add_task
|
| 35 |
+
from mcp_server.tools.list_tasks import list_tasks as mcp_list_tasks
|
| 36 |
+
from mcp_server.tools.update_task import update_task as mcp_update_task
|
| 37 |
+
from mcp_server.tools.delete_task import delete_task as mcp_delete_task
|
| 38 |
+
from mcp_server.tools.complete_task import complete_task as mcp_complete_task
|
| 39 |
+
from mcp_server.tools.complete_all_tasks import complete_all_tasks as mcp_complete_all_tasks
|
| 40 |
+
from mcp_server.tools.delete_all_tasks import delete_all_tasks as mcp_delete_all_tasks
|
| 41 |
+
|
| 42 |
+
logger = logging.getLogger(__name__)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# =============================================================================
|
| 46 |
+
# Agents SDK Function Wrappers
|
| 47 |
+
# =============================================================================
|
| 48 |
+
|
| 49 |
+
@function_tool(
|
| 50 |
+
name="create_task",
|
| 51 |
+
description_override="Create a new task in the user's todo list. Use this when the user wants to create, add, or remind themselves about a task. Parameters: title (required), description (optional), due_date (optional, ISO 8601 or relative), priority (optional: low/medium/high), tags (optional list)."
|
| 52 |
+
)
|
| 53 |
+
async def create_task_tool(
|
| 54 |
+
ctx: RunContextWrapper,
|
| 55 |
+
title: str,
|
| 56 |
+
description: Optional[str] = None,
|
| 57 |
+
due_date: Optional[str] = None,
|
| 58 |
+
priority: Optional[str] = None,
|
| 59 |
+
tags: Optional[list[str]] = None,
|
| 60 |
+
) -> str:
|
| 61 |
+
"""Create a task via MCP tool.
|
| 62 |
+
|
| 63 |
+
[From]: specs/010-chatkit-migration/contracts/backend.md - Tool Contracts
|
| 64 |
+
[Task]: T012
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
ctx: Agents SDK run context containing user_id
|
| 68 |
+
title: Task title
|
| 69 |
+
description: Optional task description
|
| 70 |
+
due_date: Optional due date
|
| 71 |
+
priority: Optional priority level
|
| 72 |
+
tags: Optional list of tags
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
JSON string with created task details
|
| 76 |
+
"""
|
| 77 |
+
user_id = ctx.context.user_id
|
| 78 |
+
|
| 79 |
+
try:
|
| 80 |
+
result = await mcp_add_task(
|
| 81 |
+
user_id=user_id,
|
| 82 |
+
title=title,
|
| 83 |
+
description=description,
|
| 84 |
+
due_date=due_date,
|
| 85 |
+
priority=priority,
|
| 86 |
+
tags=tags or []
|
| 87 |
+
)
|
| 88 |
+
logger.info(f"Task created: {result['task']['id']} for user {user_id}")
|
| 89 |
+
return json.dumps(result)
|
| 90 |
+
except Exception as e:
|
| 91 |
+
logger.error(f"Failed to create task for user {user_id}: {e}")
|
| 92 |
+
return json.dumps({"success": False, "error": str(e)})
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@function_tool(
|
| 96 |
+
name="list_tasks",
|
| 97 |
+
description_override="List all tasks for the user, optionally filtered by completion status, priority, tag, or due date range. Returns a list of tasks with their details."
|
| 98 |
+
)
|
| 99 |
+
async def list_tasks_tool(
|
| 100 |
+
ctx: RunContextWrapper,
|
| 101 |
+
completed: Optional[bool] = None,
|
| 102 |
+
priority: Optional[str] = None,
|
| 103 |
+
tag: Optional[str] = None,
|
| 104 |
+
due_before: Optional[str] = None,
|
| 105 |
+
due_after: Optional[str] = None,
|
| 106 |
+
) -> str:
|
| 107 |
+
"""List tasks via MCP tool.
|
| 108 |
+
|
| 109 |
+
[From]: specs/010-chatkit-migration/contracts/backend.md - Tool Contracts
|
| 110 |
+
[Task]: T013
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
ctx: Agents SDK run context containing user_id
|
| 114 |
+
completed: Optional filter by completion status
|
| 115 |
+
priority: Optional filter by priority
|
| 116 |
+
tag: Optional filter by tag
|
| 117 |
+
due_before: Optional due date upper bound
|
| 118 |
+
due_after: Optional due date lower bound
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
JSON string with list of tasks
|
| 122 |
+
"""
|
| 123 |
+
user_id = ctx.context.user_id
|
| 124 |
+
|
| 125 |
+
try:
|
| 126 |
+
result = await mcp_list_tasks(
|
| 127 |
+
user_id=user_id,
|
| 128 |
+
completed=completed,
|
| 129 |
+
priority=priority,
|
| 130 |
+
tag=tag,
|
| 131 |
+
due_before=due_before,
|
| 132 |
+
due_after=due_after
|
| 133 |
+
)
|
| 134 |
+
task_count = len(result.get("tasks", []))
|
| 135 |
+
logger.info(f"Listed {task_count} tasks for user {user_id}")
|
| 136 |
+
return json.dumps(result)
|
| 137 |
+
except Exception as e:
|
| 138 |
+
logger.error(f"Failed to list tasks for user {user_id}: {e}")
|
| 139 |
+
return json.dumps({"success": False, "error": str(e), "tasks": []})
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
@function_tool(
|
| 143 |
+
name="update_task",
|
| 144 |
+
description_override="Update an existing task. Parameters: task_id (required), title (optional), description (optional), due_date (optional), priority (optional), tags (optional)."
|
| 145 |
+
)
|
| 146 |
+
async def update_task_tool(
|
| 147 |
+
ctx: RunContextWrapper,
|
| 148 |
+
task_id: str,
|
| 149 |
+
title: Optional[str] = None,
|
| 150 |
+
description: Optional[str] = None,
|
| 151 |
+
due_date: Optional[str] = None,
|
| 152 |
+
priority: Optional[str] = None,
|
| 153 |
+
tags: Optional[list[str]] = None,
|
| 154 |
+
) -> str:
|
| 155 |
+
"""Update a task via MCP tool.
|
| 156 |
+
|
| 157 |
+
[From]: specs/010-chatkit-migration/contracts/backend.md - Tool Contracts
|
| 158 |
+
[Task]: T014
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
ctx: Agents SDK run context containing user_id
|
| 162 |
+
task_id: Task ID to update
|
| 163 |
+
title: Optional new title
|
| 164 |
+
description: Optional new description
|
| 165 |
+
due_date: Optional new due date
|
| 166 |
+
priority: Optional new priority
|
| 167 |
+
tags: Optional new tag list
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
JSON string with updated task details
|
| 171 |
+
"""
|
| 172 |
+
user_id = ctx.context.user_id
|
| 173 |
+
|
| 174 |
+
try:
|
| 175 |
+
result = await mcp_update_task(
|
| 176 |
+
user_id=user_id,
|
| 177 |
+
task_id=task_id,
|
| 178 |
+
title=title,
|
| 179 |
+
description=description,
|
| 180 |
+
due_date=due_date,
|
| 181 |
+
priority=priority,
|
| 182 |
+
tags=tags
|
| 183 |
+
)
|
| 184 |
+
logger.info(f"Task updated: {task_id} for user {user_id}")
|
| 185 |
+
return json.dumps(result)
|
| 186 |
+
except Exception as e:
|
| 187 |
+
logger.error(f"Failed to update task {task_id} for user {user_id}: {e}")
|
| 188 |
+
return json.dumps({"success": False, "error": str(e)})
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
@function_tool(
|
| 192 |
+
name="delete_task",
|
| 193 |
+
description_override="Delete a task permanently. Parameters: task_id (required)."
|
| 194 |
+
)
|
| 195 |
+
async def delete_task_tool(
|
| 196 |
+
ctx: RunContextWrapper,
|
| 197 |
+
task_id: str,
|
| 198 |
+
) -> str:
|
| 199 |
+
"""Delete a task via MCP tool.
|
| 200 |
+
|
| 201 |
+
[From]: specs/010-chatkit-migration/contracts/backend.md - Tool Contracts
|
| 202 |
+
[Task]: T015
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
ctx: Agents SDK run context containing user_id
|
| 206 |
+
task_id: Task ID to delete
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
JSON string with deletion confirmation
|
| 210 |
+
"""
|
| 211 |
+
user_id = ctx.context.user_id
|
| 212 |
+
|
| 213 |
+
try:
|
| 214 |
+
result = await mcp_delete_task(
|
| 215 |
+
user_id=user_id,
|
| 216 |
+
task_id=task_id
|
| 217 |
+
)
|
| 218 |
+
logger.info(f"Task deleted: {task_id} for user {user_id}")
|
| 219 |
+
return json.dumps(result)
|
| 220 |
+
except Exception as e:
|
| 221 |
+
logger.error(f"Failed to delete task {task_id} for user {user_id}: {e}")
|
| 222 |
+
return json.dumps({"success": False, "error": str(e)})
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
@function_tool(
|
| 226 |
+
name="complete_task",
|
| 227 |
+
description_override="Mark a task as completed or incomplete. Parameters: task_id (required), completed (boolean, required)."
|
| 228 |
+
)
|
| 229 |
+
async def complete_task_tool(
|
| 230 |
+
ctx: RunContextWrapper,
|
| 231 |
+
task_id: str,
|
| 232 |
+
completed: bool,
|
| 233 |
+
) -> str:
|
| 234 |
+
"""Complete/uncomplete a task via MCP tool.
|
| 235 |
+
|
| 236 |
+
[From]: specs/010-chatkit-migration/contracts/backend.md - Tool Contracts
|
| 237 |
+
[Task]: T016
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
ctx: Agents SDK run context containing user_id
|
| 241 |
+
task_id: Task ID to toggle
|
| 242 |
+
completed: Whether task is completed (true) or not (false)
|
| 243 |
+
|
| 244 |
+
Returns:
|
| 245 |
+
JSON string with updated task details
|
| 246 |
+
"""
|
| 247 |
+
user_id = ctx.context.user_id
|
| 248 |
+
|
| 249 |
+
try:
|
| 250 |
+
result = await mcp_complete_task(
|
| 251 |
+
user_id=user_id,
|
| 252 |
+
task_id=task_id,
|
| 253 |
+
completed=completed
|
| 254 |
+
)
|
| 255 |
+
logger.info(f"Task completion updated: {task_id} -> {completed} for user {user_id}")
|
| 256 |
+
return json.dumps(result)
|
| 257 |
+
except Exception as e:
|
| 258 |
+
logger.error(f"Failed to update completion for task {task_id} for user {user_id}: {e}")
|
| 259 |
+
return json.dumps({"success": False, "error": str(e)})
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
@function_tool(
|
| 263 |
+
name="complete_all_tasks",
|
| 264 |
+
description_override="Mark all tasks as completed. Parameters: confirm (boolean, required - must be true to execute)."
|
| 265 |
+
)
|
| 266 |
+
async def complete_all_tasks_tool(
|
| 267 |
+
ctx: RunContextWrapper,
|
| 268 |
+
confirm: bool,
|
| 269 |
+
) -> str:
|
| 270 |
+
"""Complete all tasks via MCP tool.
|
| 271 |
+
|
| 272 |
+
[From]: specs/010-chatkit-migration/contracts/backend.md - Tool Contracts
|
| 273 |
+
[Task]: T017
|
| 274 |
+
|
| 275 |
+
Args:
|
| 276 |
+
ctx: Agents SDK run context containing user_id
|
| 277 |
+
confirm: Must be true to execute this destructive operation
|
| 278 |
+
|
| 279 |
+
Returns:
|
| 280 |
+
JSON string with bulk completion results
|
| 281 |
+
"""
|
| 282 |
+
user_id = ctx.context.user_id
|
| 283 |
+
|
| 284 |
+
if not confirm:
|
| 285 |
+
return json.dumps({
|
| 286 |
+
"success": False,
|
| 287 |
+
"error": "Confirmation required. Set confirm=true to complete all tasks."
|
| 288 |
+
})
|
| 289 |
+
|
| 290 |
+
try:
|
| 291 |
+
result = await mcp_complete_all_tasks(
|
| 292 |
+
user_id=user_id
|
| 293 |
+
)
|
| 294 |
+
completed_count = result.get("completed_count", 0)
|
| 295 |
+
logger.info(f"Completed {completed_count} tasks for user {user_id}")
|
| 296 |
+
return json.dumps(result)
|
| 297 |
+
except Exception as e:
|
| 298 |
+
logger.error(f"Failed to complete all tasks for user {user_id}: {e}")
|
| 299 |
+
return json.dumps({"success": False, "error": str(e)})
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
@function_tool(
|
| 303 |
+
name="delete_all_tasks",
|
| 304 |
+
description_override="Delete all tasks permanently. Parameters: confirm (boolean, required - must be true to execute)."
|
| 305 |
+
)
|
| 306 |
+
async def delete_all_tasks_tool(
|
| 307 |
+
ctx: RunContextWrapper,
|
| 308 |
+
confirm: bool,
|
| 309 |
+
) -> str:
|
| 310 |
+
"""Delete all tasks via MCP tool.
|
| 311 |
+
|
| 312 |
+
[From]: specs/010-chatkit-migration/contracts/backend.md - Tool Contracts
|
| 313 |
+
[Task]: T018
|
| 314 |
+
|
| 315 |
+
Args:
|
| 316 |
+
ctx: Agents SDK run context containing user_id
|
| 317 |
+
confirm: Must be true to execute this destructive operation
|
| 318 |
+
|
| 319 |
+
Returns:
|
| 320 |
+
JSON string with bulk deletion results
|
| 321 |
+
"""
|
| 322 |
+
user_id = ctx.context.user_id
|
| 323 |
+
|
| 324 |
+
if not confirm:
|
| 325 |
+
return json.dumps({
|
| 326 |
+
"success": False,
|
| 327 |
+
"error": "Confirmation required. Set confirm=true to delete all tasks."
|
| 328 |
+
})
|
| 329 |
+
|
| 330 |
+
try:
|
| 331 |
+
result = await mcp_delete_all_tasks(
|
| 332 |
+
user_id=user_id
|
| 333 |
+
)
|
| 334 |
+
deleted_count = result.get("deleted_count", 0)
|
| 335 |
+
logger.info(f"Deleted {deleted_count} tasks for user {user_id}")
|
| 336 |
+
return json.dumps(result)
|
| 337 |
+
except Exception as e:
|
| 338 |
+
logger.error(f"Failed to delete all tasks for user {user_id}: {e}")
|
| 339 |
+
return json.dumps({"success": False, "error": str(e)})
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
# =============================================================================
|
| 343 |
+
# Tool List for Agent Configuration
|
| 344 |
+
# =============================================================================
|
| 345 |
+
|
| 346 |
+
# Export all tool functions for easy import
|
| 347 |
+
TOOL_FUNCTIONS = [
|
| 348 |
+
create_task_tool,
|
| 349 |
+
list_tasks_tool,
|
| 350 |
+
update_task_tool,
|
| 351 |
+
delete_task_tool,
|
| 352 |
+
complete_task_tool,
|
| 353 |
+
complete_all_tasks_tool,
|
| 354 |
+
delete_all_tasks_tool,
|
| 355 |
+
]
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def get_tool_names() -> list[str]:
|
| 359 |
+
"""Get list of all tool names.
|
| 360 |
+
|
| 361 |
+
[From]: specs/010-chatkit-migration/tasks.md - T019
|
| 362 |
+
|
| 363 |
+
Returns:
|
| 364 |
+
List of tool function names
|
| 365 |
+
"""
|
| 366 |
+
return [tool.name for tool in TOOL_FUNCTIONS]
|
|
@@ -13,7 +13,7 @@ import logging
|
|
| 13 |
import asyncio
|
| 14 |
from datetime import datetime
|
| 15 |
from typing import Annotated, Optional
|
| 16 |
-
from fastapi import APIRouter, HTTPException, status, Depends,
|
| 17 |
from pydantic import BaseModel, Field, field_validator, ValidationError
|
| 18 |
from sqlmodel import Session
|
| 19 |
from sqlalchemy.exc import SQLAlchemyError
|
|
@@ -24,14 +24,13 @@ from core.security import decode_access_token
|
|
| 24 |
from models.message import Message, MessageRole
|
| 25 |
from services.security import sanitize_message
|
| 26 |
from models.conversation import Conversation
|
| 27 |
-
from ai_agent import
|
| 28 |
from services.conversation import (
|
| 29 |
get_or_create_conversation,
|
| 30 |
load_conversation_history,
|
| 31 |
update_conversation_timestamp
|
| 32 |
)
|
| 33 |
from services.rate_limiter import check_rate_limit
|
| 34 |
-
from ws_manager.manager import manager
|
| 35 |
|
| 36 |
|
| 37 |
# Configure error logger
|
|
@@ -290,12 +289,12 @@ async def chat(
|
|
| 290 |
{"role": "user", "content": sanitized_message}
|
| 291 |
]
|
| 292 |
|
| 293 |
-
# Run AI agent
|
| 294 |
# [From]: T014 - Initialize OpenAI Agents SDK with Gemini
|
| 295 |
-
#
|
| 296 |
# [From]: T060 - Add comprehensive error messages for edge cases
|
| 297 |
try:
|
| 298 |
-
ai_response_text = await
|
| 299 |
messages=messages_for_agent,
|
| 300 |
user_id=user_id
|
| 301 |
)
|
|
@@ -407,72 +406,277 @@ async def chat(
|
|
| 407 |
)
|
| 408 |
|
| 409 |
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
user_id: str,
|
| 414 |
-
db: Session = Depends(get_db)
|
| 415 |
-
):
|
| 416 |
-
"""WebSocket endpoint for real-time chat progress updates.
|
| 417 |
|
| 418 |
-
|
| 419 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
- connection_established: Confirmation of successful connection
|
| 424 |
-
- agent_thinking: AI agent is processing
|
| 425 |
-
- tool_starting: A tool is about to execute
|
| 426 |
-
- tool_progress: Tool execution progress (e.g., "Found 3 tasks")
|
| 427 |
-
- tool_complete: Tool finished successfully
|
| 428 |
-
- tool_error: Tool execution failed
|
| 429 |
-
- agent_done: AI agent finished processing
|
| 430 |
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
|
| 435 |
-
|
| 436 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
db: Database session (for any future DB operations)
|
| 442 |
-
|
| 443 |
-
The connection is kept alive and can receive messages from the client,
|
| 444 |
-
though currently it's primarily used for server-to-client progress updates.
|
| 445 |
-
"""
|
| 446 |
-
# Connect the WebSocket (manager handles accept)
|
| 447 |
-
# [From]: specs/004-ai-chatbot/research.md - Section 4
|
| 448 |
-
await manager.connect(user_id, websocket)
|
| 449 |
|
|
|
|
| 450 |
try:
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
data
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
# Echo back a simple acknowledgment
|
| 463 |
-
# (optional - mainly for debugging)
|
| 464 |
-
pass
|
| 465 |
-
|
| 466 |
-
except WebSocketDisconnect:
|
| 467 |
-
# Normal disconnect - clean up
|
| 468 |
-
manager.disconnect(user_id, websocket)
|
| 469 |
-
error_logger.info(f"WebSocket disconnected normally for user {user_id}")
|
| 470 |
|
|
|
|
|
|
|
|
|
|
| 471 |
except Exception as e:
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
import asyncio
|
| 14 |
from datetime import datetime
|
| 15 |
from typing import Annotated, Optional
|
| 16 |
+
from fastapi import APIRouter, HTTPException, status, Depends, BackgroundTasks, Request
|
| 17 |
from pydantic import BaseModel, Field, field_validator, ValidationError
|
| 18 |
from sqlmodel import Session
|
| 19 |
from sqlalchemy.exc import SQLAlchemyError
|
|
|
|
| 24 |
from models.message import Message, MessageRole
|
| 25 |
from services.security import sanitize_message
|
| 26 |
from models.conversation import Conversation
|
| 27 |
+
from ai_agent import run_agent, is_gemini_configured
|
| 28 |
from services.conversation import (
|
| 29 |
get_or_create_conversation,
|
| 30 |
load_conversation_history,
|
| 31 |
update_conversation_timestamp
|
| 32 |
)
|
| 33 |
from services.rate_limiter import check_rate_limit
|
|
|
|
| 34 |
|
| 35 |
|
| 36 |
# Configure error logger
|
|
|
|
| 289 |
{"role": "user", "content": sanitized_message}
|
| 290 |
]
|
| 291 |
|
| 292 |
+
# Run AI agent (non-streaming for legacy endpoint)
|
| 293 |
# [From]: T014 - Initialize OpenAI Agents SDK with Gemini
|
| 294 |
+
# NOTE: Streaming is now handled by ChatKit SSE endpoint
|
| 295 |
# [From]: T060 - Add comprehensive error messages for edge cases
|
| 296 |
try:
|
| 297 |
+
ai_response_text = await run_agent(
|
| 298 |
messages=messages_for_agent,
|
| 299 |
user_id=user_id
|
| 300 |
)
|
|
|
|
| 406 |
)
|
| 407 |
|
| 408 |
|
| 409 |
+
# ============================================================================
|
| 410 |
+
# ChatKit SSE Endpoint (Phase 010-chatkit-migration)
|
| 411 |
+
# ============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
|
| 413 |
+
@router.post("/chatkit")
|
| 414 |
+
async def chatkit_endpoint(
|
| 415 |
+
request: Request, # Starlette Request object for raw body access
|
| 416 |
+
background_tasks: BackgroundTasks,
|
| 417 |
+
):
|
| 418 |
+
"""ChatKit SSE endpoint for streaming chat with Gemini LLM.
|
| 419 |
+
|
| 420 |
+
[Task]: T011
|
| 421 |
+
[From]: specs/010-chatkit-migration/contracts/backend.md - ChatKit SSE Endpoint
|
| 422 |
+
|
| 423 |
+
This endpoint implements the ChatKit protocol using Server-Sent Events (SSE).
|
| 424 |
+
It replaces the WebSocket-based streaming with a simpler HTTP-based approach.
|
| 425 |
+
|
| 426 |
+
Endpoint: POST /api/chatkit
|
| 427 |
+
Response: Server-Sent Events (text/event-stream)
|
| 428 |
+
|
| 429 |
+
Authentication: JWT via httpOnly cookie (auth_token)
|
| 430 |
+
|
| 431 |
+
Request Body (ChatKit protocol):
|
| 432 |
+
{
|
| 433 |
+
"event": "conversation_item_created",
|
| 434 |
+
"conversation_id": "<thread_uuid>",
|
| 435 |
+
"item": {
|
| 436 |
+
"type": "message",
|
| 437 |
+
"role": "user",
|
| 438 |
+
"content": [{"type": "text", "text": "Your message here"}]
|
| 439 |
+
}
|
| 440 |
+
}
|
| 441 |
+
|
| 442 |
+
SSE Event Types:
|
| 443 |
+
- message_delta: Streaming text content
|
| 444 |
+
- tool_call_created: Tool invocation started
|
| 445 |
+
- tool_call_done: Tool execution completed
|
| 446 |
+
- message_done: Message fully streamed
|
| 447 |
+
- error: Error occurred
|
| 448 |
+
|
| 449 |
+
[From]: specs/010-chatkit-migration/research.md - Section 4
|
| 450 |
+
"""
|
| 451 |
+
from fastapi import Response
|
| 452 |
+
from fastapi.responses import StreamingResponse
|
| 453 |
+
from starlette.requests import Request as StarletteRequest
|
| 454 |
|
| 455 |
+
# Import for authentication
|
| 456 |
+
from core.security import get_current_user_id_from_cookie
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
|
| 458 |
+
# Get authenticated user ID from JWT cookie
|
| 459 |
+
# [From]: specs/010-chatkit-migration/contracts/backend.md - Authentication
|
| 460 |
+
try:
|
| 461 |
+
user_id = await get_current_user_id_from_cookie(request)
|
| 462 |
+
if not user_id:
|
| 463 |
+
# Return error as SSE event
|
| 464 |
+
async def error_stream():
|
| 465 |
+
yield "event: error\n"
|
| 466 |
+
yield 'data: {"detail": "Invalid authentication"}\n\n'
|
| 467 |
+
return StreamingResponse(
|
| 468 |
+
error_stream(),
|
| 469 |
+
media_type="text/event-stream",
|
| 470 |
+
status_code=401
|
| 471 |
+
)
|
| 472 |
+
except Exception as e:
|
| 473 |
+
error_logger.error(f"Auth error in ChatKit endpoint: {e}")
|
| 474 |
+
async def error_stream():
|
| 475 |
+
yield "event: error\n"
|
| 476 |
+
yield f'data: {{"detail": "Authentication failed"}}\n\n'
|
| 477 |
+
return StreamingResponse(
|
| 478 |
+
error_stream(),
|
| 479 |
+
media_type="text/event-stream",
|
| 480 |
+
status_code=401
|
| 481 |
+
)
|
| 482 |
|
| 483 |
+
# Check rate limit before processing
|
| 484 |
+
# [From]: specs/010-chatkit-migration/tasks.md - T020
|
| 485 |
+
# [From]: specs/010-chatkit-migration/spec.md - FR-015
|
| 486 |
+
try:
|
| 487 |
+
from uuid import UUID
|
| 488 |
+
from core.database import engine
|
| 489 |
+
from sqlmodel import Session
|
| 490 |
+
|
| 491 |
+
# Create synchronous session for rate limit check
|
| 492 |
+
with Session(engine) as db:
|
| 493 |
+
allowed, remaining, reset_time = check_rate_limit(db, UUID(user_id))
|
| 494 |
+
|
| 495 |
+
if not allowed:
|
| 496 |
+
# Rate limit exceeded
|
| 497 |
+
async def rate_limit_stream():
|
| 498 |
+
yield "event: error\n"
|
| 499 |
+
import json
|
| 500 |
+
yield f'data: {json.dumps({"detail": "Daily message limit reached", "limit": 100, "resets_at": reset_time.isoformat() if reset_time else None})}\n\n'
|
| 501 |
+
return StreamingResponse(
|
| 502 |
+
rate_limit_stream(),
|
| 503 |
+
media_type="text/event-stream",
|
| 504 |
+
status_code=429
|
| 505 |
+
)
|
| 506 |
+
except HTTPException:
|
| 507 |
+
# Re-raise HTTP exceptions (rate limit errors)
|
| 508 |
+
raise
|
| 509 |
+
except Exception as e:
|
| 510 |
+
# Log unexpected errors but don't block the request
|
| 511 |
+
error_logger.error(f"Rate limit check failed for ChatKit endpoint: {e}")
|
| 512 |
+
# Continue processing - fail open for rate limit errors
|
| 513 |
|
| 514 |
+
# Create ChatKit server with synchronous database operations
|
| 515 |
+
# [From]: specs/010-chatkit-migration/contracts/backend.md - Store Interface Implementation
|
| 516 |
+
import json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 517 |
|
| 518 |
+
# Parse request body
|
| 519 |
try:
|
| 520 |
+
body = await request.body()
|
| 521 |
+
except Exception as e:
|
| 522 |
+
error_logger.error(f"Failed to read ChatKit request body: {e}")
|
| 523 |
+
async def error_stream():
|
| 524 |
+
yield "event: error\n"
|
| 525 |
+
yield f'data: {{"detail": "Invalid request format"}}\n\n'
|
| 526 |
+
return StreamingResponse(
|
| 527 |
+
error_stream(),
|
| 528 |
+
media_type="text/event-stream",
|
| 529 |
+
status_code=400
|
| 530 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 531 |
|
| 532 |
+
# Parse ChatKit protocol request
|
| 533 |
+
try:
|
| 534 |
+
request_data = json.loads(body.decode('utf-8'))
|
| 535 |
except Exception as e:
|
| 536 |
+
error_logger.error(f"Failed to parse ChatKit request: {e}")
|
| 537 |
+
async def error_stream():
|
| 538 |
+
yield "event: error\n"
|
| 539 |
+
yield f'data: {{"detail": "Invalid JSON format"}}\n\n'
|
| 540 |
+
return StreamingResponse(
|
| 541 |
+
error_stream(),
|
| 542 |
+
media_type="text/event-stream",
|
| 543 |
+
status_code=400
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
# Extract thread ID and message content
|
| 547 |
+
conversation_id = request_data.get("conversation_id")
|
| 548 |
+
item = request_data.get("item", {})
|
| 549 |
+
event_type = request_data.get("event", "conversation_item_created")
|
| 550 |
+
|
| 551 |
+
# Extract user message
|
| 552 |
+
def extract_message_content(item_dict):
|
| 553 |
+
content_array = item_dict.get("content", [])
|
| 554 |
+
for content_block in content_array:
|
| 555 |
+
if content_block.get("type") == "text":
|
| 556 |
+
return content_block.get("text", "")
|
| 557 |
+
return ""
|
| 558 |
+
|
| 559 |
+
user_message = extract_message_content(item)
|
| 560 |
+
if not user_message:
|
| 561 |
+
async def error_stream():
|
| 562 |
+
yield "event: error\n"
|
| 563 |
+
yield f'data: {{"detail": "No message content provided"}}\n\n'
|
| 564 |
+
return StreamingResponse(
|
| 565 |
+
error_stream(),
|
| 566 |
+
media_type="text/event-stream",
|
| 567 |
+
status_code=400
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
# Use synchronous database session for thread/message operations
|
| 571 |
+
from core.database import engine
|
| 572 |
+
from sqlmodel import Session
|
| 573 |
+
|
| 574 |
+
# Get or create thread
|
| 575 |
+
thread_id = conversation_id
|
| 576 |
+
if not thread_id:
|
| 577 |
+
# Create new thread for first message
|
| 578 |
+
with Session(engine) as db:
|
| 579 |
+
from models.thread import Thread
|
| 580 |
+
import uuid
|
| 581 |
+
new_thread = Thread(
|
| 582 |
+
user_id=UUID(user_id),
|
| 583 |
+
title=None,
|
| 584 |
+
thread_metadata={}
|
| 585 |
+
)
|
| 586 |
+
db.add(new_thread)
|
| 587 |
+
db.commit()
|
| 588 |
+
db.refresh(new_thread)
|
| 589 |
+
thread_id = str(new_thread.id)
|
| 590 |
+
error_logger.info(f"Created new thread: {thread_id}")
|
| 591 |
+
|
| 592 |
+
# Save user message to database
|
| 593 |
+
with Session(engine) as db:
|
| 594 |
+
from models.message import Message, MessageRole
|
| 595 |
+
import uuid
|
| 596 |
+
user_msg = Message(
|
| 597 |
+
thread_id=UUID(thread_id) if thread_id else None,
|
| 598 |
+
user_id=UUID(user_id),
|
| 599 |
+
role=MessageRole.USER,
|
| 600 |
+
content=user_message,
|
| 601 |
+
)
|
| 602 |
+
db.add(user_msg)
|
| 603 |
+
db.commit()
|
| 604 |
+
|
| 605 |
+
# Run AI agent and stream response
|
| 606 |
+
# [Task]: T033 - Add SSE error handling for connection drops
|
| 607 |
+
async def stream_chat_response():
|
| 608 |
+
"""Stream ChatKit events as SSE with timeout protection.
|
| 609 |
+
|
| 610 |
+
[Task]: T034 - Timeout handling for long-running tool executions
|
| 611 |
+
"""
|
| 612 |
+
try:
|
| 613 |
+
# Import agent components
|
| 614 |
+
from ai_agent import run_agent
|
| 615 |
+
|
| 616 |
+
# Run agent with timeout protection
|
| 617 |
+
async with asyncio.timeout(120):
|
| 618 |
+
ai_response = await run_agent(
|
| 619 |
+
messages=[{"role": "user", "content": user_message}],
|
| 620 |
+
user_id=user_id
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
# Save assistant message to database
|
| 624 |
+
with Session(engine) as db:
|
| 625 |
+
from models.message import Message, MessageRole
|
| 626 |
+
import uuid
|
| 627 |
+
assistant_msg = Message(
|
| 628 |
+
thread_id=UUID(thread_id) if thread_id else None,
|
| 629 |
+
user_id=UUID(user_id),
|
| 630 |
+
role=MessageRole.ASSISTANT,
|
| 631 |
+
content=ai_response,
|
| 632 |
+
)
|
| 633 |
+
db.add(assistant_msg)
|
| 634 |
+
db.commit()
|
| 635 |
+
|
| 636 |
+
# Stream the response
|
| 637 |
+
import json
|
| 638 |
+
yield "event: message_delta\n"
|
| 639 |
+
yield f'data: {json.dumps({"type": "text", "text": ai_response})}\n\n'
|
| 640 |
+
|
| 641 |
+
# Send message done event
|
| 642 |
+
yield "event: message_done\n"
|
| 643 |
+
yield f'data: {json.dumps({"message_id": thread_id, "role": "assistant", "thread_id": thread_id})}\n\n'
|
| 644 |
+
|
| 645 |
+
except TimeoutError:
|
| 646 |
+
error_logger.error(f"Agent execution timeout for thread {thread_id}")
|
| 647 |
+
import json
|
| 648 |
+
yield "event: error\n"
|
| 649 |
+
yield f'data: {json.dumps({"detail": "Request timed out", "message": "The AI assistant took too long to respond. Please try again."})}\n\n'
|
| 650 |
+
except Exception as e:
|
| 651 |
+
error_logger.error(f"Agent execution error: {e}", exc_info=True)
|
| 652 |
+
import json
|
| 653 |
+
yield "event: error\n"
|
| 654 |
+
yield f'data: {json.dumps({"detail": "Processing error", "message": str(e)})}\n\n'
|
| 655 |
|
| 656 |
+
# [Task]: T033 - Wrap with connection-aware streaming
|
| 657 |
+
async def connection_aware_stream():
|
| 658 |
+
"""Stream SSE events with connection drop detection.
|
| 659 |
+
|
| 660 |
+
[Task]: T033 - SSE error handling for connection drops
|
| 661 |
+
"""
|
| 662 |
+
try:
|
| 663 |
+
async for chunk in stream_chat_response():
|
| 664 |
+
yield chunk
|
| 665 |
+
except (ConnectionError, OSError) as e:
|
| 666 |
+
# Client disconnected during streaming
|
| 667 |
+
error_logger.info(f"Client disconnected during ChatKit streaming: {e}")
|
| 668 |
+
except Exception as e:
|
| 669 |
+
# Unexpected streaming error
|
| 670 |
+
error_logger.error(f"Unexpected error during ChatKit streaming: {e}", exc_info=True)
|
| 671 |
+
yield "event: error\n"
|
| 672 |
+
yield f'data: {{"detail": "Streaming error", "message": str(e)}}\n\n'
|
| 673 |
+
|
| 674 |
+
return StreamingResponse(
|
| 675 |
+
connection_aware_stream(),
|
| 676 |
+
media_type="text/event-stream",
|
| 677 |
+
headers={
|
| 678 |
+
"Cache-Control": "no-cache",
|
| 679 |
+
"Connection": "keep-alive",
|
| 680 |
+
"X-Accel-Buffering": "no",
|
| 681 |
+
}
|
| 682 |
+
)
|
|
@@ -1,478 +0,0 @@
|
|
| 1 |
-
"""Chat API endpoint for AI-powered task management.
|
| 2 |
-
|
| 3 |
-
[Task]: T015, T071
|
| 4 |
-
[From]: specs/004-ai-chatbot/tasks.md
|
| 5 |
-
|
| 6 |
-
This endpoint provides a conversational interface for task management.
|
| 7 |
-
Users can create, list, update, complete, and delete tasks through natural language.
|
| 8 |
-
|
| 9 |
-
Also includes WebSocket endpoint for real-time progress streaming.
|
| 10 |
-
"""
|
| 11 |
-
import uuid
|
| 12 |
-
import logging
|
| 13 |
-
import asyncio
|
| 14 |
-
from datetime import datetime
|
| 15 |
-
from typing import Annotated, Optional
|
| 16 |
-
from fastapi import APIRouter, HTTPException, status, Depends, WebSocket, WebSocketDisconnect, BackgroundTasks
|
| 17 |
-
from pydantic import BaseModel, Field, field_validator, ValidationError
|
| 18 |
-
from sqlmodel import Session
|
| 19 |
-
from sqlalchemy.exc import SQLAlchemyError
|
| 20 |
-
|
| 21 |
-
from core.database import get_db
|
| 22 |
-
from core.validators import validate_message_length
|
| 23 |
-
from core.security import decode_access_token
|
| 24 |
-
from models.message import Message, MessageRole
|
| 25 |
-
from services.security import sanitize_message
|
| 26 |
-
from models.conversation import Conversation
|
| 27 |
-
from ai_agent import run_agent_with_streaming, is_gemini_configured
|
| 28 |
-
from services.conversation import (
|
| 29 |
-
get_or_create_conversation,
|
| 30 |
-
load_conversation_history,
|
| 31 |
-
update_conversation_timestamp
|
| 32 |
-
)
|
| 33 |
-
from services.rate_limiter import check_rate_limit
|
| 34 |
-
from ws_manager.manager import manager
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
# Configure error logger
|
| 38 |
-
error_logger = logging.getLogger("api.errors")
|
| 39 |
-
error_logger.setLevel(logging.ERROR)
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
# Request/Response models
|
| 43 |
-
class ChatRequest(BaseModel):
|
| 44 |
-
"""Request model for chat endpoint.
|
| 45 |
-
|
| 46 |
-
[From]: specs/004-ai-chatbot/plan.md - API Contract
|
| 47 |
-
"""
|
| 48 |
-
message: str = Field(
|
| 49 |
-
...,
|
| 50 |
-
description="User message content",
|
| 51 |
-
min_length=1,
|
| 52 |
-
max_length=10000 # FR-042
|
| 53 |
-
)
|
| 54 |
-
conversation_id: Optional[str] = Field(
|
| 55 |
-
None,
|
| 56 |
-
description="Optional conversation ID to continue existing conversation"
|
| 57 |
-
)
|
| 58 |
-
|
| 59 |
-
@field_validator('message')
|
| 60 |
-
@classmethod
|
| 61 |
-
def validate_message(cls, v: str) -> str:
|
| 62 |
-
"""Validate message content."""
|
| 63 |
-
if not v or not v.strip():
|
| 64 |
-
raise ValueError("Message content cannot be empty")
|
| 65 |
-
if len(v) > 10000:
|
| 66 |
-
raise ValueError("Message content exceeds maximum length of 10,000 characters")
|
| 67 |
-
return v.strip()
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
class TaskReference(BaseModel):
|
| 71 |
-
"""Reference to a task created or modified by AI."""
|
| 72 |
-
id: str
|
| 73 |
-
title: str
|
| 74 |
-
description: Optional[str] = None
|
| 75 |
-
due_date: Optional[str] = None
|
| 76 |
-
priority: Optional[str] = None
|
| 77 |
-
completed: bool = False
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
class ChatResponse(BaseModel):
|
| 81 |
-
"""Response model for chat endpoint.
|
| 82 |
-
|
| 83 |
-
[From]: specs/004-ai-chatbot/plan.md - API Contract
|
| 84 |
-
"""
|
| 85 |
-
response: str = Field(
|
| 86 |
-
...,
|
| 87 |
-
description="AI assistant's text response"
|
| 88 |
-
)
|
| 89 |
-
conversation_id: str = Field(
|
| 90 |
-
...,
|
| 91 |
-
description="Conversation ID (new or existing)"
|
| 92 |
-
)
|
| 93 |
-
tasks: list[TaskReference] = Field(
|
| 94 |
-
default_factory=list,
|
| 95 |
-
description="List of tasks created or modified in this interaction"
|
| 96 |
-
)
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
# Create API router
|
| 100 |
-
router = APIRouter(prefix="/api", tags=["chat"])
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
@router.post("/{user_id}/chat", response_model=ChatResponse, status_code=status.HTTP_200_OK)
|
| 104 |
-
async def chat(
|
| 105 |
-
user_id: str,
|
| 106 |
-
request: ChatRequest,
|
| 107 |
-
background_tasks: BackgroundTasks,
|
| 108 |
-
db: Session = Depends(get_db)
|
| 109 |
-
):
|
| 110 |
-
"""Process user message through AI agent and return response.
|
| 111 |
-
|
| 112 |
-
[From]: specs/004-ai-chatbot/spec.md - US1
|
| 113 |
-
|
| 114 |
-
This endpoint:
|
| 115 |
-
1. Validates user input and rate limits
|
| 116 |
-
2. Gets or creates conversation
|
| 117 |
-
3. Runs AI agent with WebSocket progress streaming
|
| 118 |
-
4. Returns AI response immediately
|
| 119 |
-
5. Saves messages to DB in background (non-blocking)
|
| 120 |
-
|
| 121 |
-
Args:
|
| 122 |
-
user_id: User ID (UUID string from path)
|
| 123 |
-
request: Chat request with message and optional conversation_id
|
| 124 |
-
background_tasks: FastAPI background tasks for non-blocking DB saves
|
| 125 |
-
db: Database session
|
| 126 |
-
|
| 127 |
-
Returns:
|
| 128 |
-
ChatResponse with AI response, conversation_id, and task references
|
| 129 |
-
|
| 130 |
-
Raises:
|
| 131 |
-
HTTPException 400: Invalid message content
|
| 132 |
-
HTTPException 503: AI service unavailable
|
| 133 |
-
"""
|
| 134 |
-
# Check if Gemini API is configured
|
| 135 |
-
# [From]: specs/004-ai-chatbot/tasks.md - T022
|
| 136 |
-
# [From]: T060 - Add comprehensive error messages for edge cases
|
| 137 |
-
if not is_gemini_configured():
|
| 138 |
-
raise HTTPException(
|
| 139 |
-
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 140 |
-
detail={
|
| 141 |
-
"error": "AI service unavailable",
|
| 142 |
-
"message": "The AI service is currently not configured. Please ensure GEMINI_API_KEY is set in the environment.",
|
| 143 |
-
"suggestion": "Contact your administrator or check your API key configuration."
|
| 144 |
-
}
|
| 145 |
-
)
|
| 146 |
-
|
| 147 |
-
# Validate user_id format
|
| 148 |
-
# [From]: T060 - Add comprehensive error messages for edge cases
|
| 149 |
-
try:
|
| 150 |
-
user_uuid = uuid.UUID(user_id)
|
| 151 |
-
except ValueError:
|
| 152 |
-
raise HTTPException(
|
| 153 |
-
status_code=status.HTTP_400_BAD_REQUEST,
|
| 154 |
-
detail={
|
| 155 |
-
"error": "Invalid user ID",
|
| 156 |
-
"message": f"User ID '{user_id}' is not a valid UUID format.",
|
| 157 |
-
"expected_format": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx",
|
| 158 |
-
"suggestion": "Ensure you are using a valid UUID for the user_id path parameter."
|
| 159 |
-
}
|
| 160 |
-
)
|
| 161 |
-
|
| 162 |
-
# Validate message content
|
| 163 |
-
# [From]: T060 - Add comprehensive error messages for edge cases
|
| 164 |
-
try:
|
| 165 |
-
validated_message = validate_message_length(request.message)
|
| 166 |
-
except ValueError as e:
|
| 167 |
-
raise HTTPException(
|
| 168 |
-
status_code=status.HTTP_400_BAD_REQUEST,
|
| 169 |
-
detail={
|
| 170 |
-
"error": "Message validation failed",
|
| 171 |
-
"message": str(e),
|
| 172 |
-
"max_length": 10000,
|
| 173 |
-
"suggestion": "Keep your message under 10,000 characters and ensure it contains meaningful content."
|
| 174 |
-
}
|
| 175 |
-
)
|
| 176 |
-
|
| 177 |
-
# Sanitize message to prevent prompt injection
|
| 178 |
-
# [From]: T057 - Implement prompt injection sanitization
|
| 179 |
-
# [From]: T060 - Add comprehensive error messages for edge cases
|
| 180 |
-
try:
|
| 181 |
-
sanitized_message = sanitize_message(validated_message)
|
| 182 |
-
except ValueError as e:
|
| 183 |
-
raise HTTPException(
|
| 184 |
-
status_code=status.HTTP_400_BAD_REQUEST,
|
| 185 |
-
detail={
|
| 186 |
-
"error": "Message content blocked",
|
| 187 |
-
"message": str(e),
|
| 188 |
-
"suggestion": "Please rephrase your message without attempting to manipulate system instructions."
|
| 189 |
-
}
|
| 190 |
-
)
|
| 191 |
-
|
| 192 |
-
# Check rate limit
|
| 193 |
-
# [From]: specs/004-ai-chatbot/spec.md - NFR-011
|
| 194 |
-
# [From]: T021 - Implement daily message limit enforcement (100/day)
|
| 195 |
-
# [From]: T060 - Add comprehensive error messages for edge cases
|
| 196 |
-
try:
|
| 197 |
-
allowed, remaining, reset_time = check_rate_limit(db, user_uuid)
|
| 198 |
-
|
| 199 |
-
if not allowed:
|
| 200 |
-
raise HTTPException(
|
| 201 |
-
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
| 202 |
-
detail={
|
| 203 |
-
"error": "Rate limit exceeded",
|
| 204 |
-
"message": "You have reached the daily message limit. Please try again later.",
|
| 205 |
-
"limit": 100,
|
| 206 |
-
"resets_at": reset_time.isoformat() if reset_time else None,
|
| 207 |
-
"suggestion": "Free tier accounts are limited to 100 messages per day. Upgrade for unlimited access."
|
| 208 |
-
}
|
| 209 |
-
)
|
| 210 |
-
except HTTPException:
|
| 211 |
-
# Re-raise HTTP exceptions (rate limit errors)
|
| 212 |
-
raise
|
| 213 |
-
except Exception as e:
|
| 214 |
-
# Log unexpected errors but don't block the request
|
| 215 |
-
error_logger.error(f"Rate limit check failed for user {user_id}: {e}")
|
| 216 |
-
# Continue processing - fail open for rate limit errors
|
| 217 |
-
|
| 218 |
-
# Get or create conversation
|
| 219 |
-
# [From]: T016 - Implement conversation history loading
|
| 220 |
-
# [From]: T035 - Handle auto-deleted conversations gracefully
|
| 221 |
-
# [From]: T060 - Add comprehensive error messages for edge cases
|
| 222 |
-
conversation_id: uuid.UUID
|
| 223 |
-
|
| 224 |
-
if request.conversation_id:
|
| 225 |
-
# Load existing conversation using service
|
| 226 |
-
try:
|
| 227 |
-
conv_uuid = uuid.UUID(request.conversation_id)
|
| 228 |
-
except ValueError:
|
| 229 |
-
# Invalid conversation_id format
|
| 230 |
-
raise HTTPException(
|
| 231 |
-
status_code=status.HTTP_400_BAD_REQUEST,
|
| 232 |
-
detail={
|
| 233 |
-
"error": "Invalid conversation ID",
|
| 234 |
-
"message": f"Conversation ID '{request.conversation_id}' is not a valid UUID format.",
|
| 235 |
-
"suggestion": "Provide a valid UUID or omit the conversation_id to start a new conversation."
|
| 236 |
-
}
|
| 237 |
-
)
|
| 238 |
-
|
| 239 |
-
try:
|
| 240 |
-
conversation = get_or_create_conversation(
|
| 241 |
-
db=db,
|
| 242 |
-
user_id=user_uuid,
|
| 243 |
-
conversation_id=conv_uuid
|
| 244 |
-
)
|
| 245 |
-
conversation_id = conversation.id
|
| 246 |
-
except (KeyError, ValueError) as e:
|
| 247 |
-
# Conversation may have been auto-deleted (90-day policy) or otherwise not found
|
| 248 |
-
# [From]: T035 - Handle auto-deleted conversations gracefully
|
| 249 |
-
# Create a new conversation instead of failing
|
| 250 |
-
conversation = get_or_create_conversation(
|
| 251 |
-
db=db,
|
| 252 |
-
user_id=user_uuid
|
| 253 |
-
)
|
| 254 |
-
conversation_id = conversation.id
|
| 255 |
-
else:
|
| 256 |
-
# Create new conversation using service
|
| 257 |
-
conversation = get_or_create_conversation(
|
| 258 |
-
db=db,
|
| 259 |
-
user_id=user_uuid
|
| 260 |
-
)
|
| 261 |
-
conversation_id = conversation.id
|
| 262 |
-
|
| 263 |
-
# Load conversation history using service
|
| 264 |
-
# [From]: T016 - Implement conversation history loading
|
| 265 |
-
# [From]: T060 - Add comprehensive error messages for edge cases
|
| 266 |
-
try:
|
| 267 |
-
conversation_history = load_conversation_history(
|
| 268 |
-
db=db,
|
| 269 |
-
conversation_id=conversation_id
|
| 270 |
-
)
|
| 271 |
-
except SQLAlchemyError as e:
|
| 272 |
-
error_logger.error(f"Database error loading conversation history for {conversation_id}: {e}")
|
| 273 |
-
# Continue with empty history if load fails
|
| 274 |
-
conversation_history = []
|
| 275 |
-
|
| 276 |
-
# Prepare user message for background save
|
| 277 |
-
user_message_id = uuid.uuid4()
|
| 278 |
-
user_message_data = {
|
| 279 |
-
"id": user_message_id,
|
| 280 |
-
"conversation_id": conversation_id,
|
| 281 |
-
"user_id": user_uuid,
|
| 282 |
-
"role": MessageRole.USER,
|
| 283 |
-
"content": sanitized_message,
|
| 284 |
-
"created_at": datetime.utcnow()
|
| 285 |
-
}
|
| 286 |
-
|
| 287 |
-
# Add current user message to conversation history for AI processing
|
| 288 |
-
# This is critical - the agent needs the user's current message in context
|
| 289 |
-
messages_for_agent = conversation_history + [
|
| 290 |
-
{"role": "user", "content": sanitized_message}
|
| 291 |
-
]
|
| 292 |
-
|
| 293 |
-
# Run AI agent with streaming (broadcasts WebSocket events)
|
| 294 |
-
# [From]: T014 - Initialize OpenAI Agents SDK with Gemini
|
| 295 |
-
# [From]: T072 - Use streaming agent for real-time progress
|
| 296 |
-
# [From]: T060 - Add comprehensive error messages for edge cases
|
| 297 |
-
try:
|
| 298 |
-
ai_response_text = await run_agent_with_streaming(
|
| 299 |
-
messages=messages_for_agent,
|
| 300 |
-
user_id=user_id
|
| 301 |
-
)
|
| 302 |
-
except ValueError as e:
|
| 303 |
-
# Configuration errors (missing API key, invalid model)
|
| 304 |
-
# [From]: T022 - Add error handling for Gemini API unavailability
|
| 305 |
-
error_logger.error(f"AI configuration error for user {user_id}: {e}")
|
| 306 |
-
raise HTTPException(
|
| 307 |
-
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 308 |
-
detail={
|
| 309 |
-
"error": "AI service configuration error",
|
| 310 |
-
"message": str(e),
|
| 311 |
-
"suggestion": "Verify GEMINI_API_KEY and GEMINI_MODEL are correctly configured."
|
| 312 |
-
}
|
| 313 |
-
)
|
| 314 |
-
except ConnectionError as e:
|
| 315 |
-
# Network/connection issues
|
| 316 |
-
# [From]: T022 - Add error handling for Gemini API unavailability
|
| 317 |
-
error_logger.error(f"AI connection error for user {user_id}: {e}")
|
| 318 |
-
raise HTTPException(
|
| 319 |
-
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 320 |
-
detail={
|
| 321 |
-
"error": "AI service unreachable",
|
| 322 |
-
"message": "Could not connect to the AI service. Please check your network connection.",
|
| 323 |
-
"suggestion": "If the problem persists, the AI service may be temporarily down."
|
| 324 |
-
}
|
| 325 |
-
)
|
| 326 |
-
except TimeoutError as e:
|
| 327 |
-
# Timeout errors
|
| 328 |
-
# [From]: T022 - Add error handling for Gemini API unavailability
|
| 329 |
-
error_logger.error(f"AI timeout error for user {user_id}: {e}")
|
| 330 |
-
raise HTTPException(
|
| 331 |
-
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
|
| 332 |
-
detail={
|
| 333 |
-
"error": "AI service timeout",
|
| 334 |
-
"message": "The AI service took too long to respond. Please try again.",
|
| 335 |
-
"suggestion": "Your message may be too complex. Try breaking it into smaller requests."
|
| 336 |
-
}
|
| 337 |
-
)
|
| 338 |
-
except Exception as e:
|
| 339 |
-
# Other errors (rate limits, authentication, context, etc.)
|
| 340 |
-
# [From]: T022 - Add error handling for Gemini API unavailability
|
| 341 |
-
error_logger.error(f"Unexpected AI error for user {user_id}: {type(e).__name__}: {e}")
|
| 342 |
-
raise HTTPException(
|
| 343 |
-
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 344 |
-
detail={
|
| 345 |
-
"error": "AI service error",
|
| 346 |
-
"message": f"An unexpected error occurred: {str(e)}",
|
| 347 |
-
"suggestion": "Please try again later or contact support if the problem persists."
|
| 348 |
-
}
|
| 349 |
-
)
|
| 350 |
-
|
| 351 |
-
# Prepare AI response for background save
|
| 352 |
-
ai_message_data = {
|
| 353 |
-
"id": uuid.uuid4(),
|
| 354 |
-
"conversation_id": conversation_id,
|
| 355 |
-
"user_id": user_uuid,
|
| 356 |
-
"role": MessageRole.ASSISTANT,
|
| 357 |
-
"content": ai_response_text,
|
| 358 |
-
"created_at": datetime.utcnow()
|
| 359 |
-
}
|
| 360 |
-
|
| 361 |
-
# Save messages to DB in background (non-blocking)
|
| 362 |
-
# This significantly improves response time
|
| 363 |
-
def save_messages_to_db():
|
| 364 |
-
"""Background task to save messages to database."""
|
| 365 |
-
try:
|
| 366 |
-
from core.database import engine
|
| 367 |
-
from sqlmodel import Session
|
| 368 |
-
|
| 369 |
-
# Create a new session for background task
|
| 370 |
-
bg_db = Session(engine)
|
| 371 |
-
|
| 372 |
-
try:
|
| 373 |
-
# Save user message
|
| 374 |
-
user_msg = Message(**user_message_data)
|
| 375 |
-
bg_db.add(user_msg)
|
| 376 |
-
|
| 377 |
-
# Save AI response
|
| 378 |
-
ai_msg = Message(**ai_message_data)
|
| 379 |
-
bg_db.add(ai_msg)
|
| 380 |
-
|
| 381 |
-
bg_db.commit()
|
| 382 |
-
|
| 383 |
-
# Update conversation timestamp
|
| 384 |
-
try:
|
| 385 |
-
update_conversation_timestamp(db=bg_db, conversation_id=conversation_id)
|
| 386 |
-
except SQLAlchemyError as e:
|
| 387 |
-
error_logger.error(f"Database error updating conversation timestamp for {conversation_id}: {e}")
|
| 388 |
-
|
| 389 |
-
except SQLAlchemyError as e:
|
| 390 |
-
error_logger.error(f"Background task: Database error saving messages for user {user_id}: {e}")
|
| 391 |
-
bg_db.rollback()
|
| 392 |
-
finally:
|
| 393 |
-
bg_db.close()
|
| 394 |
-
except Exception as e:
|
| 395 |
-
error_logger.error(f"Background task: Unexpected error saving messages for user {user_id}: {e}")
|
| 396 |
-
|
| 397 |
-
background_tasks.add_task(save_messages_to_db)
|
| 398 |
-
|
| 399 |
-
# TODO: Parse AI response for task references
|
| 400 |
-
# This will be enhanced in future tasks to extract task IDs from AI responses
|
| 401 |
-
task_references: list[TaskReference] = []
|
| 402 |
-
|
| 403 |
-
return ChatResponse(
|
| 404 |
-
response=ai_response_text,
|
| 405 |
-
conversation_id=str(conversation_id),
|
| 406 |
-
tasks=task_references
|
| 407 |
-
)
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
@router.websocket("/ws/{user_id}/chat")
|
| 411 |
-
async def websocket_chat(
|
| 412 |
-
websocket: WebSocket,
|
| 413 |
-
user_id: str,
|
| 414 |
-
db: Session = Depends(get_db)
|
| 415 |
-
):
|
| 416 |
-
"""WebSocket endpoint for real-time chat progress updates.
|
| 417 |
-
|
| 418 |
-
[From]: specs/004-ai-chatbot/research.md - Section 4
|
| 419 |
-
[Task]: T071
|
| 420 |
-
|
| 421 |
-
This endpoint provides a WebSocket connection for receiving real-time
|
| 422 |
-
progress events during AI agent execution. Events include:
|
| 423 |
-
- connection_established: Confirmation of successful connection
|
| 424 |
-
- agent_thinking: AI agent is processing
|
| 425 |
-
- tool_starting: A tool is about to execute
|
| 426 |
-
- tool_progress: Tool execution progress (e.g., "Found 3 tasks")
|
| 427 |
-
- tool_complete: Tool finished successfully
|
| 428 |
-
- tool_error: Tool execution failed
|
| 429 |
-
- agent_done: AI agent finished processing
|
| 430 |
-
|
| 431 |
-
Note: Authentication is handled implicitly by the frontend - users must
|
| 432 |
-
be logged in to access the chat page. The WebSocket only broadcasts
|
| 433 |
-
progress updates (not sensitive data), so strict auth is bypassed here.
|
| 434 |
-
|
| 435 |
-
Connection URL format:
|
| 436 |
-
ws://localhost:8000/ws/{user_id}/chat
|
| 437 |
-
|
| 438 |
-
Args:
|
| 439 |
-
websocket: The WebSocket connection instance
|
| 440 |
-
user_id: User ID from URL path (used to route progress events)
|
| 441 |
-
db: Database session (for any future DB operations)
|
| 442 |
-
|
| 443 |
-
The connection is kept alive and can receive messages from the client,
|
| 444 |
-
though currently it's primarily used for server-to-client progress updates.
|
| 445 |
-
"""
|
| 446 |
-
# Connect the WebSocket (manager handles accept)
|
| 447 |
-
# [From]: specs/004-ai-chatbot/research.md - Section 4
|
| 448 |
-
await manager.connect(user_id, websocket)
|
| 449 |
-
|
| 450 |
-
try:
|
| 451 |
-
# Keep connection alive and listen for client messages
|
| 452 |
-
# Currently, we don't expect many client messages, but we
|
| 453 |
-
# maintain the connection to receive any control messages
|
| 454 |
-
while True:
|
| 455 |
-
# Wait for message from client (with timeout)
|
| 456 |
-
data = await websocket.receive_text()
|
| 457 |
-
|
| 458 |
-
# Handle client messages if needed
|
| 459 |
-
# For now, we just acknowledge receipt
|
| 460 |
-
# Future: could handle ping/pong for connection health
|
| 461 |
-
if data:
|
| 462 |
-
# Echo back a simple acknowledgment
|
| 463 |
-
# (optional - mainly for debugging)
|
| 464 |
-
pass
|
| 465 |
-
|
| 466 |
-
except WebSocketDisconnect:
|
| 467 |
-
# Normal disconnect - clean up
|
| 468 |
-
manager.disconnect(user_id, websocket)
|
| 469 |
-
error_logger.info(f"WebSocket disconnected normally for user {user_id}")
|
| 470 |
-
|
| 471 |
-
except Exception as e:
|
| 472 |
-
# Unexpected error - clean up and log
|
| 473 |
-
error_logger.error(f"WebSocket error for user {user_id}: {e}")
|
| 474 |
-
manager.disconnect(user_id, websocket)
|
| 475 |
-
|
| 476 |
-
finally:
|
| 477 |
-
# Ensure disconnect is always called
|
| 478 |
-
manager.disconnect(user_id, websocket)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ChatKit Server implementation for task management with Gemini LLM.
|
| 2 |
+
|
| 3 |
+
[Task]: T010
|
| 4 |
+
[From]: specs/010-chatkit-migration/contracts/backend.md - ChatKitServer Implementation
|
| 5 |
+
[From]: specs/010-chatkit-migration/research.md - Section 3
|
| 6 |
+
|
| 7 |
+
This module implements the ChatKitServer class which handles ChatKit protocol
|
| 8 |
+
requests and streams responses using Server-Sent Events (SSE).
|
| 9 |
+
|
| 10 |
+
The server integrates:
|
| 11 |
+
- ChatKit Python SDK for protocol handling
|
| 12 |
+
- OpenAI Agents SDK for agent orchestration
|
| 13 |
+
- Gemini LLM via OpenAI-compatible endpoint
|
| 14 |
+
- MCP tools wrapped as Agents SDK functions
|
| 15 |
+
|
| 16 |
+
Architecture:
|
| 17 |
+
Frontend (ChatKit.js)
|
| 18 |
+
↓ SSE with custom fetch
|
| 19 |
+
ChatKitServer (this module)
|
| 20 |
+
↓ Agents SDK
|
| 21 |
+
Gemini API (via AsyncOpenAI with base_url)
|
| 22 |
+
"""
|
| 23 |
+
import asyncio
|
| 24 |
+
import logging
|
| 25 |
+
from typing import Any, AsyncIterator, Optional
|
| 26 |
+
from uuid import UUID
|
| 27 |
+
from openai import AsyncOpenAI
|
| 28 |
+
from agents import Agent, set_default_openai_client, RunContextWrapper, Runner
|
| 29 |
+
|
| 30 |
+
from core.config import get_gemini_client, get_settings
|
| 31 |
+
from services.chatkit_store import PostgresChatKitStore
|
| 32 |
+
|
| 33 |
+
logger = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class AgentContext:
|
| 37 |
+
"""Context object passed to agent during execution.
|
| 38 |
+
|
| 39 |
+
Contains:
|
| 40 |
+
- thread_id: Current thread/conversation ID
|
| 41 |
+
- user_id: Authenticated user ID
|
| 42 |
+
- store: Database store for persistence
|
| 43 |
+
- request_context: Additional request metadata
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
thread_id: str,
|
| 49 |
+
user_id: str,
|
| 50 |
+
store: PostgresChatKitStore,
|
| 51 |
+
request_context: Optional[dict] = None,
|
| 52 |
+
):
|
| 53 |
+
self.thread_id = thread_id
|
| 54 |
+
self.user_id = user_id
|
| 55 |
+
self.store = store
|
| 56 |
+
self.request_context = request_context or {}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class TaskManagerChatKitServer:
|
| 60 |
+
"""ChatKit Server for task management with Gemini LLM.
|
| 61 |
+
|
| 62 |
+
[From]: specs/010-chatkit-migration/contracts/backend.md - ChatKitServer Implementation
|
| 63 |
+
|
| 64 |
+
This server extends the ChatKit protocol to work with:
|
| 65 |
+
- Custom authentication (JWT cookies)
|
| 66 |
+
- Gemini LLM via OpenAI-compatible endpoint
|
| 67 |
+
- Server-side tool execution (MCP tools)
|
| 68 |
+
- PostgreSQL thread/message persistence
|
| 69 |
+
|
| 70 |
+
Usage:
|
| 71 |
+
from fastapi import FastAPI, Request
|
| 72 |
+
from chatkit_server import TaskManagerChatKitServer
|
| 73 |
+
|
| 74 |
+
app = FastAPI()
|
| 75 |
+
server = TaskManagerChatKitServer(store=postgres_store)
|
| 76 |
+
|
| 77 |
+
@app.post("/api/chatkit")
|
| 78 |
+
async def chatkit_endpoint(request: Request):
|
| 79 |
+
body = await request.body()
|
| 80 |
+
user_id = get_current_user_id(request)
|
| 81 |
+
result = await server.process(body, {"user_id": user_id})
|
| 82 |
+
if hasattr(result, '__aiter__'):
|
| 83 |
+
return StreamingResponse(result, media_type="text/event-stream")
|
| 84 |
+
return Response(content=result.json, media_type="application/json")
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(self, store: PostgresChatKitStore):
|
| 88 |
+
"""Initialize the ChatKit server.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
store: PostgresChatKitStore instance for persistence
|
| 92 |
+
"""
|
| 93 |
+
self.store = store
|
| 94 |
+
|
| 95 |
+
# Configure Gemini client as default for Agents SDK
|
| 96 |
+
# [From]: specs/010-chatkit-migration/research.md - Section 2
|
| 97 |
+
try:
|
| 98 |
+
gemini_client = get_gemini_client()
|
| 99 |
+
set_default_openai_client(gemini_client)
|
| 100 |
+
logger.info("Gemini client configured for Agents SDK")
|
| 101 |
+
except Exception as e:
|
| 102 |
+
logger.warning(f"Gemini client not configured: {e}")
|
| 103 |
+
|
| 104 |
+
# Assistant agent will be configured after tools are wrapped
|
| 105 |
+
# This placeholder will be replaced in T019 with actual tools
|
| 106 |
+
self.assistant_agent: Optional[Agent] = None
|
| 107 |
+
|
| 108 |
+
def set_agent(self, agent: Agent) -> None:
|
| 109 |
+
"""Set the assistant agent with tools.
|
| 110 |
+
|
| 111 |
+
[From]: specs/010-chatkit-migration/tasks.md - T019
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
agent: Configured Agent with tools and instructions
|
| 115 |
+
"""
|
| 116 |
+
self.assistant_agent = agent
|
| 117 |
+
logger.info(f"Agent configured: {agent.name} with model {agent.model}")
|
| 118 |
+
|
| 119 |
+
async def process(
|
| 120 |
+
self,
|
| 121 |
+
body: bytes,
|
| 122 |
+
context: dict[str, Any]
|
| 123 |
+
) -> Any:
|
| 124 |
+
"""Process ChatKit request and return streaming or non-streaming result.
|
| 125 |
+
|
| 126 |
+
[From]: specs/010-chatkit-migration/contracts/backend.md - ChatKit SSE Endpoint
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
body: Raw request body bytes from ChatKit.js
|
| 130 |
+
context: Request context containing user_id and auth info
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
StreamingResult for SSE responses or dict for JSON responses
|
| 134 |
+
|
| 135 |
+
Note: This is a placeholder implementation. The actual implementation
|
| 136 |
+
would use the ChatKit Python SDK's process() method which handles
|
| 137 |
+
protocol parsing, event routing, and response formatting.
|
| 138 |
+
"""
|
| 139 |
+
import json
|
| 140 |
+
from fastapi.responses import StreamingResponse
|
| 141 |
+
|
| 142 |
+
# Parse ChatKit protocol request
|
| 143 |
+
try:
|
| 144 |
+
request_data = json.loads(body.decode('utf-8'))
|
| 145 |
+
except Exception as e:
|
| 146 |
+
logger.error(f"Failed to parse ChatKit request: {e}")
|
| 147 |
+
return {"error": "Invalid request format"}
|
| 148 |
+
|
| 149 |
+
# Extract thread ID and message content
|
| 150 |
+
conversation_id = request_data.get("conversation_id")
|
| 151 |
+
item = request_data.get("item", {})
|
| 152 |
+
event_type = request_data.get("event", "conversation_item_created")
|
| 153 |
+
|
| 154 |
+
logger.info(f"ChatKit request: event={event_type}, conversation_id={conversation_id}")
|
| 155 |
+
|
| 156 |
+
# Get or create thread
|
| 157 |
+
thread_id = conversation_id
|
| 158 |
+
if not thread_id:
|
| 159 |
+
# Create new thread for first message
|
| 160 |
+
user_id = context.get("user_id")
|
| 161 |
+
if not user_id:
|
| 162 |
+
return {"error": "Unauthorized: no user_id in context"}
|
| 163 |
+
|
| 164 |
+
thread_meta = await self.store.create_thread(
|
| 165 |
+
user_id=user_id,
|
| 166 |
+
title=None,
|
| 167 |
+
metadata={}
|
| 168 |
+
)
|
| 169 |
+
thread_id = thread_meta["id"]
|
| 170 |
+
logger.info(f"Created new thread: {thread_id}")
|
| 171 |
+
|
| 172 |
+
# Extract user message
|
| 173 |
+
user_message = self._extract_message_content(item)
|
| 174 |
+
if not user_message:
|
| 175 |
+
return {"error": "No message content provided"}
|
| 176 |
+
|
| 177 |
+
# Build agent context
|
| 178 |
+
user_id = context.get("user_id")
|
| 179 |
+
agent_context = AgentContext(
|
| 180 |
+
thread_id=thread_id,
|
| 181 |
+
user_id=user_id,
|
| 182 |
+
store=self.store,
|
| 183 |
+
request_context=context,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# Create user message in database
|
| 187 |
+
await self.store.create_message(
|
| 188 |
+
thread_id=thread_id,
|
| 189 |
+
item={
|
| 190 |
+
"type": "message",
|
| 191 |
+
"role": "user",
|
| 192 |
+
"content": [{"type": "text", "text": user_message}],
|
| 193 |
+
}
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# Stream agent response
|
| 197 |
+
# [From]: specs/010-chatkit-migration/research.md - Section 3
|
| 198 |
+
# [Task]: T034 - Add timeout handling for long-running tool executions
|
| 199 |
+
async def stream_response():
|
| 200 |
+
"""Stream ChatKit events as SSE with timeout protection.
|
| 201 |
+
|
| 202 |
+
[Task]: T034 - Timeout handling for long-running tool executions
|
| 203 |
+
|
| 204 |
+
Implements a 120-second timeout for agent execution to prevent
|
| 205 |
+
indefinite hangs from slow tools or network issues.
|
| 206 |
+
"""
|
| 207 |
+
if not self.assistant_agent:
|
| 208 |
+
yield self._sse_event("error", {"message": "Agent not configured"})
|
| 209 |
+
return
|
| 210 |
+
|
| 211 |
+
try:
|
| 212 |
+
# Run agent with streaming and timeout
|
| 213 |
+
# [Task]: T034 - 120 second timeout for entire agent execution
|
| 214 |
+
# This covers LLM calls, tool executions, and any delays
|
| 215 |
+
async with asyncio.timeout(120):
|
| 216 |
+
result = Runner.run_streamed(
|
| 217 |
+
self.assistant_agent,
|
| 218 |
+
[{"role": "user", "content": user_message}],
|
| 219 |
+
context=agent_context,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# Collect assistant response
|
| 223 |
+
full_response = ""
|
| 224 |
+
async for chunk in result:
|
| 225 |
+
if hasattr(chunk, 'content'):
|
| 226 |
+
content = chunk.content
|
| 227 |
+
if content:
|
| 228 |
+
full_response += content
|
| 229 |
+
# Stream text delta
|
| 230 |
+
yield self._sse_event("message_delta", {
|
| 231 |
+
"type": "text",
|
| 232 |
+
"text": content
|
| 233 |
+
})
|
| 234 |
+
|
| 235 |
+
# Create assistant message in database
|
| 236 |
+
await self.store.create_message(
|
| 237 |
+
thread_id=thread_id,
|
| 238 |
+
item={
|
| 239 |
+
"type": "message",
|
| 240 |
+
"role": "assistant",
|
| 241 |
+
"content": [{"type": "text", "text": full_response}],
|
| 242 |
+
}
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
# Send message done event
|
| 246 |
+
yield self._sse_event("message_done", {
|
| 247 |
+
"message_id": thread_id,
|
| 248 |
+
"role": "assistant"
|
| 249 |
+
})
|
| 250 |
+
|
| 251 |
+
except TimeoutError:
|
| 252 |
+
# [Task]: T034 - Handle timeout gracefully
|
| 253 |
+
logger.error(f"Agent execution timeout for thread {thread_id}")
|
| 254 |
+
yield self._sse_event("error", {
|
| 255 |
+
"message": "Request timed out",
|
| 256 |
+
"detail": "The AI assistant took too long to respond. Please try again."
|
| 257 |
+
})
|
| 258 |
+
except Exception as e:
|
| 259 |
+
logger.error(f"Agent execution error: {e}", exc_info=True)
|
| 260 |
+
yield self._sse_event("error", {"message": str(e)})
|
| 261 |
+
|
| 262 |
+
# Return streaming result
|
| 263 |
+
class StreamingResult:
|
| 264 |
+
def __init__(self, generator):
|
| 265 |
+
self.generator = generator
|
| 266 |
+
self.json = json.dumps({"thread_id": thread_id})
|
| 267 |
+
|
| 268 |
+
def __aiter__(self):
|
| 269 |
+
return self.generator()
|
| 270 |
+
|
| 271 |
+
return StreamingResult(stream_response())
|
| 272 |
+
|
| 273 |
+
def _extract_message_content(self, item: dict) -> str:
|
| 274 |
+
"""Extract text content from ChatKit item.
|
| 275 |
+
|
| 276 |
+
Args:
|
| 277 |
+
item: ChatKit message item
|
| 278 |
+
|
| 279 |
+
Returns:
|
| 280 |
+
Extracted text content
|
| 281 |
+
"""
|
| 282 |
+
content_array = item.get("content", [])
|
| 283 |
+
for content_block in content_array:
|
| 284 |
+
if content_block.get("type") == "text":
|
| 285 |
+
return content_block.get("text", "")
|
| 286 |
+
return ""
|
| 287 |
+
|
| 288 |
+
def _sse_event(self, event_type: str, data: dict) -> str:
|
| 289 |
+
"""Format data as Server-Sent Event.
|
| 290 |
+
|
| 291 |
+
[From]: specs/010-chatkit-migration/contracts/backend.md - SSE Event Types
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
event_type: Event type name
|
| 295 |
+
data: Event data payload
|
| 296 |
+
|
| 297 |
+
Returns:
|
| 298 |
+
Formatted SSE string
|
| 299 |
+
"""
|
| 300 |
+
import json
|
| 301 |
+
return f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
# Create singleton instance (will be configured with tools in T019)
|
| 305 |
+
_server_instance: Optional[TaskManagerChatKitServer] = None
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def get_chatkit_server(store: PostgresChatKitStore) -> TaskManagerChatKitServer:
|
| 309 |
+
"""Get or create ChatKit server singleton with agent configuration.
|
| 310 |
+
|
| 311 |
+
[From]: specs/010-chatkit-migration/contracts/backend.md
|
| 312 |
+
|
| 313 |
+
[Task]: T019 - Configure TaskAssistant agent with Gemini model and wrapped tools
|
| 314 |
+
|
| 315 |
+
Args:
|
| 316 |
+
store: PostgresChatKitStore instance
|
| 317 |
+
|
| 318 |
+
Returns:
|
| 319 |
+
ChatKit server instance with configured agent
|
| 320 |
+
"""
|
| 321 |
+
global _server_instance
|
| 322 |
+
if _server_instance is None:
|
| 323 |
+
_server_instance = TaskManagerChatKitServer(store)
|
| 324 |
+
|
| 325 |
+
# Configure the assistant agent with tools
|
| 326 |
+
# [From]: specs/010-chatkit-migration/tasks.md - T019
|
| 327 |
+
# [From]: specs/010-chatkit-migration/contracts/backend.md - Tool Contracts
|
| 328 |
+
from ai_agent.tool_wrappers import TOOL_FUNCTIONS
|
| 329 |
+
from core.config import get_settings
|
| 330 |
+
|
| 331 |
+
settings = get_settings()
|
| 332 |
+
|
| 333 |
+
# Create the TaskAssistant agent with Gemini model
|
| 334 |
+
# [From]: specs/010-chatkit-migration/research.md - Section 3
|
| 335 |
+
assistant_agent = Agent[AgentContext](
|
| 336 |
+
name="TaskAssistant",
|
| 337 |
+
model=settings.gemini_model or "gemini-2.0-flash-exp",
|
| 338 |
+
instructions="""You are a helpful task management assistant. You help users create, list, update, complete, and delete tasks through natural language.
|
| 339 |
+
|
| 340 |
+
Available tools:
|
| 341 |
+
- create_task: Create a new task with title, description, due date, priority, tags
|
| 342 |
+
- list_tasks: List all tasks with optional filters
|
| 343 |
+
- update_task: Update an existing task
|
| 344 |
+
- delete_task: Delete a task
|
| 345 |
+
- complete_task: Mark a task as completed or incomplete
|
| 346 |
+
- complete_all_tasks: Mark all tasks as completed (requires confirmation)
|
| 347 |
+
- delete_all_tasks: Delete all tasks (requires confirmation)
|
| 348 |
+
|
| 349 |
+
When users ask about tasks, use the appropriate tool. Always confirm destructive actions (complete_all_tasks, delete_all_tasks) by requiring the confirm parameter.
|
| 350 |
+
|
| 351 |
+
Be concise and helpful. If a user's request is unclear, ask for clarification.""",
|
| 352 |
+
tools=TOOL_FUNCTIONS,
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
_server_instance.set_agent(assistant_agent)
|
| 356 |
+
logger.info(f"ChatKit server initialized with {len(TOOL_FUNCTIONS)} tools and model {settings.gemini_model}")
|
| 357 |
+
|
| 358 |
+
return _server_instance
|
|
@@ -1,54 +0,0 @@
|
|
| 1 |
-
"""Application configuration and settings.
|
| 2 |
-
|
| 3 |
-
[Task]: T009
|
| 4 |
-
[From]: specs/001-user-auth/plan.md
|
| 5 |
-
|
| 6 |
-
[Task]: T003
|
| 7 |
-
[From]: specs/004-ai-chatbot/plan.md
|
| 8 |
-
"""
|
| 9 |
-
import os
|
| 10 |
-
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 11 |
-
from functools import lru_cache
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
class Settings(BaseSettings):
|
| 15 |
-
"""Application settings loaded from environment variables."""
|
| 16 |
-
|
| 17 |
-
# Database
|
| 18 |
-
database_url: str
|
| 19 |
-
|
| 20 |
-
# JWT
|
| 21 |
-
jwt_secret: str
|
| 22 |
-
jwt_algorithm: str = "HS256"
|
| 23 |
-
jwt_expiration_days: int = 7
|
| 24 |
-
|
| 25 |
-
# CORS
|
| 26 |
-
frontend_url: str
|
| 27 |
-
|
| 28 |
-
# Environment
|
| 29 |
-
environment: str = "development"
|
| 30 |
-
|
| 31 |
-
# Gemini API (Phase III: AI Chatbot)
|
| 32 |
-
gemini_api_key: str | None = None # Optional for migration/setup
|
| 33 |
-
gemini_model: str = "gemini-2.0-flash-exp"
|
| 34 |
-
|
| 35 |
-
model_config = SettingsConfigDict(
|
| 36 |
-
env_file=".env",
|
| 37 |
-
case_sensitive=False,
|
| 38 |
-
# Support legacy Better Auth environment variables
|
| 39 |
-
env_prefix="",
|
| 40 |
-
extra="ignore"
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
@lru_cache()
|
| 45 |
-
def get_settings() -> Settings:
|
| 46 |
-
"""Get cached settings instance.
|
| 47 |
-
|
| 48 |
-
Returns:
|
| 49 |
-
Settings: Application settings
|
| 50 |
-
|
| 51 |
-
Raises:
|
| 52 |
-
ValueError: If required environment variables are not set
|
| 53 |
-
"""
|
| 54 |
-
return Settings()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -5,6 +5,9 @@
|
|
| 5 |
|
| 6 |
[Task]: T003
|
| 7 |
[From]: specs/004-ai-chatbot/plan.md
|
|
|
|
|
|
|
|
|
|
| 8 |
"""
|
| 9 |
import os
|
| 10 |
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
@@ -18,12 +21,12 @@ class Settings(BaseSettings):
|
|
| 18 |
database_url: str
|
| 19 |
|
| 20 |
# JWT
|
| 21 |
-
jwt_secret: str
|
| 22 |
jwt_algorithm: str = "HS256"
|
| 23 |
jwt_expiration_days: int = 7
|
| 24 |
|
| 25 |
-
# CORS
|
| 26 |
-
frontend_url: str
|
| 27 |
|
| 28 |
# Environment
|
| 29 |
environment: str = "development"
|
|
@@ -31,6 +34,7 @@ class Settings(BaseSettings):
|
|
| 31 |
# Gemini API (Phase III: AI Chatbot)
|
| 32 |
gemini_api_key: str | None = None # Optional for migration/setup
|
| 33 |
gemini_model: str = "gemini-2.0-flash-exp"
|
|
|
|
| 34 |
|
| 35 |
model_config = SettingsConfigDict(
|
| 36 |
env_file=".env",
|
|
@@ -52,3 +56,38 @@ def get_settings() -> Settings:
|
|
| 52 |
ValueError: If required environment variables are not set
|
| 53 |
"""
|
| 54 |
return Settings()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
[Task]: T003
|
| 7 |
[From]: specs/004-ai-chatbot/plan.md
|
| 8 |
+
|
| 9 |
+
Extended for ChatKit migration with Gemini OpenAI-compatible endpoint.
|
| 10 |
+
[From]: specs/010-chatkit-migration/tasks.md - T008
|
| 11 |
"""
|
| 12 |
import os
|
| 13 |
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
|
|
| 21 |
database_url: str
|
| 22 |
|
| 23 |
# JWT
|
| 24 |
+
jwt_secret: str
|
| 25 |
jwt_algorithm: str = "HS256"
|
| 26 |
jwt_expiration_days: int = 7
|
| 27 |
|
| 28 |
+
# CORS
|
| 29 |
+
frontend_url: str
|
| 30 |
|
| 31 |
# Environment
|
| 32 |
environment: str = "development"
|
|
|
|
| 34 |
# Gemini API (Phase III: AI Chatbot)
|
| 35 |
gemini_api_key: str | None = None # Optional for migration/setup
|
| 36 |
gemini_model: str = "gemini-2.0-flash-exp"
|
| 37 |
+
gemini_base_url: str = "https://generativelanguage.googleapis.com/v1beta/openai/" # ChatKit migration
|
| 38 |
|
| 39 |
model_config = SettingsConfigDict(
|
| 40 |
env_file=".env",
|
|
|
|
| 56 |
ValueError: If required environment variables are not set
|
| 57 |
"""
|
| 58 |
return Settings()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_gemini_client():
|
| 62 |
+
"""Create and return an AsyncOpenAI client configured for Gemini.
|
| 63 |
+
|
| 64 |
+
[From]: specs/010-chatkit-migration/research.md - Section 2
|
| 65 |
+
[From]: specs/010-chatkit-migration/contracts/backend.md - Tool Contracts
|
| 66 |
+
|
| 67 |
+
This client uses Gemini's OpenAI-compatible endpoint, allowing us to use
|
| 68 |
+
the OpenAI SDK and Agents SDK with Gemini as the LLM provider.
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
AsyncOpenAI: OpenAI client configured for Gemini
|
| 72 |
+
|
| 73 |
+
Example:
|
| 74 |
+
from openai import AsyncOpenAI
|
| 75 |
+
from agents import set_default_openai_client
|
| 76 |
+
|
| 77 |
+
client = get_gemini_client()
|
| 78 |
+
set_default_openai_client(client)
|
| 79 |
+
"""
|
| 80 |
+
from openai import AsyncOpenAI
|
| 81 |
+
|
| 82 |
+
settings = get_settings()
|
| 83 |
+
|
| 84 |
+
if not settings.gemini_api_key:
|
| 85 |
+
raise ValueError(
|
| 86 |
+
"GEMINI_API_KEY is not set. Please set it in your environment or .env file. "
|
| 87 |
+
"Get your API key from https://aistudio.google.com"
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
return AsyncOpenAI(
|
| 91 |
+
api_key=settings.gemini_api_key,
|
| 92 |
+
base_url=settings.gemini_base_url,
|
| 93 |
+
)
|
|
@@ -145,3 +145,40 @@ def decode_access_token(token: str) -> dict:
|
|
| 145 |
detail="Could not validate credentials",
|
| 146 |
headers={"WWW-Authenticate": "Bearer"},
|
| 147 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
detail="Could not validate credentials",
|
| 146 |
headers={"WWW-Authenticate": "Bearer"},
|
| 147 |
)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
async def get_current_user_id_from_cookie(request) -> Optional[str]:
|
| 151 |
+
"""Extract and validate user ID from JWT token in httpOnly cookie.
|
| 152 |
+
|
| 153 |
+
[Task]: T011
|
| 154 |
+
[From]: specs/010-chatkit-migration/contracts/backend.md - Authentication Contracts
|
| 155 |
+
|
| 156 |
+
This function extracts the JWT token from the auth_token httpOnly cookie,
|
| 157 |
+
decodes it, and returns the user_id (sub claim).
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
request: FastAPI/Starlette request object
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
User ID (UUID string) or None if authentication fails
|
| 164 |
+
|
| 165 |
+
Raises:
|
| 166 |
+
HTTPException: If token is invalid (only if raise_on_error=True)
|
| 167 |
+
"""
|
| 168 |
+
# Try httpOnly cookie first
|
| 169 |
+
# [From]: specs/010-chatkit-migration/contracts/backend.md
|
| 170 |
+
auth_token = request.cookies.get("auth_token")
|
| 171 |
+
if not auth_token:
|
| 172 |
+
return None
|
| 173 |
+
|
| 174 |
+
try:
|
| 175 |
+
payload = decode_access_token(auth_token)
|
| 176 |
+
user_id = payload.get("sub")
|
| 177 |
+
return user_id
|
| 178 |
+
except HTTPException:
|
| 179 |
+
return None
|
| 180 |
+
except Exception as e:
|
| 181 |
+
# Log but don't raise for non-critical operations
|
| 182 |
+
import logging
|
| 183 |
+
logging.getLogger("api.auth").warning(f"Failed to decode token from cookie: {e}")
|
| 184 |
+
return None
|
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-- ChatKit Migration: Create threads table and update messages table
|
| 2 |
+
--
|
| 3 |
+
-- [From]: specs/010-chatkit-migration/data-model.md - Database Schema Migration
|
| 4 |
+
-- [Task]: T006
|
| 5 |
+
--
|
| 6 |
+
-- This migration:
|
| 7 |
+
-- 1. Creates the threads table for ChatKit conversation management
|
| 8 |
+
-- 2. Adds thread_id column to messages table
|
| 9 |
+
-- 3. Migrates existing conversation_id data to thread_id
|
| 10 |
+
-- 4. Creates indexes for query optimization
|
| 11 |
+
--
|
| 12 |
+
-- IMPORTANT: Run this migration after deploying the Thread model
|
| 13 |
+
--
|
| 14 |
+
-- To run: psql $DATABASE_URL < migrations/migrate_threads.sql
|
| 15 |
+
|
| 16 |
+
-- BEGIN;
|
| 17 |
+
|
| 18 |
+
-- 1. Create threads table
|
| 19 |
+
CREATE TABLE IF NOT EXISTS threads (
|
| 20 |
+
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
| 21 |
+
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
| 22 |
+
title VARCHAR(255),
|
| 23 |
+
metadata JSONB,
|
| 24 |
+
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
| 25 |
+
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
| 26 |
+
);
|
| 27 |
+
|
| 28 |
+
-- 2. Create indexes for threads table
|
| 29 |
+
CREATE INDEX IF NOT EXISTS idx_thread_user_id ON threads(user_id);
|
| 30 |
+
CREATE INDEX IF NOT EXISTS idx_thread_updated_at ON threads(user_id, updated_at DESC);
|
| 31 |
+
|
| 32 |
+
-- 3. Add thread_id column to messages table (nullable initially)
|
| 33 |
+
ALTER TABLE message ADD COLUMN IF NOT EXISTS thread_id UUID REFERENCES threads(id) ON DELETE CASCADE;
|
| 34 |
+
|
| 35 |
+
-- 4. Create index for thread_id in messages table
|
| 36 |
+
CREATE INDEX IF NOT EXISTS idx_message_thread_id ON message(thread_id, created_at ASC);
|
| 37 |
+
|
| 38 |
+
-- 5. Migrate existing conversation data to threads
|
| 39 |
+
-- This creates a thread for each unique conversation and links messages to it
|
| 40 |
+
-- Skip this step if starting fresh (no existing conversations)
|
| 41 |
+
INSERT INTO threads (id, user_id, created_at, updated_at)
|
| 42 |
+
SELECT DISTINCT
|
| 43 |
+
c.id as id, -- Use same ID as conversation for easy mapping
|
| 44 |
+
c.user_id,
|
| 45 |
+
c.created_at,
|
| 46 |
+
c.updated_at
|
| 47 |
+
FROM conversation c
|
| 48 |
+
WHERE NOT EXISTS (SELECT 1 FROM threads t WHERE t.id = c.id);
|
| 49 |
+
|
| 50 |
+
-- 6. Update messages to point to the new thread_id
|
| 51 |
+
-- This maps existing messages to their corresponding threads
|
| 52 |
+
UPDATE message m
|
| 53 |
+
SET thread_id = m.conversation_id
|
| 54 |
+
WHERE m.conversation_id IS NOT NULL
|
| 55 |
+
AND m.thread_id IS NULL;
|
| 56 |
+
|
| 57 |
+
-- 7. After migration, make thread_id NOT NULL (only after validating data)
|
| 58 |
+
-- Uncomment these lines after verifying successful migration:
|
| 59 |
+
-- ALTER TABLE message ALTER COLUMN thread_id SET NOT NULL;
|
| 60 |
+
-- ALTER TABLE message ADD CONSTRAINT message_thread_id_fkey FOREIGN KEY (thread_id) REFERENCES threads(id) ON DELETE CASCADE;
|
| 61 |
+
|
| 62 |
+
-- Optional: Drop old conversation_id column after full migration
|
| 63 |
+
-- Uncomment ONLY after confirming ChatKit is working correctly:
|
| 64 |
+
-- ALTER TABLE message DROP COLUMN IF EXISTS conversation_id;
|
| 65 |
+
-- DROP INDEX IF EXISTS idx_message_conversation_created;
|
| 66 |
+
|
| 67 |
+
-- COMMIT;
|
| 68 |
+
|
| 69 |
+
-- Rollback commands (if needed):
|
| 70 |
+
-- BEGIN;
|
| 71 |
+
-- ALTER TABLE message DROP COLUMN IF EXISTS thread_id;
|
| 72 |
+
-- DROP INDEX IF EXISTS idx_message_thread_id;
|
| 73 |
+
-- DROP INDEX IF EXISTS idx_thread_user_id;
|
| 74 |
+
-- DROP INDEX IF EXISTS idx_thread_updated_at;
|
| 75 |
+
-- DROP TABLE IF EXISTS threads;
|
| 76 |
+
-- COMMIT;
|
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Thread model for ChatKit integration.
|
| 2 |
+
|
| 3 |
+
[Task]: T004
|
| 4 |
+
[From]: specs/010-chatkit-migration/data-model.md
|
| 5 |
+
|
| 6 |
+
This model implements ChatKit's thread abstraction for grouping messages
|
| 7 |
+
into conversations. It extends the existing Conversation model with
|
| 8 |
+
ChatKit-specific fields.
|
| 9 |
+
|
| 10 |
+
Migration Note: The existing Conversation model serves a similar purpose.
|
| 11 |
+
During migration, we can either:
|
| 12 |
+
1. Use Thread alongside Conversation (dual model approach)
|
| 13 |
+
2. Migrate Conversation to Thread (single model approach)
|
| 14 |
+
|
| 15 |
+
This implementation uses Thread as the primary model for ChatKit integration.
|
| 16 |
+
"""
|
| 17 |
+
import uuid
|
| 18 |
+
from datetime import datetime
|
| 19 |
+
from typing import Optional
|
| 20 |
+
from sqlmodel import Field, SQLModel
|
| 21 |
+
from sqlalchemy import Column, DateTime, JSON as SQLJSON, String as SQLString, Index
|
| 22 |
+
from sqlalchemy.dialects.postgresql import JSONB
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Thread(SQLModel, table=True):
|
| 26 |
+
"""Thread model representing a ChatKit conversation session.
|
| 27 |
+
|
| 28 |
+
ChatKit uses "thread" terminology for conversation grouping.
|
| 29 |
+
A thread contains multiple messages and belongs to a single user.
|
| 30 |
+
|
| 31 |
+
[From]: specs/010-chatkit-migration/data-model.md - Thread Entity
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
__tablename__ = "threads"
|
| 35 |
+
|
| 36 |
+
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
| 37 |
+
user_id: uuid.UUID = Field(foreign_key="users.id", index=True, nullable=False)
|
| 38 |
+
|
| 39 |
+
# Optional thread title (ChatKit allows threads to have titles)
|
| 40 |
+
title: Optional[str] = Field(
|
| 41 |
+
default=None,
|
| 42 |
+
max_length=255,
|
| 43 |
+
sa_column=Column(SQLString(255), nullable=True)
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# Thread metadata (tags, settings, etc.) stored as JSONB
|
| 47 |
+
# Note: Using 'thread_metadata' instead of 'metadata' to avoid SQLAlchemy reserved attribute
|
| 48 |
+
thread_metadata: Optional[dict] = Field(
|
| 49 |
+
default=None,
|
| 50 |
+
sa_column=Column("metadata", JSONB, nullable=True), # Column name in DB is 'metadata'
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Thread creation timestamp
|
| 54 |
+
created_at: datetime = Field(
|
| 55 |
+
default_factory=datetime.utcnow,
|
| 56 |
+
sa_column=Column(DateTime(timezone=True), nullable=False)
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# Last message/update timestamp (auto-updated by application logic)
|
| 60 |
+
updated_at: datetime = Field(
|
| 61 |
+
default_factory=datetime.utcnow,
|
| 62 |
+
sa_column=Column(DateTime(timezone=True), nullable=False, index=True)
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# Table indexes for query optimization
|
| 66 |
+
__table_args__ = (
|
| 67 |
+
Index('idx_thread_user_id', 'user_id'),
|
| 68 |
+
Index('idx_thread_updated_at', 'user_id', 'updated_at'), # For sorting conversations by recent activity
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
def __repr__(self) -> str:
|
| 72 |
+
return f"<Thread(id={self.id}, user_id={self.user_id}, title={self.title})>"
|
|
@@ -1,276 +0,0 @@
|
|
| 1 |
-
"""Security utilities for the AI chatbot.
|
| 2 |
-
|
| 3 |
-
[Task]: T057
|
| 4 |
-
[From]: specs/004-ai-chatbot/tasks.md
|
| 5 |
-
|
| 6 |
-
This module provides security functions including prompt injection sanitization,
|
| 7 |
-
input validation, and content filtering.
|
| 8 |
-
"""
|
| 9 |
-
import re
|
| 10 |
-
import html
|
| 11 |
-
from typing import Optional, List
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
# Known prompt injection patterns
|
| 15 |
-
PROMPT_INJECTION_PATTERNS = [
|
| 16 |
-
# Direct instructions to ignore previous context
|
| 17 |
-
r"(?i)ignore\s+(all\s+)?(previous|above|prior)",
|
| 18 |
-
r"(?i)disregard\s+(all\s+)?(previous|above|prior)",
|
| 19 |
-
r"(?i)forget\s+(everything|all\s+instructions|previous)",
|
| 20 |
-
r"(?i)override\s+(your\s+)?programming",
|
| 21 |
-
r"(?i)new\s+(instruction|direction|rule)s?",
|
| 22 |
-
r"(?i)change\s+(your\s+)?(behavior|role|persona)",
|
| 23 |
-
|
| 24 |
-
# Jailbreak attempts
|
| 25 |
-
r"(?i)(jailbreak|jail\s*break)",
|
| 26 |
-
r"(?i)(developer|admin|root|privileged)\s+mode",
|
| 27 |
-
r"(?i)act\s+as\s+(a\s+)?(developer|admin|root)",
|
| 28 |
-
r"(?i)roleplay\s+as",
|
| 29 |
-
r"(?i)pretend\s+(to\s+be|you're)",
|
| 30 |
-
r"(?i)simulate\s+being",
|
| 31 |
-
|
| 32 |
-
# System prompt extraction
|
| 33 |
-
r"(?i)show\s+(your\s+)?(instructions|system\s+prompt|prompt)",
|
| 34 |
-
r"(?i)print\s+(your\s+)?(instructions|system\s+prompt)",
|
| 35 |
-
r"(?i)reveal\s+(your\s+)?(instructions|system\s+prompt)",
|
| 36 |
-
r"(?i)what\s+(are\s+)?your\s+instructions",
|
| 37 |
-
r"(?i)tell\s+me\s+how\s+you\s+work",
|
| 38 |
-
|
| 39 |
-
# DAN and similar jailbreaks
|
| 40 |
-
r"(?i)do\s+anything\s+now",
|
| 41 |
-
r"(?i)unrestricted\s+mode",
|
| 42 |
-
r"(?i)no\s+limitations?",
|
| 43 |
-
r"(?i)bypass\s+(safety|filters|restrictions)",
|
| 44 |
-
r"(?i)\bDAN\b", # Do Anything Now
|
| 45 |
-
]
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
def sanitize_message(message: str, max_length: int = 10000) -> str:
|
| 49 |
-
"""Sanitize a user message to prevent prompt injection attacks.
|
| 50 |
-
|
| 51 |
-
[From]: specs/004-ai-chatbot/spec.md - NFR-017
|
| 52 |
-
|
| 53 |
-
Args:
|
| 54 |
-
message: The raw user message
|
| 55 |
-
max_length: Maximum allowed message length
|
| 56 |
-
|
| 57 |
-
Returns:
|
| 58 |
-
Sanitized message safe for processing by AI
|
| 59 |
-
|
| 60 |
-
Raises:
|
| 61 |
-
ValueError: If message contains severe injection attempts
|
| 62 |
-
"""
|
| 63 |
-
if not message:
|
| 64 |
-
return ""
|
| 65 |
-
|
| 66 |
-
# Trim to max length
|
| 67 |
-
message = message[:max_length]
|
| 68 |
-
|
| 69 |
-
# Check for severe injection patterns
|
| 70 |
-
detected = detect_prompt_injection(message)
|
| 71 |
-
if detected:
|
| 72 |
-
# For severe attacks, reject the message
|
| 73 |
-
if detected["severity"] == "high":
|
| 74 |
-
raise ValueError(
|
| 75 |
-
"This message contains content that cannot be processed. "
|
| 76 |
-
"Please rephrase your request."
|
| 77 |
-
)
|
| 78 |
-
|
| 79 |
-
# Apply sanitization
|
| 80 |
-
sanitized = _apply_sanitization(message)
|
| 81 |
-
|
| 82 |
-
return sanitized
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
def detect_prompt_injection(message: str) -> Optional[dict]:
|
| 86 |
-
"""Detect potential prompt injection attempts in a message.
|
| 87 |
-
|
| 88 |
-
[From]: specs/004-ai-chatbot/spec.md - NFR-017
|
| 89 |
-
|
| 90 |
-
Args:
|
| 91 |
-
message: The message to check
|
| 92 |
-
|
| 93 |
-
Returns:
|
| 94 |
-
Dictionary with detection info if injection detected, None otherwise:
|
| 95 |
-
{
|
| 96 |
-
"detected": True,
|
| 97 |
-
"severity": "low" | "medium" | "high",
|
| 98 |
-
"pattern": "matched pattern",
|
| 99 |
-
"confidence": 0.0-1.0
|
| 100 |
-
}
|
| 101 |
-
"""
|
| 102 |
-
message_lower = message.lower()
|
| 103 |
-
|
| 104 |
-
for pattern in PROMPT_INJECTION_PATTERNS:
|
| 105 |
-
match = re.search(pattern, message_lower)
|
| 106 |
-
|
| 107 |
-
if match:
|
| 108 |
-
# Determine severity based on pattern type
|
| 109 |
-
severity = _get_severity_for_pattern(pattern)
|
| 110 |
-
|
| 111 |
-
# Check for context that might indicate legitimate use
|
| 112 |
-
is_legitimate = _check_legitimate_context(message, match.group())
|
| 113 |
-
|
| 114 |
-
if not is_legitimate:
|
| 115 |
-
return {
|
| 116 |
-
"detected": True,
|
| 117 |
-
"severity": severity,
|
| 118 |
-
"pattern": match.group(),
|
| 119 |
-
"confidence": 0.8
|
| 120 |
-
}
|
| 121 |
-
|
| 122 |
-
return None
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
def _get_severity_for_pattern(pattern: str) -> str:
|
| 126 |
-
"""Determine severity level for a matched pattern.
|
| 127 |
-
|
| 128 |
-
Args:
|
| 129 |
-
pattern: The regex pattern that matched
|
| 130 |
-
|
| 131 |
-
Returns:
|
| 132 |
-
"low", "medium", or "high"
|
| 133 |
-
"""
|
| 134 |
-
pattern_lower = pattern.lower()
|
| 135 |
-
|
| 136 |
-
# High severity: direct jailbreak attempts
|
| 137 |
-
if any(word in pattern_lower for word in ["jailbreak", "dan", "unrestricted", "bypass"]):
|
| 138 |
-
return "high"
|
| 139 |
-
|
| 140 |
-
# High severity: system prompt extraction
|
| 141 |
-
if any(word in pattern_lower for word in ["show", "print", "reveal", "instructions"]):
|
| 142 |
-
return "high"
|
| 143 |
-
|
| 144 |
-
# Medium severity: role/persona manipulation
|
| 145 |
-
if any(word in pattern_lower for word in ["act as", "pretend", "roleplay", "override"]):
|
| 146 |
-
return "medium"
|
| 147 |
-
|
| 148 |
-
# Low severity: ignore instructions
|
| 149 |
-
if any(word in pattern_lower for word in ["ignore", "disregard", "forget"]):
|
| 150 |
-
return "low"
|
| 151 |
-
|
| 152 |
-
return "low"
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
def _check_legitimate_context(message: str, matched_text: str) -> bool:
|
| 156 |
-
"""Check if a matched pattern might be legitimate user content.
|
| 157 |
-
|
| 158 |
-
[From]: specs/004-ai-chatbot/spec.md - NFR-017
|
| 159 |
-
|
| 160 |
-
Args:
|
| 161 |
-
message: The full message
|
| 162 |
-
matched_text: The text that matched a pattern
|
| 163 |
-
|
| 164 |
-
Returns:
|
| 165 |
-
True if this appears to be legitimate context, False otherwise
|
| 166 |
-
"""
|
| 167 |
-
message_lower = message.lower()
|
| 168 |
-
matched_lower = matched_text.lower()
|
| 169 |
-
|
| 170 |
-
# Check if the matched text is part of a task description (legitimate)
|
| 171 |
-
legitimate_contexts = [
|
| 172 |
-
# Common task-related phrases
|
| 173 |
-
"task to ignore",
|
| 174 |
-
"mark as complete",
|
| 175 |
-
"disregard this",
|
| 176 |
-
"role in the project",
|
| 177 |
-
"change status",
|
| 178 |
-
"update the role",
|
| 179 |
-
"priority change",
|
| 180 |
-
]
|
| 181 |
-
|
| 182 |
-
for context in legitimate_contexts:
|
| 183 |
-
if context in message_lower:
|
| 184 |
-
return True
|
| 185 |
-
|
| 186 |
-
# Check if matched text is very short (likely false positive)
|
| 187 |
-
if len(matched_text) <= 3:
|
| 188 |
-
return True
|
| 189 |
-
|
| 190 |
-
return False
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
def _apply_sanitization(message: str) -> str:
|
| 194 |
-
"""Apply sanitization transformations to a message.
|
| 195 |
-
|
| 196 |
-
[From]: specs/004-ai-chatbot/spec.md - NFR-017
|
| 197 |
-
|
| 198 |
-
Args:
|
| 199 |
-
message: The message to sanitize
|
| 200 |
-
|
| 201 |
-
Returns:
|
| 202 |
-
Sanitized message
|
| 203 |
-
"""
|
| 204 |
-
# Remove excessive whitespace
|
| 205 |
-
message = re.sub(r"\s+", " ", message)
|
| 206 |
-
|
| 207 |
-
# Remove control characters except newlines and tabs
|
| 208 |
-
message = re.sub(r"[\x00-\x08\x0b-\x0c\x0e-\x1f\x7f-\x9f]", "", message)
|
| 209 |
-
|
| 210 |
-
# Normalize line endings
|
| 211 |
-
message = message.replace("\r\n", "\n").replace("\r", "\n")
|
| 212 |
-
|
| 213 |
-
# Limit consecutive newlines to 2
|
| 214 |
-
message = re.sub(r"\n{3,}", "\n\n", message)
|
| 215 |
-
|
| 216 |
-
return message.strip()
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
def validate_task_input(task_data: dict) -> tuple[bool, Optional[str]]:
|
| 220 |
-
"""Validate task-related input for security issues.
|
| 221 |
-
|
| 222 |
-
[From]: specs/004-ai-chatbot/spec.md - NFR-017
|
| 223 |
-
|
| 224 |
-
Args:
|
| 225 |
-
task_data: Dictionary containing task fields
|
| 226 |
-
|
| 227 |
-
Returns:
|
| 228 |
-
Tuple of (is_valid, error_message)
|
| 229 |
-
"""
|
| 230 |
-
if not isinstance(task_data, dict):
|
| 231 |
-
return False, "Invalid task data format"
|
| 232 |
-
|
| 233 |
-
# Check for SQL injection patterns in string fields
|
| 234 |
-
sql_patterns = [
|
| 235 |
-
r"(?i)(\bunion\b.*\bselect\b)",
|
| 236 |
-
r"(?i)(\bselect\b.*\bfrom\b)",
|
| 237 |
-
r"(?i)(\binsert\b.*\binto\b)",
|
| 238 |
-
r"(?i)(\bupdate\b.*\bset\b)",
|
| 239 |
-
r"(?i)(\bdelete\b.*\bfrom\b)",
|
| 240 |
-
r"(?i)(\bdrop\b.*\btable\b)",
|
| 241 |
-
r";\s*(union|select|insert|update|delete|drop)",
|
| 242 |
-
]
|
| 243 |
-
|
| 244 |
-
for key, value in task_data.items():
|
| 245 |
-
if isinstance(value, str):
|
| 246 |
-
for pattern in sql_patterns:
|
| 247 |
-
if re.search(pattern, value):
|
| 248 |
-
return False, f"Invalid characters in {key}"
|
| 249 |
-
|
| 250 |
-
# Check for script injection
|
| 251 |
-
if re.search(r"<script[^>]*>.*?</script>", value, re.IGNORECASE):
|
| 252 |
-
return False, f"Invalid content in {key}"
|
| 253 |
-
|
| 254 |
-
return True, None
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
def sanitize_html_content(content: str) -> str:
|
| 258 |
-
"""Sanitize HTML content by escaping potentially dangerous elements.
|
| 259 |
-
|
| 260 |
-
[From]: specs/004-ai-chatbot/spec.md - NFR-017
|
| 261 |
-
|
| 262 |
-
Args:
|
| 263 |
-
content: Content that may contain HTML
|
| 264 |
-
|
| 265 |
-
Returns:
|
| 266 |
-
Escaped HTML string
|
| 267 |
-
"""
|
| 268 |
-
return html.escape(content, quote=False)
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
__all__ = [
|
| 272 |
-
"sanitize_message",
|
| 273 |
-
"detect_prompt_injection",
|
| 274 |
-
"validate_task_input",
|
| 275 |
-
"sanitize_html_content",
|
| 276 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PostgreSQL Store implementation for ChatKit SDK.
|
| 2 |
+
|
| 3 |
+
[Task]: T009
|
| 4 |
+
[From]: specs/010-chatkit-migration/data-model.md - ChatKit SDK Interface Requirements
|
| 5 |
+
[From]: specs/010-chatkit-migration/contracts/backend.md - Store Interface Implementation
|
| 6 |
+
|
| 7 |
+
This module implements the ChatKit Store interface using SQLModel and PostgreSQL.
|
| 8 |
+
The Store interface is required by ChatKit's Python SDK for thread and message persistence.
|
| 9 |
+
|
| 10 |
+
ChatKit Store Protocol Methods:
|
| 11 |
+
- list_threads: List threads for a user with pagination
|
| 12 |
+
- get_thread: Get a specific thread by ID
|
| 13 |
+
- create_thread: Create a new thread
|
| 14 |
+
- update_thread: Update thread metadata
|
| 15 |
+
- delete_thread: Delete a thread
|
| 16 |
+
- list_messages: List messages in a thread
|
| 17 |
+
- get_message: Get a specific message
|
| 18 |
+
- create_message: Create a new message
|
| 19 |
+
- update_message: Update a message
|
| 20 |
+
- delete_message: Delete a message
|
| 21 |
+
"""
|
| 22 |
+
import uuid
|
| 23 |
+
from datetime import datetime
|
| 24 |
+
from typing import Any, Optional
|
| 25 |
+
from sqlmodel import Session, select, col
|
| 26 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 27 |
+
|
| 28 |
+
from models.thread import Thread
|
| 29 |
+
from models.message import Message, MessageRole
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class PostgresChatKitStore:
|
| 33 |
+
"""PostgreSQL implementation of ChatKit Store interface.
|
| 34 |
+
|
| 35 |
+
[From]: specs/010-chatkit-migration/data-model.md - Store Interface
|
| 36 |
+
|
| 37 |
+
This store provides thread and message persistence for ChatKit using
|
| 38 |
+
the existing SQLModel models and PostgreSQL database.
|
| 39 |
+
|
| 40 |
+
Note: The ChatKit SDK uses a Protocol-based interface. The actual
|
| 41 |
+
protocol types (ThreadMetadata, MessageItem, etc.) would be imported
|
| 42 |
+
from the openai_chatkit package. For this implementation, we use
|
| 43 |
+
dictionary-based representations until the SDK is installed.
|
| 44 |
+
|
| 45 |
+
Usage:
|
| 46 |
+
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
| 47 |
+
from services.chatkit_store import PostgresChatKitStore
|
| 48 |
+
|
| 49 |
+
engine = create_async_engine(database_url)
|
| 50 |
+
async with AsyncSession(engine) as session:
|
| 51 |
+
store = PostgresChatKitStore(session)
|
| 52 |
+
thread = await store.create_thread(
|
| 53 |
+
user_id="user-123",
|
| 54 |
+
title="My Conversation",
|
| 55 |
+
metadata={"tag": "important"}
|
| 56 |
+
)
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(self, session: AsyncSession):
|
| 60 |
+
"""Initialize the store with a database session.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
session: SQLAlchemy async session for database operations
|
| 64 |
+
"""
|
| 65 |
+
self.session = session
|
| 66 |
+
|
| 67 |
+
async def list_threads(
|
| 68 |
+
self,
|
| 69 |
+
user_id: str,
|
| 70 |
+
limit: int = 50,
|
| 71 |
+
offset: int = 0
|
| 72 |
+
) -> list[dict]:
|
| 73 |
+
"""List threads for a user with pagination.
|
| 74 |
+
|
| 75 |
+
[From]: specs/010-chatkit-migration/data-model.md - Retrieve User's Conversations
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
user_id: User ID to filter threads
|
| 79 |
+
limit: Maximum number of threads to return (default: 50)
|
| 80 |
+
offset: Number of threads to skip (default: 0)
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
List of thread metadata dictionaries
|
| 84 |
+
"""
|
| 85 |
+
stmt = (
|
| 86 |
+
select(Thread)
|
| 87 |
+
.where(Thread.user_id == uuid.UUID(user_id))
|
| 88 |
+
.order_by(Thread.updated_at.desc())
|
| 89 |
+
.limit(limit)
|
| 90 |
+
.offset(offset)
|
| 91 |
+
)
|
| 92 |
+
result = await self.session.execute(stmt)
|
| 93 |
+
threads = result.scalars().all()
|
| 94 |
+
|
| 95 |
+
return [
|
| 96 |
+
{
|
| 97 |
+
"id": str(thread.id),
|
| 98 |
+
"user_id": str(thread.user_id),
|
| 99 |
+
"title": thread.title,
|
| 100 |
+
"metadata": thread.thread_metadata or {},
|
| 101 |
+
"created_at": thread.created_at.isoformat(),
|
| 102 |
+
"updated_at": thread.updated_at.isoformat(),
|
| 103 |
+
}
|
| 104 |
+
for thread in threads
|
| 105 |
+
]
|
| 106 |
+
|
| 107 |
+
async def get_thread(self, thread_id: str) -> Optional[dict]:
|
| 108 |
+
"""Get a specific thread by ID.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
thread_id: Thread UUID as string
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Thread metadata dictionary or None if not found
|
| 115 |
+
"""
|
| 116 |
+
stmt = select(Thread).where(Thread.id == uuid.UUID(thread_id))
|
| 117 |
+
result = await self.session.execute(stmt)
|
| 118 |
+
thread = result.scalar_one_or_none()
|
| 119 |
+
|
| 120 |
+
if thread is None:
|
| 121 |
+
return None
|
| 122 |
+
|
| 123 |
+
return {
|
| 124 |
+
"id": str(thread.id),
|
| 125 |
+
"user_id": str(thread.user_id),
|
| 126 |
+
"title": thread.title,
|
| 127 |
+
"metadata": thread.thread_metadata or {},
|
| 128 |
+
"created_at": thread.created_at.isoformat(),
|
| 129 |
+
"updated_at": thread.updated_at.isoformat(),
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
async def create_thread(
|
| 133 |
+
self,
|
| 134 |
+
user_id: str,
|
| 135 |
+
title: Optional[str] = None,
|
| 136 |
+
metadata: Optional[dict] = None
|
| 137 |
+
) -> dict:
|
| 138 |
+
"""Create a new thread.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
user_id: User ID who owns the thread
|
| 142 |
+
title: Optional thread title
|
| 143 |
+
metadata: Optional thread metadata
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
Created thread metadata dictionary
|
| 147 |
+
"""
|
| 148 |
+
thread = Thread(
|
| 149 |
+
user_id=uuid.UUID(user_id),
|
| 150 |
+
title=title,
|
| 151 |
+
thread_metadata=metadata or {},
|
| 152 |
+
)
|
| 153 |
+
self.session.add(thread)
|
| 154 |
+
await self.session.commit()
|
| 155 |
+
await self.session.refresh(thread)
|
| 156 |
+
|
| 157 |
+
return {
|
| 158 |
+
"id": str(thread.id),
|
| 159 |
+
"user_id": str(thread.user_id),
|
| 160 |
+
"title": thread.title,
|
| 161 |
+
"metadata": thread.thread_metadata or {},
|
| 162 |
+
"created_at": thread.created_at.isoformat(),
|
| 163 |
+
"updated_at": thread.updated_at.isoformat(),
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
async def update_thread(
|
| 167 |
+
self,
|
| 168 |
+
thread_id: str,
|
| 169 |
+
title: Optional[str] = None,
|
| 170 |
+
metadata: Optional[dict] = None
|
| 171 |
+
) -> Optional[dict]:
|
| 172 |
+
"""Update a thread.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
thread_id: Thread UUID as string
|
| 176 |
+
title: New title (optional)
|
| 177 |
+
metadata: New metadata (optional)
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
Updated thread metadata dictionary or None if not found
|
| 181 |
+
"""
|
| 182 |
+
stmt = select(Thread).where(Thread.id == uuid.UUID(thread_id))
|
| 183 |
+
result = await self.session.execute(stmt)
|
| 184 |
+
thread = result.scalar_one_or_none()
|
| 185 |
+
|
| 186 |
+
if thread is None:
|
| 187 |
+
return None
|
| 188 |
+
|
| 189 |
+
if title is not None:
|
| 190 |
+
thread.title = title
|
| 191 |
+
if metadata is not None:
|
| 192 |
+
thread.thread_metadata = metadata
|
| 193 |
+
thread.updated_at = datetime.utcnow()
|
| 194 |
+
|
| 195 |
+
await self.session.commit()
|
| 196 |
+
await self.session.refresh(thread)
|
| 197 |
+
|
| 198 |
+
return {
|
| 199 |
+
"id": str(thread.id),
|
| 200 |
+
"user_id": str(thread.user_id),
|
| 201 |
+
"title": thread.title,
|
| 202 |
+
"metadata": thread.thread_metadata or {},
|
| 203 |
+
"created_at": thread.created_at.isoformat(),
|
| 204 |
+
"updated_at": thread.updated_at.isoformat(),
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
async def delete_thread(self, thread_id: str) -> bool:
|
| 208 |
+
"""Delete a thread.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
thread_id: Thread UUID as string
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
True if deleted, False if not found
|
| 215 |
+
"""
|
| 216 |
+
stmt = select(Thread).where(Thread.id == uuid.UUID(thread_id))
|
| 217 |
+
result = await self.session.execute(stmt)
|
| 218 |
+
thread = result.scalar_one_or_none()
|
| 219 |
+
|
| 220 |
+
if thread is None:
|
| 221 |
+
return False
|
| 222 |
+
|
| 223 |
+
await self.session.delete(thread)
|
| 224 |
+
await self.session.commit()
|
| 225 |
+
return True
|
| 226 |
+
|
| 227 |
+
async def list_messages(
|
| 228 |
+
self,
|
| 229 |
+
thread_id: str,
|
| 230 |
+
limit: int = 50,
|
| 231 |
+
offset: int = 0
|
| 232 |
+
) -> list[dict]:
|
| 233 |
+
"""List messages in a thread.
|
| 234 |
+
|
| 235 |
+
[From]: specs/010-chatkit-migration/data-model.md - Retrieve Conversation Messages
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
thread_id: Thread UUID as string
|
| 239 |
+
limit: Maximum number of messages to return
|
| 240 |
+
offset: Number of messages to skip
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
List of message item dictionaries
|
| 244 |
+
"""
|
| 245 |
+
stmt = (
|
| 246 |
+
select(Message)
|
| 247 |
+
.where(Message.thread_id == uuid.UUID(thread_id))
|
| 248 |
+
.order_by(Message.created_at.asc())
|
| 249 |
+
.limit(limit)
|
| 250 |
+
.offset(offset)
|
| 251 |
+
)
|
| 252 |
+
result = await self.session.execute(stmt)
|
| 253 |
+
messages = result.scalars().all()
|
| 254 |
+
|
| 255 |
+
return [
|
| 256 |
+
{
|
| 257 |
+
"id": str(msg.id),
|
| 258 |
+
"type": "message",
|
| 259 |
+
"role": msg.role.value,
|
| 260 |
+
"content": [{"type": "text", "text": msg.content}],
|
| 261 |
+
"tool_calls": msg.tool_calls,
|
| 262 |
+
"created_at": msg.created_at.isoformat(),
|
| 263 |
+
}
|
| 264 |
+
for msg in messages
|
| 265 |
+
]
|
| 266 |
+
|
| 267 |
+
async def get_message(self, message_id: str) -> Optional[dict]:
|
| 268 |
+
"""Get a specific message by ID.
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
message_id: Message UUID as string
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
Message item dictionary or None if not found
|
| 275 |
+
"""
|
| 276 |
+
stmt = select(Message).where(Message.id == uuid.UUID(message_id))
|
| 277 |
+
result = await self.session.execute(stmt)
|
| 278 |
+
message = result.scalar_one_or_none()
|
| 279 |
+
|
| 280 |
+
if message is None:
|
| 281 |
+
return None
|
| 282 |
+
|
| 283 |
+
return {
|
| 284 |
+
"id": str(message.id),
|
| 285 |
+
"type": "message",
|
| 286 |
+
"role": message.role.value,
|
| 287 |
+
"content": [{"type": "text", "text": message.content}],
|
| 288 |
+
"tool_calls": message.tool_calls,
|
| 289 |
+
"created_at": message.created_at.isoformat(),
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
async def create_message(
|
| 293 |
+
self,
|
| 294 |
+
thread_id: str,
|
| 295 |
+
item: dict
|
| 296 |
+
) -> dict:
|
| 297 |
+
"""Create a new message in a thread.
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
thread_id: Thread UUID as string
|
| 301 |
+
item: Message item from ChatKit (UserMessageItem or ClientToolCallOutputItem)
|
| 302 |
+
|
| 303 |
+
Returns:
|
| 304 |
+
Created message item dictionary
|
| 305 |
+
|
| 306 |
+
Raises:
|
| 307 |
+
ValueError: If item format is invalid
|
| 308 |
+
"""
|
| 309 |
+
# Extract content from ChatKit item format
|
| 310 |
+
# ChatKit uses: {"type": "message", "role": "user", "content": [{"type": "text", "text": "..."}]}
|
| 311 |
+
item_type = item.get("type", "message")
|
| 312 |
+
role = item.get("role", "user")
|
| 313 |
+
|
| 314 |
+
# Extract text content from content array
|
| 315 |
+
content_array = item.get("content", [])
|
| 316 |
+
text_content = ""
|
| 317 |
+
for content_block in content_array:
|
| 318 |
+
if content_block.get("type") == "text":
|
| 319 |
+
text_content = content_block.get("text", "")
|
| 320 |
+
break
|
| 321 |
+
|
| 322 |
+
# Handle client tool output items
|
| 323 |
+
if item_type == "client_tool_call_output":
|
| 324 |
+
text_content = item.get("output", "")
|
| 325 |
+
|
| 326 |
+
message = Message(
|
| 327 |
+
thread_id=uuid.UUID(thread_id),
|
| 328 |
+
role=MessageRole(role),
|
| 329 |
+
content=text_content,
|
| 330 |
+
tool_calls=item.get("tool_calls"),
|
| 331 |
+
)
|
| 332 |
+
self.session.add(message)
|
| 333 |
+
await self.session.commit()
|
| 334 |
+
await self.session.refresh(message)
|
| 335 |
+
|
| 336 |
+
# Update thread's updated_at timestamp
|
| 337 |
+
thread_stmt = select(Thread).where(Thread.id == uuid.UUID(thread_id))
|
| 338 |
+
thread_result = await self.session.execute(thread_stmt)
|
| 339 |
+
thread = thread_result.scalar_one_or_none()
|
| 340 |
+
if thread:
|
| 341 |
+
thread.updated_at = datetime.utcnow()
|
| 342 |
+
await self.session.commit()
|
| 343 |
+
|
| 344 |
+
return {
|
| 345 |
+
"id": str(message.id),
|
| 346 |
+
"type": "message",
|
| 347 |
+
"role": message.role.value,
|
| 348 |
+
"content": [{"type": "text", "text": message.content}],
|
| 349 |
+
"tool_calls": message.tool_calls,
|
| 350 |
+
"created_at": message.created_at.isoformat(),
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
async def update_message(
|
| 354 |
+
self,
|
| 355 |
+
message_id: str,
|
| 356 |
+
item: dict
|
| 357 |
+
) -> Optional[dict]:
|
| 358 |
+
"""Update a message.
|
| 359 |
+
|
| 360 |
+
Args:
|
| 361 |
+
message_id: Message UUID as string
|
| 362 |
+
item: Updated message item
|
| 363 |
+
|
| 364 |
+
Returns:
|
| 365 |
+
Updated message item dictionary or None if not found
|
| 366 |
+
"""
|
| 367 |
+
stmt = select(Message).where(Message.id == uuid.UUID(message_id))
|
| 368 |
+
result = await self.session.execute(stmt)
|
| 369 |
+
message = result.scalar_one_or_none()
|
| 370 |
+
|
| 371 |
+
if message is None:
|
| 372 |
+
return None
|
| 373 |
+
|
| 374 |
+
# Update content if provided
|
| 375 |
+
content_array = item.get("content", [])
|
| 376 |
+
if content_array:
|
| 377 |
+
for content_block in content_array:
|
| 378 |
+
if content_block.get("type") == "text":
|
| 379 |
+
message.content = content_block.get("text", message.content)
|
| 380 |
+
break
|
| 381 |
+
|
| 382 |
+
# Update tool_calls if provided
|
| 383 |
+
if "tool_calls" in item:
|
| 384 |
+
message.tool_calls = item["tool_calls"]
|
| 385 |
+
|
| 386 |
+
await self.session.commit()
|
| 387 |
+
await self.session.refresh(message)
|
| 388 |
+
|
| 389 |
+
return {
|
| 390 |
+
"id": str(message.id),
|
| 391 |
+
"type": "message",
|
| 392 |
+
"role": message.role.value,
|
| 393 |
+
"content": [{"type": "text", "text": message.content}],
|
| 394 |
+
"tool_calls": message.tool_calls,
|
| 395 |
+
"created_at": message.created_at.isoformat(),
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
async def delete_message(self, message_id: str) -> bool:
|
| 399 |
+
"""Delete a message.
|
| 400 |
+
|
| 401 |
+
Args:
|
| 402 |
+
message_id: Message UUID as string
|
| 403 |
+
|
| 404 |
+
Returns:
|
| 405 |
+
True if deleted, False if not found
|
| 406 |
+
"""
|
| 407 |
+
stmt = select(Message).where(Message.id == uuid.UUID(message_id))
|
| 408 |
+
result = await self.session.execute(stmt)
|
| 409 |
+
message = result.scalar_one_or_none()
|
| 410 |
+
|
| 411 |
+
if message is None:
|
| 412 |
+
return False
|
| 413 |
+
|
| 414 |
+
await self.session.delete(message)
|
| 415 |
+
await self.session.commit()
|
| 416 |
+
return True
|