GrowWithTalha Claude Opus 4.6 commited on
Commit
a57a50a
·
1 Parent(s): 84c328d

feat: add ChatKit migration with SSE streaming

Browse files

This commit implements the ChatKit migration for the Hugging Face
Spaces deployment, replacing WebSocket with Server-Sent Events (SSE).

New Features:
- ChatKit server (chatkit_server.py) with SSE streaming
- PostgreSQL ChatKit store (services/chatkit_store.py)
- Thread model for ChatKit conversations (models/thread.py)
- MCP tool wrappers for Agents SDK (ai_agent/tool_wrappers.py)
- Database migration for threads table (migrations/migrate_threads.sql)

Backend Changes:
- Updated /api/chatkit endpoint with SSE streaming
- Fixed SSE JSON serialization using json.dumps()
- Added tool_calls JSONB column support
- Enhanced authentication with httpOnly cookie support

Updated Files:
- api/chat.py: Added ChatKit SSE endpoint
- core/config.py: Added get_gemini_client() function
- core/security.py: Enhanced cookie-based authentication
- ai_agent/__init__.py: Updated agent imports
- .env.example: Added GEMINI_BASE_URL configuration

This enables the same ChatKit functionality on Hugging Face Spaces
as in the main SDDRI-Hackathon-2 repository.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

.env.example CHANGED
@@ -14,3 +14,5 @@ ENVIRONMENT=development
14
  # Get your API key from https://aistudio.google.com
15
  GEMINI_API_KEY=your-gemini-api-key-here
16
  GEMINI_MODEL=gemini-2.0-flash-exp
 
 
 
14
  # Get your API key from https://aistudio.google.com
15
  GEMINI_API_KEY=your-gemini-api-key-here
16
  GEMINI_MODEL=gemini-2.0-flash-exp
17
+ # Gemini OpenAI-compatible endpoint for ChatKit integration
18
+ GEMINI_BASE_URL=https://generativelanguage.googleapis.com/v1beta/openai/
__init__.py DELETED
File without changes
ai_agent/__init__.py CHANGED
@@ -2,26 +2,22 @@
2
 
3
  [Task]: T014, T072
4
  [From]: specs/004-ai-chatbot/tasks.md
 
5
 
6
  This module provides the AI agent that powers the chatbot functionality.
7
  It uses OpenAI SDK with function calling and Gemini via AsyncOpenAI adapter.
8
 
9
- Includes streaming support for real-time WebSocket progress events.
 
10
  """
11
  from ai_agent.agent_simple import (
12
  get_gemini_client,
13
  run_agent,
14
  is_gemini_configured
15
  )
16
- from ai_agent.agent_streaming import (
17
- run_agent_with_streaming,
18
- execute_tool_with_progress,
19
- )
20
 
21
  __all__ = [
22
  "get_gemini_client",
23
  "run_agent",
24
- "run_agent_with_streaming",
25
- "execute_tool_with_progress",
26
  "is_gemini_configured"
27
  ]
 
2
 
3
  [Task]: T014, T072
4
  [From]: specs/004-ai-chatbot/tasks.md
5
+ [From]: T045 - Delete agent_streaming.py (ChatKit migration replaces WebSocket streaming)
6
 
7
  This module provides the AI agent that powers the chatbot functionality.
8
  It uses OpenAI SDK with function calling and Gemini via AsyncOpenAI adapter.
9
 
10
+ NOTE: The streaming agent functionality has been migrated to ChatKit SSE endpoint.
11
+ See backend/chatkit_server.py for the new ChatKit-based implementation.
12
  """
13
  from ai_agent.agent_simple import (
14
  get_gemini_client,
15
  run_agent,
16
  is_gemini_configured
17
  )
 
 
 
 
18
 
19
  __all__ = [
20
  "get_gemini_client",
21
  "run_agent",
 
 
22
  "is_gemini_configured"
23
  ]
ai_agent/tool_wrappers.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Agents SDK function wrappers for MCP task management tools.
2
+
3
+ [Task]: T012-T018
4
+ [From]: specs/010-chatkit-migration/tasks.md - Phase 3 Backend Implementation
5
+ [From]: specs/010-chatkit-migration/contracts/backend.md - Tool Contracts
6
+
7
+ This module wraps existing MCP tools as Agents SDK functions using the
8
+ @function_tool decorator. Each wrapper calls the underlying MCP tool function
9
+ and returns the result in a format compatible with the Agents SDK.
10
+
11
+ Tools wrapped:
12
+ 1. create_task (T012) - from mcp_server/tools/add_task.py
13
+ 2. list_tasks (T013) - from mcp_server/tools/list_tasks.py
14
+ 3. update_task (T014) - from mcp_server/tools/update_task.py
15
+ 4. delete_task (T015) - from mcp_server/tools/delete_task.py
16
+ 5. complete_task (T016) - from mcp_server/tools/complete_task.py
17
+ 6. complete_all_tasks (T017) - from mcp_server/tools/complete_all_tasks.py
18
+ 7. delete_all_tasks (T018) - from mcp_server/tools/delete_all_tasks.py
19
+
20
+ [From]: specs/010-chatkit-migration/research.md - Section 7 (Tool Visualization Support)
21
+ """
22
+ import json
23
+ import logging
24
+ from typing import Optional
25
+
26
+ from agents import function_tool, RunContextWrapper
27
+
28
+ # Import MCP tools
29
+ # Note: We import the actual async functions from MCP tools
30
+ import sys
31
+ import os
32
+ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
33
+
34
+ from mcp_server.tools.add_task import add_task as mcp_add_task
35
+ from mcp_server.tools.list_tasks import list_tasks as mcp_list_tasks
36
+ from mcp_server.tools.update_task import update_task as mcp_update_task
37
+ from mcp_server.tools.delete_task import delete_task as mcp_delete_task
38
+ from mcp_server.tools.complete_task import complete_task as mcp_complete_task
39
+ from mcp_server.tools.complete_all_tasks import complete_all_tasks as mcp_complete_all_tasks
40
+ from mcp_server.tools.delete_all_tasks import delete_all_tasks as mcp_delete_all_tasks
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ # =============================================================================
46
+ # Agents SDK Function Wrappers
47
+ # =============================================================================
48
+
49
+ @function_tool(
50
+ name="create_task",
51
+ description_override="Create a new task in the user's todo list. Use this when the user wants to create, add, or remind themselves about a task. Parameters: title (required), description (optional), due_date (optional, ISO 8601 or relative), priority (optional: low/medium/high), tags (optional list)."
52
+ )
53
+ async def create_task_tool(
54
+ ctx: RunContextWrapper,
55
+ title: str,
56
+ description: Optional[str] = None,
57
+ due_date: Optional[str] = None,
58
+ priority: Optional[str] = None,
59
+ tags: Optional[list[str]] = None,
60
+ ) -> str:
61
+ """Create a task via MCP tool.
62
+
63
+ [From]: specs/010-chatkit-migration/contracts/backend.md - Tool Contracts
64
+ [Task]: T012
65
+
66
+ Args:
67
+ ctx: Agents SDK run context containing user_id
68
+ title: Task title
69
+ description: Optional task description
70
+ due_date: Optional due date
71
+ priority: Optional priority level
72
+ tags: Optional list of tags
73
+
74
+ Returns:
75
+ JSON string with created task details
76
+ """
77
+ user_id = ctx.context.user_id
78
+
79
+ try:
80
+ result = await mcp_add_task(
81
+ user_id=user_id,
82
+ title=title,
83
+ description=description,
84
+ due_date=due_date,
85
+ priority=priority,
86
+ tags=tags or []
87
+ )
88
+ logger.info(f"Task created: {result['task']['id']} for user {user_id}")
89
+ return json.dumps(result)
90
+ except Exception as e:
91
+ logger.error(f"Failed to create task for user {user_id}: {e}")
92
+ return json.dumps({"success": False, "error": str(e)})
93
+
94
+
95
+ @function_tool(
96
+ name="list_tasks",
97
+ description_override="List all tasks for the user, optionally filtered by completion status, priority, tag, or due date range. Returns a list of tasks with their details."
98
+ )
99
+ async def list_tasks_tool(
100
+ ctx: RunContextWrapper,
101
+ completed: Optional[bool] = None,
102
+ priority: Optional[str] = None,
103
+ tag: Optional[str] = None,
104
+ due_before: Optional[str] = None,
105
+ due_after: Optional[str] = None,
106
+ ) -> str:
107
+ """List tasks via MCP tool.
108
+
109
+ [From]: specs/010-chatkit-migration/contracts/backend.md - Tool Contracts
110
+ [Task]: T013
111
+
112
+ Args:
113
+ ctx: Agents SDK run context containing user_id
114
+ completed: Optional filter by completion status
115
+ priority: Optional filter by priority
116
+ tag: Optional filter by tag
117
+ due_before: Optional due date upper bound
118
+ due_after: Optional due date lower bound
119
+
120
+ Returns:
121
+ JSON string with list of tasks
122
+ """
123
+ user_id = ctx.context.user_id
124
+
125
+ try:
126
+ result = await mcp_list_tasks(
127
+ user_id=user_id,
128
+ completed=completed,
129
+ priority=priority,
130
+ tag=tag,
131
+ due_before=due_before,
132
+ due_after=due_after
133
+ )
134
+ task_count = len(result.get("tasks", []))
135
+ logger.info(f"Listed {task_count} tasks for user {user_id}")
136
+ return json.dumps(result)
137
+ except Exception as e:
138
+ logger.error(f"Failed to list tasks for user {user_id}: {e}")
139
+ return json.dumps({"success": False, "error": str(e), "tasks": []})
140
+
141
+
142
+ @function_tool(
143
+ name="update_task",
144
+ description_override="Update an existing task. Parameters: task_id (required), title (optional), description (optional), due_date (optional), priority (optional), tags (optional)."
145
+ )
146
+ async def update_task_tool(
147
+ ctx: RunContextWrapper,
148
+ task_id: str,
149
+ title: Optional[str] = None,
150
+ description: Optional[str] = None,
151
+ due_date: Optional[str] = None,
152
+ priority: Optional[str] = None,
153
+ tags: Optional[list[str]] = None,
154
+ ) -> str:
155
+ """Update a task via MCP tool.
156
+
157
+ [From]: specs/010-chatkit-migration/contracts/backend.md - Tool Contracts
158
+ [Task]: T014
159
+
160
+ Args:
161
+ ctx: Agents SDK run context containing user_id
162
+ task_id: Task ID to update
163
+ title: Optional new title
164
+ description: Optional new description
165
+ due_date: Optional new due date
166
+ priority: Optional new priority
167
+ tags: Optional new tag list
168
+
169
+ Returns:
170
+ JSON string with updated task details
171
+ """
172
+ user_id = ctx.context.user_id
173
+
174
+ try:
175
+ result = await mcp_update_task(
176
+ user_id=user_id,
177
+ task_id=task_id,
178
+ title=title,
179
+ description=description,
180
+ due_date=due_date,
181
+ priority=priority,
182
+ tags=tags
183
+ )
184
+ logger.info(f"Task updated: {task_id} for user {user_id}")
185
+ return json.dumps(result)
186
+ except Exception as e:
187
+ logger.error(f"Failed to update task {task_id} for user {user_id}: {e}")
188
+ return json.dumps({"success": False, "error": str(e)})
189
+
190
+
191
+ @function_tool(
192
+ name="delete_task",
193
+ description_override="Delete a task permanently. Parameters: task_id (required)."
194
+ )
195
+ async def delete_task_tool(
196
+ ctx: RunContextWrapper,
197
+ task_id: str,
198
+ ) -> str:
199
+ """Delete a task via MCP tool.
200
+
201
+ [From]: specs/010-chatkit-migration/contracts/backend.md - Tool Contracts
202
+ [Task]: T015
203
+
204
+ Args:
205
+ ctx: Agents SDK run context containing user_id
206
+ task_id: Task ID to delete
207
+
208
+ Returns:
209
+ JSON string with deletion confirmation
210
+ """
211
+ user_id = ctx.context.user_id
212
+
213
+ try:
214
+ result = await mcp_delete_task(
215
+ user_id=user_id,
216
+ task_id=task_id
217
+ )
218
+ logger.info(f"Task deleted: {task_id} for user {user_id}")
219
+ return json.dumps(result)
220
+ except Exception as e:
221
+ logger.error(f"Failed to delete task {task_id} for user {user_id}: {e}")
222
+ return json.dumps({"success": False, "error": str(e)})
223
+
224
+
225
+ @function_tool(
226
+ name="complete_task",
227
+ description_override="Mark a task as completed or incomplete. Parameters: task_id (required), completed (boolean, required)."
228
+ )
229
+ async def complete_task_tool(
230
+ ctx: RunContextWrapper,
231
+ task_id: str,
232
+ completed: bool,
233
+ ) -> str:
234
+ """Complete/uncomplete a task via MCP tool.
235
+
236
+ [From]: specs/010-chatkit-migration/contracts/backend.md - Tool Contracts
237
+ [Task]: T016
238
+
239
+ Args:
240
+ ctx: Agents SDK run context containing user_id
241
+ task_id: Task ID to toggle
242
+ completed: Whether task is completed (true) or not (false)
243
+
244
+ Returns:
245
+ JSON string with updated task details
246
+ """
247
+ user_id = ctx.context.user_id
248
+
249
+ try:
250
+ result = await mcp_complete_task(
251
+ user_id=user_id,
252
+ task_id=task_id,
253
+ completed=completed
254
+ )
255
+ logger.info(f"Task completion updated: {task_id} -> {completed} for user {user_id}")
256
+ return json.dumps(result)
257
+ except Exception as e:
258
+ logger.error(f"Failed to update completion for task {task_id} for user {user_id}: {e}")
259
+ return json.dumps({"success": False, "error": str(e)})
260
+
261
+
262
+ @function_tool(
263
+ name="complete_all_tasks",
264
+ description_override="Mark all tasks as completed. Parameters: confirm (boolean, required - must be true to execute)."
265
+ )
266
+ async def complete_all_tasks_tool(
267
+ ctx: RunContextWrapper,
268
+ confirm: bool,
269
+ ) -> str:
270
+ """Complete all tasks via MCP tool.
271
+
272
+ [From]: specs/010-chatkit-migration/contracts/backend.md - Tool Contracts
273
+ [Task]: T017
274
+
275
+ Args:
276
+ ctx: Agents SDK run context containing user_id
277
+ confirm: Must be true to execute this destructive operation
278
+
279
+ Returns:
280
+ JSON string with bulk completion results
281
+ """
282
+ user_id = ctx.context.user_id
283
+
284
+ if not confirm:
285
+ return json.dumps({
286
+ "success": False,
287
+ "error": "Confirmation required. Set confirm=true to complete all tasks."
288
+ })
289
+
290
+ try:
291
+ result = await mcp_complete_all_tasks(
292
+ user_id=user_id
293
+ )
294
+ completed_count = result.get("completed_count", 0)
295
+ logger.info(f"Completed {completed_count} tasks for user {user_id}")
296
+ return json.dumps(result)
297
+ except Exception as e:
298
+ logger.error(f"Failed to complete all tasks for user {user_id}: {e}")
299
+ return json.dumps({"success": False, "error": str(e)})
300
+
301
+
302
+ @function_tool(
303
+ name="delete_all_tasks",
304
+ description_override="Delete all tasks permanently. Parameters: confirm (boolean, required - must be true to execute)."
305
+ )
306
+ async def delete_all_tasks_tool(
307
+ ctx: RunContextWrapper,
308
+ confirm: bool,
309
+ ) -> str:
310
+ """Delete all tasks via MCP tool.
311
+
312
+ [From]: specs/010-chatkit-migration/contracts/backend.md - Tool Contracts
313
+ [Task]: T018
314
+
315
+ Args:
316
+ ctx: Agents SDK run context containing user_id
317
+ confirm: Must be true to execute this destructive operation
318
+
319
+ Returns:
320
+ JSON string with bulk deletion results
321
+ """
322
+ user_id = ctx.context.user_id
323
+
324
+ if not confirm:
325
+ return json.dumps({
326
+ "success": False,
327
+ "error": "Confirmation required. Set confirm=true to delete all tasks."
328
+ })
329
+
330
+ try:
331
+ result = await mcp_delete_all_tasks(
332
+ user_id=user_id
333
+ )
334
+ deleted_count = result.get("deleted_count", 0)
335
+ logger.info(f"Deleted {deleted_count} tasks for user {user_id}")
336
+ return json.dumps(result)
337
+ except Exception as e:
338
+ logger.error(f"Failed to delete all tasks for user {user_id}: {e}")
339
+ return json.dumps({"success": False, "error": str(e)})
340
+
341
+
342
+ # =============================================================================
343
+ # Tool List for Agent Configuration
344
+ # =============================================================================
345
+
346
+ # Export all tool functions for easy import
347
+ TOOL_FUNCTIONS = [
348
+ create_task_tool,
349
+ list_tasks_tool,
350
+ update_task_tool,
351
+ delete_task_tool,
352
+ complete_task_tool,
353
+ complete_all_tasks_tool,
354
+ delete_all_tasks_tool,
355
+ ]
356
+
357
+
358
+ def get_tool_names() -> list[str]:
359
+ """Get list of all tool names.
360
+
361
+ [From]: specs/010-chatkit-migration/tasks.md - T019
362
+
363
+ Returns:
364
+ List of tool function names
365
+ """
366
+ return [tool.name for tool in TOOL_FUNCTIONS]
api/chat.py CHANGED
@@ -13,7 +13,7 @@ import logging
13
  import asyncio
14
  from datetime import datetime
15
  from typing import Annotated, Optional
16
- from fastapi import APIRouter, HTTPException, status, Depends, WebSocket, WebSocketDisconnect, BackgroundTasks
17
  from pydantic import BaseModel, Field, field_validator, ValidationError
18
  from sqlmodel import Session
19
  from sqlalchemy.exc import SQLAlchemyError
@@ -24,14 +24,13 @@ from core.security import decode_access_token
24
  from models.message import Message, MessageRole
25
  from services.security import sanitize_message
26
  from models.conversation import Conversation
27
- from ai_agent import run_agent_with_streaming, is_gemini_configured
28
  from services.conversation import (
29
  get_or_create_conversation,
30
  load_conversation_history,
31
  update_conversation_timestamp
32
  )
33
  from services.rate_limiter import check_rate_limit
34
- from ws_manager.manager import manager
35
 
36
 
37
  # Configure error logger
@@ -290,12 +289,12 @@ async def chat(
290
  {"role": "user", "content": sanitized_message}
291
  ]
292
 
293
- # Run AI agent with streaming (broadcasts WebSocket events)
294
  # [From]: T014 - Initialize OpenAI Agents SDK with Gemini
295
- # [From]: T072 - Use streaming agent for real-time progress
296
  # [From]: T060 - Add comprehensive error messages for edge cases
297
  try:
298
- ai_response_text = await run_agent_with_streaming(
299
  messages=messages_for_agent,
300
  user_id=user_id
301
  )
@@ -407,72 +406,277 @@ async def chat(
407
  )
408
 
409
 
410
- @router.websocket("/ws/{user_id}/chat")
411
- async def websocket_chat(
412
- websocket: WebSocket,
413
- user_id: str,
414
- db: Session = Depends(get_db)
415
- ):
416
- """WebSocket endpoint for real-time chat progress updates.
417
 
