Spaces:
Build error
Build error
| """Streaming chat orchestration utilities for the frontend voicebot.""" | |
| from __future__ import annotations | |
| import asyncio | |
| import logging | |
| import os | |
| from queue import Queue | |
| from threading import Lock, Thread | |
| from typing import Any, AsyncGenerator, Dict, Iterator, List, Optional | |
| from dotenv import load_dotenv | |
| from langfuse import Langfuse | |
| from langfuse.decorators import langfuse_context, observe | |
| import sys | |
| sys.path.append(os.path.abspath('./backend')) | |
| from models import LLMFinanceAnalyzer | |
| from functions import MongoHybridSearch | |
| from conversation_store import conversation_store | |
| from utils import get_device | |
| if get_device() == "mps": | |
| load_dotenv(override=True) | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") | |
| logger = logging.getLogger(__name__) | |
| langfuse = Langfuse( | |
| secret_key=os.getenv("LANGFUSE_SECRET_KEY"), | |
| public_key=os.getenv("LANGFUSE_PUBLIC_KEY"), | |
| host=os.getenv("LANGFUSE_HOST"), | |
| ) | |
| langfuse_context.configure(environment="development") | |
| try: | |
| llm_analyzer = LLMFinanceAnalyzer() | |
| search_engine = MongoHybridSearch() | |
| logger.info("Initialized LLM analyzer and Mongo hybrid search for streaming chat.") | |
| except Exception as exc: | |
| logger.critical("Failed to initialise backend components: %s", exc, exc_info=True) | |
| raise | |
| _stream_loop: Optional[asyncio.AbstractEventLoop] = None | |
| _stream_thread: Optional[Thread] = None | |
| _stream_loop_lock: "Lock" = Lock() | |
| def _loop_worker(loop: asyncio.AbstractEventLoop) -> None: | |
| asyncio.set_event_loop(loop) | |
| loop.run_forever() | |
| def _ensure_stream_loop() -> asyncio.AbstractEventLoop: | |
| global _stream_loop, _stream_thread | |
| with _stream_loop_lock: | |
| if _stream_loop is None or _stream_loop.is_closed(): | |
| _stream_loop = asyncio.new_event_loop() | |
| _stream_thread = Thread(target=_loop_worker, args=(_stream_loop,), daemon=True) | |
| _stream_thread.start() | |
| return _stream_loop | |
| def _create_truncated_history( | |
| full_conversation: List[Dict[str, str]], | |
| max_assistant_length: int, | |
| ) -> List[Dict[str, str]]: | |
| # truncated = [] | |
| # for msg in full_conversation: | |
| # processed = msg.copy() | |
| # if processed.get("role") == "assistant" and len(processed.get("content", "")) > max_assistant_length: | |
| # processed["content"] = processed["content"][:max_assistant_length] + "..." | |
| # truncated.append(processed) | |
| # return truncated | |
| return full_conversation | |
| def _generate_pseudo_conversation(conversation: List[Dict[str, str]]) -> List[Dict[str, str]]: | |
| pseudo = "".join(f"{msg.get('role', 'unknown')}: {msg.get('content', '')}\n" for msg in conversation) | |
| return [{"role": "user", "content": pseudo.strip()}] | |
| async def _stream_chat_async( | |
| session_id: str, | |
| message: str, | |
| persona: str, | |
| persona_state: Optional[Dict[str, Any]] = None, | |
| user_info: Optional[Any] = None, | |
| mode: Optional[Any] = None, | |
| ) -> AsyncGenerator[str, None]: | |
| try: | |
| stored_history = await conversation_store.get_history(session_id) | |
| except Exception as exc: # noqa: BLE001 | |
| logger.error("Failed to load conversation history for %s: %s", session_id, exc, exc_info=True) | |
| stored_history = [] | |
| if session_id: | |
| try: | |
| await conversation_store.upsert_session_metadata( | |
| session_id, | |
| persona=persona_state, | |
| user_info=user_info, | |
| ) | |
| except Exception as exc: # noqa: BLE001 | |
| logger.error("Failed to persist session metadata for %s: %s", session_id, exc, exc_info=True) | |
| full_conversation = [msg.copy() for msg in stored_history] + [{"role": "user", "content": message}] | |
| if message and session_id: | |
| try: | |
| await conversation_store.append_messages(session_id, [{"role": "user", "content": message}]) | |
| except Exception as exc: # noqa: BLE001 | |
| logger.error("Failed to persist user message for %s: %s", session_id, exc, exc_info=True) | |
| truncated_history = _create_truncated_history(full_conversation, 300) | |
| pseudo_conversation = _generate_pseudo_conversation(truncated_history) | |
| rag_decision = "yes" | |
| logger.info("RAG decision: %s", rag_decision) | |
| if rag_decision == "yes": | |
| # query = await llm_analyzer.generate_subquery(pseudo_conversation) | |
| # if query is None: | |
| # yield "ขออภัยค่ะ ไม่สามารถวิเคราะห์คำถามเพื่อดึงข้อมูลได้" | |
| # return | |
| # retrieved_data = "" | |
| # if query: | |
| # try: | |
| # docs = await search_engine.search_documents(query) | |
| # retrieved_data = "\n-------\n".join(docs) | |
| # logger.info("Retrieved %d documents for streaming response.", len(docs)) | |
| # except Exception as search_err: | |
| # logger.error("Error during document search: %s", search_err, exc_info=True) | |
| # yield "ขออภัยค่ะ เกิดข้อผิดพลาดขณะค้นหาข้อมูล" | |
| # return | |
| # if len(full_conversation) > 7: | |
| # if full_conversation[-7].get("role") == "tool": | |
| # limited_conversation = full_conversation | |
| # else: | |
| # limited_conversation = full_conversation[-7:] | |
| # else: | |
| # limited_conversation = full_conversation | |
| limited_conversation = full_conversation | |
| response_generator = llm_analyzer.generate_normal_response( | |
| limited_conversation, | |
| persona, | |
| user_info, | |
| session_id=session_id, | |
| mode=mode | |
| ) | |
| async for chunk in response_generator: | |
| if chunk: | |
| yield chunk | |
| await asyncio.sleep(0.05) | |
| else: | |
| limited_conversation = full_conversation[-9:] if len(full_conversation) > 9 else full_conversation | |
| final_response = await llm_analyzer.generate_non_rag_response( | |
| limited_conversation, | |
| session_id=session_id, | |
| ) | |
| if final_response: | |
| yield final_response | |
| else: | |
| yield "ขออภัยค่ะ เกิดข้อผิดพลาดในการประมวลผลคำถามของคุณ" | |
| def stream_chat_response( | |
| session_id: str, | |
| message: str, | |
| persona: str, | |
| persona_state: Optional[Dict[str, Any]] = None, | |
| user_info: Optional[Any] = None, | |
| mode: Optional[Any] = None, | |
| ) -> Iterator[str]: | |
| """Synchronously iterate over streaming LLM chunks.""" | |
| loop = _ensure_stream_loop() | |
| output_queue: "Queue[Optional[str]]" = Queue() | |
| async def runner() -> None: | |
| try: | |
| async for chunk in _stream_chat_async(session_id, message, persona, persona_state, user_info,mode): | |
| output_queue.put_nowait(chunk) | |
| except Exception as exc: # noqa: BLE001 | |
| logger.error("Unhandled error in async chat stream: %s", exc, exc_info=True) | |
| output_queue.put_nowait(f"[Error: {exc}]") | |
| finally: | |
| output_queue.put_nowait(None) | |
| future = asyncio.run_coroutine_threadsafe(runner(), loop) | |
| while True: | |
| chunk = output_queue.get() | |
| if chunk is None: | |
| break | |
| yield chunk | |
| # Propagate any exception that was not handled in runner(). | |
| future.result() | |
| __all__ = ["stream_chat_response"] | |