|
|
from fastapi import FastAPI, Depends, HTTPException |
|
|
from pydantic import BaseModel |
|
|
import os |
|
|
import chromadb |
|
|
from langchain_community.document_loaders.pdf import PyPDFDirectoryLoader |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
from langchain_openai import OpenAIEmbeddings, ChatOpenAI |
|
|
from langchain.vectorstores import Chroma |
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
|
|
|
|
|
|
class DocumentChatbot: |
|
|
def __init__(self, model_name: str, embedding_model: str, documents_path: str, chroma_path: str): |
|
|
self.model = ChatOpenAI(model=model_name, temperature=0) |
|
|
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) |
|
|
self.embeddings = OpenAIEmbeddings(model=embedding_model) |
|
|
self.db_chroma = self._load_documents(documents_path, chroma_path) |
|
|
self.prompt_template = """ |
|
|
Answer the question based only on the following context: |
|
|
{context} |
|
|
Answer the question based on the above context: {question}. |
|
|
Provide a detailed answer. |
|
|
Don’t justify your answers. |
|
|
Don’t give information not mentioned in the CONTEXT INFORMATION. |
|
|
Do not say "according to the context" or "mentioned in the context" or similar. |
|
|
""" |
|
|
|
|
|
def _load_documents(self, documents_path: str, chroma_path: str): |
|
|
|
|
|
loader = PyPDFDirectoryLoader(documents_path) |
|
|
pages = loader.load_and_split(self.text_splitter) |
|
|
db_chroma = Chroma.from_documents(pages, self.embeddings, persist_directory=chroma_path) |
|
|
return db_chroma |
|
|
|
|
|
def generate_response(self, message: str): |
|
|
docs_chroma = self.db_chroma.similarity_search_with_score(message, k=5) |
|
|
context_text = "\n\n".join([doc.page_content for doc, _score in docs_chroma]) |
|
|
prompt_template = ChatPromptTemplate.from_template(self.prompt_template) |
|
|
prompt = prompt_template.format(context=context_text, question=message) |
|
|
|
|
|
response = "" |
|
|
for chunk in self.model.stream(prompt): |
|
|
response += chunk.content |
|
|
return response |
|
|
|
|
|
|
|
|
class ChatRequest(BaseModel): |
|
|
message: str |
|
|
|
|
|
|
|
|
def get_chatbot(): |
|
|
return DocumentChatbot( |
|
|
model_name="gpt-4", |
|
|
embedding_model="text-embedding-3-small", |
|
|
documents_path="/content/drive/MyDrive/Test Documents", |
|
|
chroma_path="test-documents-2" |
|
|
) |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
@app.post("/chat") |
|
|
async def chat(request: ChatRequest, chatbot: DocumentChatbot = Depends(get_chatbot)): |
|
|
try: |
|
|
response = chatbot.generate_response(request.message) |
|
|
return {"response": response} |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
return {"status": "ok"} |
|
|
|