"""Chat endpoint with streaming support.""" import uuid import json from typing import List, Dict, Any, Optional from fastapi import APIRouter, Depends, HTTPException from langchain_core.messages import HumanMessage, AIMessage from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sse_starlette.sse import EventSourceResponse from src.agents.chat_handler import ChatHandler from src.config.settings import settings from src.db.postgres.connection import get_db from src.db.postgres.models import ChatMessage, MessageSource from src.db.redis.connection import get_redis from src.middlewares.logging import get_logger, log_execution logger = get_logger("chat_api") router = APIRouter(prefix="/api/v1", tags=["Chat"]) _GREETINGS = frozenset(["hi", "hello", "hey", "halo", "hai", "hei"]) _GOODBYES = frozenset(["bye", "goodbye", "thanks", "thank you", "terima kasih", "sampai jumpa"]) def _fast_intent(message: str) -> Optional[str]: """Return a direct response for obvious greetings/farewells, else None.""" lower = message.lower().strip().rstrip("!.,?") if lower in _GREETINGS: return "Hello! How can I assist you today?" if lower in _GOODBYES: return "Goodbye! Have a great day!" return None class ChatRequest(BaseModel): user_id: str room_id: str message: str async def get_cached_response(redis, cache_key: str) -> Optional[str]: cached = await redis.get(cache_key) if cached: return json.loads(cached) return None async def cache_response(redis, cache_key: str, response: str): await redis.setex(cache_key, 86400, json.dumps(response)) async def load_history(db: AsyncSession, room_id: str, limit: int = 10) -> list: """Load recent chat messages for a room as LangChain message objects (oldest-first).""" result = await db.execute( select(ChatMessage) .where(ChatMessage.room_id == room_id) .order_by(ChatMessage.created_at.asc()) .limit(limit) ) rows = result.scalars().all() return [ HumanMessage(content=row.content) if row.role == "user" else AIMessage(content=row.content) for row in rows ] async def save_messages( db: AsyncSession, room_id: str, user_content: str, assistant_content: str, sources: Optional[List[Dict[str, Any]]] = None, ): """Persist user and assistant messages, and attach sources to the assistant message.""" db.add(ChatMessage(id=str(uuid.uuid4()), room_id=room_id, role="user", content=user_content)) assistant_id = str(uuid.uuid4()) db.add(ChatMessage(id=assistant_id, room_id=room_id, role="assistant", content=assistant_content)) for src in (sources or []): page = src.get("page_label") db.add(MessageSource( id=str(uuid.uuid4()), message_id=assistant_id, document_id=src.get("document_id"), filename=src.get("filename"), page_label=str(page) if page is not None else None, )) await db.commit() @router.post("/chat/stream") @log_execution(logger) async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)): """Chat endpoint with streaming response. SSE event sequence: 1. sources — JSON array of source refs from ChatHandler (table for structured; deduped document_id/page_label for unstructured) 2. chunk — text fragments of the answer 3. done — signals end of stream """ redis = await get_redis() cache_key = f"{settings.redis_prefix}chat:{request.room_id}:{request.message}" # Redis cache hit cached = await get_cached_response(redis, cache_key) if cached: logger.info("Returning cached response") async def stream_cached(): yield {"event": "sources", "data": json.dumps([])} for i in range(0, len(cached), 50): yield {"event": "chunk", "data": cached[i:i + 50]} yield {"event": "done", "data": ""} return EventSourceResponse(stream_cached()) try: # Fast intent: greetings/farewells bypass LLM entirely direct = _fast_intent(request.message) if direct: await cache_response(redis, cache_key, direct) await save_messages(db, request.room_id, request.message, direct, sources=[]) async def stream_direct(): yield {"event": "sources", "data": json.dumps([])} yield {"event": "chunk", "data": direct} yield {"event": "done", "data": ""} return EventSourceResponse(stream_direct()) history = await load_history(db, request.room_id, limit=10) handler = ChatHandler() async def stream_response(): full_response = "" sources: List[Dict[str, Any]] = [] async for event in handler.handle(request.message, request.user_id, history): if event["event"] == "sources": try: sources = json.loads(event["data"]) or [] except (TypeError, ValueError): sources = [] yield event elif event["event"] == "chunk": full_response += event["data"] yield event elif event["event"] == "done": await cache_response(redis, cache_key, full_response) await save_messages(db, request.room_id, request.message, full_response, sources=sources) yield event elif event["event"] == "error": yield event return # "intent" event: consumed internally, not forwarded to frontend return EventSourceResponse(stream_response()) except Exception as e: logger.error("Chat failed", error=str(e)) raise HTTPException(status_code=500, detail=f"Chat failed: {str(e)}")