deepfake-fastapi / app /chatbot.py
ShunTay12
Delete model files to prevent Binary files
cfcf570
"""
Chatbot API routes.
"""
import logging
from uuid import uuid4
from fastapi import APIRouter, Depends, HTTPException
from psycopg import Connection
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_tavily import TavilySearch
from app.core.chatbot.config import CHAT_HISTORY_TABLE
from app.core.chatbot.chains import general_chain, news_chain
from app.core.chatbot.database import db_connection_dependency
from app.schemas.chat import ChatRequest
from app.services.chatbot.history import delete_guest_session, get_message_history
from app.services.chatbot.search import (
extract_search_query,
format_search_results,
get_tavily_search,
should_use_search,
)
logger = logging.getLogger(__name__)
chat = APIRouter()
@chat.post("/chat")
async def chat_endpoint(
request: ChatRequest,
tavily_search: TavilySearch = Depends(get_tavily_search),
):
"""
Chat endpoint with intelligent query processing.
Flow:
1. LLM classifies if query needs web search
2. If NO search needed -> respond directly with LLM knowledge
3. If search needed -> extract optimized search query -> search with Tavily -> respond with sources
Returns:
response: The LLM's response
session_id: Session ID for conversation history
used_search: Whether web search was used
search_reason: Why search was/wasn't used
"""
try:
# Ensure every session has an id for history tracking (DB or in-memory)
session_id = request.session_id or str(uuid4())
save_to_db = request.save_to_db
# Step 1: LLM decides if search is needed
needs_search, search_reason = should_use_search(request.query)
logger.info("Search needed: %s, Reason: %s", needs_search, search_reason)
def get_history(sid: str):
return get_message_history(sid, save_to_db)
# Step 2: Take appropriate path
if not needs_search:
# Respond directly without search
chain_with_history = RunnableWithMessageHistory(
general_chain,
get_history,
input_messages_key="question",
history_messages_key="chat_history",
)
response = chain_with_history.invoke(
{"question": request.query},
config={"configurable": {"session_id": session_id}},
)
else:
# Step 2a: Extract optimized search query
optimized_query = extract_search_query(request.query)
logger.info("Optimized search query: %s", optimized_query)
# Step 2b: Search with optimized query
search_results = tavily_search.invoke(optimized_query)
formatted_results = format_search_results(search_results)
logger.debug("Search results: %s", search_results)
# Step 2c: Respond with sources
chain_with_history = RunnableWithMessageHistory(
news_chain,
get_history,
input_messages_key="question",
history_messages_key="chat_history",
)
response = chain_with_history.invoke(
{
"question": request.query,
"search_results": formatted_results,
},
config={"configurable": {"session_id": session_id}},
)
logger.debug("Response: %s", response.content)
return {
"response": {
"content": response.content,
},
"session_id": session_id,
"used_search": needs_search,
"search_reason": search_reason,
}
except HTTPException:
raise
except Exception as exc: # pragma: no cover - defensive server guard
logger.exception("Unhandled error in chat_endpoint")
raise HTTPException(status_code=500, detail=str(exc))
@chat.delete("/chat/{session_id}")
async def delete_chat_history(session_id: str, conn: Connection = Depends(db_connection_dependency)):
"""Delete chat history for a specific session."""
try:
# Delete from in-memory guest sessions
delete_guest_session(session_id)
# Delete from database
with conn.cursor() as cursor:
cursor.execute(
f'DELETE FROM "{CHAT_HISTORY_TABLE}" WHERE session_id = %s',
(session_id,),
)
deleted_count = cursor.rowcount
conn.commit()
return {
"success": True,
"deleted_count": deleted_count,
"session_id": session_id,
}
except Exception as e:
logger.exception("Error in delete_chat_history")
raise HTTPException(status_code=500, detail=str(e))