|
|
import os |
|
|
from dotenv import load_dotenv |
|
|
from fastapi import FastAPI, HTTPException, Request |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.responses import JSONResponse |
|
|
from pydantic import BaseModel |
|
|
from typing import List, Optional |
|
|
|
|
|
|
|
|
from langchain.embeddings import OpenAIEmbeddings |
|
|
from langchain.vectorstores import Chroma |
|
|
from langchain.text_splitter import CharacterTextSplitter |
|
|
from langchain.chains import RetrievalQA |
|
|
from langchain.llms import OpenAI |
|
|
from langchain.document_loaders import TextLoader |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
app = FastAPI(title="Educational Research Methods Chatbot API") |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
class ChatRequest(BaseModel): |
|
|
message: str |
|
|
conversation_history: Optional[List[dict]] = [] |
|
|
|
|
|
class ChatResponse(BaseModel): |
|
|
response: str |
|
|
citations: List[dict] = [] |
|
|
|
|
|
|
|
|
def initialize_rag(): |
|
|
|
|
|
loader = TextLoader("/home/ubuntu/research_methods_chatbot/research_methods_info.md") |
|
|
documents = loader.load() |
|
|
|
|
|
|
|
|
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200) |
|
|
texts = text_splitter.split_documents(documents) |
|
|
|
|
|
|
|
|
embeddings = OpenAIEmbeddings() |
|
|
|
|
|
|
|
|
db = Chroma.from_documents(texts, embeddings) |
|
|
|
|
|
|
|
|
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3}) |
|
|
|
|
|
|
|
|
qa = RetrievalQA.from_chain_type( |
|
|
llm=OpenAI(), |
|
|
chain_type="stuff", |
|
|
retriever=retriever, |
|
|
return_source_documents=True, |
|
|
verbose=True, |
|
|
) |
|
|
|
|
|
return qa |
|
|
|
|
|
|
|
|
qa_chain = initialize_rag() |
|
|
|
|
|
@app.post("/chat", response_model=ChatResponse) |
|
|
async def chat(request: ChatRequest): |
|
|
try: |
|
|
|
|
|
query = request.message |
|
|
if request.conversation_history: |
|
|
context = "\n".join([f"User: {msg['message']}\nAssistant: {msg['response']}" |
|
|
for msg in request.conversation_history[-3:]]) |
|
|
query = f"Conversation history:\n{context}\n\nCurrent question: {query}" |
|
|
|
|
|
|
|
|
query += "\nPlease include APA7 citations for any information provided." |
|
|
|
|
|
|
|
|
result = qa_chain({"query": query}) |
|
|
|
|
|
|
|
|
citations = [] |
|
|
if "source_documents" in result: |
|
|
for i, doc in enumerate(result["source_documents"]): |
|
|
if hasattr(doc, "metadata") and "source" in doc.metadata: |
|
|
citations.append({ |
|
|
"id": i + 1, |
|
|
"text": doc.metadata["source"], |
|
|
"page": doc.metadata.get("page", "") |
|
|
}) |
|
|
|
|
|
return ChatResponse( |
|
|
response=result["result"], |
|
|
citations=citations |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
return {"status": "healthy"} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|