aia / backend /main.py
MickMick102's picture
add: add unit linked mode
be04392
"""Streaming chat orchestration utilities for the frontend voicebot."""
from __future__ import annotations
import asyncio
import logging
import os
from queue import Queue
from threading import Lock, Thread
from typing import Any, AsyncGenerator, Dict, Iterator, List, Optional
from dotenv import load_dotenv
from langfuse import Langfuse
from langfuse.decorators import langfuse_context, observe
import sys
sys.path.append(os.path.abspath('./backend'))
from models import LLMFinanceAnalyzer
from functions import MongoHybridSearch
from conversation_store import conversation_store
from utils import get_device
if get_device() == "mps":
load_dotenv(override=True)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
langfuse = Langfuse(
secret_key=os.getenv("LANGFUSE_SECRET_KEY"),
public_key=os.getenv("LANGFUSE_PUBLIC_KEY"),
host=os.getenv("LANGFUSE_HOST"),
)
langfuse_context.configure(environment="development")
try:
llm_analyzer = LLMFinanceAnalyzer()
search_engine = MongoHybridSearch()
logger.info("Initialized LLM analyzer and Mongo hybrid search for streaming chat.")
except Exception as exc:
logger.critical("Failed to initialise backend components: %s", exc, exc_info=True)
raise
_stream_loop: Optional[asyncio.AbstractEventLoop] = None
_stream_thread: Optional[Thread] = None
_stream_loop_lock: "Lock" = Lock()
def _loop_worker(loop: asyncio.AbstractEventLoop) -> None:
asyncio.set_event_loop(loop)
loop.run_forever()
def _ensure_stream_loop() -> asyncio.AbstractEventLoop:
global _stream_loop, _stream_thread
with _stream_loop_lock:
if _stream_loop is None or _stream_loop.is_closed():
_stream_loop = asyncio.new_event_loop()
_stream_thread = Thread(target=_loop_worker, args=(_stream_loop,), daemon=True)
_stream_thread.start()
return _stream_loop
def _create_truncated_history(
full_conversation: List[Dict[str, str]],
max_assistant_length: int,
) -> List[Dict[str, str]]:
# truncated = []
# for msg in full_conversation:
# processed = msg.copy()
# if processed.get("role") == "assistant" and len(processed.get("content", "")) > max_assistant_length:
# processed["content"] = processed["content"][:max_assistant_length] + "..."
# truncated.append(processed)
# return truncated
return full_conversation
def _generate_pseudo_conversation(conversation: List[Dict[str, str]]) -> List[Dict[str, str]]:
pseudo = "".join(f"{msg.get('role', 'unknown')}: {msg.get('content', '')}\n" for msg in conversation)
return [{"role": "user", "content": pseudo.strip()}]
@observe()
async def _stream_chat_async(
session_id: str,
message: str,
persona: str,
persona_state: Optional[Dict[str, Any]] = None,
user_info: Optional[Any] = None,
mode: Optional[Any] = None,
) -> AsyncGenerator[str, None]:
try:
stored_history = await conversation_store.get_history(session_id)
except Exception as exc: # noqa: BLE001
logger.error("Failed to load conversation history for %s: %s", session_id, exc, exc_info=True)
stored_history = []
if session_id:
try:
await conversation_store.upsert_session_metadata(
session_id,
persona=persona_state,
user_info=user_info,
)
except Exception as exc: # noqa: BLE001
logger.error("Failed to persist session metadata for %s: %s", session_id, exc, exc_info=True)
full_conversation = [msg.copy() for msg in stored_history] + [{"role": "user", "content": message}]
if message and session_id:
try:
await conversation_store.append_messages(session_id, [{"role": "user", "content": message}])
except Exception as exc: # noqa: BLE001
logger.error("Failed to persist user message for %s: %s", session_id, exc, exc_info=True)
truncated_history = _create_truncated_history(full_conversation, 300)
pseudo_conversation = _generate_pseudo_conversation(truncated_history)
rag_decision = "yes"
logger.info("RAG decision: %s", rag_decision)
if rag_decision == "yes":
# query = await llm_analyzer.generate_subquery(pseudo_conversation)
# if query is None:
# yield "ขออภัยค่ะ ไม่สามารถวิเคราะห์คำถามเพื่อดึงข้อมูลได้"
# return
# retrieved_data = ""
# if query:
# try:
# docs = await search_engine.search_documents(query)
# retrieved_data = "\n-------\n".join(docs)
# logger.info("Retrieved %d documents for streaming response.", len(docs))
# except Exception as search_err:
# logger.error("Error during document search: %s", search_err, exc_info=True)
# yield "ขออภัยค่ะ เกิดข้อผิดพลาดขณะค้นหาข้อมูล"
# return
# if len(full_conversation) > 7:
# if full_conversation[-7].get("role") == "tool":
# limited_conversation = full_conversation
# else:
# limited_conversation = full_conversation[-7:]
# else:
# limited_conversation = full_conversation
limited_conversation = full_conversation
response_generator = llm_analyzer.generate_normal_response(
limited_conversation,
persona,
user_info,
session_id=session_id,
mode=mode
)
async for chunk in response_generator:
if chunk:
yield chunk
await asyncio.sleep(0.05)
else:
limited_conversation = full_conversation[-9:] if len(full_conversation) > 9 else full_conversation
final_response = await llm_analyzer.generate_non_rag_response(
limited_conversation,
session_id=session_id,
)
if final_response:
yield final_response
else:
yield "ขออภัยค่ะ เกิดข้อผิดพลาดในการประมวลผลคำถามของคุณ"
def stream_chat_response(
session_id: str,
message: str,
persona: str,
persona_state: Optional[Dict[str, Any]] = None,
user_info: Optional[Any] = None,
mode: Optional[Any] = None,
) -> Iterator[str]:
"""Synchronously iterate over streaming LLM chunks."""
loop = _ensure_stream_loop()
output_queue: "Queue[Optional[str]]" = Queue()
async def runner() -> None:
try:
async for chunk in _stream_chat_async(session_id, message, persona, persona_state, user_info,mode):
output_queue.put_nowait(chunk)
except Exception as exc: # noqa: BLE001
logger.error("Unhandled error in async chat stream: %s", exc, exc_info=True)
output_queue.put_nowait(f"[Error: {exc}]")
finally:
output_queue.put_nowait(None)
future = asyncio.run_coroutine_threadsafe(runner(), loop)
while True:
chunk = output_queue.get()
if chunk is None:
break
yield chunk
# Propagate any exception that was not handled in runner().
future.result()
__all__ = ["stream_chat_response"]