File size: 10,206 Bytes
a67367b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
"""FastAPI entrypoint for the ChatKit backend with TodoAgent."""

from __future__ import annotations
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
from typing import Optional, Dict, Any
import logging
from contextlib import asynccontextmanager
from uuid import uuid4

from chatkit.server import ChatKitServer, StreamingResult
from chatkit.store import Store, Page
from chatkit.types import ThreadMetadata, UserMessageItem, ThreadItem, AssistantMessageItem, ThreadItemAddedEvent, ErrorEvent
import json
from todo_agent import TodoAgent

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class MemoryStore(Store):
    """In-memory implementation of ChatKit Store for development."""

    def __init__(self):
        self.threads: Dict[str, ThreadMetadata] = {}
        self.items: Dict[str, Dict[str, Any]] = {}
        self.attachments: Dict[str, Any] = {}

    async def load_threads(self, *, context: Dict[str, Any] | None = None, **kwargs) -> Page:
        """Load all threads for a user, extracting user_id from context."""
        user_id = context.get("user_id") if context else None
        if not user_id:
            # If no user_id is found, return an empty page.
            return Page(data=[], has_more=False)
            
        threads = [t for t in self.threads.values() if t.user and t.user.id == user_id]
        return Page(data=threads, has_more=False)

    async def load_thread(self, thread_id: str, *args, context: Dict[str, Any] | None = None, **kwargs) -> ThreadMetadata | None:
        """Load a single thread, create if doesn't exist."""
        thread = self.threads.get(thread_id)
        if thread is None:
            # Create a new thread if it doesn't exist to prevent NoneType errors downstream.
            user_id = context.get("user_id", "unknown") if context else "unknown"
            thread = ThreadMetadata(
                id=thread_id,
                user_id=user_id,
                created_at=0,  # Using a simple timestamp
                updated_at=0,
            )
            self.threads[thread_id] = thread
        return thread

    async def save_thread(self, thread: ThreadMetadata, *args, context: Dict[str, Any] | None = None, **kwargs) -> None:
        """Save a thread."""
        self.threads[thread.id] = thread

    async def delete_thread(self, thread_id: str, *args, context: Dict[str, Any] | None = None, **kwargs) -> None:
        """Delete a thread."""
        self.threads.pop(thread_id, None)
        self.items.pop(thread_id, None)

    async def load_thread_items(self, thread_id: str, *args, context: Dict[str, Any] | None = None, **kwargs) -> Page:
        """Load all items for a thread."""
        items = list(self.items.get(thread_id, {}).values())
        return Page(data=items, has_more=False)

    async def load_item(self, thread_id: str, item_id: str, *args, context: Dict[str, Any] | None = None, **kwargs) -> ThreadItem | None:
        """Load a single item."""
        return self.items.get(thread_id, {}).get(item_id)

    async def save_item(self, thread_id: str, item: ThreadItem, *args, context: Dict[str, Any] | None = None, **kwargs) -> None:
        """Save an item."""
        if thread_id not in self.items:
            self.items[thread_id] = {}
        self.items[thread_id][item.id] = item

    async def add_thread_item(self, thread_id: str, item: ThreadItem, *args, context: Dict[str, Any] | None = None, **kwargs) -> None:
        """Add an item to a thread."""
        await self.save_item(thread_id, item, context=context)

    async def delete_thread_item(self, thread_id: str, item_id: str, *args, context: Dict[str, Any] | None = None, **kwargs) -> None:
        """Delete an item from a thread."""
        if thread_id in self.items:
            self.items[thread_id].pop(item_id, None)

    async def save_attachment(self, attachment: Any, *args, context: Dict[str, Any] | None = None, **kwargs) -> None:
        """Save an attachment."""
        self.attachments[attachment.id] = attachment

    async def load_attachment(self, attachment_id: str, *args, context: Dict[str, Any] | None = None, **kwargs) -> Any | None:
        """Load an attachment."""
        return self.attachments.get(attachment_id)

    async def delete_attachment(self, attachment_id: str, *args, context: Dict[str, Any] | None = None, **kwargs) -> None:
        """Delete an attachment."""
        self.attachments.pop(attachment_id, None)


# Global instances
todo_agent: Optional[TodoAgent] = None
chatkit_server: Optional[ChatKitServer] = None


