ishaq101's picture
feat/Catalog Retrieval System (#1)
6bff5d9
raw
history blame
6.04 kB
"""Chat endpoint with streaming support."""
import uuid
import json
from typing import List, Dict, Any, Optional
from fastapi import APIRouter, Depends, HTTPException
from langchain_core.messages import HumanMessage, AIMessage
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sse_starlette.sse import EventSourceResponse
from src.agents.chat_handler import ChatHandler
from src.config.settings import settings
from src.db.postgres.connection import get_db
from src.db.postgres.models import ChatMessage, MessageSource
from src.db.redis.connection import get_redis
from src.middlewares.logging import get_logger, log_execution
logger = get_logger("chat_api")
router = APIRouter(prefix="/api/v1", tags=["Chat"])
_GREETINGS = frozenset(["hi", "hello", "hey", "halo", "hai", "hei"])
_GOODBYES = frozenset(["bye", "goodbye", "thanks", "thank you", "terima kasih", "sampai jumpa"])
def _fast_intent(message: str) -> Optional[str]:
"""Return a direct response for obvious greetings/farewells, else None."""
lower = message.lower().strip().rstrip("!.,?")
if lower in _GREETINGS:
return "Hello! How can I assist you today?"
if lower in _GOODBYES:
return "Goodbye! Have a great day!"
return None
class ChatRequest(BaseModel):
user_id: str
room_id: str
message: str
async def get_cached_response(redis, cache_key: str) -> Optional[str]:
cached = await redis.get(cache_key)
if cached:
return json.loads(cached)
return None
async def cache_response(redis, cache_key: str, response: str):
await redis.setex(cache_key, 86400, json.dumps(response))
async def load_history(db: AsyncSession, room_id: str, limit: int = 10) -> list:
"""Load recent chat messages for a room as LangChain message objects (oldest-first)."""
result = await db.execute(
select(ChatMessage)
.where(ChatMessage.room_id == room_id)
.order_by(ChatMessage.created_at.asc())
.limit(limit)
)
rows = result.scalars().all()
return [
HumanMessage(content=row.content) if row.role == "user" else AIMessage(content=row.content)
for row in rows
]
async def save_messages(
db: AsyncSession,
room_id: str,
user_content: str,
assistant_content: str,
sources: Optional[List[Dict[str, Any]]] = None,
):
"""Persist user and assistant messages, and attach sources to the assistant message."""
db.add(ChatMessage(id=str(uuid.uuid4()), room_id=room_id, role="user", content=user_content))
assistant_id = str(uuid.uuid4())
db.add(ChatMessage(id=assistant_id, room_id=room_id, role="assistant", content=assistant_content))
for src in (sources or []):
page = src.get("page_label")
db.add(MessageSource(
id=str(uuid.uuid4()),
message_id=assistant_id,
document_id=src.get("document_id"),
filename=src.get("filename"),
page_label=str(page) if page is not None else None,
))
await db.commit()
@router.post("/chat/stream")
@log_execution(logger)
async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
"""Chat endpoint with streaming response.
SSE event sequence:
1. sources — JSON array of source refs from ChatHandler (table for
structured; deduped document_id/page_label for unstructured)
2. chunk — text fragments of the answer
3. done — signals end of stream
"""
redis = await get_redis()
cache_key = f"{settings.redis_prefix}chat:{request.room_id}:{request.message}"
# Redis cache hit
cached = await get_cached_response(redis, cache_key)
if cached:
logger.info("Returning cached response")
async def stream_cached():
yield {"event": "sources", "data": json.dumps([])}
for i in range(0, len(cached), 50):
yield {"event": "chunk", "data": cached[i:i + 50]}
yield {"event": "done", "data": ""}
return EventSourceResponse(stream_cached())
try:
# Fast intent: greetings/farewells bypass LLM entirely
direct = _fast_intent(request.message)
if direct:
await cache_response(redis, cache_key, direct)
await save_messages(db, request.room_id, request.message, direct, sources=[])
async def stream_direct():
yield {"event": "sources", "data": json.dumps([])}
yield {"event": "chunk", "data": direct}
yield {"event": "done", "data": ""}
return EventSourceResponse(stream_direct())
history = await load_history(db, request.room_id, limit=10)
handler = ChatHandler()
async def stream_response():
full_response = ""
sources: List[Dict[str, Any]] = []
async for event in handler.handle(request.message, request.user_id, history):
if event["event"] == "sources":
try:
sources = json.loads(event["data"]) or []
except (TypeError, ValueError):
sources = []
yield event
elif event["event"] == "chunk":
full_response += event["data"]
yield event
elif event["event"] == "done":
await cache_response(redis, cache_key, full_response)
await save_messages(db, request.room_id, request.message, full_response, sources=sources)
yield event
elif event["event"] == "error":
yield event
return
# "intent" event: consumed internally, not forwarded to frontend
return EventSourceResponse(stream_response())
except Exception as e:
logger.error("Chat failed", error=str(e))
raise HTTPException(status_code=500, detail=f"Chat failed: {str(e)}")