ERMA / app.py
mfirat007's picture
Upload 27 files
5cf374f verified
import os
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from typing import List, Optional
# Import RAG components
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.text_splitter import CharacterTextSplitter
from langchain.chains import RetrievalQA
from langchain.llms import OpenAI
from langchain.document_loaders import TextLoader
# Load environment variables
load_dotenv()
# Initialize FastAPI app
app = FastAPI(title="Educational Research Methods Chatbot API")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, replace with specific origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Define request and response models
class ChatRequest(BaseModel):
message: str
conversation_history: Optional[List[dict]] = []
class ChatResponse(BaseModel):
response: str
citations: List[dict] = []
# Initialize RAG components
def initialize_rag():
# Load research methods information
loader = TextLoader("/home/ubuntu/research_methods_chatbot/research_methods_info.md")
documents = loader.load()
# Split documents into chunks
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
texts = text_splitter.split_documents(documents)
# Create embeddings
embeddings = OpenAIEmbeddings()
# Create vector store
db = Chroma.from_documents(texts, embeddings)
# Create retriever
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3})
# Create QA chain
qa = RetrievalQA.from_chain_type(
llm=OpenAI(),
chain_type="stuff",
retriever=retriever,
return_source_documents=True,
verbose=True,
)
return qa
# Initialize RAG pipeline
qa_chain = initialize_rag()
@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
try:
# Prepare the query with conversation history context
query = request.message
if request.conversation_history:
context = "\n".join([f"User: {msg['message']}\nAssistant: {msg['response']}"
for msg in request.conversation_history[-3:]])
query = f"Conversation history:\n{context}\n\nCurrent question: {query}"
# Add instruction for APA7 citations
query += "\nPlease include APA7 citations for any information provided."
# Get response from RAG pipeline
result = qa_chain({"query": query})
# Extract citations from source documents
citations = []
if "source_documents" in result:
for i, doc in enumerate(result["source_documents"]):
if hasattr(doc, "metadata") and "source" in doc.metadata:
citations.append({
"id": i + 1,
"text": doc.metadata["source"],
"page": doc.metadata.get("page", "")
})
return ChatResponse(
response=result["result"],
citations=citations
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "healthy"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)