class TodoChatKitServer(ChatKitServer):
    """ChatKit server that wraps the TodoAgent."""

    def __init__(self, todo_agent: TodoAgent):
        self.todo_agent = todo_agent
        super().__init__(MemoryStore())

    async def respond(self, thread: ThreadMetadata, input_user_message: UserMessageItem | None, context: Dict[str, Any]):
        """Implement the abstract respond method to handle user messages."""
        try:
            user_id = context.get("user_id", "unknown")
            
            # Get the message content
            message_content = ""
            if input_user_message and hasattr(input_user_message, 'content'):
                message_content = input_user_message.content
            
            logger.info(f"Processing message from user {user_id}: {message_content}")
            
            # Use TodoAgent to generate response
            response = await self.todo_agent.process_message(user_id, message_content)
            
            # Create an AssistantMessageItem for the response, ensuring all fields are correctly formatted.
            assistant_item = AssistantMessageItem(
                id=f"msg_{uuid4()}",
                thread_id=thread.id,  # Pass the thread_id
                created_at=0,  # Pass a timestamp
                content=[{"type": "output_text", "text": response}],  # Format content as a list of blocks
            )
            
            # Yield ThreadItemAddedEvent to add the assistant message to the thread
            yield ThreadItemAddedEvent(
                type="thread.item.added",
                thread_id=thread.id,
                item=assistant_item
            )
            
        except Exception as e:
            logger.error(f"Error in respond(): {e}", exc_info=True)
            yield ErrorEvent(
                type="error",
                error=str(e)
            )

    async def process(self, payload: bytes | str, context: Dict[str, Any]) -> StreamingResult | dict | str:
        """Process ChatKit events using TodoAgent."""
        try:
            # Parse the incoming payload
            if isinstance(payload, bytes):
                body = json.loads(payload.decode('utf-8'))
            else:
                body = json.loads(payload)

            logger.info(f"Processing ChatKit event: {body.get('type')}")

            # Extract user_id from context or payload
            user_id = str(context.get("user_id", body.get("user", {}).get("id", "unknown")))

            # Call the parent process method which handles ChatKit protocol
            result = await super().process(payload, context)

            return result

        except Exception as e:
            logger.error(f"Error in ChatKitServer.process(): {e}", exc_info=True)
            return {"type": "error", "content": str(e)}


@asynccontextmanager
async def lifespan(app: FastAPI):
    """Initialize agents on startup, cleanup on shutdown."""
    global todo_agent, chatkit_server

    logger.info("Initializing TodoAgent...")
    todo_agent = TodoAgent()
    logger.info("TodoAgent initialized")

    logger.info("Initializing ChatKit Server...")
    chatkit_server = TodoChatKitServer(todo_agent)
    logger.info("ChatKit Server initialized")

    yield

    logger.info("Shutting down...")
    todo_agent = None
    chatkit_server = None


app = FastAPI(
    title="ChatKit Backend with TodoAgent",
    description="ChatKit backend powered by TodoAgent",
    version="1.0.0",
    lifespan=lifespan
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.post("/chatkit")
async def chatkit_endpoint(request: Request) -> Response:
    """Proxy the ChatKit web component payload to the server implementation."""
    global chatkit_server

    if chatkit_server is None:
        logger.error("ChatKit server not initialized")
        return JSONResponse({"error": "Server not initialized"}, status_code=500)

    try:
        payload = await request.body()
        
        # Extract user_id from the request if available
        context = {"request": request}
        try:
            import json
            body = json.loads(payload.decode('utf-8'))
            if "user" in body and "id" in body["user"]:
                context["user_id"] = str(body["user"]["id"])
        except Exception as e:
            logger.debug(f"Could not extract user_id from payload: {e}")

        result = await chatkit_server.process(payload, context)

        if isinstance(result, StreamingResult):
            return StreamingResponse(result, media_type="text/event-stream")
        if hasattr(result, "json"):
            return Response(content=result.json, media_type="application/json")
        return JSONResponse(result)

    except Exception as e:
        logger.error(f"Error in chatkit_endpoint: {e}", exc_info=True)
        return JSONResponse({"error": str(e)}, status_code=500)


@app.get("/")
async def health_check():
    """Health check endpoint."""
    return {
        "service": "chatkit-todo-assistant",
        "initialized": chatkit_server is not None
    }

@app.get("/health")
async def health():
        return{"status": "ok"}

@app.post("/chatkit/api")
async def chatkit_api_endpoint(request: Request) -> Response:
    """Alias for /chatkit endpoint for compatibility."""
    return await chatkit_endpoint(request)



if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)