"""Chat endpoint with streaming support.""" import asyncio import uuid from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.ext.asyncio import AsyncSession from src.db.postgres.connection import get_db from src.db.postgres.models import ChatMessage from src.agents.orchestration import orchestrator from src.agents.chatbot import chatbot from src.rag.retriever import retriever from src.db.redis.connection import get_redis from src.config.settings import settings from src.middlewares.logging import get_logger, log_execution from sse_starlette.sse import EventSourceResponse from langchain_core.messages import HumanMessage from pydantic import BaseModel from typing import List, Dict, Any, Optional import json _GREETINGS = frozenset(["hi", "hello", "hey", "halo", "hai", "hei"]) _GOODBYES = frozenset(["bye", "goodbye", "thanks", "thank you", "terima kasih", "sampai jumpa"]) def _fast_intent(message: str) -> Optional[dict]: """Bypass LLM orchestrator for obvious greetings and farewells.""" lower = message.lower().strip().rstrip("!.,?") if lower in _GREETINGS: return {"intent": "greeting", "needs_search": False, "direct_response": "Hello! How can I assist you today?", "search_query": ""} if lower in _GOODBYES: return {"intent": "goodbye", "needs_search": False, "direct_response": "Goodbye! Have a great day!", "search_query": ""} return None logger = get_logger("chat_api") router = APIRouter(prefix="/api/v1", tags=["Chat"]) class ChatRequest(BaseModel): user_id: str room_id: str message: str def _format_context(results: List[Dict[str, Any]]) -> str: """Format retrieval results as context string for the LLM.""" lines = [] for result in results: filename = result["metadata"].get("filename", "Unknown") page = result["metadata"].get("page_label") source_label = f"{filename}, p.{page}" if page else filename lines.append(f"[Source: {source_label}]\n{result['content']}\n") return "\n".join(lines) def _extract_sources(results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Extract deduplicated source references from retrieval results.""" seen = set() sources = [] for result in results: meta = result["metadata"] key = (meta.get("document_id"), meta.get("page_label")) if key not in seen: seen.add(key) sources.append({ "document_id": meta.get("document_id"), "filename": meta.get("filename", "Unknown"), "page_label": meta.get("page_label"), }) return sources async def get_cached_response(redis, cache_key: str) -> Optional[str]: cached = await redis.get(cache_key) if cached: return json.loads(cached) return None async def cache_response(redis, cache_key: str, response: str): await redis.setex(cache_key, 86400, json.dumps(response)) async def save_messages(db: AsyncSession, room_id: str, user_content: str, assistant_content: str): """Persist user and assistant messages to chat_messages table.""" db.add(ChatMessage(id=str(uuid.uuid4()), room_id=room_id, role="user", content=user_content)) db.add(ChatMessage(id=str(uuid.uuid4()), room_id=room_id, role="assistant", content=assistant_content)) await db.commit() @router.post("/chat/stream") @log_execution(logger) async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)): """Chat endpoint with streaming response. SSE event sequence: 1. sources — JSON array of {document_id, filename, page_label} 2. chunk — text fragments of the answer 3. done — signals end of stream """ redis = await get_redis() cache_key = f"{settings.redis_prefix}chat:{request.user_id}:{request.message}" cached = await get_cached_response(redis, cache_key) if cached: logger.info("Returning cached response") async def stream_cached(): yield {"event": "sources", "data": json.dumps([])} for i in range(0, len(cached), 50): yield {"event": "chunk", "data": cached[i:i + 50]} yield {"event": "done", "data": ""} return EventSourceResponse(stream_cached()) try: # Step 1: Fast local intent check (skips LLM for greetings/farewells) intent_result = _fast_intent(request.message) context = "" sources: List[Dict[str, Any]] = [] if intent_result is None: # Step 2: Launch retrieval optimistically while orchestrator decides in parallel retrieval_task = asyncio.create_task( retriever.retrieve(request.message, request.user_id, db) ) intent_result = await orchestrator.analyze_message(request.message) if not intent_result.get("needs_search"): retrieval_task.cancel() raw_results = [] else: search_query = intent_result.get("search_query", request.message) logger.info(f"Searching for: {search_query}") if search_query != request.message: retrieval_task.cancel() raw_results = await retriever.retrieve( query=search_query, user_id=request.user_id, db=db, ) else: raw_results = await retrieval_task context = _format_context(raw_results) sources = _extract_sources(raw_results) # Step 3: Direct response for greetings / non-document intents if intent_result.get("direct_response"): response = intent_result["direct_response"] await cache_response(redis, cache_key, response) await save_messages(db, request.room_id, request.message, response) async def stream_direct(): yield {"event": "sources", "data": json.dumps([])} yield {"event": "message", "data": response} return EventSourceResponse(stream_direct()) # Step 4: Stream answer token-by-token as LLM generates it messages = [HumanMessage(content=request.message)] async def stream_response(): full_response = "" yield {"event": "sources", "data": json.dumps(sources)} async for token in chatbot.astream_response(messages, context): full_response += token yield {"event": "chunk", "data": token} yield {"event": "done", "data": ""} await cache_response(redis, cache_key, full_response) await save_messages(db, request.room_id, request.message, full_response) return EventSourceResponse(stream_response()) except Exception as e: logger.error("Chat failed", error=str(e)) raise HTTPException(status_code=500, detail=f"Chat failed: {str(e)}")