sepidnes's picture
Update backend/main.py
cd0a241 verified
import os
import tempfile
import shutil
from typing import List, Dict, Any
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles # Add this import
from pydantic import BaseModel
from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader, PDFLoader
from aimakerspace.openai_utils.prompts import (
UserRolePrompt,
SystemRolePrompt,
AssistantRolePrompt,
)
from aimakerspace.openai_utils.embedding import EmbeddingModel
from aimakerspace.vectordatabase import VectorDatabase
from aimakerspace.openai_utils.chatmodel import ChatOpenAI
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, replace with your frontend URL
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize components
text_splitter = CharacterTextSplitter()
vector_db = None
chat_openai = ChatOpenAI(model_name="gpt-3.5-turbo")
system_template = """\
Use the following context to answer a users question. If you cannot find the answer in the context, say you don't know the answer."""
system_role_prompt = SystemRolePrompt(system_template)
user_prompt_template = """\
Context:
{context}
Question:
{question}
"""
user_role_prompt = UserRolePrompt(user_prompt_template)
class QuestionRequest(BaseModel):
question: str
class QuestionResponse(BaseModel):
response: str
context: List[tuple]
def process_file(file_path: str, file_name: str):
print(f"Processing file: {file_name}")
# Create appropriate loader
if file_name.lower().endswith('.pdf'):
loader = PDFLoader(file_path)
else:
loader = TextFileLoader(file_path)
# Load and process the documents
documents = loader.load_documents()
texts = text_splitter.split_texts(documents)
return texts
class RetrievalAugmentedQAPipeline:
def __init__(self, llm: ChatOpenAI, vector_db_retriever: VectorDatabase) -> None:
self.llm = llm
self.vector_db_retriever = vector_db_retriever
async def run_pipeline(self, user_query: str):
context_list = self.vector_db_retriever.search_by_text(user_query, k=4)
context_prompt = ""
for context in context_list:
context_prompt += context[0] + "\n"
formatted_system_prompt = system_role_prompt.create_message()
formatted_user_prompt = user_role_prompt.create_message(question=user_query, context=context_prompt)
# Remove await since run is not an async method
response = await self.llm.run([formatted_system_prompt, formatted_user_prompt])
return {"response": response, "context": context_list}
@app.post("/upload")
async def upload_file(file: UploadFile = File(...)):
global vector_db
# Validate file type
if not file.filename.lower().endswith(('.txt', '.pdf')):
raise HTTPException(status_code=400, detail="Only .txt and .pdf files are allowed")
# Create a temporary file
suffix = f".{file.filename.split('.')[-1]}"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
# Copy the uploaded file content to the temporary file
content = await file.read()
print(f"File read complete. Size: {len(content)} bytes")
with open(temp_file.name, "wb") as f:
f.write(content)
try:
# Process the file
texts = process_file(temp_file.name, file.filename)
print(f"Processing {len(texts)} text chunks")
# Create a vector store
vector_db = VectorDatabase()
vector_db = await vector_db.abuild_from_list(texts)
print("Document processing complete")
return {"message": f"File {file.filename} processed successfully", "chunks": len(texts)}
finally:
# Clean up the temporary file
try:
os.unlink(temp_file.name)
except Exception as e:
print(f"Error cleaning up temporary file: {e}")
@app.post("/ask", response_model=QuestionResponse)
async def ask_question(request: QuestionRequest):
global vector_db
if vector_db is None:
raise HTTPException(status_code=400, detail="Please upload a file first")
# Create a chain
retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(
vector_db_retriever=vector_db,
llm=chat_openai
)
# Run the pipeline
result = await retrieval_augmented_qa_pipeline.run_pipeline(request.question)
return QuestionResponse(
response=result["response"],
context=result["context"]
)
@app.get("/health")
async def health_check():
return {"status": "healthy"}
app.mount("/", StaticFiles(directory="../frontend/build", html=True), name="frontend")