| """Chat endpoint with streaming support.""" |
|
|
| import asyncio |
| import uuid |
| from fastapi import APIRouter, Depends, HTTPException |
| from sqlalchemy.ext.asyncio import AsyncSession |
| from src.db.postgres.connection import get_db |
| from src.db.postgres.models import ChatMessage |
| from src.agents.orchestration import orchestrator |
| from src.agents.chatbot import chatbot |
| from src.rag.retriever import retriever |
| from src.db.redis.connection import get_redis |
| from src.config.settings import settings |
| from src.middlewares.logging import get_logger, log_execution |
| from sse_starlette.sse import EventSourceResponse |
| from langchain_core.messages import HumanMessage |
| from pydantic import BaseModel |
| from typing import List, Dict, Any, Optional |
| import json |
|
|
| _GREETINGS = frozenset(["hi", "hello", "hey", "halo", "hai", "hei"]) |
| _GOODBYES = frozenset(["bye", "goodbye", "thanks", "thank you", "terima kasih", "sampai jumpa"]) |
|
|
|
|
| def _fast_intent(message: str) -> Optional[dict]: |
| """Bypass LLM orchestrator for obvious greetings and farewells.""" |
| lower = message.lower().strip().rstrip("!.,?") |
| if lower in _GREETINGS: |
| return {"intent": "greeting", "needs_search": False, |
| "direct_response": "Hello! How can I assist you today?", "search_query": ""} |
| if lower in _GOODBYES: |
| return {"intent": "goodbye", "needs_search": False, |
| "direct_response": "Goodbye! Have a great day!", "search_query": ""} |
| return None |
|
|
| logger = get_logger("chat_api") |
|
|
| router = APIRouter(prefix="/api/v1", tags=["Chat"]) |
|
|
|
|
| class ChatRequest(BaseModel): |
| user_id: str |
| room_id: str |
| message: str |
|
|
|
|
| def _format_context(results: List[Dict[str, Any]]) -> str: |
| """Format retrieval results as context string for the LLM.""" |
| lines = [] |
| for result in results: |
| filename = result["metadata"].get("filename", "Unknown") |
| page = result["metadata"].get("page_label") |
| source_label = f"{filename}, p.{page}" if page else filename |
| lines.append(f"[Source: {source_label}]\n{result['content']}\n") |
| return "\n".join(lines) |
|
|
|
|
| def _extract_sources(results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| """Extract deduplicated source references from retrieval results.""" |
| seen = set() |
| sources = [] |
| for result in results: |
| meta = result["metadata"] |
| key = (meta.get("document_id"), meta.get("page_label")) |
| if key not in seen: |
| seen.add(key) |
| sources.append({ |
| "document_id": meta.get("document_id"), |
| "filename": meta.get("filename", "Unknown"), |
| "page_label": meta.get("page_label"), |
| }) |
| return sources |
|
|
|
|
| async def get_cached_response(redis, cache_key: str) -> Optional[str]: |
| cached = await redis.get(cache_key) |
| if cached: |
| return json.loads(cached) |
| return None |
|
|
|
|
| async def cache_response(redis, cache_key: str, response: str): |
| await redis.setex(cache_key, 86400, json.dumps(response)) |
|
|
|
|
| async def save_messages(db: AsyncSession, room_id: str, user_content: str, assistant_content: str): |
| """Persist user and assistant messages to chat_messages table.""" |
| db.add(ChatMessage(id=str(uuid.uuid4()), room_id=room_id, role="user", content=user_content)) |
| db.add(ChatMessage(id=str(uuid.uuid4()), room_id=room_id, role="assistant", content=assistant_content)) |
| await db.commit() |
|
|
|
|
| @router.post("/chat/stream") |
| @log_execution(logger) |
| async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)): |
| """Chat endpoint with streaming response. |
| |
| SSE event sequence: |
| 1. sources — JSON array of {document_id, filename, page_label} |
| 2. chunk — text fragments of the answer |
| 3. done — signals end of stream |
| """ |
| redis = await get_redis() |
|
|
| cache_key = f"{settings.redis_prefix}chat:{request.user_id}:{request.message}" |
| cached = await get_cached_response(redis, cache_key) |
| if cached: |
| logger.info("Returning cached response") |
|
|
| async def stream_cached(): |
| yield {"event": "sources", "data": json.dumps([])} |
| for i in range(0, len(cached), 50): |
| yield {"event": "chunk", "data": cached[i:i + 50]} |
| yield {"event": "done", "data": ""} |
|
|
| return EventSourceResponse(stream_cached()) |
|
|
| try: |
| |
| intent_result = _fast_intent(request.message) |
|
|
| context = "" |
| sources: List[Dict[str, Any]] = [] |
|
|
| if intent_result is None: |
| |
| retrieval_task = asyncio.create_task( |
| retriever.retrieve(request.message, request.user_id, db) |
| ) |
| intent_result = await orchestrator.analyze_message(request.message) |
|
|
| if not intent_result.get("needs_search"): |
| retrieval_task.cancel() |
| raw_results = [] |
| else: |
| search_query = intent_result.get("search_query", request.message) |
| logger.info(f"Searching for: {search_query}") |
| if search_query != request.message: |
| retrieval_task.cancel() |
| raw_results = await retriever.retrieve( |
| query=search_query, |
| user_id=request.user_id, |
| db=db, |
| ) |
| else: |
| raw_results = await retrieval_task |
|
|
| context = _format_context(raw_results) |
| sources = _extract_sources(raw_results) |
|
|
| |
| if intent_result.get("direct_response"): |
| response = intent_result["direct_response"] |
| await cache_response(redis, cache_key, response) |
| await save_messages(db, request.room_id, request.message, response) |
|
|
| async def stream_direct(): |
| yield {"event": "sources", "data": json.dumps([])} |
| yield {"event": "message", "data": response} |
|
|
| return EventSourceResponse(stream_direct()) |
|
|
| |
| messages = [HumanMessage(content=request.message)] |
|
|
| async def stream_response(): |
| full_response = "" |
| yield {"event": "sources", "data": json.dumps(sources)} |
| async for token in chatbot.astream_response(messages, context): |
| full_response += token |
| yield {"event": "chunk", "data": token} |
| yield {"event": "done", "data": ""} |
| await cache_response(redis, cache_key, full_response) |
| await save_messages(db, request.room_id, request.message, full_response) |
|
|
| return EventSourceResponse(stream_response()) |
|
|
| except Exception as e: |
| logger.error("Chat failed", error=str(e)) |
| raise HTTPException(status_code=500, detail=f"Chat failed: {str(e)}") |
|
|