Spaces:
Running
Running
| """FastAPI backend service for RAG application.""" | |
| from fastapi import FastAPI, HTTPException, BackgroundTasks | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from typing import List, Optional, Dict | |
| import uvicorn | |
| from datetime import datetime | |
| import os | |
| from config import settings | |
| from dataset_loader import RAGBenchLoader | |
| from vector_store import ChromaDBManager | |
| from llm_client import GroqLLMClient, RAGPipeline | |
| from trace_evaluator import TRACEEvaluator | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="RAG Capstone API", | |
| description="API for RAG system with TRACE evaluation", | |
| version="1.0.0" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global state | |
| rag_pipeline: Optional[RAGPipeline] = None | |
| vector_store: Optional[ChromaDBManager] = None | |
| current_collection: Optional[str] = None | |
| # Request/Response models | |
| class DatasetLoadRequest(BaseModel): | |
| """Request model for loading dataset.""" | |
| dataset_name: str = Field(..., description="Name of the dataset") | |
| num_samples: int = Field(50, description="Number of samples to load") | |
| chunking_strategy: str = Field("hybrid", description="Chunking strategy") | |
| chunk_size: int = Field(512, description="Size of chunks") | |
| overlap: int = Field(50, description="Overlap between chunks") | |
| embedding_model: str = Field(..., description="Embedding model name") | |
| llm_model: str = Field("llama-3.1-8b-instant", description="LLM model name") | |
| groq_api_key: str = Field(..., description="Groq API key") | |
| class QueryRequest(BaseModel): | |
| """Request model for querying.""" | |
| query: str = Field(..., description="User query") | |
| n_results: int = Field(5, description="Number of documents to retrieve") | |
| max_tokens: int = Field(1024, description="Maximum tokens to generate") | |
| temperature: float = Field(0.7, description="Sampling temperature") | |
| class QueryResponse(BaseModel): | |
| """Response model for query.""" | |
| query: str | |
| response: str | |
| retrieved_documents: List[Dict] | |
| timestamp: str | |
| class EvaluationRequest(BaseModel): | |
| """Request model for evaluation.""" | |
| num_samples: int = Field(10, description="Number of test samples") | |
| class CollectionInfo(BaseModel): | |
| """Collection information model.""" | |
| name: str | |
| count: int | |
| metadata: Dict | |
| # API endpoints | |
| async def root(): | |
| """Root endpoint.""" | |
| return { | |
| "message": "RAG Capstone API", | |
| "version": "1.0.0", | |
| "docs": "/docs" | |
| } | |
| async def health_check(): | |
| """Health check endpoint.""" | |
| return { | |
| "status": "healthy", | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| async def list_datasets(): | |
| """List available datasets.""" | |
| return { | |
| "datasets": settings.ragbench_datasets | |
| } | |
| async def list_embedding_models(): | |
| """List available embedding models.""" | |
| return { | |
| "embedding_models": settings.embedding_models | |
| } | |
| async def list_llm_models(): | |
| """List available LLM models.""" | |
| return { | |
| "llm_models": settings.llm_models | |
| } | |
| async def list_chunking_strategies(): | |
| """List available chunking strategies.""" | |
| return { | |
| "chunking_strategies": settings.chunking_strategies | |
| } | |
| async def list_collections(): | |
| """List all vector store collections.""" | |
| global vector_store | |
| if not vector_store: | |
| vector_store = ChromaDBManager(settings.chroma_persist_directory) | |
| collections = vector_store.list_collections() | |
| return { | |
| "collections": collections, | |
| "count": len(collections) | |
| } | |
| async def get_collection_info(collection_name: str): | |
| """Get information about a specific collection.""" | |
| global vector_store | |
| if not vector_store: | |
| vector_store = ChromaDBManager(settings.chroma_persist_directory) | |
| try: | |
| stats = vector_store.get_collection_stats(collection_name) | |
| return stats | |
| except Exception as e: | |
| raise HTTPException(status_code=404, detail=f"Collection not found: {str(e)}") | |
| async def load_dataset(request: DatasetLoadRequest, background_tasks: BackgroundTasks): | |
| """Load dataset and create vector collection.""" | |
| global rag_pipeline, vector_store, current_collection | |
| try: | |
| # Initialize dataset loader | |
| loader = RAGBenchLoader() | |
| # Load dataset | |
| dataset = loader.load_dataset( | |
| request.dataset_name, | |
| split="train", | |
| max_samples=request.num_samples | |
| ) | |
| if not dataset: | |
| raise HTTPException(status_code=400, detail="Failed to load dataset") | |
| # Initialize vector store | |
| vector_store = ChromaDBManager(settings.chroma_persist_directory) | |
| # Create collection name | |
| collection_name = f"{request.dataset_name}_{request.chunking_strategy}_{request.embedding_model.split('/')[-1]}" | |
| collection_name = collection_name.replace("-", "_").replace(".", "_") | |
| # Load data into collection | |
| vector_store.load_dataset_into_collection( | |
| collection_name=collection_name, | |
| embedding_model_name=request.embedding_model, | |
| chunking_strategy=request.chunking_strategy, | |
| dataset_data=dataset, | |
| chunk_size=request.chunk_size, | |
| overlap=request.overlap | |
| ) | |
| # Initialize LLM client | |
| llm_client = GroqLLMClient( | |
| api_key=request.groq_api_key, | |
| model_name=request.llm_model, | |
| max_rpm=settings.groq_rpm_limit, | |
| rate_limit_delay=settings.rate_limit_delay | |
| ) | |
| # Create RAG pipeline | |
| rag_pipeline = RAGPipeline(llm_client, vector_store) | |
| current_collection = collection_name | |
| return { | |
| "status": "success", | |
| "collection_name": collection_name, | |
| "num_documents": len(dataset), | |
| "message": f"Collection '{collection_name}' created successfully" | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error loading dataset: {str(e)}") | |
| async def query_rag(request: QueryRequest): | |
| """Query the RAG system.""" | |
| global rag_pipeline | |
| if not rag_pipeline: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="RAG pipeline not initialized. Load a dataset first." | |
| ) | |
| try: | |
| result = rag_pipeline.query( | |
| query=request.query, | |
| n_results=request.n_results, | |
| max_tokens=request.max_tokens, | |
| temperature=request.temperature | |
| ) | |
| result["timestamp"] = datetime.now().isoformat() | |
| return result | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}") | |
| async def get_chat_history(): | |
| """Get chat history.""" | |
| global rag_pipeline | |
| if not rag_pipeline: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="RAG pipeline not initialized. Load a dataset first." | |
| ) | |
| return { | |
| "history": rag_pipeline.get_chat_history() | |
| } | |
| async def clear_chat_history(): | |
| """Clear chat history.""" | |
| global rag_pipeline | |
| if not rag_pipeline: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="RAG pipeline not initialized. Load a dataset first." | |
| ) | |
| rag_pipeline.clear_history() | |
| return { | |
| "status": "success", | |
| "message": "Chat history cleared" | |
| } | |
| async def run_evaluation(request: EvaluationRequest): | |
| """Run TRACE evaluation.""" | |
| global rag_pipeline, current_collection | |
| if not rag_pipeline: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="RAG pipeline not initialized. Load a dataset first." | |
| ) | |
| try: | |
| # Get dataset name from collection metadata | |
| collection_metadata = vector_store.current_collection.metadata | |
| dataset_name = current_collection.split("_")[0] if current_collection else "hotpotqa" | |
| # Get test data | |
| loader = RAGBenchLoader() | |
| test_data = loader.get_test_data(dataset_name, request.num_samples) | |
| # Prepare test cases | |
| test_cases = [] | |
| for sample in test_data: | |
| result = rag_pipeline.query(sample["question"], n_results=5) | |
| test_cases.append({ | |
| "query": sample["question"], | |
| "response": result["response"], | |
| "retrieved_documents": [doc["document"] for doc in result["retrieved_documents"]], | |
| "ground_truth": sample.get("answer", "") | |
| }) | |
| # Run evaluation | |
| evaluator = TRACEEvaluator() | |
| results = evaluator.evaluate_batch(test_cases) | |
| return { | |
| "status": "success", | |
| "results": results | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error during evaluation: {str(e)}") | |
| async def delete_collection(collection_name: str): | |
| """Delete a collection.""" | |
| global vector_store | |
| if not vector_store: | |
| vector_store = ChromaDBManager(settings.chroma_persist_directory) | |
| try: | |
| vector_store.delete_collection(collection_name) | |
| return { | |
| "status": "success", | |
| "message": f"Collection '{collection_name}' deleted" | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error deleting collection: {str(e)}") | |
| async def get_current_collection(): | |
| """Get current collection information.""" | |
| global current_collection, vector_store | |
| if not current_collection: | |
| return { | |
| "collection": None, | |
| "message": "No collection loaded" | |
| } | |
| try: | |
| stats = vector_store.get_collection_stats(current_collection) | |
| return { | |
| "collection": current_collection, | |
| "stats": stats | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error getting collection info: {str(e)}") | |
| if __name__ == "__main__": | |
| uvicorn.run( | |
| "api:app", | |
| host="0.0.0.0", | |
| port=8000, | |
| reload=True | |
| ) | |