| | """FastAPI application for EyeWiki RAG system.""" |
| |
|
| | import logging |
| | import time |
| | from contextlib import asynccontextmanager |
| | from pathlib import Path |
| | from typing import Optional |
| |
|
| | from fastapi import FastAPI, HTTPException, Request, status |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from fastapi.responses import StreamingResponse |
| | from pydantic import BaseModel, Field |
| | import gradio as gr |
| |
|
| | from src.api.gradio_ui import create_gradio_interface |
| | from config.settings import LLMProvider, Settings |
| | from src.llm.llm_client import LLMClient |
| | from src.llm.ollama_client import OllamaClient |
| | from src.llm.openai_client import OpenAIClient |
| | from src.llm.sentence_transformer_client import SentenceTransformerClient |
| | from src.rag.query_engine import EyeWikiQueryEngine, QueryResponse |
| | from src.rag.reranker import CrossEncoderReranker |
| | from src.rag.retriever import HybridRetriever |
| | from src.vectorstore.qdrant_store import QdrantStoreManager |
| |
|
| |
|
| | |
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" |
| | ) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class QueryRequest(BaseModel): |
| | """ |
| | Request model for query endpoint. |
| | |
| | Attributes: |
| | question: User's question |
| | include_sources: Whether to include source information |
| | filters: Optional metadata filters (disease_name, icd_codes, etc.) |
| | """ |
| | question: str = Field(..., min_length=3, description="User's question") |
| | include_sources: bool = Field(default=True, description="Include source documents") |
| | filters: Optional[dict] = Field(default=None, description="Metadata filters") |
| |
|
| |
|
| | class StreamQueryRequest(BaseModel): |
| | """ |
| | Request model for streaming query endpoint. |
| | |
| | Attributes: |
| | question: User's question |
| | filters: Optional metadata filters |
| | """ |
| | question: str = Field(..., min_length=3, description="User's question") |
| | filters: Optional[dict] = Field(default=None, description="Metadata filters") |
| |
|
| |
|
| | class HealthResponse(BaseModel): |
| | """ |
| | Response model for health check. |
| | |
| | Attributes: |
| | status: Overall status (healthy/unhealthy) |
| | llm: LLM service status |
| | qdrant: Qdrant service status |
| | query_engine: Query engine initialization status |
| | timestamp: Check timestamp |
| | """ |
| | status: str = Field(..., description="Overall status") |
| | llm: dict = Field(..., description="LLM service status") |
| | qdrant: dict = Field(..., description="Qdrant service status") |
| | query_engine: dict = Field(..., description="Query engine status") |
| | timestamp: float = Field(..., description="Unix timestamp") |
| |
|
| |
|
| | class StatsResponse(BaseModel): |
| | """ |
| | Response model for statistics endpoint. |
| | |
| | Attributes: |
| | collection_info: Qdrant collection information |
| | pipeline_config: Query engine pipeline configuration |
| | documents_indexed: Number of indexed documents |
| | timestamp: Stats timestamp |
| | """ |
| | collection_info: dict = Field(..., description="Collection information") |
| | pipeline_config: dict = Field(..., description="Pipeline configuration") |
| | documents_indexed: int = Field(..., description="Number of indexed documents") |
| | timestamp: float = Field(..., description="Unix timestamp") |
| |
|
| |
|
| | class ErrorResponse(BaseModel): |
| | """ |
| | Error response model. |
| | |
| | Attributes: |
| | error: Error message |
| | detail: Optional detailed error information |
| | timestamp: Error timestamp |
| | """ |
| | error: str = Field(..., description="Error message") |
| | detail: Optional[str] = Field(default=None, description="Error details") |
| | timestamp: float = Field(..., description="Unix timestamp") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class AppState: |
| | """Application state container.""" |
| |
|
| | def __init__(self): |
| | self.settings: Optional[Settings] = None |
| | self.llm_client: Optional[LLMClient] = None |
| | self.embedding_client: Optional[SentenceTransformerClient] = None |
| | self.qdrant_manager: Optional[QdrantStoreManager] = None |
| | self.retriever: Optional[HybridRetriever] = None |
| | self.reranker: Optional[CrossEncoderReranker] = None |
| | self.query_engine: Optional[EyeWikiQueryEngine] = None |
| | self.initialized: bool = False |
| | self.initialization_error: Optional[str] = None |
| |
|
| |
|
| | app_state = AppState() |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @asynccontextmanager |
| | async def lifespan(app: FastAPI): |
| | """ |
| | Application lifespan manager. |
| | |
| | Handles startup and shutdown events. |
| | """ |
| | |
| | logger.info("Starting EyeWiki RAG API...") |
| |
|
| | try: |
| | |
| | logger.info("Loading settings...") |
| | app_state.settings = Settings() |
| |
|
| | |
| | logger.info(f"Initializing LLM client (provider: {app_state.settings.llm_provider.value})...") |
| | if app_state.settings.llm_provider == LLMProvider.OPENAI: |
| | app_state.llm_client = OpenAIClient( |
| | api_key=app_state.settings.openai_api_key, |
| | base_url=app_state.settings.openai_base_url, |
| | model=app_state.settings.openai_model, |
| | ) |
| | else: |
| | app_state.llm_client = OllamaClient( |
| | base_url=app_state.settings.ollama_base_url, |
| | embedding_model=None, |
| | llm_model=app_state.settings.llm_model, |
| | timeout=app_state.settings.ollama_timeout, |
| | ) |
| |
|
| | |
| | logger.info("Initializing embedding client...") |
| | app_state.embedding_client = SentenceTransformerClient( |
| | model_name=app_state.settings.embedding_model, |
| | ) |
| | logger.info(f"Embedding model loaded: {app_state.settings.embedding_model}") |
| |
|
| | |
| | logger.info("Initializing Qdrant manager...") |
| | app_state.qdrant_manager = QdrantStoreManager( |
| | collection_name=app_state.settings.qdrant_collection_name, |
| | path=app_state.settings.qdrant_path, |
| | url=app_state.settings.qdrant_url, |
| | api_key=app_state.settings.qdrant_api_key, |
| | embedding_dim=app_state.embedding_client.embedding_dim, |
| | ) |
| |
|
| | |
| | collection_info = app_state.qdrant_manager.get_collection_info() |
| | if not collection_info: |
| | raise RuntimeError( |
| | f"Qdrant collection '{app_state.settings.qdrant_collection_name}' not found. " |
| | "Please run 'python scripts/build_index.py --index-vectors' first." |
| | ) |
| |
|
| | logger.info( |
| | f"Qdrant collection loaded: {collection_info['vectors_count']} vectors" |
| | ) |
| |
|
| | |
| | logger.info("Initializing retriever...") |
| | app_state.retriever = HybridRetriever( |
| | qdrant_manager=app_state.qdrant_manager, |
| | embedding_client=app_state.embedding_client, |
| | ) |
| |
|
| | |
| | logger.info("Initializing reranker...") |
| | app_state.reranker = CrossEncoderReranker( |
| | model_name=app_state.settings.reranker_model, |
| | ) |
| |
|
| | |
| | project_root = Path(__file__).parent.parent.parent |
| | prompts_dir = project_root / "prompts" |
| |
|
| | system_prompt_path = prompts_dir / "system_prompt.txt" |
| | query_prompt_path = prompts_dir / "query_prompt.txt" |
| | disclaimer_path = prompts_dir / "medical_disclaimer.txt" |
| |
|
| | |
| | if not system_prompt_path.exists(): |
| | logger.warning(f"System prompt not found: {system_prompt_path}") |
| | system_prompt_path = None |
| |
|
| | if not query_prompt_path.exists(): |
| | logger.warning(f"Query prompt not found: {query_prompt_path}") |
| | query_prompt_path = None |
| |
|
| | if not disclaimer_path.exists(): |
| | logger.warning(f"Disclaimer not found: {disclaimer_path}") |
| | disclaimer_path = None |
| |
|
| | |
| | logger.info("Initializing query engine...") |
| | app_state.query_engine = EyeWikiQueryEngine( |
| | retriever=app_state.retriever, |
| | reranker=app_state.reranker, |
| | llm_client=app_state.llm_client, |
| | system_prompt_path=system_prompt_path, |
| | query_prompt_path=query_prompt_path, |
| | disclaimer_path=disclaimer_path, |
| | max_context_tokens=app_state.settings.max_context_tokens, |
| | retrieval_k=20, |
| | rerank_k=5, |
| | ) |
| |
|
| | app_state.initialized = True |
| | logger.info("EyeWiki RAG API started successfully") |
| | logger.info("Gradio UI available at /ui") |
| |
|
| | except Exception as e: |
| | error_msg = f"Failed to initialize application: {e}" |
| | logger.error(error_msg, exc_info=True) |
| | app_state.initialization_error = error_msg |
| | |
| |
|
| | yield |
| |
|
| | |
| | logger.info("Shutting down EyeWiki RAG API...") |
| |
|
| | |
| | if app_state.qdrant_manager: |
| | try: |
| | app_state.qdrant_manager.close() |
| | logger.info("Qdrant client closed") |
| | except Exception as e: |
| | logger.error(f"Error closing Qdrant client: {e}") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | app = FastAPI( |
| | title="EyeWiki RAG API", |
| | description="Retrieval-Augmented Generation API for EyeWiki medical knowledge base", |
| | version="1.0.0", |
| | lifespan=lifespan, |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| |
|
| | @app.middleware("http") |
| | async def log_requests(request: Request, call_next): |
| | """ |
| | Request logging middleware. |
| | |
| | Logs all incoming requests with timing information. |
| | """ |
| | start_time = time.time() |
| |
|
| | |
| | logger.info( |
| | f"Request: {request.method} {request.url.path} " |
| | f"from {request.client.host if request.client else 'unknown'}" |
| | ) |
| |
|
| | |
| | response = await call_next(request) |
| |
|
| | |
| | duration = time.time() - start_time |
| | logger.info( |
| | f"Response: {response.status_code} " |
| | f"in {duration:.3f}s" |
| | ) |
| |
|
| | return response |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def check_initialization(): |
| | """ |
| | Check if application is initialized. |
| | |
| | Raises: |
| | HTTPException: If app not initialized |
| | """ |
| | if not app_state.initialized: |
| | error_detail = app_state.initialization_error or "Application not initialized" |
| | raise HTTPException( |
| | status_code=status.HTTP_503_SERVICE_UNAVAILABLE, |
| | detail=error_detail |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @app.get("/") |
| | async def root(): |
| | """ |
| | Root endpoint. |
| | |
| | Returns: |
| | Welcome message with API information |
| | """ |
| | return { |
| | "name": "EyeWiki RAG API", |
| | "version": "1.0.0", |
| | "description": "Retrieval-Augmented Generation API for EyeWiki medical knowledge base", |
| | "endpoints": { |
| | "health": "GET /health", |
| | "query": "POST /query", |
| | "stream": "POST /query/stream", |
| | "stats": "GET /stats", |
| | "docs": "GET /docs", |
| | } |
| | } |
| |
|
| |
|
| | @app.get("/health", response_model=HealthResponse) |
| | async def health_check(): |
| | """ |
| | Health check endpoint. |
| | |
| | Checks status of: |
| | - Ollama service |
| | - Qdrant service |
| | - Query engine initialization |
| | |
| | Returns: |
| | HealthResponse with service statuses |
| | """ |
| | timestamp = time.time() |
| |
|
| | |
| | llm_status = {"status": "unknown", "detail": None} |
| | if app_state.llm_client: |
| | provider = app_state.settings.llm_provider.value if app_state.settings else "unknown" |
| | llm_status["provider"] = provider |
| | try: |
| | if isinstance(app_state.llm_client, OllamaClient): |
| | health_ok = app_state.llm_client.check_health() |
| | llm_status["status"] = "healthy" if health_ok else "unhealthy" |
| | llm_status["model"] = app_state.llm_client.llm_model |
| | else: |
| | |
| | llm_status["status"] = "healthy" |
| | llm_status["model"] = app_state.llm_client.llm_model |
| | except Exception as e: |
| | llm_status = {"status": "unhealthy", "detail": str(e), "provider": provider} |
| | else: |
| | llm_status = {"status": "not_initialized", "detail": "Client not created"} |
| |
|
| | |
| | qdrant_status = {"status": "unknown", "detail": None} |
| | if app_state.qdrant_manager: |
| | try: |
| | info = app_state.qdrant_manager.get_collection_info() |
| | if info: |
| | qdrant_status = { |
| | "status": "healthy", |
| | "collection": info["name"], |
| | "vectors_count": info["vectors_count"], |
| | } |
| | else: |
| | qdrant_status = { |
| | "status": "unhealthy", |
| | "detail": "Collection not found" |
| | } |
| | except Exception as e: |
| | qdrant_status = {"status": "unhealthy", "detail": str(e)} |
| | else: |
| | qdrant_status = {"status": "not_initialized", "detail": "Manager not created"} |
| |
|
| | |
| | query_engine_status = { |
| | "status": "initialized" if app_state.initialized else "not_initialized", |
| | "error": app_state.initialization_error, |
| | } |
| |
|
| | |
| | overall_status = "healthy" |
| | if not app_state.initialized: |
| | overall_status = "unhealthy" |
| | elif llm_status["status"] != "healthy" or qdrant_status["status"] != "healthy": |
| | overall_status = "degraded" |
| |
|
| | return HealthResponse( |
| | status=overall_status, |
| | llm=llm_status, |
| | qdrant=qdrant_status, |
| | query_engine=query_engine_status, |
| | timestamp=timestamp, |
| | ) |
| |
|
| |
|
| | @app.post("/query", response_model=QueryResponse) |
| | async def query(request: QueryRequest): |
| | """ |
| | Main query endpoint. |
| | |
| | Processes a question using the full RAG pipeline: |
| | 1. Retrieval (hybrid search) |
| | 2. Reranking (cross-encoder) |
| | 3. Context assembly |
| | 4. LLM generation |
| | |
| | Args: |
| | request: QueryRequest with question and options |
| | |
| | Returns: |
| | QueryResponse with answer, sources, and disclaimer |
| | |
| | Raises: |
| | HTTPException: If service unavailable or query fails |
| | """ |
| | check_initialization() |
| |
|
| | try: |
| | logger.info(f"Processing query: '{request.question}'") |
| |
|
| | response = app_state.query_engine.query( |
| | question=request.question, |
| | include_sources=request.include_sources, |
| | filters=request.filters, |
| | ) |
| |
|
| | logger.info( |
| | f"Query complete: {len(response.sources)} sources, " |
| | f"confidence: {response.confidence:.2f}" |
| | ) |
| |
|
| | return response |
| |
|
| | except Exception as e: |
| | logger.error(f"Error processing query: {e}", exc_info=True) |
| | raise HTTPException( |
| | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| | detail=f"Error processing query: {str(e)}" |
| | ) |
| |
|
| |
|
| | @app.post("/query/stream") |
| | async def stream_query(request: StreamQueryRequest): |
| | """ |
| | Streaming query endpoint. |
| | |
| | Returns answer as Server-Sent Events (SSE) for real-time streaming. |
| | |
| | Args: |
| | request: StreamQueryRequest with question and options |
| | |
| | Returns: |
| | StreamingResponse with SSE |
| | |
| | Raises: |
| | HTTPException: If service unavailable or query fails |
| | """ |
| | check_initialization() |
| |
|
| | async def generate(): |
| | """Generate SSE stream.""" |
| | try: |
| | logger.info(f"Processing streaming query: '{request.question}'") |
| |
|
| | |
| | for chunk in app_state.query_engine.stream_query( |
| | question=request.question, |
| | filters=request.filters, |
| | ): |
| | |
| | yield f"data: {chunk}\n\n" |
| |
|
| | logger.info("Streaming query complete") |
| |
|
| | except Exception as e: |
| | logger.error(f"Error in streaming query: {e}", exc_info=True) |
| | yield f"data: [ERROR] {str(e)}\n\n" |
| |
|
| | return StreamingResponse( |
| | generate(), |
| | media_type="text/event-stream", |
| | headers={ |
| | "Cache-Control": "no-cache", |
| | "Connection": "keep-alive", |
| | } |
| | ) |
| |
|
| |
|
| | @app.get("/stats", response_model=StatsResponse) |
| | async def get_stats(): |
| | """ |
| | Get index and pipeline statistics. |
| | |
| | Returns: |
| | StatsResponse with collection info and pipeline config |
| | |
| | Raises: |
| | HTTPException: If service unavailable or stats retrieval fails |
| | """ |
| | check_initialization() |
| |
|
| | try: |
| | |
| | collection_info = app_state.qdrant_manager.get_collection_info() |
| | if not collection_info: |
| | raise HTTPException( |
| | status_code=status.HTTP_404_NOT_FOUND, |
| | detail="Collection not found" |
| | ) |
| |
|
| | |
| | pipeline_config = app_state.query_engine.get_pipeline_info() |
| |
|
| | return StatsResponse( |
| | collection_info=collection_info, |
| | pipeline_config=pipeline_config, |
| | documents_indexed=collection_info.get("vectors_count", 0), |
| | timestamp=time.time(), |
| | ) |
| |
|
| | except HTTPException: |
| | raise |
| | except Exception as e: |
| | logger.error(f"Error retrieving stats: {e}", exc_info=True) |
| | raise HTTPException( |
| | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| | detail=f"Error retrieving stats: {str(e)}" |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @app.exception_handler(HTTPException) |
| | async def http_exception_handler(request: Request, exc: HTTPException): |
| | """ |
| | Handle HTTP exceptions. |
| | |
| | Returns: |
| | JSON error response with proper status code |
| | """ |
| | return { |
| | "error": exc.detail, |
| | "status_code": exc.status_code, |
| | "timestamp": time.time(), |
| | } |
| |
|
| |
|
| | @app.exception_handler(Exception) |
| | async def general_exception_handler(request: Request, exc: Exception): |
| | """ |
| | Handle general exceptions. |
| | |
| | Returns: |
| | JSON error response with 500 status |
| | """ |
| | logger.error(f"Unhandled exception: {exc}", exc_info=True) |
| |
|
| | return { |
| | "error": "Internal server error", |
| | "detail": str(exc), |
| | "status_code": status.HTTP_500_INTERNAL_SERVER_ERROR, |
| | "timestamp": time.time(), |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | gradio_interface = create_gradio_interface( |
| | query_engine_getter=lambda: app_state.query_engine |
| | ) |
| | app = gr.mount_gradio_app(app, gradio_interface, path="/ui") |
| | logger.info("Gradio UI mounted at /ui") |
| |
|