backend / app /rag /routes.py
precison9's picture
integrate RabbitMQ with CloudAMQP
3557eaa
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, Form
from fastapi.responses import StreamingResponse
from motor.motor_asyncio import AsyncIOMotorDatabase
from bson import ObjectId
from groq import Groq
import json
import logging
from datetime import datetime
from app.database.connection import get_db
from app.database.schemas import ConversationDB
from app.auth.routes import get_current_user
from app.auth.models import UserPublic
from app.rag.models import ALLOWED_MODELS, Message
from app.rag.rag_processor import build_context_from_files, web_search
from app.config import settings
from app.rabbitmq.publishers import (
publish_conversation_created,
publish_message_sent,
)
router = APIRouter(tags=["RAG Chat"])
logger = logging.getLogger(__name__)
SYSTEM_PROMPT = """You are a helpful assistant. Use the provided context if relevant. If web search is enabled and you need up-to-date information, use the web_search tool. Reason step-by-step before deciding to use tools."""
WEB_SEARCH_TOOL = {
"type": "function",
"function": {
"name": "web_search",
"description": "Search the web using DuckDuckGo for up-to-date information.",
"parameters": {
"type": "object",
"properties": {"query": {"type": "string", "description": "The search query"}},
"required": ["query"],
},
},
}
@router.post("/conversations", status_code=status.HTTP_201_CREATED)
async def create_conversation(
current_user: UserPublic = Depends(get_current_user),
db: AsyncIOMotorDatabase = Depends(get_db),
):
conv = ConversationDB(user_id=current_user.username)
result = await db.conversations.insert_one(conv.dict(exclude={"id"}))
conv_id = str(result.inserted_id)
# Publish conversation created event
await publish_conversation_created(current_user.username, conv_id)
return {"conversation_id": conv_id}
@router.get("/conversations/{conv_id}")
async def get_conversation(
conv_id: str,
current_user: UserPublic = Depends(get_current_user),
db: AsyncIOMotorDatabase = Depends(get_db),
):
try:
oid = ObjectId(conv_id)
except Exception:
raise HTTPException(status_code=400, detail="Invalid conversation ID")
conv = await db.conversations.find_one({"_id": oid, "user_id": current_user.username})
if not conv:
raise HTTPException(status_code=404, detail="Conversation not found")
conv["id"] = str(conv["_id"])
del conv["_id"]
return conv
@router.post("/conversations/{conv_id}/messages")
async def send_message(
conv_id: str,
model: str = Form(...),
enable_web_search: bool = Form(False),
message: str = Form(...),
files: List[UploadFile] = None,
current_user: UserPublic = Depends(get_current_user),
db: AsyncIOMotorDatabase = Depends(get_db),
):
if model not in ALLOWED_MODELS:
raise HTTPException(status_code=400, detail="Invalid model")
try:
oid = ObjectId(conv_id)
except Exception:
raise HTTPException(status_code=400, detail="Invalid conversation ID")
conv = await db.conversations.find_one({"_id": oid, "user_id": current_user.username})
if not conv:
raise HTTPException(status_code=404, detail="Conversation not found")
messages = [Message(**m) for m in conv.get("messages", [])]
rag_context = ""
if files:
rag_context = build_context_from_files(files, message)
system_msg = {
"role": "system",
"content": SYSTEM_PROMPT + (f"\n\nContext:\n{rag_context}" if rag_context else "")
}
user_msg = Message(role="user", content=message)
messages.append(user_msg)
# Publish message sent event
await publish_message_sent(
user_id=current_user.username,
conversation_id=conv_id,
model=model,
message_preview=message,
)
client = Groq(api_key=settings.groq_api_key)
tools = [WEB_SEARCH_TOOL] if enable_web_search else None
chat_history = [system_msg] + [
m if isinstance(m, dict) else m.dict() for m in messages
]
max_tool_loops = 3
for _ in range(max_tool_loops):
completion = client.chat.completions.create(
model=model,
messages=chat_history,
temperature=1,
max_tokens=500,
top_p=1,
stream=False,
stop=None,
tools=tools,
)
choice = completion.choices[0].message
if not choice.tool_calls:
break
for tool_call in choice.tool_calls:
if tool_call.function.name == "web_search":
args = json.loads(tool_call.function.arguments)
result = web_search(args["query"])
chat_history.append({
"role": "tool",
"tool_call_id": tool_call.id,
"name": "web_search",
"content": result,
})
else:
logger.warning("Max tool loops reached")
raise HTTPException(status_code=500, detail="Too many tool calls")
completion = client.chat.completions.create(
model=model,
messages=chat_history,
temperature=1,
max_tokens=500,
top_p=1,
stream=True,
stop=None,
)
async def generate():
response_content = ""
for chunk in completion:
content = chunk.choices[0].delta.content or ""
response_content += content
yield content
messages.append(Message(role="assistant", content=response_content))
await db.conversations.update_one(
{"_id": oid},
{"$set": {
"messages": [m.dict() for m in messages],
"updated_at": datetime.utcnow()
}}
)
return StreamingResponse(generate(), media_type="text/event-stream")