Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from typing import Any, AsyncIterator,List | |
| from pydantic import BaseModel | |
| from agents import RunConfig, Runner, SQLiteSession | |
| from agents.model_settings import ModelSettings | |
| from chatkit.agents import AgentContext, stream_agent_response | |
| from chatkit.server import ChatKitServer, StreamingResult | |
| import os | |
| from chatkit.types import ( | |
| Attachment, | |
| ClientToolCallItem, | |
| ThreadMetadata, | |
| ThreadStreamEvent, | |
| UserMessageItem, | |
| AssistantMessageItem | |
| ) | |
| from .MultiAgent import build_sugguestion_information_agent, build_kimi_information_agent, build_summarizer_agent,build_google_information_agent | |
| from chatkit.types import AssistantMessageItem as AssistantMsg | |
| from fastapi import Depends, FastAPI, Query, Request, HTTPException, BackgroundTasks, Header, File, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import Response, StreamingResponse | |
| from openai.types.responses import ResponseInputContentParam | |
| from starlette.responses import JSONResponse | |
| import json | |
| import asyncio | |
| from collections import defaultdict | |
| import random | |
| import string | |
| import traceback | |
| from datetime import datetime, timezone, timedelta | |
| from .sqlite_store import SQLiteStore | |
| from .memory_store import MemoryStore | |
| from .user_state import UserStateManager | |
| from dotenv import load_dotenv | |
| load_dotenv(dotenv_path="./.env.local") | |
| DEFAULT_THREAD_ID = "demo_default_thread" | |
| def _user_message_text(item: UserMessageItem) -> str: | |
| parts: list[str] = [] | |
| for part in item.content: | |
| text = getattr(part, "text", None) | |
| if text: | |
| parts.append(text) | |
| return " ".join(parts).strip() | |
| def _is_tool_completion_item(item: Any) -> bool: | |
| return isinstance(item, ClientToolCallItem) | |
| class deepseek_CustomerSupportServer(ChatKitServer[dict[str, Any]]): | |
| def __init__( | |
| self, | |
| agent_state: UserStateManager, | |
| ) -> None: | |
| store = SQLiteStore(db_path=os.getenv("CHATKIT_DB_PATH", "chatkit_threads.db")) | |
| super().__init__(store) | |
| self.store = store | |
| self.agent_state = agent_state | |
| self.information_agent = build_sugguestion_information_agent() | |
| self.summarizer_agent = build_summarizer_agent() | |
| def _resolve_thread_id(self, thread: ThreadMetadata | None) -> str: | |
| return thread.id if thread and thread.id else DEFAULT_THREAD_ID | |
| async def prepare_conversation_context(self, thread_key: str, message_text: str, status:str) -> str: | |
| # Await handle_history which is now fast because summarization is backgrounded | |
| summary_text, recent_context = await self.handle_history(thread_key) | |
| user_data = self.agent_state.get_user(thread_key) | |
| customer_context = ( | |
| "Customer context:\n" | |
| f"- Name: {user_data.customer_name or ''}\n" | |
| f"- Email: {user_data.customer_email or ''}\n" | |
| f"- Phone: {user_data.customer_phone or ''}\n" | |
| f"- Timezone: {user_data.Timezone or ''}\n" | |
| ) | |
| if status.lower() == "offline": | |
| combined_prompt = ( | |
| f"{customer_context}\n" | |
| f"Previous summary:\n{summary_text}\n\n" | |
| f"Recent conversation (last 5 messages):\n{recent_context}\n\n" | |
| f"If the user asks to talk to a human sales agent, respond: " | |
| f"-This Company you are representing for : Sunmarke School\n" | |
| f"Current request: {message_text}\n" | |
| ) | |
| return combined_prompt | |
| combined_prompt = ( | |
| f"{customer_context}\n" | |
| f"Previous summary:\n{summary_text}\n\n" | |
| f"Recent conversation (last 5 messages):\n{recent_context}\n\n" | |
| f"-This Company you are representing for : Sunmarke School\n" | |
| f"Current request: {message_text}\n" | |
| ) | |
| return combined_prompt | |
| # Await handle_history which is now fast because summarization is backgrounded | |
| summary_text, recent_context = await self.handle_history(thread_key) | |
| user_data = self.agent_state.get_user(thread_key) | |
| customer_context = ( | |
| "Customer context:\n" | |
| f"- Name: {user_data.customer_name or ''}\n" | |
| f"- Email: {user_data.customer_email or ''}\n" | |
| f"- Phone: {user_data.customer_phone or ''}\n" | |
| f"- Company: {user_data.company_name or ''}\n" | |
| f"- Timezone: {user_data.Timezone or ''}\n" | |
| ) | |
| if status.lower() == "offline": | |
| combined_prompt = ( | |
| f"{customer_context}\n" | |
| f"Previous summary:\n{summary_text}\n\n" | |
| f"Recent conversation (last 5 messages):\n{recent_context}\n\n" | |
| f"If the user asks to talk to a human sales agent, respond: " | |
| f'\"Our human sales agent is currently offline, May i help in book an appointment for you.\" ' | |
| f"Current request: {message_text}\n" | |
| ) | |
| return combined_prompt | |
| combined_prompt = ( | |
| f"{customer_context}\n" | |
| f"Previous summary:\n{summary_text}\n\n" | |
| f"Recent conversation (last 5 messages):\n{recent_context}\n\n" | |
| f"Current request: {message_text}\n" | |
| ) | |
| return combined_prompt | |
| async def _async_summarize(self, thread_key: str, user_messages: list, previous_summary: str): | |
| """Background task to perform summarization without blocking the main flow.""" | |
| try: | |
| to_summarize = user_messages[:-5] | |
| combined_text = "\n".join( | |
| f"{'User' if isinstance(i, UserMessageItem) else 'Assistant'}: {_user_message_text(i)}" | |
| for i in to_summarize | |
| ) | |
| summarizer_prompt = ( | |
| f"Previous summary:\n{previous_summary}\n\n" | |
| f"Add the following messages into the summary:\n{combined_text}\n" | |
| f"Return a concise updated summary of the entire conversation." | |
| ) | |
| session = SQLiteSession(thread_key) | |
| result = await Runner.run( | |
| self.summarizer_agent, | |
| summarizer_prompt, | |
| session=session, | |
| ) | |
| self.agent_state.set_summary(thread_key, result.final_output) | |
| print(f"🧠 [BACKGROUND] Summary updated for thread: {thread_key}") | |
| except Exception as e: | |
| print(f"⚠️ [BACKGROUND] Summarization failed for thread {thread_key}: {e}") | |
| async def handle_history(self, thread_key: str) -> tuple[str, str]: | |
| """Handles message history, returns current state, and triggers summarization in background if needed.""" | |
| # 1. Fetch history from store (Fast) | |
| history = self.store._items(thread_key) | |
| user_messages = [i for i in history if isinstance(i, (UserMessageItem, AssistantMessageItem))] | |
| # Keep context within limits | |
| if len(user_messages) > 15: | |
| user_messages = user_messages[-15:] | |
| # 2. Get existing summary from persistence (Fast) | |
| summary_text = self.agent_state.get_summary(thread_key) | |
| # 3. Trigger summarization in background if criteria met (Non-blocking) | |
| if len(user_messages) >= 5 and len(user_messages) % 5 == 0: | |
| print(f"🧠 [HISTORY] Triggering background summarization for {thread_key}") | |
| asyncio.create_task(self._async_summarize(thread_key, user_messages, summary_text)) | |
| # 4. Compute recent context for the prompt (Fast) | |
| last_five = user_messages[-10:] | |
| recent_context = "\n".join( | |
| f"{'User' if isinstance(i, UserMessageItem) else 'Assistant'}: {_user_message_text(i)}" | |
| for i in last_five | |
| ) | |
| return summary_text, recent_context | |
| async def respond( | |
| self, | |
| thread: ThreadMetadata, | |
| item: UserMessageItem | None, | |
| context: dict[str, Any], | |
| ) -> AsyncIterator[ThreadStreamEvent]: | |
| if item is None: | |
| return | |
| message_text = _user_message_text(item) | |
| if not message_text: | |
| return | |
| thread_key = self._resolve_thread_id(thread) | |
| try: | |
| user_data = self.agent_state.get_user(thread_key) | |
| request_context_enriched = { | |
| **(context or {}), | |
| } | |
| except Exception: | |
| request_context_enriched = context or {} | |
| session = SQLiteSession(thread_key) | |
| agent_context = AgentContext(thread=thread, store=self.store, request_context=request_context_enriched) | |
| combined_prompt = await self.prepare_conversation_context(thread_key, message_text, 'offline') | |
| result_stream = Runner.run_streamed(self.information_agent, combined_prompt, context=agent_context, session=session) | |
| async for event in stream_agent_response(agent_context, result_stream): | |
| yield event | |
| class kimi_CustomerSupportServer(ChatKitServer[dict[str, Any]]): | |
| def __init__( | |
| self, | |
| agent_state: UserStateManager, | |
| ) -> None: | |
| store = SQLiteStore(db_path="chatkit_threads.db") | |
| super().__init__(store) | |
| self.store = store | |
| self.agent_state = agent_state | |
| self.summarizer_agent = build_summarizer_agent() | |
| self.information_agent = build_kimi_information_agent() | |
| def _resolve_thread_id(self, thread: ThreadMetadata | None) -> str: | |
| return thread.id if thread and thread.id else DEFAULT_THREAD_ID | |
| async def prepare_conversation_context(self, thread_key: str, message_text: str, status:str) -> str: | |
| # Await handle_history which is now fast because summarization is backgrounded | |
| summary_text, recent_context = await self.handle_history(thread_key) | |
| user_data = self.agent_state.get_user(thread_key) | |
| customer_context = ( | |
| "Customer context:\n" | |
| f"- Name: {user_data.customer_name or ''}\n" | |
| f"- Email: {user_data.customer_email or ''}\n" | |
| f"- Phone: {user_data.customer_phone or ''}\n" | |
| f"- Timezone: {user_data.Timezone or ''}\n" | |
| ) | |
| if status.lower() == "offline": | |
| combined_prompt = ( | |
| f"{customer_context}\n" | |
| f"Previous summary:\n{summary_text}\n\n" | |
| f"Recent conversation (last 5 messages):\n{recent_context}\n\n" | |
| f"If the user asks to talk to a human sales agent, respond: " | |
| f"-This Company you are representing for : Sunmarke School\n" | |
| f"Current request: {message_text}\n" | |
| ) | |
| return combined_prompt | |
| combined_prompt = ( | |
| f"{customer_context}\n" | |
| f"Previous summary:\n{summary_text}\n\n" | |
| f"Recent conversation (last 5 messages):\n{recent_context}\n\n" | |
| f"-This Company you are representing for : Sunmarke School\n" | |
| f"Current request: {message_text}\n" | |
| ) | |
| return combined_prompt | |
| # Await handle_history which is now fast because summarization is backgrounded | |
| summary_text, recent_context = await self.handle_history(thread_key) | |
| user_data = self.agent_state.get_user(thread_key) | |
| customer_context = ( | |
| "Customer context:\n" | |
| f"- Name: {user_data.customer_name or ''}\n" | |
| f"- Email: {user_data.customer_email or ''}\n" | |
| f"- Phone: {user_data.customer_phone or ''}\n" | |
| f"- Company: {user_data.company_name or ''}\n" | |
| f"- Timezone: {user_data.Timezone or ''}\n" | |
| ) | |
| if status.lower() == "offline": | |
| combined_prompt = ( | |
| f"{customer_context}\n" | |
| f"Previous summary:\n{summary_text}\n\n" | |
| f"Recent conversation (last 5 messages):\n{recent_context}\n\n" | |
| f"If the user asks to talk to a human sales agent, respond: " | |
| f'\"Our human sales agent is currently offline, May i help in book an appointment for you.\" ' | |
| f"Current request: {message_text}\n" | |
| ) | |
| return combined_prompt | |
| combined_prompt = ( | |
| f"{customer_context}\n" | |
| f"Previous summary:\n{summary_text}\n\n" | |
| f"Recent conversation (last 5 messages):\n{recent_context}\n\n" | |
| f"Current request: {message_text}\n" | |
| ) | |
| return combined_prompt | |
| async def _async_summarize(self, thread_key: str, user_messages: list, previous_summary: str): | |
| """Background task to perform summarization without blocking the main flow.""" | |
| try: | |
| to_summarize = user_messages[:-5] | |
| combined_text = "\n".join( | |
| f"{'User' if isinstance(i, UserMessageItem) else 'Assistant'}: {_user_message_text(i)}" | |
| for i in to_summarize | |
| ) | |
| summarizer_prompt = ( | |
| f"Previous summary:\n{previous_summary}\n\n" | |
| f"Add the following messages into the summary:\n{combined_text}\n" | |
| f"Return a concise updated summary of the entire conversation." | |
| ) | |
| session = SQLiteSession(thread_key) | |
| result = await Runner.run( | |
| self.summarizer_agent, | |
| summarizer_prompt, | |
| session=session, | |
| ) | |
| self.agent_state.set_summary(thread_key, result.final_output) | |
| print(f"🧠 [BACKGROUND] Summary updated for thread: {thread_key}") | |
| except Exception as e: | |
| print(f"⚠️ [BACKGROUND] Summarization failed for thread {thread_key}: {e}") | |
| async def handle_history(self, thread_key: str) -> tuple[str, str]: | |
| """Handles message history, returns current state, and triggers summarization in background if needed.""" | |
| # 1. Fetch history from store (Fast) | |
| history = self.store._items(thread_key) | |
| user_messages = [i for i in history if isinstance(i, (UserMessageItem, AssistantMessageItem))] | |
| # Keep context within limits | |
| if len(user_messages) > 15: | |
| user_messages = user_messages[-15:] | |
| # 2. Get existing summary from persistence (Fast) | |
| summary_text = self.agent_state.get_summary(thread_key) | |
| # 3. Trigger summarization in background if criteria met (Non-blocking) | |
| if len(user_messages) >= 5 and len(user_messages) % 5 == 0: | |
| print(f"🧠 [HISTORY] Triggering background summarization for {thread_key}") | |
| asyncio.create_task(self._async_summarize(thread_key, user_messages, summary_text)) | |
| # 4. Compute recent context for the prompt (Fast) | |
| last_five = user_messages[-10:] | |
| recent_context = "\n".join( | |
| f"{'User' if isinstance(i, UserMessageItem) else 'Assistant'}: {_user_message_text(i)}" | |
| for i in last_five | |
| ) | |
| return summary_text, recent_context | |
| async def respond( | |
| self, | |
| thread: ThreadMetadata, | |
| item: UserMessageItem | None, | |
| context: dict[str, Any], | |
| ) -> AsyncIterator[ThreadStreamEvent]: | |
| if item is None: | |
| return | |
| message_text = _user_message_text(item) | |
| if not message_text: | |
| return | |
| thread_key = self._resolve_thread_id(thread) | |
| try: | |
| user_data = self.agent_state.get_user(thread_key) | |
| request_context_enriched = { | |
| **(context or {}), | |
| } | |
| except Exception: | |
| request_context_enriched = context or {} | |
| session = SQLiteSession(thread_key) | |
| agent_context = AgentContext(thread=thread, store=self.store, request_context=request_context_enriched) | |
| combined_prompt = await self.prepare_conversation_context(thread_key, message_text, 'offline') | |
| result_stream = Runner.run_streamed(self.information_agent, combined_prompt, context=agent_context, session=session) | |
| async for event in stream_agent_response(agent_context, result_stream): | |
| yield event | |
| class google_CustomerSupportServer(ChatKitServer[dict[str, Any]]): | |
| def __init__( | |
| self, | |
| agent_state: UserStateManager, | |
| ) -> None: | |
| store = SQLiteStore(db_path="chatkit_threads.db") | |
| super().__init__(store) | |
| self.store = store | |
| self.agent_state = agent_state | |
| self.information_agent = build_google_information_agent() | |
| self.summarizer_agent = build_summarizer_agent() | |
| def _resolve_thread_id(self, thread: ThreadMetadata | None) -> str: | |
| return thread.id if thread and thread.id else DEFAULT_THREAD_ID | |
| async def prepare_conversation_context(self, thread_key: str, message_text: str, status:str) -> str: | |
| # Await handle_history which is now fast because summarization is backgrounded | |
| summary_text, recent_context = await self.handle_history(thread_key) | |
| user_data = self.agent_state.get_user(thread_key) | |
| customer_context = ( | |
| "Customer context:\n" | |
| f"- Name: {user_data.customer_name or ''}\n" | |
| f"- Email: {user_data.customer_email or ''}\n" | |
| f"- Phone: {user_data.customer_phone or ''}\n" | |
| f"- Timezone: {user_data.Timezone or ''}\n" | |
| ) | |
| if status.lower() == "offline": | |
| combined_prompt = ( | |
| f"{customer_context}\n" | |
| f"Previous summary:\n{summary_text}\n\n" | |
| f"Recent conversation (last 5 messages):\n{recent_context}\n\n" | |
| f"If the user asks to talk to a human sales agent, respond: " | |
| f"-This Company you are representing for : Sunmarke School\n" | |
| f"Current request: {message_text}\n" | |
| ) | |
| return combined_prompt | |
| combined_prompt = ( | |
| f"{customer_context}\n" | |
| f"Previous summary:\n{summary_text}\n\n" | |
| f"Recent conversation (last 5 messages):\n{recent_context}\n\n" | |
| f"-This Company you are representing for : Sunmarke School\n" | |
| f"Current request: {message_text}\n" | |
| ) | |
| return combined_prompt | |
| async def _async_summarize(self, thread_key: str, user_messages: list, previous_summary: str): | |
| """Background task to perform summarization without blocking the main flow.""" | |
| try: | |
| to_summarize = user_messages[:-5] | |
| combined_text = "\n".join( | |
| f"{'User' if isinstance(i, UserMessageItem) else 'Assistant'}: {_user_message_text(i)}" | |
| for i in to_summarize | |
| ) | |
| summarizer_prompt = ( | |
| f"Previous summary:\n{previous_summary}\n\n" | |
| f"Add the following messages into the summary:\n{combined_text}\n" | |
| f"Return a concise updated summary of the entire conversation." | |
| ) | |
| session = SQLiteSession(thread_key) | |
| result = await Runner.run( | |
| self.summarizer_agent, | |
| summarizer_prompt, | |
| session=session, | |
| ) | |
| self.agent_state.set_summary(thread_key, result.final_output) | |
| print(f"🧠 [BACKGROUND] Summary updated for thread: {thread_key}") | |
| except Exception as e: | |
| print(f"⚠️ [BACKGROUND] Summarization failed for thread {thread_key}: {e}") | |
| async def handle_history(self, thread_key: str) -> tuple[str, str]: | |
| """Handles message history, returns current state, and triggers summarization in background if needed.""" | |
| # 1. Fetch history from store (Fast) | |
| history = self.store._items(thread_key) | |
| user_messages = [i for i in history if isinstance(i, (UserMessageItem, AssistantMessageItem))] | |
| # Keep context within limits | |
| if len(user_messages) > 15: | |
| user_messages = user_messages[-15:] | |
| # 2. Get existing summary from persistence (Fast) | |
| summary_text = self.agent_state.get_summary(thread_key) | |
| # 3. Trigger summarization in background if criteria met (Non-blocking) | |
| if len(user_messages) >= 5 and len(user_messages) % 5 == 0: | |
| print(f"🧠 [HISTORY] Triggering background summarization for {thread_key}") | |
| asyncio.create_task(self._async_summarize(thread_key, user_messages, summary_text)) | |
| # 4. Compute recent context for the prompt (Fast) | |
| last_five = user_messages[-10:] | |
| recent_context = "\n".join( | |
| f"{'User' if isinstance(i, UserMessageItem) else 'Assistant'}: {_user_message_text(i)}" | |
| for i in last_five | |
| ) | |
| return summary_text, recent_context | |
| async def respond( | |
| self, | |
| thread: ThreadMetadata, | |
| item: UserMessageItem | None, | |
| context: dict[str, Any], | |
| ) -> AsyncIterator[ThreadStreamEvent]: | |
| if item is None: | |
| return | |
| message_text = _user_message_text(item) | |
| if not message_text: | |
| return | |
| thread_key = self._resolve_thread_id(thread) | |
| try: | |
| user_data = self.agent_state.get_user(thread_key) | |
| request_context_enriched = { | |
| **(context or {}), | |
| } | |
| except Exception: | |
| request_context_enriched = context or {} | |
| session = SQLiteSession(thread_key) | |
| agent_context = AgentContext(thread=thread, store=self.store, request_context=request_context_enriched) | |
| combined_prompt = await self.prepare_conversation_context(thread_key, message_text, 'offline') | |
| result_stream = Runner.run_streamed(self.information_agent, combined_prompt, context=agent_context, session=session) | |
| async for event in stream_agent_response(agent_context, result_stream): | |
| yield event | |
| state_manager = UserStateManager(db_path=os.getenv("USER_STATE_DB_PATH", "user_state.db")) | |
| support_server = deepseek_CustomerSupportServer(agent_state=state_manager) | |
| kimi_support_server = kimi_CustomerSupportServer(agent_state=state_manager) | |
| google_support_server = google_CustomerSupportServer(agent_state=state_manager) | |
| app = FastAPI(title="ChatKit Customer Support API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def get_server() -> deepseek_CustomerSupportServer: | |
| return support_server | |
| def get_kimi_server() -> kimi_CustomerSupportServer: | |
| return kimi_support_server | |
| def get_google_server() -> google_CustomerSupportServer: | |
| return google_support_server | |
| async def chatkit_endpoint( | |
| request: Request, server: deepseek_CustomerSupportServer = Depends(get_server) | |
| ) -> Response: | |
| payload = await request.body() | |
| result = await server.process(payload, {"request": request}) | |
| if isinstance(result, StreamingResult): | |
| return StreamingResponse(result, media_type="text/event-stream") | |
| if hasattr(result, "json"): | |
| return Response(content=result.json, media_type="application/json") | |
| return JSONResponse(result) | |
| async def chatkit_endpoint( | |
| request: Request, server: kimi_CustomerSupportServer = Depends(get_kimi_server) | |
| ) -> Response: | |
| payload = await request.body() | |
| result = await server.process(payload, {"request": request}) | |
| if isinstance(result, StreamingResult): | |
| return StreamingResponse(result, media_type="text/event-stream") | |
| if hasattr(result, "json"): | |
| return Response(content=result.json, media_type="application/json") | |
| return JSONResponse(result) | |
| async def chatkit_endpoint( | |
| request: Request, server: google_CustomerSupportServer = Depends(get_google_server) | |
| ) -> Response: | |
| payload = await request.body() | |
| result = await server.process(payload, {"request": request}) | |
| if isinstance(result, StreamingResult): | |
| return StreamingResponse(result, media_type="text/event-stream") | |
| if hasattr(result, "json"): | |
| return Response(content=result.json, media_type="application/json") | |
| return JSONResponse(result) | |
| async def chat_debug(request: Request): | |
| body = await request.body() | |
| print("RAW BODY RECEIVED:", body) | |
| return {"received": body.decode()} | |
| def _thread_param(thread_id: str | None) -> str: | |
| return thread_id or DEFAULT_THREAD_ID | |
| async def deepseek_customer_snapshot( | |
| thread_id: str | None = Query(None, description="ChatKit thread identifier"), | |
| server: deepseek_CustomerSupportServer = Depends(get_server), | |
| ) -> dict[str, Any]: | |
| key = _thread_param(thread_id) | |
| data = server.agent_state.to_dict(key) | |
| print("data", data) | |
| return {"customer": data} | |
| async def kimi_customer_snapshot( | |
| thread_id: str | None = Query(None, description="ChatKit thread identifier"), | |
| server: kimi_CustomerSupportServer = Depends(get_kimi_server), | |
| ) -> dict[str, Any]: | |
| key = _thread_param(thread_id) | |
| data = server.agent_state.to_dict(key) | |
| print("data", data) | |
| return {"customer": data} | |
| async def google_customer_snapshot( | |
| thread_id: str | None = Query(None, description="ChatKit thread identifier"), | |
| server: google_CustomerSupportServer = Depends(get_google_server), | |
| ) -> dict[str, Any]: | |
| key = _thread_param(thread_id) | |
| data = server.agent_state.to_dict(key) | |
| print("data", data) | |
| return {"customer": data} | |
| async def customer_snapshot( | |
| thread_id: str | None = Query(None, description="ChatKit thread identifier"), | |
| server: google_CustomerSupportServer = Depends(get_google_server), | |
| ) -> dict[str, Any]: | |
| key = _thread_param(thread_id) | |
| data = server.agent_state.to_dict(key) | |
| print("data", data) | |
| return {"customer": data} | |
| async def transcribe_audio(file: UploadFile = File(...)): | |
| """Transcribe audio using Groq Whisper model""" | |
| import tempfile | |
| import os | |
| from groq import Groq | |
| try: | |
| # Initialize Groq client with API key from environment | |
| groq_api_key = os.getenv("GROQ_API_KEY") | |
| if not groq_api_key: | |
| raise HTTPException(status_code=500, detail="GROQ_API_KEY not found in environment") | |
| client = Groq(api_key=groq_api_key) | |
| # Read audio file | |
| audio_data = await file.read() | |
| # Create temporary file with original extension or default to .webm | |
| file_extension = os.path.splitext(file.filename)[1] if file.filename else ".webm" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file: | |
| temp_file.write(audio_data) | |
| temp_file_path = temp_file.name | |
| # Transcribe with Groq Whisper | |
| with open(temp_file_path, "rb") as audio_file: | |
| transcription = client.audio.transcriptions.create( | |
| file=audio_file, | |
| model="whisper-large-v3-turbo", | |
| response_format="verbose_json", | |
| timestamp_granularities=["word", "segment"], | |
| language="en", | |
| temperature=0.0 | |
| ) | |
| # Cleanup temporary file | |
| os.unlink(temp_file_path) | |
| # Return transcription text and full details | |
| return { | |
| "text": transcription.text, | |
| "details": json.loads(json.dumps(transcription, default=str)) | |
| } | |
| except Exception as e: | |
| # Cleanup temp file if it exists | |
| if 'temp_file_path' in locals() and os.path.exists(temp_file_path): | |
| os.unlink(temp_file_path) | |
| print(f"❌ Transcription error: {str(e)}") | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}") | |
| async def deepseek_transcribe_audio(file: UploadFile = File(...)): | |
| """Transcribe audio using Groq Whisper model""" | |
| import tempfile | |
| import os | |
| from groq import Groq | |
| try: | |
| # Initialize Groq client with API key from environment | |
| groq_api_key = os.getenv("GROQ_API_KEY") | |
| if not groq_api_key: | |
| raise HTTPException(status_code=500, detail="GROQ_API_KEY not found in environment") | |
| client = Groq(api_key=groq_api_key) | |
| # Read audio file | |
| audio_data = await file.read() | |
| # Create temporary file with original extension or default to .webm | |
| file_extension = os.path.splitext(file.filename)[1] if file.filename else ".webm" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file: | |
| temp_file.write(audio_data) | |
| temp_file_path = temp_file.name | |
| # Transcribe with Groq Whisper | |
| with open(temp_file_path, "rb") as audio_file: | |
| transcription = client.audio.transcriptions.create( | |
| file=audio_file, | |
| model="whisper-large-v3-turbo", | |
| response_format="verbose_json", | |
| timestamp_granularities=["word", "segment"], | |
| language="en", | |
| temperature=0.0 | |
| ) | |
| # Cleanup temporary file | |
| os.unlink(temp_file_path) | |
| # Return transcription text and full details | |
| return { | |
| "text": transcription.text, | |
| "details": json.loads(json.dumps(transcription, default=str)) | |
| } | |
| except Exception as e: | |
| # Cleanup temp file if it exists | |
| if 'temp_file_path' in locals() and os.path.exists(temp_file_path): | |
| os.unlink(temp_file_path) | |
| print(f"❌ Transcription error: {str(e)}") | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}") | |
| async def google_transcribe_audio(file: UploadFile = File(...)): | |
| """Transcribe audio using Groq Whisper model""" | |
| import tempfile | |
| import os | |
| from groq import Groq | |
| try: | |
| # Initialize Groq client with API key from environment | |
| groq_api_key = os.getenv("GROQ_API_KEY") | |
| if not groq_api_key: | |
| raise HTTPException(status_code=500, detail="GROQ_API_KEY not found in environment") | |
| client = Groq(api_key=groq_api_key) | |
| # Read audio file | |
| audio_data = await file.read() | |
| # Create temporary file with original extension or default to .webm | |
| file_extension = os.path.splitext(file.filename)[1] if file.filename else ".webm" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file: | |
| temp_file.write(audio_data) | |
| temp_file_path = temp_file.name | |
| # Transcribe with Groq Whisper | |
| with open(temp_file_path, "rb") as audio_file: | |
| transcription = client.audio.transcriptions.create( | |
| file=audio_file, | |
| model="whisper-large-v3-turbo", | |
| response_format="verbose_json", | |
| timestamp_granularities=["word", "segment"], | |
| language="en", | |
| temperature=0.0 | |
| ) | |
| # Cleanup temporary file | |
| os.unlink(temp_file_path) | |
| # Return transcription text and full details | |
| return { | |
| "text": transcription.text, | |
| "details": json.loads(json.dumps(transcription, default=str)) | |
| } | |
| except Exception as e: | |
| # Cleanup temp file if it exists | |
| if 'temp_file_path' in locals() and os.path.exists(temp_file_path): | |
| os.unlink(temp_file_path) | |
| print(f"❌ Transcription error: {str(e)}") | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}") | |
| async def kimi_transcribe_audio(file: UploadFile = File(...)): | |
| """Transcribe audio using Groq Whisper model""" | |
| import tempfile | |
| import os | |
| from groq import Groq | |
| try: | |
| # Initialize Groq client with API key from environment | |
| groq_api_key = os.getenv("GROQ_API_KEY") | |
| if not groq_api_key: | |
| raise HTTPException(status_code=500, detail="GROQ_API_KEY not found in environment") | |
| client = Groq(api_key=groq_api_key) | |
| # Read audio file | |
| audio_data = await file.read() | |
| # Create temporary file with original extension or default to .webm | |
| file_extension = os.path.splitext(file.filename)[1] if file.filename else ".webm" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file: | |
| temp_file.write(audio_data) | |
| temp_file_path = temp_file.name | |
| # Transcribe with Groq Whisper | |
| with open(temp_file_path, "rb") as audio_file: | |
| transcription = client.audio.transcriptions.create( | |
| file=audio_file, | |
| model="whisper-large-v3-turbo", | |
| response_format="verbose_json", | |
| timestamp_granularities=["word", "segment"], | |
| language="en", | |
| temperature=0.0 | |
| ) | |
| # Cleanup temporary file | |
| os.unlink(temp_file_path) | |
| # Return transcription text and full details | |
| return { | |
| "text": transcription.text, | |
| "details": json.loads(json.dumps(transcription, default=str)) | |
| } | |
| except Exception as e: | |
| # Cleanup temp file if it exists | |
| if 'temp_file_path' in locals() and os.path.exists(temp_file_path): | |
| os.unlink(temp_file_path) | |
| print(f"❌ Transcription error: {str(e)}") | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}") | |
| async def deepseek_get_thread_messages( | |
| thread_id: str, | |
| server: deepseek_CustomerSupportServer = Depends(get_server) | |
| ): | |
| """Get last 10 messages for a specific thread.""" | |
| try: | |
| # Get all items from the thread | |
| items = server.store._items(thread_id) | |
| # Filter to only UserMessageItem and AssistantMessageItem | |
| messages = [] | |
| for item in items: | |
| if isinstance(item, (UserMessageItem, AssistantMessageItem)): | |
| message_dict = item.model_dump() | |
| # Ensure created_at is a string | |
| if hasattr(message_dict.get('created_at'), 'isoformat'): | |
| message_dict['created_at'] = message_dict['created_at'].isoformat() | |
| elif message_dict.get('created_at'): | |
| message_dict['created_at'] = str(message_dict['created_at']) | |
| messages.append(message_dict) | |
| # Get last 10 messages | |
| last_10_messages = messages[-10:] if len(messages) > 10 else messages | |
| return { | |
| "thread_id": thread_id, | |
| "total_message_count": len(messages), | |
| "returned_message_count": len(last_10_messages), | |
| "messages": last_10_messages | |
| } | |
| except Exception as e: | |
| print(f"❌ Error fetching messages for thread {thread_id}: {e}") | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"Failed to fetch messages: {str(e)}") | |
| async def get_thread_messages( | |
| thread_id: str, | |
| server: deepseek_CustomerSupportServer = Depends(get_server) | |
| ): | |
| """Get last 10 messages for a specific thread.""" | |
| try: | |
| # Get all items from the thread | |
| items = server.store._items(thread_id) | |
| # Filter to only UserMessageItem and AssistantMessageItem | |
| messages = [] | |
| for item in items: | |
| if isinstance(item, (UserMessageItem, AssistantMessageItem)): | |
| message_dict = item.model_dump() | |
| # Ensure created_at is a string | |
| if hasattr(message_dict.get('created_at'), 'isoformat'): | |
| message_dict['created_at'] = message_dict['created_at'].isoformat() | |
| elif message_dict.get('created_at'): | |
| message_dict['created_at'] = str(message_dict['created_at']) | |
| messages.append(message_dict) | |
| # Get last 10 messages | |
| last_10_messages = messages[-10:] if len(messages) > 10 else messages | |
| return { | |
| "thread_id": thread_id, | |
| "total_message_count": len(messages), | |
| "returned_message_count": len(last_10_messages), | |
| "messages": last_10_messages | |
| } | |
| except Exception as e: | |
| print(f"❌ Error fetching messages for thread {thread_id}: {e}") | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"Failed to fetch messages: {str(e)}") | |
| async def kimi_get_thread_messages( | |
| thread_id: str, | |
| server:kimi_CustomerSupportServer = Depends(get_kimi_server) | |
| ): | |
| """Get last 10 messages for a specific thread.""" | |
| try: | |
| # Get all items from the thread | |
| items = server.store._items(thread_id) | |
| # Filter to only UserMessageItem and AssistantMessageItem | |
| messages = [] | |
| for item in items: | |
| if isinstance(item, (UserMessageItem, AssistantMessageItem)): | |
| message_dict = item.model_dump() | |
| # Ensure created_at is a string | |
| if hasattr(message_dict.get('created_at'), 'isoformat'): | |
| message_dict['created_at'] = message_dict['created_at'].isoformat() | |
| elif message_dict.get('created_at'): | |
| message_dict['created_at'] = str(message_dict['created_at']) | |
| messages.append(message_dict) | |
| # Get last 10 messages | |
| last_10_messages = messages[-10:] if len(messages) > 10 else messages | |
| return { | |
| "thread_id": thread_id, | |
| "total_message_count": len(messages), | |
| "returned_message_count": len(last_10_messages), | |
| "messages": last_10_messages | |
| } | |
| except Exception as e: | |
| print(f"❌ Error fetching messages for thread {thread_id}: {e}") | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"Failed to fetch messages: {str(e)}") | |
| async def google_get_thread_messages( | |
| thread_id: str, | |
| server:google_CustomerSupportServer = Depends(get_google_server) | |
| ): | |
| """Get last 10 messages for a specific thread.""" | |
| try: | |
| # Get all items from the thread | |
| items = server.store._items(thread_id) | |
| # Filter to only UserMessageItem and AssistantMessageItem | |
| messages = [] | |
| for item in items: | |
| if isinstance(item, (UserMessageItem, AssistantMessageItem)): | |
| message_dict = item.model_dump() | |
| # Ensure created_at is a string | |
| if hasattr(message_dict.get('created_at'), 'isoformat'): | |
| message_dict['created_at'] = message_dict['created_at'].isoformat() | |
| elif message_dict.get('created_at'): | |
| message_dict['created_at'] = str(message_dict['created_at']) | |
| messages.append(message_dict) | |
| # Get last 10 messages | |
| last_10_messages = messages[-10:] if len(messages) > 10 else messages | |
| return { | |
| "thread_id": thread_id, | |
| "total_message_count": len(messages), | |
| "returned_message_count": len(last_10_messages), | |
| "messages": last_10_messages | |
| } | |
| except Exception as e: | |
| print(f"❌ Error fetching messages for thread {thread_id}: {e}") | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"Failed to fetch messages: {str(e)}") | |
| async def customer_update( | |
| request: Request, | |
| thread_id: str | None = Query(None, description="ChatKit thread identifier"), | |
| server: deepseek_CustomerSupportServer = Depends(get_server), | |
| ) -> dict[str, str]: | |
| try: | |
| payload = await request.json() | |
| except Exception: | |
| payload = {} | |
| key = _thread_param(thread_id) | |
| try: | |
| print(f"payload: {payload}") | |
| name = (payload.get("name") or "").strip() | |
| email = (payload.get("email") or "").strip() | |
| phone = (payload.get("phone") or "").strip() | |
| company_name = (payload.get("company_name") or payload.get("company") or "").strip() | |
| timezone = (payload.get("timeZone") or payload.get("timezone") or "").strip() | |
| server.agent_state.set_customer_info( | |
| key, | |
| name=name or None, | |
| email=email or None, | |
| phone=phone or None, | |
| company_name=company_name or None, | |
| ) | |
| if timezone: | |
| server.agent_state.set_timezone(key, timezone) | |
| # 🔥 Preload vector index for company to avoid 19s delay on first message | |
| return {"status": "ok"} | |
| except Exception: | |
| return {"status": "error"} | |
| async def health_check() -> dict[str, str]: | |
| return {"status": "healthy"} | |
| def root(): | |
| return { | |
| "status": "ok", | |
| "message": "ChatKit backend is running 🚀" | |
| } | |