418
- [From]: specs/004-ai-chatbot/research.md - Section 4
419
- [Task]: T071
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
 
421
- This endpoint provides a WebSocket connection for receiving real-time
422
- progress events during AI agent execution. Events include:
423
- - connection_established: Confirmation of successful connection
424
- - agent_thinking: AI agent is processing
425
- - tool_starting: A tool is about to execute
426
- - tool_progress: Tool execution progress (e.g., "Found 3 tasks")
427
- - tool_complete: Tool finished successfully
428
- - tool_error: Tool execution failed
429
- - agent_done: AI agent finished processing
430
 
431
- Note: Authentication is handled implicitly by the frontend - users must
432
- be logged in to access the chat page. The WebSocket only broadcasts
433
- progress updates (not sensitive data), so strict auth is bypassed here.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
 
435
- Connection URL format:
436
- ws://localhost:8000/ws/{user_id}/chat
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
437
 
438
- Args:
439
- websocket: The WebSocket connection instance
440
- user_id: User ID from URL path (used to route progress events)
441
- db: Database session (for any future DB operations)
442
-
443
- The connection is kept alive and can receive messages from the client,
444
- though currently it's primarily used for server-to-client progress updates.
445
- """
446
- # Connect the WebSocket (manager handles accept)
447
- # [From]: specs/004-ai-chatbot/research.md - Section 4
448
- await manager.connect(user_id, websocket)
449
 
 
450
  try:
451
- # Keep connection alive and listen for client messages
452
- # Currently, we don't expect many client messages, but we
453
- # maintain the connection to receive any control messages
454
- while True:
455
- # Wait for message from client (with timeout)
456
- data = await websocket.receive_text()
457
-
458
- # Handle client messages if needed
459
- # For now, we just acknowledge receipt
460
- # Future: could handle ping/pong for connection health
461
- if data:
462
- # Echo back a simple acknowledgment
463
- # (optional - mainly for debugging)
464
- pass
465
-
466
- except WebSocketDisconnect:
467
- # Normal disconnect - clean up
468
- manager.disconnect(user_id, websocket)
469
- error_logger.info(f"WebSocket disconnected normally for user {user_id}")
470
 
 
 
 
471
  except Exception as e:
472
- # Unexpected error - clean up and log
473
- error_logger.error(f"WebSocket error for user {user_id}: {e}")
474
- manager.disconnect(user_id, websocket)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
 
476
- finally:
477
- # Ensure disconnect is always called
478
- manager.disconnect(user_id, websocket)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  import asyncio
14
  from datetime import datetime
15
  from typing import Annotated, Optional
16
+ from fastapi import APIRouter, HTTPException, status, Depends, BackgroundTasks, Request
17
  from pydantic import BaseModel, Field, field_validator, ValidationError
18
  from sqlmodel import Session
19
  from sqlalchemy.exc import SQLAlchemyError
 
24
  from models.message import Message, MessageRole
25
  from services.security import sanitize_message
26
  from models.conversation import Conversation
27
+ from ai_agent import run_agent, is_gemini_configured
28
  from services.conversation import (
29
  get_or_create_conversation,
30
  load_conversation_history,
31
  update_conversation_timestamp
32
  )
33
  from services.rate_limiter import check_rate_limit
 
34
 
35
 
36
  # Configure error logger
 
289
  {"role": "user", "content": sanitized_message}
290
  ]
291
 
292
+ # Run AI agent (non-streaming for legacy endpoint)
293
  # [From]: T014 - Initialize OpenAI Agents SDK with Gemini
294
+ # NOTE: Streaming is now handled by ChatKit SSE endpoint
295
  # [From]: T060 - Add comprehensive error messages for edge cases
296
  try:
297
+ ai_response_text = await run_agent(
298
  messages=messages_for_agent,
299
  user_id=user_id
300
  )
 
406
  )
407
 
408
 
409
+ # ============================================================================
410
+ # ChatKit SSE Endpoint (Phase 010-chatkit-migration)
411
+ # ============================================================================
 
 
 
 
412
 
413
+ @router.post("/chatkit")
414
+ async def chatkit_endpoint(
415
+ request: Request, # Starlette Request object for raw body access
416
+ background_tasks: BackgroundTasks,
417
+ ):
418
+ """ChatKit SSE endpoint for streaming chat with Gemini LLM.
419
+
420
+ [Task]: T011
421
+ [From]: specs/010-chatkit-migration/contracts/backend.md - ChatKit SSE Endpoint
422
+
423
+ This endpoint implements the ChatKit protocol using Server-Sent Events (SSE).
424
+ It replaces the WebSocket-based streaming with a simpler HTTP-based approach.
425
+
426
+ Endpoint: POST /api/chatkit
427
+ Response: Server-Sent Events (text/event-stream)
428
+
429
+ Authentication: JWT via httpOnly cookie (auth_token)
430
+
431
+ Request Body (ChatKit protocol):
432
+ {
433
+ "event": "conversation_item_created",
434
+ "conversation_id": "<thread_uuid>",
435
+ "item": {
436
+ "type": "message",
437
+ "role": "user",
438
+ "content": [{"type": "text", "text": "Your message here"}]
439
+ }
440
+ }
441
+
442
+ SSE Event Types:
443
+ - message_delta: Streaming text content
444
+ - tool_call_created: Tool invocation started
445
+ - tool_call_done: Tool execution completed
446
+ - message_done: Message fully streamed
447
+ - error: Error occurred
448
+
449
+ [From]: specs/010-chatkit-migration/research.md - Section 4
450
+ """
451
+ from fastapi import Response
452
+ from fastapi.responses import StreamingResponse
453
+ from starlette.requests import Request as StarletteRequest
454
 
455
+ # Import for authentication
456
+ from core.security import get_current_user_id_from_cookie
 
 
 
 
 
 
 
457
 
458
+ # Get authenticated user ID from JWT cookie
459
+ # [From]: specs/010-chatkit-migration/contracts/backend.md - Authentication
460
+ try:
461
+ user_id = await get_current_user_id_from_cookie(request)
462
+ if not user_id:
463
+ # Return error as SSE event
464
+ async def error_stream():
465
+ yield "event: error\n"
466
+ yield 'data: {"detail": "Invalid authentication"}\n\n'
467
+ return StreamingResponse(
468
+ error_stream(),
469
+ media_type="text/event-stream",
470
+ status_code=401
471
+ )
472
+ except Exception as e:
473
+ error_logger.error(f"Auth error in ChatKit endpoint: {e}")
474
+ async def error_stream():
475
+ yield "event: error\n"
476
+ yield f'data: {{"detail": "Authentication failed"}}\n\n'
477
+ return StreamingResponse(
478
+ error_stream(),
479
+ media_type="text/event-stream",
480
+ status_code=401
481
+ )
482
 
483
+ # Check rate limit before processing
484
+ # [From]: specs/010-chatkit-migration/tasks.md - T020
485
+ # [From]: specs/010-chatkit-migration/spec.md - FR-015
486
+ try:
487
+ from uuid import UUID
488
+ from core.database import engine
489
+ from sqlmodel import Session
490
+
491
+ # Create synchronous session for rate limit check
492
+ with Session(engine) as db:
493
+ allowed, remaining, reset_time = check_rate_limit(db, UUID(user_id))
494
+
495
+ if not allowed:
496
+ # Rate limit exceeded
497
+ async def rate_limit_stream():
498
+ yield "event: error\n"
499
+ import json
500
+ yield f'data: {json.dumps({"detail": "Daily message limit reached", "limit": 100, "resets_at": reset_time.isoformat() if reset_time else None})}\n\n'
501
+ return StreamingResponse(
502
+ rate_limit_stream(),
503
+ media_type="text/event-stream",
504
+ status_code=429
505
+ )
506
+ except HTTPException:
507
+ # Re-raise HTTP exceptions (rate limit errors)
508
+ raise
509
+ except Exception as e:
510
+ # Log unexpected errors but don't block the request
511
+ error_logger.error(f"Rate limit check failed for ChatKit endpoint: {e}")
512
+ # Continue processing - fail open for rate limit errors
513
 
514
+ # Create ChatKit server with synchronous database operations
515
+ # [From]: specs/010-chatkit-migration/contracts/backend.md - Store Interface Implementation
516
+ import json
 
 
 
 
 
 
 
 
517
 
518
+ # Parse request body
519
  try:
520
+ body = await request.body()
521
+ except Exception as e:
522
+ error_logger.error(f"Failed to read ChatKit request body: {e}")
523
+ async def error_stream():
524
+ yield "event: error\n"
525
+ yield f'data: {{"detail": "Invalid request format"}}\n\n'
526
+ return StreamingResponse(
527
+ error_stream(),
528
+ media_type="text/event-stream",
529
+ status_code=400
530
+ )
 
 
 
 
 
 
 
 
531
 
532
+ # Parse ChatKit protocol request
533
+ try:
534
+ request_data = json.loads(body.decode('utf-8'))
535
  except Exception as e:
536
+ error_logger.error(f"Failed to parse ChatKit request: {e}")
537
+ async def error_stream():
538
+ yield "event: error\n"
539
+ yield f'data: {{"detail": "Invalid JSON format"}}\n\n'
540
+ return StreamingResponse(
541
+ error_stream(),
542
+ media_type="text/event-stream",
543
+ status_code=400
544
+ )
545
+
546
+ # Extract thread ID and message content
547
+ conversation_id = request_data.get("conversation_id")
548
+ item = request_data.get("item", {})
549
+ event_type = request_data.get("event", "conversation_item_created")
550
+
551
+ # Extract user message
552
+ def extract_message_content(item_dict):
553
+ content_array = item_dict.get("content", [])
554
+ for content_block in content_array:
555
+ if content_block.get("type") == "text":
556
+ return content_block.get("text", "")
557
+ return ""
558
+
559
+ user_message = extract_message_content(item)
560
+ if not user_message:
561
+ async def error_stream():
562
+ yield "event: error\n"
563
+ yield f'data: {{"detail": "No message content provided"}}\n\n'
564
+ return StreamingResponse(
565
+ error_stream(),
566
+ media_type="text/event-stream",
567
+ status_code=400
568
+ )
569
+
570
+ # Use synchronous database session for thread/message operations
571
+ from core.database import engine
572
+ from sqlmodel import Session
573
+
574
+ # Get or create thread
575
+ thread_id = conversation_id
576
+ if not thread_id:
577
+ # Create new thread for first message
578
+ with Session(engine) as db:
579
+ from models.thread import Thread
580
+ import uuid
581
+ new_thread = Thread(
582
+ user_id=UUID(user_id),
583
+ title=None,
584
+ thread_metadata={}
585
+ )
586
+ db.add(new_thread)
587
+ db.commit()
588
+ db.refresh(new_thread)
589
+ thread_id = str(new_thread.id)
590
+ error_logger.info(f"Created new thread: {thread_id}")
591
+
592
+ # Save user message to database
593
+ with Session(engine) as db:
594
+ from models.message import Message, MessageRole
595
+ import uuid
596
+ user_msg = Message(
597
+ thread_id=UUID(thread_id) if thread_id else None,
598
+ user_id=UUID(user_id),
599
+ role=MessageRole.USER,
600
+ content=user_message,
601
+ )
602
+ db.add(user_msg)
603
+ db.commit()
604
+
605
+ # Run AI agent and stream response
606
+ # [Task]: T033 - Add SSE error handling for connection drops
607
+ async def stream_chat_response():
608
+ """Stream ChatKit events as SSE with timeout protection.
609
+
610
+ [Task]: T034 - Timeout handling for long-running tool executions
611
+ """
612
+ try:
613
+ # Import agent components
614
+ from ai_agent import run_agent
615
+
616
+ # Run agent with timeout protection
617
+ async with asyncio.timeout(120):
618
+ ai_response = await run_agent(
619
+ messages=[{"role": "user", "content": user_message}],
620
+ user_id=user_id
621
+ )
622
+
623
+ # Save assistant message to database
624
+ with Session(engine) as db:
625
+ from models.message import Message, MessageRole
626
+ import uuid
627
+ assistant_msg = Message(
628
+ thread_id=UUID(thread_id) if thread_id else None,
629
+ user_id=UUID(user_id),
630
+ role=MessageRole.ASSISTANT,
631
+ content=ai_response,
632
+ )
633
+ db.add(assistant_msg)
634
+ db.commit()
635
+
636
+ # Stream the response
637
+ import json
638
+ yield "event: message_delta\n"
639
+ yield f'data: {json.dumps({"type": "text", "text": ai_response})}\n\n'
640
+
641
+ # Send message done event
642
+ yield "event: message_done\n"
643
+ yield f'data: {json.dumps({"message_id": thread_id, "role": "assistant", "thread_id": thread_id})}\n\n'
644
+
645
+ except TimeoutError:
646
+ error_logger.error(f"Agent execution timeout for thread {thread_id}")
647
+ import json
648
+ yield "event: error\n"
649
+ yield f'data: {json.dumps({"detail": "Request timed out", "message": "The AI assistant took too long to respond. Please try again."})}\n\n'
650
+ except Exception as e:
651
+ error_logger.error(f"Agent execution error: {e}", exc_info=True)
652
+ import json
653
+ yield "event: error\n"
654
+ yield f'data: {json.dumps({"detail": "Processing error", "message": str(e)})}\n\n'
655
 
656
+ # [Task]: T033 - Wrap with connection-aware streaming
657
+ async def connection_aware_stream():
658
+ """Stream SSE events with connection drop detection.
659
+
660
+ [Task]: T033 - SSE error handling for connection drops
661
+ """
662
+ try:
663
+ async for chunk in stream_chat_response():
664
+ yield chunk
665
+ except (ConnectionError, OSError) as e:
666
+ # Client disconnected during streaming
667
+ error_logger.info(f"Client disconnected during ChatKit streaming: {e}")
668
+ except Exception as e:
669
+ # Unexpected streaming error
670
+ error_logger.error(f"Unexpected error during ChatKit streaming: {e}", exc_info=True)
671
+ yield "event: error\n"
672
+ yield f'data: {{"detail": "Streaming error", "message": str(e)}}\n\n'
673
+
674
+ return StreamingResponse(
675
+ connection_aware_stream(),
676
+ media_type="text/event-stream",
677
+ headers={
678
+ "Cache-Control": "no-cache",
679
+ "Connection": "keep-alive",
680
+ "X-Accel-Buffering": "no",
681
+ }
682
+ )
chat.py DELETED
@@ -1,478 +0,0 @@
1
- """Chat API endpoint for AI-powered task management.
2
-
3
- [Task]: T015, T071
4
- [From]: specs/004-ai-chatbot/tasks.md
5
-
6
- This endpoint provides a conversational interface for task management.
7
- Users can create, list, update, complete, and delete tasks through natural language.
8
-
9
- Also includes WebSocket endpoint for real-time progress streaming.
10
- """
11
- import uuid
12
- import logging
13
- import asyncio
14
- from datetime import datetime
15
- from typing import Annotated, Optional
16
- from fastapi import APIRouter, HTTPException, status, Depends, WebSocket, WebSocketDisconnect, BackgroundTasks
17
- from pydantic import BaseModel, Field, field_validator, ValidationError
18
- from sqlmodel import Session
19
- from sqlalchemy.exc import SQLAlchemyError
20
-
21
- from core.database import get_db
22
- from core.validators import validate_message_length
23
- from core.security import decode_access_token
24
- from models.message import Message, MessageRole
25
- from services.security import sanitize_message
26
- from models.conversation import Conversation
27
- from ai_agent import run_agent_with_streaming, is_gemini_configured
28
- from services.conversation import (
29
- get_or_create_conversation,
30
- load_conversation_history,
31
- update_conversation_timestamp
32
- )
33
- from services.rate_limiter import check_rate_limit
34
- from ws_manager.manager import manager
35
-
36
-
37
- # Configure error logger
38
- error_logger = logging.getLogger("api.errors")
39
- error_logger.setLevel(logging.ERROR)
40
-
41
-
42
- # Request/Response models
43
- class ChatRequest(BaseModel):
44
- """Request model for chat endpoint.
45
-
46
- [From]: specs/004-ai-chatbot/plan.md - API Contract
47
- """
48
- message: str = Field(
49
- ...,
50
- description="User message content",
51
- min_length=1,
52
- max_length=10000 # FR-042
53
- )
54
- conversation_id: Optional[str] = Field(
55
- None,
56
- description="Optional conversation ID to continue existing conversation"
57
- )
58
-
59
- @field_validator('message')
60
- @classmethod
61
- def validate_message(cls, v: str) -> str:
62
- """Validate message content."""
63
- if not v or not v.strip():
64
- raise ValueError("Message content cannot be empty")
65
- if len(v) > 10000:
66
- raise ValueError("Message content exceeds maximum length of 10,000 characters")
67
- return v.strip()
68
-
69
-
70
- class TaskReference(BaseModel):
71
- """Reference to a task created or modified by AI."""
72
- id: str
73
- title: str
74
- description: Optional[str] = None
75
- due_date: Optional[str] = None
76
- priority: Optional[str] = None
77
- completed: bool = False
78
-
79
-
80
- class ChatResponse(BaseModel):
81
- """Response model for chat endpoint.
82
-
83
- [From]: specs/004-ai-chatbot/plan.md - API Contract
84
- """
85
- response: str = Field(
86
- ...,
87
- description="AI assistant's text response"
88
- )
89
- conversation_id: str = Field(
90
- ...,
91
- description="Conversation ID (new or existing)"
92
- )
93
- tasks: list[TaskReference] = Field(
94
- default_factory=list,
95
- description="List of tasks created or modified in this interaction"
96
- )
97
-
98
-
99
- # Create API router
100
- router = APIRouter(prefix="/api", tags=["chat"])
101
-
102
-
103
- @router.post("/{user_id}/chat", response_model=ChatResponse, status_code=status.HTTP_200_OK)
104
- async def chat(
105
- user_id: str,
106
- request: ChatRequest,
107
- background_tasks: BackgroundTasks,
108
- db: Session = Depends(get_db)
109
- ):
110
- """Process user message through AI agent and return response.
111
-
112
- [From]: specs/004-ai-chatbot/spec.md - US1
113
-
114
- This endpoint:
115
- 1. Validates user input and rate limits
116
- 2. Gets or creates conversation
117
- 3. Runs AI agent with WebSocket progress streaming
118
- 4. Returns AI response immediately
119
- 5. Saves messages to DB in background (non-blocking)
120
-
121
- Args:
122
- user_id: User ID (UUID string from path)
123
- request: Chat request with message and optional conversation_id
124
- background_tasks: FastAPI background tasks for non-blocking DB saves
125
- db: Database session
126
-
127
- Returns:
128
- ChatResponse with AI response, conversation_id, and task references
129
-
130
- Raises:
131
- HTTPException 400: Invalid message content
132
- HTTPException 503: AI service unavailable
133
- """
134
- # Check if Gemini API is configured
135
- # [From]: specs/004-ai-chatbot/tasks.md - T022
136
- # [From]: T060 - Add comprehensive error messages for edge cases
137
- if not is_gemini_configured():
138
- raise HTTPException(
139
- status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
140
- detail={
141
- "error": "AI service unavailable",
142
- "message": "The AI service is currently not configured. Please ensure GEMINI_API_KEY is set in the environment.",
143
- "suggestion": "Contact your administrator or check your API key configuration."
144
- }
145
- )
146
-
147
- # Validate user_id format
148
- # [From]: T060 - Add comprehensive error messages for edge cases
149
- try:
150
- user_uuid = uuid.UUID(user_id)
151
- except ValueError:
152
- raise HTTPException(
153
- status_code=status.HTTP_400_BAD_REQUEST,
154
- detail={
155
- "error": "Invalid user ID",
156
- "message": f"User ID '{user_id}' is not a valid UUID format.",
157
- "expected_format": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx",
158
- "suggestion": "Ensure you are using a valid UUID for the user_id path parameter."
159
- }
160
- )
161
-
162
- # Validate message content
163
- # [From]: T060 - Add comprehensive error messages for edge cases
164
- try:
165
- validated_message = validate_message_length(request.message)
166
- except ValueError as e:
167
- raise HTTPException(
168
- status_code=status.HTTP_400_BAD_REQUEST,
169
- detail={
170
- "error": "Message validation failed",
171
- "message": str(e),
172
- "max_length": 10000,
173
- "suggestion": "Keep your message under 10,000 characters and ensure it contains meaningful content."
174
- }
175
- )
176
-
177
- # Sanitize message to prevent prompt injection
178
- # [From]: T057 - Implement prompt injection sanitization
179
- # [From]: T060 - Add comprehensive error messages for edge cases
180
- try:
181
- sanitized_message = sanitize_message(validated_message)
182
- except ValueError as e:
183
- raise HTTPException(
184
- status_code=status.HTTP_400_BAD_REQUEST,
185
- detail={
186
- "error": "Message content blocked",
187
- "message": str(e),
188
- "suggestion": "Please rephrase your message without attempting to manipulate system instructions."
189
- }
190
- )
191
-
192
- # Check rate limit
193
- # [From]: specs/004-ai-chatbot/spec.md - NFR-011
194
- # [From]: T021 - Implement daily message limit enforcement (100/day)
195
- # [From]: T060 - Add comprehensive error messages for edge cases
196
- try:
197
- allowed, remaining, reset_time = check_rate_limit(db, user_uuid)
198
-
199
- if not allowed:
200
- raise HTTPException(
201
- status_code=status.HTTP_429_TOO_MANY_REQUESTS,
202
- detail={
203
- "error": "Rate limit exceeded",
204
- "message": "You have reached the daily message limit. Please try again later.",
205
- "limit": 100,
206
- "resets_at": reset_time.isoformat() if reset_time else None,
207
- "suggestion": "Free tier accounts are limited to 100 messages per day. Upgrade for unlimited access."
208
- }
209
- )
210
- except HTTPException:
211
- # Re-raise HTTP exceptions (rate limit errors)
212
- raise
213
- except Exception as e:
214
- # Log unexpected errors but don't block the request
215
- error_logger.error(f"Rate limit check failed for user {user_id}: {e}")
216
- # Continue processing - fail open for rate limit errors
217
-
218
- # Get or create conversation
219
- # [From]: T016 - Implement conversation history loading
220
- # [From]: T035 - Handle auto-deleted conversations gracefully
221
- # [From]: T060 - Add comprehensive error messages for edge cases
222
- conversation_id: uuid.UUID
223
-
224
- if request.conversation_id:
225
- # Load existing conversation using service
226
- try:
227
- conv_uuid = uuid.UUID(request.conversation_id)
228
- except ValueError:
229
- # Invalid conversation_id format
230
- raise HTTPException(
231
- status_code=status.HTTP_400_BAD_REQUEST,
232
- detail={
233
- "error": "Invalid conversation ID",
234
- "message": f"Conversation ID '{request.conversation_id}' is not a valid UUID format.",
235
- "suggestion": "Provide a valid UUID or omit the conversation_id to start a new conversation."
236
- }
237
- )
238
-
239
- try:
240
- conversation = get_or_create_conversation(
241
- db=db,
242
- user_id=user_uuid,
243
- conversation_id=conv_uuid
244
- )
245
- conversation_id = conversation.id
246
- except (KeyError, ValueError) as e:
247
- # Conversation may have been auto-deleted (90-day policy) or otherwise not found
248
- # [From]: T035 - Handle auto-deleted conversations gracefully
249
- # Create a new conversation instead of failing
250
- conversation = get_or_create_conversation(
251
- db=db,
252
- user_id=user_uuid
253
- )
254
- conversation_id = conversation.id
255
- else:
256
- # Create new conversation using service
257
- conversation = get_or_create_conversation(
258
- db=db,
259
- user_id=user_uuid
260
- )
261
- conversation_id = conversation.id
262
-
263
- # Load conversation history using service
264
- # [From]: T016 - Implement conversation history loading
265
- # [From]: T060 - Add comprehensive error messages for edge cases
266
- try:
267
- conversation_history = load_conversation_history(
268
- db=db,
269
- conversation_id=conversation_id
270
- )
271
- except SQLAlchemyError as e:
272
- error_logger.error(f"Database error loading conversation history for {conversation_id}: {e}")
273
- # Continue with empty history if load fails
274
- conversation_history = []
275
-
276
- # Prepare user message for background save
277
- user_message_id = uuid.uuid4()
278
- user_message_data = {
279
- "id": user_message_id,
280
- "conversation_id": conversation_id,
281
- "user_id": user_uuid,
282
- "role": MessageRole.USER,
283
- "content": sanitized_message,
284
- "created_at": datetime.utcnow()
285
- }
286
-
287
- # Add current user message to conversation history for AI processing
288
- # This is critical - the agent needs the user's current message in context
289
- messages_for_agent = conversation_history + [
290
- {"role": "user", "content": sanitized_message}
291
- ]
292
-
293
- # Run AI agent with streaming (broadcasts WebSocket events)
294
- # [From]: T014 - Initialize OpenAI Agents SDK with Gemini
295
- # [From]: T072 - Use streaming agent for real-time progress
296
- # [From]: T060 - Add comprehensive error messages for edge cases
297
- try:
298
- ai_response_text = await run_agent_with_streaming(
299
- messages=messages_for_agent,
300
- user_id=user_id
301
- )
302
- except ValueError as e:
303
- # Configuration errors (missing API key, invalid model)
304
- # [From]: T022 - Add error handling for Gemini API unavailability
305
- error_logger.error(f"AI configuration error for user {user_id}: {e}")
306
- raise HTTPException(
307
- status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
308
- detail={
309
- "error": "AI service configuration error",
310
- "message": str(e),
311
- "suggestion": "Verify GEMINI_API_KEY and GEMINI_MODEL are correctly configured."
312
- }
313
- )
314
- except ConnectionError as e:
315
- # Network/connection issues
316
- # [From]: T022 - Add error handling for Gemini API unavailability
317
- error_logger.error(f"AI connection error for user {user_id}: {e}")
318
- raise HTTPException(
319
- status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
320
- detail={
321
- "error": "AI service unreachable",
322
- "message": "Could not connect to the AI service. Please check your network connection.",
323
- "suggestion": "If the problem persists, the AI service may be temporarily down."
324
- }
325
- )
326
- except TimeoutError as e:
327
- # Timeout errors
328
- # [From]: T022 - Add error handling for Gemini API unavailability
329
- error_logger.error(f"AI timeout error for user {user_id}: {e}")
330
- raise HTTPException(
331
- status_code=status.HTTP_504_GATEWAY_TIMEOUT,
332
- detail={
333
- "error": "AI service timeout",
334
- "message": "The AI service took too long to respond. Please try again.",
335
- "suggestion": "Your message may be too complex. Try breaking it into smaller requests."
336
- }
337
- )
338
- except Exception as e:
339
- # Other errors (rate limits, authentication, context, etc.)
340
- # [From]: T022 - Add error handling for Gemini API unavailability
341
- error_logger.error(f"Unexpected AI error for user {user_id}: {type(e).__name__}: {e}")
342
- raise HTTPException(
343
- status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
344
- detail={
345
- "error": "AI service error",
346
- "message": f"An unexpected error occurred: {str(e)}",
347
- "suggestion": "Please try again later or contact support if the problem persists."
348
- }
349
- )
350
-
351
- # Prepare AI response for background save
352
- ai_message_data = {
353
- "id": uuid.uuid4(),
354
- "conversation_id": conversation_id,
355
- "user_id": user_uuid,
356
- "role": MessageRole.ASSISTANT,
357
- "content": ai_response_text,
358
- "created_at": datetime.utcnow()
359
- }
360
-
361
- # Save messages to DB in background (non-blocking)
362
- # This significantly improves response time
363
- def save_messages_to_db():
364
- """Background task to save messages to database."""
365
- try:
366
- from core.database import engine
367
- from sqlmodel import Session
368
-
369
- # Create a new session for background task
370
- bg_db = Session(engine)
371
-
372
- try:
373
- # Save user message
374
- user_msg = Message(**user_message_data)
375
- bg_db.add(user_msg)
376
-
377
- # Save AI response
378
- ai_msg = Message(**ai_message_data)
379
- bg_db.add(ai_msg)
380
-
381
- bg_db.commit()
382
-
383
- # Update conversation timestamp
384
- try:
385
- update_conversation_timestamp(db=bg_db, conversation_id=conversation_id)
386
- except SQLAlchemyError as e:
387
- error_logger.error(f"Database error updating conversation timestamp for {conversation_id}: {e}")
388
-
389
- except SQLAlchemyError as e:
390
- error_logger.error(f"Background task: Database error saving messages for user {user_id}: {e}")
391
- bg_db.rollback()
392
- finally:
393
- bg_db.close()
394
- except Exception as e:
395
- error_logger.error(f"Background task: Unexpected error saving messages for user {user_id}: {e}")
396
-
397
- background_tasks.add_task(save_messages_to_db)
398
-
399
- # TODO: Parse AI response for task references
400
- # This will be enhanced in future tasks to extract task IDs from AI responses
401
- task_references: list[TaskReference] = []
402
-
403
- return ChatResponse(
404
- response=ai_response_text,
405
- conversation_id=str(conversation_id),
406
- tasks=task_references
407
- )
408
-
409
-
410
- @router.websocket("/ws/{user_id}/chat")
411
- async def websocket_chat(
412
- websocket: WebSocket,
413
- user_id: str,
414
- db: Session = Depends(get_db)
415
- ):
416
- """WebSocket endpoint for real-time chat progress updates.
417
-
418
- [From]: specs/004-ai-chatbot/research.md - Section 4
419
- [Task]: T071
420
-
421
- This endpoint provides a WebSocket connection for receiving real-time
422
- progress events during AI agent execution. Events include:
423
- - connection_established: Confirmation of successful connection
424
- - agent_thinking: AI agent is processing
425
- - tool_starting: A tool is about to execute
426
- - tool_progress: Tool execution progress (e.g., "Found 3 tasks")
427
- - tool_complete: Tool finished successfully
428
- - tool_error: Tool execution failed
429
- - agent_done: AI agent finished processing
430
-
431
- Note: Authentication is handled implicitly by the frontend - users must
432
- be logged in to access the chat page. The WebSocket only broadcasts
433
- progress updates (not sensitive data), so strict auth is bypassed here.
434
-
435
- Connection URL format:
436
- ws://localhost:8000/ws/{user_id}/chat
437
-
438
- Args:
439
- websocket: The WebSocket connection instance
440
- user_id: User ID from URL path (used to route progress events)
441
- db: Database session (for any future DB operations)
442
-
443
- The connection is kept alive and can receive messages from the client,
444
- though currently it's primarily used for server-to-client progress updates.
445
- """
446
- # Connect the WebSocket (manager handles accept)
447
- # [From]: specs/004-ai-chatbot/research.md - Section 4
448
- await manager.connect(user_id, websocket)
449
-
450
- try:
451
- # Keep connection alive and listen for client messages
452
- # Currently, we don't expect many client messages, but we
453
- # maintain the connection to receive any control messages
454
- while True:
455
- # Wait for message from client (with timeout)
456
- data = await websocket.receive_text()
457
-
458
- # Handle client messages if needed
459
- # For now, we just acknowledge receipt
460
- # Future: could handle ping/pong for connection health
461
- if data:
462
- # Echo back a simple acknowledgment
463
- # (optional - mainly for debugging)
464
- pass
465
-
466
- except WebSocketDisconnect:
467
- # Normal disconnect - clean up
468
- manager.disconnect(user_id, websocket)
469
- error_logger.info(f"WebSocket disconnected normally for user {user_id}")
470
-
471
- except Exception as e:
472
- # Unexpected error - clean up and log
473
- error_logger.error(f"WebSocket error for user {user_id}: {e}")
474
- manager.disconnect(user_id, websocket)
475
-
476
- finally:
477
- # Ensure disconnect is always called
478
- manager.disconnect(user_id, websocket)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
chatkit_server.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ChatKit Server implementation for task management with Gemini LLM.
2
+
3
+ [Task]: T010
4
+ [From]: specs/010-chatkit-migration/contracts/backend.md - ChatKitServer Implementation
5
+ [From]: specs/010-chatkit-migration/research.md - Section 3
6
+
7
+ This module implements the ChatKitServer class which handles ChatKit protocol
8
+ requests and streams responses using Server-Sent Events (SSE).
9
+
10
+ The server integrates:
11
+ - ChatKit Python SDK for protocol handling
12
+ - OpenAI Agents SDK for agent orchestration
13
+ - Gemini LLM via OpenAI-compatible endpoint
14
+ - MCP tools wrapped as Agents SDK functions
15
+
16
+ Architecture:
17
+ Frontend (ChatKit.js)
18
+ ↓ SSE with custom fetch
19
+ ChatKitServer (this module)
20
+ ↓ Agents SDK
21
+ Gemini API (via AsyncOpenAI with base_url)
22
+ """
23
+ import asyncio
24
+ import logging
25
+ from typing import Any, AsyncIterator, Optional
26
+ from uuid import UUID
27
+ from openai import AsyncOpenAI
28
+ from agents import Agent, set_default_openai_client, RunContextWrapper, Runner
29
+
30
+ from core.config import get_gemini_client, get_settings
31
+ from services.chatkit_store import PostgresChatKitStore
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ class AgentContext:
37
+ """Context object passed to agent during execution.
38
+
39
+ Contains:
40
+ - thread_id: Current thread/conversation ID
41
+ - user_id: Authenticated user ID
42
+ - store: Database store for persistence
43
+ - request_context: Additional request metadata
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ thread_id: str,
49
+ user_id: str,
50
+ store: PostgresChatKitStore,
51
+ request_context: Optional[dict] = None,
52
+ ):
53
+ self.thread_id = thread_id
54
+ self.user_id = user_id
55
+ self.store = store
56
+ self.request_context = request_context or {}
57
+
58
+
59
+ class TaskManagerChatKitServer:
60
+ """ChatKit Server for task management with Gemini LLM.
61
+
62
+ [From]: specs/010-chatkit-migration/contracts/backend.md - ChatKitServer Implementation
63
+
64
+ This server extends the ChatKit protocol to work with:
65
+ - Custom authentication (JWT cookies)
66
+ - Gemini LLM via OpenAI-compatible endpoint
67
+ - Server-side tool execution (MCP tools)
68
+ - PostgreSQL thread/message persistence
69
+
70
+ Usage:
71
+ from fastapi import FastAPI, Request
72
+ from chatkit_server import TaskManagerChatKitServer
73
+
74
+ app = FastAPI()
75
+ server = TaskManagerChatKitServer(store=postgres_store)
76
+
77
+ @app.post("/api/chatkit")
78
+ async def chatkit_endpoint(request: Request):
79
+ body = await request.body()
80
+ user_id = get_current_user_id(request)
81
+ result = await server.process(body, {"user_id": user_id})
82
+ if hasattr(result, '__aiter__'):
83
+ return StreamingResponse(result, media_type="text/event-stream")
84
+ return Response(content=result.json, media_type="application/json")
85
+ """
86
+
87
+ def __init__(self, store: PostgresChatKitStore):
88
+ """Initialize the ChatKit server.
89
+
90
+ Args:
91
+ store: PostgresChatKitStore instance for persistence
92
+ """
93
+ self.store = store
94
+
95
+ # Configure Gemini client as default for Agents SDK
96
+ # [From]: specs/010-chatkit-migration/research.md - Section 2
97
+ try:
98
+ gemini_client = get_gemini_client()
99
+ set_default_openai_client(gemini_client)
100
+ logger.info("Gemini client configured for Agents SDK")
101
+ except Exception as e:
102
+ logger.warning(f"Gemini client not configured: {e}")
103
+
104
+ # Assistant agent will be configured after tools are wrapped
105
+ # This placeholder will be replaced in T019 with actual tools
106
+ self.assistant_agent: Optional[Agent] = None
107
+
108
+ def set_agent(self, agent: Agent) -> None:
109
+ """Set the assistant agent with tools.
110
+
111
+ [From]: specs/010-chatkit-migration/tasks.md - T019
112
+
113
+ Args:
114
+ agent: Configured Agent with tools and instructions
115
+ """
116
+ self.assistant_agent = agent
117
+ logger.info(f"Agent configured: {agent.name} with model {agent.model}")
118
+
119
+ async def process(
120
+ self,
121
+ body: bytes,
122
+ context: dict[str, Any]
123
+ ) -> Any:
124
+ """Process ChatKit request and return streaming or non-streaming result.
125
+
126
+ [From]: specs/010-chatkit-migration/contracts/backend.md - ChatKit SSE Endpoint
127
+
128
+ Args:
129
+ body: Raw request body bytes from ChatKit.js
130
+ context: Request context containing user_id and auth info
131
+
132
+ Returns:
133
+ StreamingResult for SSE responses or dict for JSON responses
134
+
135
+ Note: This is a placeholder implementation. The actual implementation
136
+ would use the ChatKit Python SDK's process() method which handles
137
+ protocol parsing, event routing, and response formatting.
138
+ """
139
+ import json
140
+ from fastapi.responses import StreamingResponse
141
+
142
+ # Parse ChatKit protocol request
143
+ try:
144
+ request_data = json.loads(body.decode('utf-8'))
145
+ except Exception as e:
146
+ logger.error(f"Failed to parse ChatKit request: {e}")
147
+ return {"error": "Invalid request format"}
148
+
149
+ # Extract thread ID and message content
150
+ conversation_id = request_data.get("conversation_id")
151
+ item = request_data.get("item", {})
152
+ event_type = request_data.get("event", "conversation_item_created")
153
+
154
+ logger.info(f"ChatKit request: event={event_type}, conversation_id={conversation_id}")
155
+
156
+ # Get or create thread
157
+ thread_id = conversation_id
158
+ if not thread_id:
159
+ # Create new thread for first message
160
+ user_id = context.get("user_id")
161
+ if not user_id:
162
+ return {"error": "Unauthorized: no user_id in context"}
163
+
164
+ thread_meta = await self.store.create_thread(
165
+ user_id=user_id,
166
+ title=None,
167
+ metadata={}
168
+ )
169
+ thread_id = thread_meta["id"]
170
+ logger.info(f"Created new thread: {thread_id}")
171
+
172
+ # Extract user message
173
+ user_message = self._extract_message_content(item)
174
+ if not user_message:
175
+ return {"error": "No message content provided"}
176
+
177
+ # Build agent context
178
+ user_id = context.get("user_id")
179
+ agent_context = AgentContext(
180
+ thread_id=thread_id,
181
+ user_id=user_id,
182
+ store=self.store,
183
+ request_context=context,
184
+ )
185
+
186
+ # Create user message in database
187
+ await self.store.create_message(
188
+ thread_id=thread_id,
189
+ item={
190
+ "type": "message",
191
+ "role": "user",
192
+ "content": [{"type": "text", "text": user_message}],
193
+ }
194
+ )
195
+
196
+ # Stream agent response
197
+ # [From]: specs/010-chatkit-migration/research.md - Section 3
198
+ # [Task]: T034 - Add timeout handling for long-running tool executions
199
+ async def stream_response():
200
+ """Stream ChatKit events as SSE with timeout protection.
201
+
202
+ [Task]: T034 - Timeout handling for long-running tool executions
203
+
204
+ Implements a 120-second timeout for agent execution to prevent
205
+ indefinite hangs from slow tools or network issues.
206
+ """
207
+ if not self.assistant_agent:
208
+ yield self._sse_event("error", {"message": "Agent not configured"})
209
+ return
210
+
211
+ try:
212
+ # Run agent with streaming and timeout
213
+ # [Task]: T034 - 120 second timeout for entire agent execution
214
+ # This covers LLM calls, tool executions, and any delays
215
+ async with asyncio.timeout(120):
216
+ result = Runner.run_streamed(
217
+ self.assistant_agent,
218
+ [{"role": "user", "content": user_message}],
219
+ context=agent_context,
220
+ )
221
+
222
+ # Collect assistant response
223
+ full_response = ""
224
+ async for chunk in result:
225
+ if hasattr(chunk, 'content'):
226
+ content = chunk.content
227
+ if content:
228
+ full_response += content
229
+ # Stream text delta
230
+ yield self._sse_event("message_delta", {
231
+ "type": "text",
232
+ "text": content
233
+ })
234
+
235
+ # Create assistant message in database
236
+ await self.store.create_message(
237
+ thread_id=thread_id,
238
+ item={
239
+ "type": "message",
240
+ "role": "assistant",
241
+ "content": [{"type": "text", "text": full_response}],
242
+ }
243
+ )
244
+
245
+ # Send message done event
246
+ yield self._sse_event("message_done", {
247
+ "message_id": thread_id,
248
+ "role": "assistant"
249
+ })
250
+
251
+ except TimeoutError:
252
+ # [Task]: T034 - Handle timeout gracefully
253
+ logger.error(f"Agent execution timeout for thread {thread_id}")
254
+ yield self._sse_event("error", {
255
+ "message": "Request timed out",
256
+ "detail": "The AI assistant took too long to respond. Please try again."
257
+ })
258
+ except Exception as e:
259
+ logger.error(f"Agent execution error: {e}", exc_info=True)
260
+ yield self._sse_event("error", {"message": str(e)})
261
+
262
+ # Return streaming result
263
+ class StreamingResult:
264
+ def __init__(self, generator):
265
+ self.generator = generator
266
+ self.json = json.dumps({"thread_id": thread_id})
267
+
268
+ def __aiter__(self):
269
+ return self.generator()
270
+
271
+ return StreamingResult(stream_response())
272
+
273
+ def _extract_message_content(self, item: dict) -> str:
274
+ """Extract text content from ChatKit item.
275
+
276
+ Args:
277
+ item: ChatKit message item
278
+
279
+ Returns:
280
+ Extracted text content
281
+ """
282
+ content_array = item.get("content", [])
283
+ for content_block in content_array:
284
+ if content_block.get("type") == "text":
285
+ return content_block.get("text", "")
286
+ return ""
287
+
288
+ def _sse_event(self, event_type: str, data: dict) -> str:
289
+ """Format data as Server-Sent Event.
290
+
291
+ [From]: specs/010-chatkit-migration/contracts/backend.md - SSE Event Types
292
+
293
+ Args:
294
+ event_type: Event type name
295
+ data: Event data payload
296
+
297
+ Returns:
298
+ Formatted SSE string
299
+ """
300
+ import json
301
+ return f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
302
+
303
+
304
+ # Create singleton instance (will be configured with tools in T019)
305
+ _server_instance: Optional[TaskManagerChatKitServer] = None
306
+
307
+
308
+ def get_chatkit_server(store: PostgresChatKitStore) -> TaskManagerChatKitServer:
309
+ """Get or create ChatKit server singleton with agent configuration.
310
+
311
+ [From]: specs/010-chatkit-migration/contracts/backend.md
312
+
313
+ [Task]: T019 - Configure TaskAssistant agent with Gemini model and wrapped tools
314
+
315
+ Args:
316
+ store: PostgresChatKitStore instance
317
+
318
+ Returns:
319
+ ChatKit server instance with configured agent
320
+ """
321
+ global _server_instance
322
+ if _server_instance is None:
323
+ _server_instance = TaskManagerChatKitServer(store)
324
+
325
+ # Configure the assistant agent with tools
326
+ # [From]: specs/010-chatkit-migration/tasks.md - T019
327
+ # [From]: specs/010-chatkit-migration/contracts/backend.md - Tool Contracts
328
+ from ai_agent.tool_wrappers import TOOL_FUNCTIONS
329
+ from core.config import get_settings
330
+
331
+ settings = get_settings()
332
+
333
+ # Create the TaskAssistant agent with Gemini model
334
+ # [From]: specs/010-chatkit-migration/research.md - Section 3
335
+ assistant_agent = Agent[AgentContext](
336
+ name="TaskAssistant",
337
+ model=settings.gemini_model or "gemini-2.0-flash-exp",
338
+ instructions="""You are a helpful task management assistant. You help users create, list, update, complete, and delete tasks through natural language.
339
+
340
+ Available tools:
341
+ - create_task: Create a new task with title, description, due date, priority, tags
342
+ - list_tasks: List all tasks with optional filters
343
+ - update_task: Update an existing task
344
+ - delete_task: Delete a task
345
+ - complete_task: Mark a task as completed or incomplete
346
+ - complete_all_tasks: Mark all tasks as completed (requires confirmation)
347
+ - delete_all_tasks: Delete all tasks (requires confirmation)
348
+
349
+ When users ask about tasks, use the appropriate tool. Always confirm destructive actions (complete_all_tasks, delete_all_tasks) by requiring the confirm parameter.
350
+
351
+ Be concise and helpful. If a user's request is unclear, ask for clarification.""",
352
+ tools=TOOL_FUNCTIONS,
353
+ )
354
+
355
+ _server_instance.set_agent(assistant_agent)
356
+ logger.info(f"ChatKit server initialized with {len(TOOL_FUNCTIONS)} tools and model {settings.gemini_model}")
357
+
358
+ return _server_instance
config.py DELETED
@@ -1,54 +0,0 @@
1
- """Application configuration and settings.
2
-
3
- [Task]: T009
4
- [From]: specs/001-user-auth/plan.md
5
-
6
- [Task]: T003
7
- [From]: specs/004-ai-chatbot/plan.md
8
- """
9
- import os
10
- from pydantic_settings import BaseSettings, SettingsConfigDict
11
- from functools import lru_cache
12
-
13
-
14
- class Settings(BaseSettings):
15
- """Application settings loaded from environment variables."""
16
-
17
- # Database
18
- database_url: str
19
-
20
- # JWT
21
- jwt_secret: str
22
- jwt_algorithm: str = "HS256"
23
- jwt_expiration_days: int = 7
24
-
25
- # CORS
26
- frontend_url: str
27
-
28
- # Environment
29
- environment: str = "development"
30
-
31
- # Gemini API (Phase III: AI Chatbot)
32
- gemini_api_key: str | None = None # Optional for migration/setup
33
- gemini_model: str = "gemini-2.0-flash-exp"
34
-
35
- model_config = SettingsConfigDict(
36
- env_file=".env",
37
- case_sensitive=False,
38
- # Support legacy Better Auth environment variables
39
- env_prefix="",
40
- extra="ignore"
41
- )
42
-
43
-
44
- @lru_cache()
45
- def get_settings() -> Settings:
46
- """Get cached settings instance.
47
-
48
- Returns:
49
- Settings: Application settings
50
-
51
- Raises:
52
- ValueError: If required environment variables are not set
53
- """
54
- return Settings()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
core/config.py CHANGED
@@ -5,6 +5,9 @@
5
 
