sepidnes's picture
Update backend/main.py
c78d03f verified
raw
history blame
4.95 kB
import os
import tempfile
import shutil
from typing import List, Dict, Any
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles # Add this import
from pydantic import BaseModel
from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader, PDFLoader
from aimakerspace.openai_utils.prompts import (
UserRolePrompt,
SystemRolePrompt,
AssistantRolePrompt,
)
from aimakerspace.openai_utils.embedding import EmbeddingModel
from aimakerspace.vectordatabase import VectorDatabase
from aimakerspace.openai_utils.chatmodel import ChatOpenAI
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, replace with your frontend URL
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize components
text_splitter = CharacterTextSplitter()
vector_db = None
chat_openai = ChatOpenAI(model_name="gpt-3.5-turbo")
system_template = """\
Use the following context to answer a users question. If you cannot find the answer in the context, say you don't know the answer."""
system_role_prompt = SystemRolePrompt(system_template)
user_prompt_template = """\
Context:
{context}
Question:
{question}
"""
user_role_prompt = UserRolePrompt(user_prompt_template)
class QuestionRequest(BaseModel):
question: str
class QuestionResponse(BaseModel):
response: str
context: List[tuple]
def process_file(file_path: str, file_name: str):
print(f"Processing file: {file_name}")
# Create appropriate loader
if file_name.lower().endswith('.pdf'):
loader = PDFLoader(file_path)
else:
loader = TextFileLoader(file_path)
# Load and process the documents
documents = loader.load_documents()
texts = text_splitter.split_texts(documents)
return texts
class RetrievalAugmentedQAPipeline:
def __init__(self, llm: ChatOpenAI, vector_db_retriever: VectorDatabase) -> None:
self.llm = llm
self.vector_db_retriever = vector_db_retriever
async def run_pipeline(self, user_query: str):
context_list = self.vector_db_retriever.search_by_text(user_query, k=4)
context_prompt = ""
for context in context_list:
context_prompt += context[0] + "\n"
formatted_system_prompt = system_role_prompt.create_message()
formatted_user_prompt = user_role_prompt.create_message(question=user_query, context=context_prompt)
# Remove await since run is not an async method
response = await self.llm.run([formatted_system_prompt, formatted_user_prompt])
return {"response": response, "context": context_list}
@app.post("/upload")
async def upload_file(file: UploadFile = File(...)):
global vector_db
# Validate file type
if not file.filename.lower().endswith(('.txt', '.pdf')):
raise HTTPException(status_code=400, detail="Only .txt and .pdf files are allowed")
# Create a temporary file
suffix = f".{file.filename.split('.')[-1]}"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
# Copy the uploaded file content to the temporary file
content = await file.read()
print(f"File read complete. Size: {len(content)} bytes")
with open(temp_file.name, "wb") as f:
f.write(content)
try:
# Process the file
texts = process_file(temp_file.name, file.filename)
print(f"Processing {len(texts)} text chunks")
# Create a vector store
vector_db = VectorDatabase()
vector_db = await vector_db.abuild_from_list(texts)
print("Document processing complete")
return {"message": f"File {file.filename} processed successfully", "chunks": len(texts)}
finally:
# Clean up the temporary file
try:
os.unlink(temp_file.name)
except Exception as e:
print(f"Error cleaning up temporary file: {e}")
@app.post("/ask", response_model=QuestionResponse)
async def ask_question(request: QuestionRequest):
global vector_db
if vector_db is None:
raise HTTPException(status_code=400, detail="Please upload a file first")
# Create a chain
retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(
vector_db_retriever=vector_db,
llm=chat_openai
)
# Run the pipeline
result = await retrieval_augmented_qa_pipeline.run_pipeline(request.question)
return QuestionResponse(
response=result["response"],
context=result["context"]
)
@app.get("/health")
async def health_check():
return {"status": "healthy"}
# app.mount("/", StaticFiles(directory="../frontend/build", html=True), name="frontend")