cryogenic22's picture
Rename api/routers/rag,py to api/routers/rag.py
4904b73 verified
from fastapi import APIRouter, HTTPException, Depends
from typing import Dict, Any, Optional
from pydantic import BaseModel
from datetime import datetime
from ...core.rag.engine import RAGEngine, RAGConfig, RAGResponse
router = APIRouter()
class GenerateRequest(BaseModel):
"""Request model for RAG generation"""
query: str
context: Optional[Dict[str, Any]] = None
config: Optional[Dict[str, Any]] = None
@router.post("/generate", response_model=RAGResponse)
async def generate(request: GenerateRequest):
"""Generate response using RAG"""
try:
# Initialize RAG engine
rag = RAGEngine(request.config)
await rag.initialize()
# Generate response
response = await rag.generate(
query=request.query,
context=request.context
)
return response
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"RAG generation failed: {str(e)}"
)
@router.post("/batch", response_model=Dict[str, RAGResponse])
async def batch_generate(requests: Dict[str, GenerateRequest]):
"""Batch generate responses"""
try:
# Initialize RAG engine
rag = RAGEngine()
await rag.initialize()
# Process batch
results = {}
for key, request in requests.items():
results[key] = await rag.generate(
query=request.query,
context=request.context
)
return results
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Batch generation failed: {str(e)}"
)
@router.get("/config/validate")
async def validate_config(config: Dict[str, Any]):
"""Validate RAG configuration"""
try:
rag = RAGEngine(config)
is_valid = await rag.validate_config()
return {
"valid": is_valid,
"config": config
}
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Invalid configuration: {str(e)}"
)
@router.get("/health")
async def health_check():
"""Check RAG component health"""
try:
rag = RAGEngine()
await rag.initialize()
return {
"status": "healthy",
"timestamp": datetime.now()
}
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Health check failed: {str(e)}"
)