6
  [Task]: T003
7
  [From]: specs/004-ai-chatbot/plan.md
 
 
 
8
  """
9
  import os
10
  from pydantic_settings import BaseSettings, SettingsConfigDict
@@ -18,12 +21,12 @@ class Settings(BaseSettings):
18
  database_url: str
19
 
20
  # JWT
21
- jwt_secret: str = "change-me-in-production-use-env-var"
22
  jwt_algorithm: str = "HS256"
23
  jwt_expiration_days: int = 7
24
 
25
- # CORS (optional, defaults to allow all for public API)
26
- frontend_url: str = "*"
27
 
28
  # Environment
29
  environment: str = "development"
@@ -31,6 +34,7 @@ class Settings(BaseSettings):
31
  # Gemini API (Phase III: AI Chatbot)
32
  gemini_api_key: str | None = None # Optional for migration/setup
33
  gemini_model: str = "gemini-2.0-flash-exp"
 
34
 
35
  model_config = SettingsConfigDict(
36
  env_file=".env",
@@ -52,3 +56,38 @@ def get_settings() -> Settings:
52
  ValueError: If required environment variables are not set
53
  """
54
  return Settings()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  [Task]: T003
7
  [From]: specs/004-ai-chatbot/plan.md
8
+
9
+ Extended for ChatKit migration with Gemini OpenAI-compatible endpoint.
10
+ [From]: specs/010-chatkit-migration/tasks.md - T008
11
  """
12
  import os
13
  from pydantic_settings import BaseSettings, SettingsConfigDict
 
21
  database_url: str
22
 
23
  # JWT
24
+ jwt_secret: str
25
  jwt_algorithm: str = "HS256"
26
  jwt_expiration_days: int = 7
27
 
28
+ # CORS
29
+ frontend_url: str
30
 
31
  # Environment
32
  environment: str = "development"
 
34
  # Gemini API (Phase III: AI Chatbot)
35
  gemini_api_key: str | None = None # Optional for migration/setup
36
  gemini_model: str = "gemini-2.0-flash-exp"
37
+ gemini_base_url: str = "https://generativelanguage.googleapis.com/v1beta/openai/" # ChatKit migration
38
 
39
  model_config = SettingsConfigDict(
40
  env_file=".env",
 
56
  ValueError: If required environment variables are not set
57
  """
