import os from dotenv import load_dotenv from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from pydantic import BaseModel from typing import List, Optional # Import RAG components from langchain.embeddings import OpenAIEmbeddings from langchain.vectorstores import Chroma from langchain.text_splitter import CharacterTextSplitter from langchain.chains import RetrievalQA from langchain.llms import OpenAI from langchain.document_loaders import TextLoader # Load environment variables load_dotenv() # Initialize FastAPI app app = FastAPI(title="Educational Research Methods Chatbot API") # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # In production, replace with specific origins allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Define request and response models class ChatRequest(BaseModel): message: str conversation_history: Optional[List[dict]] = [] class ChatResponse(BaseModel): response: str citations: List[dict] = [] # Initialize RAG components def initialize_rag(): # Load research methods information loader = TextLoader("/home/ubuntu/research_methods_chatbot/research_methods_info.md") documents = loader.load() # Split documents into chunks text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200) texts = text_splitter.split_documents(documents) # Create embeddings embeddings = OpenAIEmbeddings() # Create vector store db = Chroma.from_documents(texts, embeddings) # Create retriever retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3}) # Create QA chain qa = RetrievalQA.from_chain_type( llm=OpenAI(), chain_type="stuff", retriever=retriever, return_source_documents=True, verbose=True, ) return qa # Initialize RAG pipeline qa_chain = initialize_rag() @app.post("/chat", response_model=ChatResponse) async def chat(request: ChatRequest): try: # Prepare the query with conversation history context query = request.message if request.conversation_history: context = "\n".join([f"User: {msg['message']}\nAssistant: {msg['response']}" for msg in request.conversation_history[-3:]]) query = f"Conversation history:\n{context}\n\nCurrent question: {query}" # Add instruction for APA7 citations query += "\nPlease include APA7 citations for any information provided." # Get response from RAG pipeline result = qa_chain({"query": query}) # Extract citations from source documents citations = [] if "source_documents" in result: for i, doc in enumerate(result["source_documents"]): if hasattr(doc, "metadata") and "source" in doc.metadata: citations.append({ "id": i + 1, "text": doc.metadata["source"], "page": doc.metadata.get("page", "") }) return ChatResponse( response=result["result"], citations=citations ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") async def health_check(): return {"status": "healthy"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)