|
|
"""FastAPI backend service for RAG application.""" |
|
|
from fastapi import FastAPI, HTTPException, BackgroundTasks |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel, Field |
|
|
from typing import List, Optional, Dict |
|
|
import uvicorn |
|
|
from datetime import datetime |
|
|
import os |
|
|
|
|
|
from config import settings |
|
|
from dataset_loader import RAGBenchLoader |
|
|
from vector_store import ChromaDBManager |
|
|
from llm_client import GroqLLMClient, RAGPipeline |
|
|
from trace_evaluator import TRACEEvaluator |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="RAG Capstone API", |
|
|
description="API for RAG system with TRACE evaluation", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
rag_pipeline: Optional[RAGPipeline] = None |
|
|
vector_store: Optional[ChromaDBManager] = None |
|
|
current_collection: Optional[str] = None |
|
|
|
|
|
|
|
|
|
|
|
class DatasetLoadRequest(BaseModel): |
|
|
"""Request model for loading dataset.""" |
|
|
dataset_name: str = Field(..., description="Name of the dataset") |
|
|
num_samples: int = Field(50, description="Number of samples to load") |
|
|
chunking_strategy: str = Field("hybrid", description="Chunking strategy") |
|
|
chunk_size: int = Field(512, description="Size of chunks") |
|
|
overlap: int = Field(50, description="Overlap between chunks") |
|
|
embedding_model: str = Field(..., description="Embedding model name") |
|
|
llm_model: str = Field("llama-3.1-8b-instant", description="LLM model name") |
|
|
groq_api_key: str = Field(..., description="Groq API key") |
|
|
|
|
|
|
|
|
class QueryRequest(BaseModel): |
|
|
"""Request model for querying.""" |
|
|
query: str = Field(..., description="User query") |
|
|
n_results: int = Field(5, description="Number of documents to retrieve") |
|
|
max_tokens: int = Field(1024, description="Maximum tokens to generate") |
|
|
temperature: float = Field(0.7, description="Sampling temperature") |
|
|
|
|
|
|
|
|
class QueryResponse(BaseModel): |
|
|
"""Response model for query.""" |
|
|
query: str |
|
|
response: str |
|
|
retrieved_documents: List[Dict] |
|
|
timestamp: str |
|
|
|
|
|
|
|
|
class EvaluationRequest(BaseModel): |
|
|
"""Request model for evaluation.""" |
|
|
num_samples: int = Field(10, description="Number of test samples") |
|
|
|
|
|
|
|
|
class CollectionInfo(BaseModel): |
|
|
"""Collection information model.""" |
|
|
name: str |
|
|
count: int |
|
|
metadata: Dict |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""Root endpoint.""" |
|
|
return { |
|
|
"message": "RAG Capstone API", |
|
|
"version": "1.0.0", |
|
|
"docs": "/docs" |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
"""Health check endpoint.""" |
|
|
return { |
|
|
"status": "healthy", |
|
|
"timestamp": datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/datasets") |
|
|
async def list_datasets(): |
|
|
"""List available datasets.""" |
|
|
return { |
|
|
"datasets": settings.ragbench_datasets |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/models/embedding") |
|
|
async def list_embedding_models(): |
|
|
"""List available embedding models.""" |
|
|
return { |
|
|
"embedding_models": settings.embedding_models |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/models/llm") |
|
|
async def list_llm_models(): |
|
|
"""List available LLM models.""" |
|
|
return { |
|
|
"llm_models": settings.llm_models |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/chunking-strategies") |
|
|
async def list_chunking_strategies(): |
|
|
"""List available chunking strategies.""" |
|
|
return { |
|
|
"chunking_strategies": settings.chunking_strategies |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/collections") |
|
|
async def list_collections(): |
|
|
"""List all vector store collections.""" |
|
|
global vector_store |
|
|
|
|
|
if not vector_store: |
|
|
vector_store = ChromaDBManager(settings.chroma_persist_directory) |
|
|
|
|
|
collections = vector_store.list_collections() |
|
|
|
|
|
return { |
|
|
"collections": collections, |
|
|
"count": len(collections) |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/collections/{collection_name}") |
|
|
async def get_collection_info(collection_name: str): |
|
|
"""Get information about a specific collection.""" |
|
|
global vector_store |
|
|
|
|
|
if not vector_store: |
|
|
vector_store = ChromaDBManager(settings.chroma_persist_directory) |
|
|
|
|
|
try: |
|
|
stats = vector_store.get_collection_stats(collection_name) |
|
|
return stats |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=404, detail=f"Collection not found: {str(e)}") |
|
|
|
|
|
|
|
|
@app.post("/load-dataset") |
|
|
async def load_dataset(request: DatasetLoadRequest, background_tasks: BackgroundTasks): |
|
|
"""Load dataset and create vector collection.""" |
|
|
global rag_pipeline, vector_store, current_collection |
|
|
|
|
|
try: |
|
|
|
|
|
loader = RAGBenchLoader() |
|
|
|
|
|
|
|
|
dataset = loader.load_dataset( |
|
|
request.dataset_name, |
|
|
split="train", |
|
|
max_samples=request.num_samples |
|
|
) |
|
|
|
|
|
if not dataset: |
|
|
raise HTTPException(status_code=400, detail="Failed to load dataset") |
|
|
|
|
|
|
|
|
vector_store = ChromaDBManager(settings.chroma_persist_directory) |
|
|
|
|
|
|
|
|
collection_name = f"{request.dataset_name}_{request.chunking_strategy}_{request.embedding_model.split('/')[-1]}" |
|
|
collection_name = collection_name.replace("-", "_").replace(".", "_") |
|
|
|
|
|
|
|
|
vector_store.load_dataset_into_collection( |
|
|
collection_name=collection_name, |
|
|
embedding_model_name=request.embedding_model, |
|
|
chunking_strategy=request.chunking_strategy, |
|
|
dataset_data=dataset, |
|
|
chunk_size=request.chunk_size, |
|
|
overlap=request.overlap |
|
|
) |
|
|
|
|
|
|
|
|
llm_client = GroqLLMClient( |
|
|
api_key=request.groq_api_key, |
|
|
model_name=request.llm_model, |
|
|
max_rpm=settings.groq_rpm_limit, |
|
|
rate_limit_delay=settings.rate_limit_delay |
|
|
) |
|
|
|
|
|
|
|
|
rag_pipeline = RAGPipeline(llm_client, vector_store) |
|
|
current_collection = collection_name |
|
|
|
|
|
return { |
|
|
"status": "success", |
|
|
"collection_name": collection_name, |
|
|
"num_documents": len(dataset), |
|
|
"message": f"Collection '{collection_name}' created successfully" |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Error loading dataset: {str(e)}") |
|
|
|
|
|
|
|
|
@app.post("/query", response_model=QueryResponse) |
|
|
async def query_rag(request: QueryRequest): |
|
|
"""Query the RAG system.""" |
|
|
global rag_pipeline |
|
|
|
|
|
if not rag_pipeline: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail="RAG pipeline not initialized. Load a dataset first." |
|
|
) |
|
|
|
|
|
try: |
|
|
result = rag_pipeline.query( |
|
|
query=request.query, |
|
|
n_results=request.n_results, |
|
|
max_tokens=request.max_tokens, |
|
|
temperature=request.temperature |
|
|
) |
|
|
|
|
|
result["timestamp"] = datetime.now().isoformat() |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}") |
|
|
|
|
|
|
|
|
@app.get("/chat-history") |
|
|
async def get_chat_history(): |
|
|
"""Get chat history.""" |
|
|
global rag_pipeline |
|
|
|
|
|
if not rag_pipeline: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail="RAG pipeline not initialized. Load a dataset first." |
|
|
) |
|
|
|
|
|
return { |
|
|
"history": rag_pipeline.get_chat_history() |
|
|
} |
|
|
|
|
|
|
|
|
@app.delete("/chat-history") |
|
|
async def clear_chat_history(): |
|
|
"""Clear chat history.""" |
|
|
global rag_pipeline |
|
|
|
|
|
if not rag_pipeline: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail="RAG pipeline not initialized. Load a dataset first." |
|
|
) |
|
|
|
|
|
rag_pipeline.clear_history() |
|
|
|
|
|
return { |
|
|
"status": "success", |
|
|
"message": "Chat history cleared" |
|
|
} |
|
|
|
|
|
|
|
|
@app.post("/evaluate") |
|
|
async def run_evaluation(request: EvaluationRequest): |
|
|
"""Run TRACE evaluation.""" |
|
|
global rag_pipeline, current_collection |
|
|
|
|
|
if not rag_pipeline: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail="RAG pipeline not initialized. Load a dataset first." |
|
|
) |
|
|
|
|
|
try: |
|
|
|
|
|
collection_metadata = vector_store.current_collection.metadata |
|
|
dataset_name = current_collection.split("_")[0] if current_collection else "hotpotqa" |
|
|
|
|
|
|
|
|
loader = RAGBenchLoader() |
|
|
test_data = loader.get_test_data(dataset_name, request.num_samples) |
|
|
|
|
|
|
|
|
test_cases = [] |
|
|
|
|
|
for sample in test_data: |
|
|
result = rag_pipeline.query(sample["question"], n_results=5) |
|
|
|
|
|
test_cases.append({ |
|
|
"query": sample["question"], |
|
|
"response": result["response"], |
|
|
"retrieved_documents": [doc["document"] for doc in result["retrieved_documents"]], |
|
|
"ground_truth": sample.get("answer", "") |
|
|
}) |
|
|
|
|
|
|
|
|
evaluator = TRACEEvaluator() |
|
|
results = evaluator.evaluate_batch(test_cases) |
|
|
|
|
|
return { |
|
|
"status": "success", |
|
|
"results": results |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Error during evaluation: {str(e)}") |
|
|
|
|
|
|
|
|
@app.delete("/collections/{collection_name}") |
|
|
async def delete_collection(collection_name: str): |
|
|
"""Delete a collection.""" |
|
|
global vector_store |
|
|
|
|
|
if not vector_store: |
|
|
vector_store = ChromaDBManager(settings.chroma_persist_directory) |
|
|
|
|
|
try: |
|
|
vector_store.delete_collection(collection_name) |
|
|
return { |
|
|
"status": "success", |
|
|
"message": f"Collection '{collection_name}' deleted" |
|
|
} |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Error deleting collection: {str(e)}") |
|
|
|
|
|
|
|
|
@app.get("/current-collection") |
|
|
async def get_current_collection(): |
|
|
"""Get current collection information.""" |
|
|
global current_collection, vector_store |
|
|
|
|
|
if not current_collection: |
|
|
return { |
|
|
"collection": None, |
|
|
"message": "No collection loaded" |
|
|
} |
|
|
|
|
|
try: |
|
|
stats = vector_store.get_collection_stats(current_collection) |
|
|
return { |
|
|
"collection": current_collection, |
|
|
"stats": stats |
|
|
} |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Error getting collection info: {str(e)}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run( |
|
|
"api:app", |
|
|
host="0.0.0.0", |
|
|
port=8000, |
|
|
reload=True |
|
|
) |
|
|
|