58
  return Settings()
59
+
60
+
61
+ def get_gemini_client():
62
+ """Create and return an AsyncOpenAI client configured for Gemini.
63
+
64
+ [From]: specs/010-chatkit-migration/research.md - Section 2
65
+ [From]: specs/010-chatkit-migration/contracts/backend.md - Tool Contracts
66
+
67
+ This client uses Gemini's OpenAI-compatible endpoint, allowing us to use
68
+ the OpenAI SDK and Agents SDK with Gemini as the LLM provider.
69
+
70
+ Returns:
71
+ AsyncOpenAI: OpenAI client configured for Gemini
72
+
73
+ Example:
74
+ from openai import AsyncOpenAI
75
+ from agents import set_default_openai_client
76
+
77
+ client = get_gemini_client()
78
+ set_default_openai_client(client)
79
+ """
80
+ from openai import AsyncOpenAI
81
+
82
+ settings = get_settings()
83
+
84
+ if not settings.gemini_api_key:
85
+ raise ValueError(
86
+ "GEMINI_API_KEY is not set. Please set it in your environment or .env file. "
87
+ "Get your API key from https://aistudio.google.com"
88
+ )
89
+
90
+ return AsyncOpenAI(
91
+ api_key=settings.gemini_api_key,
92
+ base_url=settings.gemini_base_url,
93
+ )
core/security.py CHANGED
@@ -145,3 +145,40 @@ def decode_access_token(token: str) -> dict:
145
  detail="Could not validate credentials",
