| """ |
| FastAPI application for Healthcare QA Chatbot. |
| """ |
| import os |
| import sys |
| from pathlib import Path |
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel, Field |
| from typing import List, Dict, Optional |
| import uvicorn |
|
|
| |
| app = FastAPI( |
| title="Healthcare QA Chatbot API", |
| description="An explainable medical question-answering system combining LLM + RAG + XAI", |
| version="1.0.0" |
| ) |
|
|
| |
| |
| |
| cors_origins_env = os.getenv("CORS_ORIGINS", "*") |
| if cors_origins_env == "*": |
| allow_origins = ["*"] |
| else: |
| allow_origins = [origin.strip() for origin in cors_origins_env.split(",")] |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=allow_origins, |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| pipeline = None |
| langchain_pipeline = None |
| langgraph_pipeline = None |
|
|
| |
| class QuestionRequest(BaseModel): |
| question: str = Field(..., min_length=5, max_length=1000) |
| include_explanation: bool = True |
| num_sources: int = Field(default=3, ge=1, le=10) |
| use_langchain: bool = Field(default=False, description="Use LangChain LCEL-based pipeline") |
| use_langgraph: bool = Field(default=False, description="Use LangGraph self-correcting RAG pipeline") |
|
|
| class SourceInfo(BaseModel): |
| source: str |
| content: str |
| score: float |
| url: Optional[str] = "" |
|
|
| class ConfidenceInfo(BaseModel): |
| score: float |
| level: str |
| explanation: str |
|
|
| class AttributionInfo(BaseModel): |
| claim: str |
| source: str |
| evidence: str |
| similarity: float |
|
|
| class AnswerResponse(BaseModel): |
| question: str |
| answer: str |
| sources: List[SourceInfo] |
| confidence: ConfidenceInfo |
| attributions: List[AttributionInfo] |
| disclaimer: str |
| rationale: Optional[str] = None |
|
|
| class HealthResponse(BaseModel): |
| status: str |
| pipeline_ready: bool |
| message: str |
|
|
| def get_pipeline(): |
| """Lazy load the pipeline.""" |
| global pipeline |
| if pipeline is None: |
| try: |
| from src.embeddings.embedding_models import MedicalEmbedder |
| from src.embeddings.vector_store import get_vector_store |
| from src.retrieval.hybrid_retriever import HybridRetriever |
| from src.generation.llm_wrapper import MedicalLLM |
| from src.generation.prompt_manager import MedicalPromptManager |
| from src.xai.confidence_scorer import ConfidenceScorer |
| from src.xai.source_attribution import SourceAttributor |
| from src.pipeline.qa_pipeline import HealthcareQAPipeline |
| |
| print("🔄 Loading pipeline components...") |
| embedder = MedicalEmbedder(model_name="all-minilm") |
| vector_store = get_vector_store() |
| retriever = HybridRetriever(embedder, vector_store) |
| |
| |
| from pathlib import Path |
| adapter_path = Path("models/fine_tuned/medical_adapter") |
| if adapter_path.exists(): |
| print(f"✅ Found fine-tuned adapter at {adapter_path}") |
| llm = MedicalLLM( |
| model_name="tinyllama", |
| adapter_path=str(adapter_path), |
| load_in_4bit=True |
| ) |
| else: |
| print("⚠️ No adapter found, using base model") |
| llm = MedicalLLM(model_name="tinyllama", load_in_4bit=False) |
| prompt_manager = MedicalPromptManager() |
| confidence_scorer = ConfidenceScorer() |
| source_attributor = SourceAttributor() |
| |
| pipeline = HealthcareQAPipeline( |
| retriever=retriever, |
| llm=llm, |
| prompt_manager=prompt_manager, |
| confidence_scorer=confidence_scorer, |
| source_attributor=source_attributor |
| ) |
| print("✅ Pipeline loaded successfully") |
| except Exception as e: |
| print(f"❌ Failed to load pipeline: {e}") |
| pipeline = None |
| return pipeline |
|
|
| def get_langchain_pipeline(): |
| """Lazy load the LangChain-based pipeline.""" |
| global langchain_pipeline |
| if langchain_pipeline is None: |
| try: |
| from src.embeddings.embedding_models import MedicalEmbedder |
| from src.embeddings.vector_store import get_vector_store |
| from src.retrieval.hybrid_retriever import HybridRetriever |
| from src.generation.llm_wrapper import MedicalLLM |
| from src.xai.confidence_scorer import ConfidenceScorer |
| from src.xai.source_attribution import SourceAttributor |
| from src.langchain import create_langchain_pipeline |
| |
| print("🔄 Loading LangChain pipeline components...") |
| embedder = MedicalEmbedder(model_name="all-minilm") |
| vector_store = get_vector_store() |
| retriever = HybridRetriever(embedder, vector_store) |
| |
| |
| from pathlib import Path |
| adapter_path = Path("models/fine_tuned/medical_adapter") |
| if adapter_path.exists(): |
| print(f"✅ Found fine-tuned adapter at {adapter_path}") |
| llm = MedicalLLM( |
| model_name="tinyllama", |
| adapter_path=str(adapter_path), |
| load_in_4bit=True |
| ) |
| else: |
| print("⚠️ No adapter found, using base model") |
| llm = MedicalLLM(model_name="tinyllama", load_in_4bit=False) |
| |
| confidence_scorer = ConfidenceScorer() |
| source_attributor = SourceAttributor() |
| |
| langchain_pipeline = create_langchain_pipeline( |
| retriever=retriever, |
| llm=llm, |
| confidence_scorer=confidence_scorer, |
| source_attributor=source_attributor |
| ) |
| print("✅ LangChain pipeline loaded successfully") |
| except Exception as e: |
| print(f"❌ Failed to load LangChain pipeline: {e}") |
| langchain_pipeline = None |
| return langchain_pipeline |
|
|
| def get_langgraph_pipeline(): |
| """Lazy load the LangGraph-based pipeline.""" |
| global langgraph_pipeline |
| if langgraph_pipeline is None: |
| try: |
| from src.embeddings.embedding_models import MedicalEmbedder |
| from src.embeddings.vector_store import get_vector_store |
| from src.retrieval.hybrid_retriever import HybridRetriever |
| from src.generation.llm_wrapper import MedicalLLM |
| from src.xai.confidence_scorer import ConfidenceScorer |
| from src.xai.source_attribution import SourceAttributor |
| from src.langgraph import create_langgraph_pipeline |
| |
| print("🔄 Loading LangGraph pipeline components...") |
| embedder = MedicalEmbedder(model_name="all-minilm") |
| vector_store = get_vector_store() |
| retriever = HybridRetriever(embedder, vector_store) |
| |
| |
| from pathlib import Path |
| adapter_path = Path("models/fine_tuned/medical_adapter") |
| if adapter_path.exists(): |
| print(f"✅ Found fine-tuned adapter at {adapter_path}") |
| llm = MedicalLLM( |
| model_name="tinyllama", |
| adapter_path=str(adapter_path), |
| load_in_4bit=True |
| ) |
| else: |
| print("⚠️ No adapter found, using base model") |
| llm = MedicalLLM(model_name="tinyllama", load_in_4bit=False) |
| |
| confidence_scorer = ConfidenceScorer() |
| source_attributor = SourceAttributor() |
| |
| langgraph_pipeline = create_langgraph_pipeline( |
| retriever=retriever, |
| llm=llm, |
| confidence_scorer=confidence_scorer, |
| source_attributor=source_attributor |
| ) |
| print("✅ LangGraph pipeline loaded successfully") |
| except Exception as e: |
| print(f"❌ Failed to load LangGraph pipeline: {e}") |
| langgraph_pipeline = None |
| return langgraph_pipeline |
|
|
| @app.get("/", response_model=HealthResponse) |
| async def root(): |
| """Root endpoint.""" |
| return HealthResponse( |
| status="ok", |
| pipeline_ready=pipeline is not None, |
| message="Healthcare QA Chatbot API is running" |
| ) |
|
|
| @app.get("/health", response_model=HealthResponse) |
| async def health_check(): |
| """Health check endpoint.""" |
| return HealthResponse( |
| status="healthy", |
| pipeline_ready=pipeline is not None, |
| message="Service is healthy" |
| ) |
|
|
| @app.post("/ask", response_model=AnswerResponse) |
| async def ask_question(request: QuestionRequest): |
| """ |
| Ask a medical question and get an explainable answer. |
| |
| Set use_langchain=true for the LangChain LCEL-based pipeline. |
| Set use_langgraph=true for the LangGraph self-correcting RAG pipeline. |
| """ |
| |
| if request.use_langgraph: |
| qa_pipeline = get_langgraph_pipeline() |
| pipeline_name = "LangGraph" |
| elif request.use_langchain: |
| qa_pipeline = get_langchain_pipeline() |
| pipeline_name = "LangChain" |
| else: |
| qa_pipeline = get_pipeline() |
| pipeline_name = "Standard" |
| |
| if qa_pipeline is None: |
| raise HTTPException( |
| status_code=503, |
| detail=f"{pipeline_name} pipeline not initialized. Please check that the knowledge base is built." |
| ) |
| |
| try: |
| |
| if request.use_langgraph or request.use_langchain: |
| response = qa_pipeline.answer(request.question) |
| else: |
| response = qa_pipeline.answer( |
| question=request.question, |
| num_documents=request.num_sources, |
| include_explanation=request.include_explanation |
| ) |
| |
| return AnswerResponse( |
| question=response.question, |
| answer=response.answer, |
| sources=[SourceInfo(**s) for s in response.sources], |
| confidence=ConfidenceInfo(**response.confidence), |
| attributions=[AttributionInfo(**a) for a in response.attributions], |
| disclaimer=response.disclaimer, |
| rationale=getattr(response, 'rationale', None) |
| ) |
| except Exception as e: |
| raise HTTPException( |
| status_code=500, |
| detail=f"Error processing question with {pipeline_name} pipeline: {str(e)}" |
| ) |
|
|
| @app.post("/ask/simple") |
| async def ask_simple(question: str): |
| """ |
| Simple question endpoint (minimal response). |
| """ |
| qa_pipeline = get_pipeline() |
| |
| if qa_pipeline is None: |
| raise HTTPException(status_code=503, detail="Pipeline not initialized") |
| |
| try: |
| response = qa_pipeline.answer( |
| question=question, |
| num_documents=3, |
| include_explanation=False |
| ) |
| return { |
| "answer": response.answer, |
| "confidence": response.confidence["level"] |
| } |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| if __name__ == "__main__": |
| |
| get_pipeline() |
| |
| |
| uvicorn.run( |
| app, |
| host="0.0.0.0", |
| port=8000, |
| log_level="info" |
| ) |
|
|