AI_API / features /rag_chatbot /controller.py
Sangyog10's picture
set up rag pipeline for chatbot
29fbb51
raw
history blame
6.1 kB
import os
import asyncio
import logging
from io import BytesIO
from typing import Dict, Any
from fastapi import HTTPException, UploadFile, status, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from .rag_pipeline import route_and_process_query, add_document_to_rag, check_system_health
from .document_handler import extract_text_from_file
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
security = HTTPBearer()
# Supported file types
SUPPORTED_CONTENT_TYPES = {
"application/pdf",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"text/plain"
}
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB
async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""Verify Bearer token from Authorization header."""
token = credentials.credentials
expected_token = os.getenv("MY_SECRET_TOKEN")
if not expected_token:
logger.error("MY_SECRET_TOKEN not configured")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Server configuration error"
)
if token != expected_token:
logger.warning(f"Invalid token attempt: {token[:10]}...")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid or expired token"
)
return token
async def handle_rag_query(query: str) -> Dict[str, Any]:
"""Handle an incoming query by routing it and getting the appropriate answer."""
# Input validation
if not query or not query.strip():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Query cannot be empty"
)
if len(query) > 1000: # Reasonable limit
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Query too long. Please limit to 1000 characters."
)
try:
logger.info(f"Processing query: {query[:50]}...")
# Process query in thread pool
response = await asyncio.to_thread(route_and_process_query, query)
logger.info(f"Query processed successfully. Route: {response.get('route', 'Unknown')}")
return response
except Exception as e:
logger.error(f"Error processing query: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error processing your query. Please try again."
)
async def handle_document_upload(file: UploadFile) -> Dict[str, str]:
"""Handle uploading a document to the RAG's vector store."""
# File validation
if not file.filename:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="No file provided"
)
if file.content_type not in SUPPORTED_CONTENT_TYPES:
raise HTTPException(
status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
detail=f"Unsupported file type: {file.content_type}. "
f"Supported types: {', '.join(SUPPORTED_CONTENT_TYPES)}"
)
# Check file size
contents = await file.read()
if len(contents) > MAX_FILE_SIZE:
raise HTTPException(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
detail=f"File too large. Maximum size: {MAX_FILE_SIZE / (1024*1024):.1f}MB"
)
# Reset file pointer
await file.seek(0)
try:
logger.info(f"Processing file upload: {file.filename}")
# Extract text from file
text = await extract_text_from_file(file)
if not text or not text.strip():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The file appears to be empty or could not be read."
)
if len(text) < 50: # Too short to be meaningful
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The extracted text is too short to be meaningful."
)
# Add to RAG system
success = await asyncio.to_thread(
add_document_to_rag,
text,
{
"source": file.filename,
"content_type": file.content_type,
"size": len(contents)
}
)
if not success:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to add document to the knowledge base"
)
logger.info(f"Successfully processed file: {file.filename}")
return {
"message": f"Successfully uploaded and processed '{file.filename}'. "
f"It is now available for querying.",
"filename": file.filename,
"text_length": len(text),
"content_type": file.content_type
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error processing file {file.filename}: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error processing the file. Please try again."
)
async def handle_health_check() -> Dict[str, Any]:
"""Handle health check requests."""
try:
health_status = await asyncio.to_thread(check_system_health)
if health_status["status"] == "unhealthy":
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Service is currently unhealthy"
)
return health_status
except HTTPException:
raise
except Exception as e:
logger.error(f"Health check failed: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Health check failed"
)