146
  headers={"WWW-Authenticate": "Bearer"},
147
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  detail="Could not validate credentials",
146
  headers={"WWW-Authenticate": "Bearer"},
147
  )
148
+
149
+
150
+ async def get_current_user_id_from_cookie(request) -> Optional[str]:
151
+ """Extract and validate user ID from JWT token in httpOnly cookie.
152
+
153
+ [Task]: T011
154
+ [From]: specs/010-chatkit-migration/contracts/backend.md - Authentication Contracts
155
+
156
+ This function extracts the JWT token from the auth_token httpOnly cookie,
157
+ decodes it, and returns the user_id (sub claim).
158
+
159
+ Args:
160
+ request: FastAPI/Starlette request object
161
+
162
+ Returns:
163
+ User ID (UUID string) or None if authentication fails
164
+
165
+ Raises:
166
+ HTTPException: If token is invalid (only if raise_on_error=True)
167
+ """
168
+ # Try httpOnly cookie first
169
+ # [From]: specs/010-chatkit-migration/contracts/backend.md
170
+ auth_token = request.cookies.get("auth_token")
171
+ if not auth_token:
172
+ return None
173
+
174
+ try:
175
+ payload = decode_access_token(auth_token)
176
+ user_id = payload.get("sub")
177
+ return user_id
178
+ except HTTPException:
179
+ return None
180
+ except Exception as e:
181
+ # Log but don't raise for non-critical operations
182
+ import logging
183
+ logging.getLogger("api.auth").warning(f"Failed to decode token from cookie: {e}")
184
+ return None
migrations/migrate_threads.sql ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -- ChatKit Migration: Create threads table and update messages table
2
+ --
3
+ -- [From]: specs/010-chatkit-migration/data-model.md - Database Schema Migration
4
+ -- [Task]: T006
5
+ --
6
+ -- This migration:
7
+ -- 1. Creates the threads table for ChatKit conversation management
8
+ -- 2. Adds thread_id column to messages table
9
+ -- 3. Migrates existing conversation_id data to thread_id
10
+ -- 4. Creates indexes for query optimization
11
+ --
12
+ -- IMPORTANT: Run this migration after deploying the Thread model
13
+ --
14
+ -- To run: psql $DATABASE_URL < migrations/migrate_threads.sql
15
+
16
+ -- BEGIN;
17
+
18
+ -- 1. Create threads table
19
+ CREATE TABLE IF NOT EXISTS threads (
20
+ id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
21
+ user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
22
+ title VARCHAR(255),
23
+ metadata JSONB,
24
+ created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
25
+ updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
26
+ );
27
+
28
+ -- 2. Create indexes for threads table
29
+ CREATE INDEX IF NOT EXISTS idx_thread_user_id ON threads(user_id);
30
+ CREATE INDEX IF NOT EXISTS idx_thread_updated_at ON threads(user_id, updated_at DESC);
31
+
32
+ -- 3. Add thread_id column to messages table (nullable initially)
33
+ ALTER TABLE message ADD COLUMN IF NOT EXISTS thread_id UUID REFERENCES threads(id) ON DELETE CASCADE;
34
+
35
+ -- 4. Create index for thread_id in messages table
36
+ CREATE INDEX IF NOT EXISTS idx_message_thread_id ON message(thread_id, created_at ASC);
37
+
38
+ -- 5. Migrate existing conversation data to threads
39
+ -- This creates a thread for each unique conversation and links messages to it
40
+ -- Skip this step if starting fresh (no existing conversations)
41
+ INSERT INTO threads (id, user_id, created_at, updated_at)
42
+ SELECT DISTINCT
43
+ c.id as id, -- Use same ID as conversation for easy mapping
44
+ c.user_id,
45
+ c.created_at,
46
+ c.updated_at
47
+ FROM conversation c
48
+ WHERE NOT EXISTS (SELECT 1 FROM threads t WHERE t.id = c.id);
49
+
50
+ -- 6. Update messages to point to the new thread_id
51
+ -- This maps existing messages to their corresponding threads
52
+ UPDATE message m
53
+ SET thread_id = m.conversation_id
54
+ WHERE m.conversation_id IS NOT NULL
55
+ AND m.thread_id IS NULL;
56
+
57
+ -- 7. After migration, make thread_id NOT NULL (only after validating data)
58
+ -- Uncomment these lines after verifying successful migration:
59
+ -- ALTER TABLE message ALTER COLUMN thread_id SET NOT NULL;
60
+ -- ALTER TABLE message ADD CONSTRAINT message_thread_id_fkey FOREIGN KEY (thread_id) REFERENCES threads(id) ON DELETE CASCADE;
61
+
62
+ -- Optional: Drop old conversation_id column after full migration
63
+ -- Uncomment ONLY after confirming ChatKit is working correctly:
64
+ -- ALTER TABLE message DROP COLUMN IF EXISTS conversation_id;
65
+ -- DROP INDEX IF EXISTS idx_message_conversation_created;
66
+
67
+ -- COMMIT;
68
+
69
+ -- Rollback commands (if needed):
70
+ -- BEGIN;
71
+ -- ALTER TABLE message DROP COLUMN IF EXISTS thread_id;
72
+ -- DROP INDEX IF EXISTS idx_message_thread_id;
73
+ -- DROP INDEX IF EXISTS idx_thread_user_id;
74
+ -- DROP INDEX IF EXISTS idx_thread_updated_at;
75
+ -- DROP TABLE IF EXISTS threads;
76
+ -- COMMIT;
models/thread.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Thread model for ChatKit integration.
2
+
3
+ [Task]: T004
4
+ [From]: specs/010-chatkit-migration/data-model.md
5
+
6
+ This model implements ChatKit's thread abstraction for grouping messages
7
+ into conversations. It extends the existing Conversation model with
8
+ ChatKit-specific fields.
9
+
10
+ Migration Note: The existing Conversation model serves a similar purpose.
11
+ During migration, we can either:
12
+ 1. Use Thread alongside Conversation (dual model approach)
13
+ 2. Migrate Conversation to Thread (single model approach)
14
+
15
+ This implementation uses Thread as the primary model for ChatKit integration.
16
+ """
17
+ import uuid
18
+ from datetime import datetime
19
+ from typing import Optional
20
+ from sqlmodel import Field, SQLModel
21
+ from sqlalchemy import Column, DateTime, JSON as SQLJSON, String as SQLString, Index
22
+ from sqlalchemy.dialects.postgresql import JSONB
23
+
24
+
25
+ class Thread(SQLModel, table=True):
26
+ """Thread model representing a ChatKit conversation session.
27
+
28
+ ChatKit uses "thread" terminology for conversation grouping.
29
+ A thread contains multiple messages and belongs to a single user.
30
+
31
+ [From]: specs/010-chatkit-migration/data-model.md - Thread Entity
32
+ """
33
+
34
+ __tablename__ = "threads"
35
+
36
+ id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
37
+ user_id: uuid.UUID = Field(foreign_key="users.id", index=True, nullable=False)
38
+
39
+ # Optional thread title (ChatKit allows threads to have titles)
40
+ title: Optional[str] = Field(
41
+ default=None,
42
+ max_length=255,
43
+ sa_column=Column(SQLString(255), nullable=True)
44
+ )
45
+
46
+ # Thread metadata (tags, settings, etc.) stored as JSONB
47
+ # Note: Using 'thread_metadata' instead of 'metadata' to avoid SQLAlchemy reserved attribute
48
+ thread_metadata: Optional[dict] = Field(
49
+ default=None,
50
+ sa_column=Column("metadata", JSONB, nullable=True), # Column name in DB is 'metadata'
51
+ )
52
+
53
+ # Thread creation timestamp
54
+ created_at: datetime = Field(
55
+ default_factory=datetime.utcnow,
56
+ sa_column=Column(DateTime(timezone=True), nullable=False)
57
+ )
58
+
59
+ # Last message/update timestamp (auto-updated by application logic)
60
+ updated_at: datetime = Field(
61
+ default_factory=datetime.utcnow,
62
+ sa_column=Column(DateTime(timezone=True), nullable=False, index=True)
63
+ )
64
+
65
+ # Table indexes for query optimization
66
+ __table_args__ = (
67
+ Index('idx_thread_user_id', 'user_id'),
68
+ Index('idx_thread_updated_at', 'user_id', 'updated_at'), # For sorting conversations by recent activity
69
+ )
70
+
71
+ def __repr__(self) -> str:
72
+ return f"<Thread(id={self.id}, user_id={self.user_id}, title={self.title})>"
security.py DELETED
@@ -1,276 +0,0 @@
1
- """Security utilities for the AI chatbot.
2
-
3
- [Task]: T057
4
- [From]: specs/004-ai-chatbot/tasks.md
5
-
6
- This module provides security functions including prompt injection sanitization,
7
- input validation, and content filtering.
8
- """
9
- import re
10
- import html
11
- from typing import Optional, List
12
-
13
-
14
- # Known prompt injection patterns
15
- PROMPT_INJECTION_PATTERNS = [
16
- # Direct instructions to ignore previous context
17
- r"(?i)ignore\s+(all\s+)?(previous|above|prior)",
18
- r"(?i)disregard\s+(all\s+)?(previous|above|prior)",
19
- r"(?i)forget\s+(everything|all\s+instructions|previous)",
20
- r"(?i)override\s+(your\s+)?programming",
21
- r"(?i)new\s+(instruction|direction|rule)s?",
22
- r"(?i)change\s+(your\s+)?(behavior|role|persona)",
23
-
24
- # Jailbreak attempts
25
- r"(?i)(jailbreak|jail\s*break)",
26
- r"(?i)(developer|admin|root|privileged)\s+mode",
27
- r"(?i)act\s+as\s+(a\s+)?(developer|admin|root)",
28
- r"(?i)roleplay\s+as",
29
- r"(?i)pretend\s+(to\s+be|you're)",
30
- r"(?i)simulate\s+being",
31
-
32
- # System prompt extraction
33
- r"(?i)show\s+(your\s+)?(instructions|system\s+prompt|prompt)",
34
- r"(?i)print\s+(your\s+)?(instructions|system\s+prompt)",
35
- r"(?i)reveal\s+(your\s+)?(instructions|system\s+prompt)",
36
- r"(?i)what\s+(are\s+)?your\s+instructions",
37
- r"(?i)tell\s+me\s+how\s+you\s+work",
38
-
39
- # DAN and similar jailbreaks
40
- r"(?i)do\s+anything\s+now",
41
- r"(?i)unrestricted\s+mode",
42
- r"(?i)no\s+limitations?",
43
- r"(?i)bypass\s+(safety|filters|restrictions)",
44
- r"(?i)\bDAN\b", # Do Anything Now
45
- ]
46
-
47
-
48
- def sanitize_message(message: str, max_length: int = 10000) -> str:
49
- """Sanitize a user message to prevent prompt injection attacks.
50
-
51
- [From]: specs/004-ai-chatbot/spec.md - NFR-017
52
-
53
- Args:
54
- message: The raw user message
55
- max_length: Maximum allowed message length
56
-
57
- Returns:
58
- Sanitized message safe for processing by AI
59
-
60
- Raises:
61
- ValueError: If message contains severe injection attempts
62
- """
63
- if not message:
64
- return ""
65
-
66
- # Trim to max length
67
- message = message[:max_length]
68
-
69
- # Check for severe injection patterns
70
- detected = detect_prompt_injection(message)
71
- if detected:
72
- # For severe attacks, reject the message
73
- if detected["severity"] == "high":
74
- raise ValueError(
75
- "This message contains content that cannot be processed. "
76
- "Please rephrase your request."
77
- )
78
-
79
- # Apply sanitization
80
- sanitized = _apply_sanitization(message)
81
-
82
- return sanitized
83
-
84
-
85
- def detect_prompt_injection(message: str) -> Optional[dict]:
86
- """Detect potential prompt injection attempts in a message.
87
-
88
- [From]: specs/004-ai-chatbot/spec.md - NFR-017
89
-
90
- Args:
91
- message: The message to check
92
-
93
- Returns:
94
- Dictionary with detection info if injection detected, None otherwise:
95
- {
96
- "detected": True,
97
- "severity": "low" | "medium" | "high",
98
- "pattern": "matched pattern",
99
- "confidence": 0.0-1.0
100
- }
101
- """
102
- message_lower = message.lower()
103
-
104
- for pattern in PROMPT_INJECTION_PATTERNS:
105
- match = re.search(pattern, message_lower)
106
-
107
- if match:
108
- # Determine severity based on pattern type
109
- severity = _get_severity_for_pattern(pattern)
110
-
111
- # Check for context that might indicate legitimate use
112
- is_legitimate = _check_legitimate_context(message, match.group())
113
-
114
- if not is_legitimate:
115
- return {
116
- "detected": True,
117
- "severity": severity,
118
- "pattern": match.group(),
119
- "confidence": 0.8
120
- }
121
-
122
- return None
123
-
124
-
125
- def _get_severity_for_pattern(pattern: str) -> str:
126
- """Determine severity level for a matched pattern.
127
-
128
- Args:
129
- pattern: The regex pattern that matched
130
-
131
- Returns:
132
- "low", "medium", or "high"
133
- """
134
- pattern_lower = pattern.lower()
135
-
136
- # High severity: direct jailbreak attempts
137
- if any(word in pattern_lower for word in ["jailbreak", "dan", "unrestricted", "bypass"]):
138
- return "high"
139
-
140
- # High severity: system prompt extraction
141
- if any(word in pattern_lower for word in ["show", "print", "reveal", "instructions"]):
142
- return "high"
143
-
144
- # Medium severity: role/persona manipulation
145
- if any(word in pattern_lower for word in ["act as", "pretend", "roleplay", "override"]):
146
- return "medium"
147
-
148
- # Low severity: ignore instructions
149
- if any(word in pattern_lower for word in ["ignore", "disregard", "forget"]):
150
- return "low"
151
-
152
- return "low"
153
-
154
-
155
- def _check_legitimate_context(message: str, matched_text: str) -> bool:
156
- """Check if a matched pattern might be legitimate user content.
157
-
158
- [From]: specs/004-ai-chatbot/spec.md - NFR-017
159
-
160
- Args:
161
- message: The full message
162
- matched_text: The text that matched a pattern
163
-
164
- Returns:
165
- True if this appears to be legitimate context, False otherwise
166
- """
167
- message_lower = message.lower()
168
- matched_lower = matched_text.lower()
169
-
170
- # Check if the matched text is part of a task description (legitimate)
171
- legitimate_contexts = [
172
- # Common task-related phrases
173
- "task to ignore",
174
- "mark as complete",
175
- "disregard this",
176
- "role in the project",
177
- "change status",
178
- "update the role",
179
- "priority change",
180
- ]
181
-
182
- for context in legitimate_contexts:
183
- if context in message_lower:
184
- return True
185
-
186
- # Check if matched text is very short (likely false positive)
187
- if len(matched_text) <= 3:
188
- return True
189
-
190
- return False
191
-
192
-
193
- def _apply_sanitization(message: str) -> str:
194
- """Apply sanitization transformations to a message.
195
-
196
- [From]: specs/004-ai-chatbot/spec.md - NFR-017
197
-
198
- Args:
199
- message: The message to sanitize
200
-
201
- Returns:
202
- Sanitized message
203
- """
204
- # Remove excessive whitespace
205
- message = re.sub(r"\s+", " ", message)
206
-
207
- # Remove control characters except newlines and tabs
208
- message = re.sub(r"[\x00-\x08\x0b-\x0c\x0e-\x1f\x7f-\x9f]", "", message)
209
-
210
- # Normalize line endings
211
- message = message.replace("\r\n", "\n").replace("\r", "\n")
212
-
213
- # Limit consecutive newlines to 2
214
- message = re.sub(r"\n{3,}", "\n\n", message)
215
-
216
- return message.strip()
217
-
218
-
219
- def validate_task_input(task_data: dict) -> tuple[bool, Optional[str]]:
220
- """Validate task-related input for security issues.
221
-
222
- [From]: specs/004-ai-chatbot/spec.md - NFR-017
223
-
224
- Args:
225
- task_data: Dictionary containing task fields
226
-
227
- Returns:
228
- Tuple of (is_valid, error_message)
229
- """
230
- if not isinstance(task_data, dict):
231
- return False, "Invalid task data format"
232
-
233
- # Check for SQL injection patterns in string fields
234
- sql_patterns = [
235
- r"(?i)(\bunion\b.*\bselect\b)",
236
- r"(?i)(\bselect\b.*\bfrom\b)",
237
- r"(?i)(\binsert\b.*\binto\b)",
238
- r"(?i)(\bupdate\b.*\bset\b)",
239
- r"(?i)(\bdelete\b.*\bfrom\b)",
240
- r"(?i)(\bdrop\b.*\btable\b)",
241
- r";\s*(union|select|insert|update|delete|drop)",
242
- ]
243
-
244
- for key, value in task_data.items():
245
- if isinstance(value, str):
246
- for pattern in sql_patterns:
247
- if re.search(pattern, value):
248
- return False, f"Invalid characters in {key}"
249
-
250
- # Check for script injection
251
- if re.search(r"<script[^>]*>.*?</script>", value, re.IGNORECASE):
252
- return False, f"Invalid content in {key}"
253
-
254
- return True, None
255
-
256
-
257
- def sanitize_html_content(content: str) -> str:
258
- """Sanitize HTML content by escaping potentially dangerous elements.
259
-
260
- [From]: specs/004-ai-chatbot/spec.md - NFR-017
261
-
262
- Args:
263
- content: Content that may contain HTML
264
-
265
- Returns:
266
- Escaped HTML string
267
- """
268
- return html.escape(content, quote=False)
269
-
270
-
271
- __all__ = [
272
- "sanitize_message",
273
- "detect_prompt_injection",
274
- "validate_task_input",
275
- "sanitize_html_content",
276
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
services/chatkit_store.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PostgreSQL Store implementation for ChatKit SDK.
2
+
3
+ [Task]: T009
4
+ [From]: specs/010-chatkit-migration/data-model.md - ChatKit SDK Interface Requirements
5
+ [From]: specs/010-chatkit-migration/contracts/backend.md - Store Interface Implementation
6
+
7
+ This module implements the ChatKit Store interface using SQLModel and PostgreSQL.
8
+ The Store interface is required by ChatKit's Python SDK for thread and message persistence.
9
+
10
+ ChatKit Store Protocol Methods:
11
+ - list_threads: List threads for a user with pagination
12
+ - get_thread: Get a specific thread by ID
13
+ - create_thread: Create a new thread
14
+ - update_thread: Update thread metadata
15
+ - delete_thread: Delete a thread
16
+ - list_messages: List messages in a thread
17
+ - get_message: Get a specific message
18
+ - create_message: Create a new message
19
+ - update_message: Update a message
20
+ - delete_message: Delete a message
21
+ """
22
+ import uuid
23
+ from datetime import datetime
24
+ from typing import Any, Optional
25
+ from sqlmodel import Session, select, col
26
+ from sqlalchemy.ext.asyncio import AsyncSession
27
+
28
+ from models.thread import Thread
29
+ from models.message import Message, MessageRole
30
+
31
+
32
+ class PostgresChatKitStore:
33
+ """PostgreSQL implementation of ChatKit Store interface.
34
+
35
+ [From]: specs/010-chatkit-migration/data-model.md - Store Interface
36
+
37
+ This store provides thread and message persistence for ChatKit using
38
+ the existing SQLModel models and PostgreSQL database.
39
+
40
+ Note: The ChatKit SDK uses a Protocol-based interface. The actual
41
+ protocol types (ThreadMetadata, MessageItem, etc.) would be imported
42
+ from the openai_chatkit package. For this implementation, we use
43
+ dictionary-based representations until the SDK is installed.
44
+
45
+ Usage:
46
+ from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
47
+ from services.chatkit_store import PostgresChatKitStore
48
+
49
+ engine = create_async_engine(database_url)
50
+ async with AsyncSession(engine) as session:
51
+ store = PostgresChatKitStore(session)
52
+ thread = await store.create_thread(
53
+ user_id="user-123",
54
+ title="My Conversation",
55
+ metadata={"tag": "important"}
56
+ )
57
+ """
58
+
59
+ def __init__(self, session: AsyncSession):
60
+ """Initialize the store with a database session.
61
+
62
+ Args:
63
+ session: SQLAlchemy async session for database operations
64
+ """
65
+ self.session = session
66
+
67
+ async def list_threads(
68
+ self,
69
+ user_id: str,
70
+ limit: int = 50,
71
+ offset: int = 0
72
+ ) -> list[dict]:
73
+ """List threads for a user with pagination.
74
+
75
+ [From]: specs/010-chatkit-migration/data-model.md - Retrieve User's Conversations
76
+
77
+ Args:
78
+ user_id: User ID to filter threads
79
+ limit: Maximum number of threads to return (default: 50)
80
+ offset: Number of threads to skip (default: 0)
81
+
82
+ Returns:
83
+ List of thread metadata dictionaries
84
+ """
85
+ stmt = (
86
+ select(Thread)
87
+ .where(Thread.user_id == uuid.UUID(user_id))
88
+ .order_by(Thread.updated_at.desc())
89
+ .limit(limit)
90
+ .offset(offset)
91
+ )
92
+ result = await self.session.execute(stmt)
93
+ threads = result.scalars().all()
94
+
95
+ return [
96
+ {
97
+ "id": str(thread.id),
98
+ "user_id": str(thread.user_id),
99
+ "title": thread.title,
100
+ "metadata": thread.thread_metadata or {},
101
+ "created_at": thread.created_at.isoformat(),
102
+ "updated_at": thread.updated_at.isoformat(),
103
+ }
104
+ for thread in threads
105
+ ]
106
+
107
+ async def get_thread(self, thread_id: str) -> Optional[dict]:
108
+ """Get a specific thread by ID.
109
+
110
+ Args:
111
+ thread_id: Thread UUID as string
112
+
113
+ Returns:
114
+ Thread metadata dictionary or None if not found
115
+ """
116
+ stmt = select(Thread).where(Thread.id == uuid.UUID(thread_id))
117
+ result = await self.session.execute(stmt)
118
+ thread = result.scalar_one_or_none()
119
+
120
+ if thread is None:
121
+ return None
122
+
123
+ return {
124
+ "id": str(thread.id),
125
+ "user_id": str(thread.user_id),
126
+ "title": thread.title,
127
+ "metadata": thread.thread_metadata or {},
128
+ "created_at": thread.created_at.isoformat(),
129
+ "updated_at": thread.updated_at.isoformat(),
130
+ }
131
+
132
+ async def create_thread(
133
+ self,
134
+ user_id: str,
135
+ title: Optional[str] = None,
136
+ metadata: Optional[dict] = None
137
+ ) -> dict:
138
+ """Create a new thread.
139
+
140
+ Args:
141
+ user_id: User ID who owns the thread
142
+ title: Optional thread title
143
+ metadata: Optional thread metadata
144
+
145
+ Returns:
146
+ Created thread metadata dictionary
147
+ """
148
+ thread = Thread(
149
+ user_id=uuid.UUID(user_id),
150
+ title=title,
151
+ thread_metadata=metadata or {},
152
+ )
153
+ self.session.add(thread)
154
+ await self.session.commit()
155
+ await self.session.refresh(thread)
156
+
157
+ return {
158
+ "id": str(thread.id),
159
+ "user_id": str(thread.user_id),
160
+ "title": thread.title,
161
+ "metadata": thread.thread_metadata or {},
162
+ "created_at": thread.created_at.isoformat(),
163
+ "updated_at": thread.updated_at.isoformat(),
164
+ }
165
+
166
+ async def update_thread(
167
+ self,
168
+ thread_id: str,
169
+ title: Optional[str] = None,
170
+ metadata: Optional[dict] = None
171
+ ) -> Optional[dict]:
172
+ """Update a thread.
173
+
174
+ Args:
175
+ thread_id: Thread UUID as string
176
+ title: New title (optional)
177
+ metadata: New metadata (optional)
178
+
179
+ Returns:
180
+ Updated thread metadata dictionary or None if not found
181
+ """
182
+ stmt = select(Thread).where(Thread.id == uuid.UUID(thread_id))
183
+ result = await self.session.execute(stmt)
184
+ thread = result.scalar_one_or_none()
185
+
186
+ if thread is None:
187
+ return None
188
+
189
+ if title is not None:
190
+ thread.title = title
191
+ if metadata is not None:
192
+ thread.thread_metadata = metadata
193
+ thread.updated_at = datetime.utcnow()
194
+
195
+ await self.session.commit()
196
+ await self.session.refresh(thread)
197
+
198
+ return {
199
+ "id": str(thread.id),
200
+ "user_id": str(thread.user_id),
201
+ "title": thread.title,
202
+ "metadata": thread.thread_metadata or {},
203
+ "created_at": thread.created_at.isoformat(),
204
+ "updated_at": thread.updated_at.isoformat(),
205
+ }
206
+
207
+ async def delete_thread(self, thread_id: str) -> bool:
208
+ """Delete a thread.
209
+
210
+ Args:
211
+ thread_id: Thread UUID as string
212
+
213
+ Returns:
214
+ True if deleted, False if not found
215
+ """
216
+ stmt = select(Thread).where(Thread.id == uuid.UUID(thread_id))
217
+ result = await self.session.execute(stmt)
218
+ thread = result.scalar_one_or_none()
219
+
220
+ if thread is None:
221
+ return False
222
+
223
+ await self.session.delete(thread)
224
+ await self.session.commit()
225
+ return True
226
+
227
+ async def list_messages(
228
+ self,
229
+ thread_id: str,
230
+ limit: int = 50,
231
+ offset: int = 0
232
+ ) -> list[dict]:
233
+ """List messages in a thread.
234
+
235
+ [From]: specs/010-chatkit-migration/data-model.md - Retrieve Conversation Messages
236
+
237
+ Args:
238
+ thread_id: Thread UUID as string
239
+ limit: Maximum number of messages to return
240
+ offset: Number of messages to skip
241
+
242
+ Returns:
243
+ List of message item dictionaries
244
+ """
245
+ stmt = (
246
+ select(Message)
247
+ .where(Message.thread_id == uuid.UUID(thread_id))
248
+ .order_by(Message.created_at.asc())
249
+ .limit(limit)
250
+ .offset(offset)
251
+ )
252
+ result = await self.session.execute(stmt)
253
+ messages = result.scalars().all()
254
+
255
+ return [
256
+ {
257
+ "id": str(msg.id),
258
+ "type": "message",
259
+ "role": msg.role.value,
260
+ "content": [{"type": "text", "text": msg.content}],
261
+ "tool_calls": msg.tool_calls,
262
+ "created_at": msg.created_at.isoformat(),
263
+ }
264
+ for msg in messages
265
+ ]
266
+
267
+ async def get_message(self, message_id: str) -> Optional[dict]:
268
+ """Get a specific message by ID.
269
+
270
+ Args:
271
+ message_id: Message UUID as string
272
+
273
+ Returns:
274
+ Message item dictionary or None if not found
275
+ """
276
+ stmt = select(Message).where(Message.id == uuid.UUID(message_id))
277
+ result = await self.session.execute(stmt)
278
+ message = result.scalar_one_or_none()
279
+
280
+ if message is None:
281
+ return None
282
+
283
+ return {
284
+ "id": str(message.id),
285
+ "type": "message",
286
+ "role": message.role.value,
287
+ "content": [{"type": "text", "text": message.content}],
288
+ "tool_calls": message.tool_calls,
289
+ "created_at": message.created_at.isoformat(),
290
+ }
291
+
292
+ async def create_message(
293
+ self,
294
+ thread_id: str,
295
+ item: dict
296
+ ) -> dict:
297
+ """Create a new message in a thread.
298
+
299
+ Args:
300
+ thread_id: Thread UUID as string
301
+ item: Message item from ChatKit (UserMessageItem or ClientToolCallOutputItem)
302
+
303
+ Returns:
304
+ Created message item dictionary
305
+
306
+ Raises:
307
+ ValueError: If item format is invalid
308
+ """
309
+ # Extract content from ChatKit item format
310
+ # ChatKit uses: {"type": "message", "role": "user", "content": [{"type": "text", "text": "..."}]}
311
+ item_type = item.get("type", "message")
312
+ role = item.get("role", "user")
313
+
314
+ # Extract text content from content array
315
+ content_array = item.get("content", [])
316
+ text_content = ""
317
+ for content_block in content_array:
318
+ if content_block.get("type") == "text":
319
+ text_content = content_block.get("text", "")
320
+ break
321
+
322
+ # Handle client tool output items
323
+ if item_type == "client_tool_call_output":
324
+ text_content = item.get("output", "")
325
+
326
+ message = Message(
327
+ thread_id=uuid.UUID(thread_id),
328
+ role=MessageRole(role),
329
+ content=text_content,
330
+ tool_calls=item.get("tool_calls"),
331
+ )
332
+ self.session.add(message)
333
+ await self.session.commit()
334
+ await self.session.refresh(message)
335
+
336
+ # Update thread's updated_at timestamp
337
+ thread_stmt = select(Thread).where(Thread.id == uuid.UUID(thread_id))
338
+ thread_result = await self.session.execute(thread_stmt)
339
+ thread = thread_result.scalar_one_or_none()
340
+ if thread:
341
+ thread.updated_at = datetime.utcnow()
342
+ await self.session.commit()
343
+
344
+ return {
345
+ "id": str(message.id),
346
+ "type": "message",
347
+ "role": message.role.value,
348
+ "content": [{"type": "text", "text": message.content}],
349
+ "tool_calls": message.tool_calls,
350
+ "created_at": message.created_at.isoformat(),
351
+ }
352
+
353
+ async def update_message(
354
+ self,
355
+ message_id: str,
356
+ item: dict
357
+ ) -> Optional[dict]:
358
+ """Update a message.
359
+
360
+ Args:
361
+ message_id: Message UUID as string
362
+ item: Updated message item
363
+
364
+ Returns:
365
+ Updated message item dictionary or None if not found
366
+ """
367
+ stmt = select(Message).where(Message.id == uuid.UUID(message_id))
368
+ result = await self.session.execute(stmt)
369
+ message = result.scalar_one_or_none()
370
+
371
+ if message is None:
372
+ return None
373
+
374
+ # Update content if provided
375
+ content_array = item.get("content", [])
376
+ if content_array:
377
+ for content_block in content_array:
378
+ if content_block.get("type") == "text":
379
+ message.content = content_block.get("text", message.content)
380
+ break
381
+
382
+ # Update tool_calls if provided
383
+ if "tool_calls" in item:
384
+ message.tool_calls = item["tool_calls"]
385
+
386
+ await self.session.commit()
387
+ await self.session.refresh(message)
388
+
389
+ return {
390
+ "id": str(message.id),
391
+ "type": "message",
392
+ "role": message.role.value,
393
+ "content": [{"type": "text", "text": message.content}],
394
+ "tool_calls": message.tool_calls,
395
+ "created_at": message.created_at.isoformat(),
396
+ }
397
+
398
+ async def delete_message(self, message_id: str) -> bool:
399
+ """Delete a message.
400
+
401
+ Args:
402
+ message_id: Message UUID as string
403
+
404
+ Returns:
405
+ True if deleted, False if not found
406
+ """
407
+ stmt = select(Message).where(Message.id == uuid.UUID(message_id))
408
+ result = await self.session.execute(stmt)
409
+ message = result.scalar_one_or_none()
410
+
411
+ if message is None:
412
+ return False
413
+
414
+ await self.session.delete(message)
415
+ await self.session.commit()
416
+ return True