Spaces:
Sleeping
Sleeping
File size: 4,952 Bytes
ea1eada d2a06a4 ea1eada 15a4149 ea1eada 3f52b2e ea1eada 48a6f56 ea1eada a322815 c78d03f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import os
import tempfile
import shutil
from typing import List, Dict, Any
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles # Add this import
from pydantic import BaseModel
from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader, PDFLoader
from aimakerspace.openai_utils.prompts import (
UserRolePrompt,
SystemRolePrompt,
AssistantRolePrompt,
)
from aimakerspace.openai_utils.embedding import EmbeddingModel
from aimakerspace.vectordatabase import VectorDatabase
from aimakerspace.openai_utils.chatmodel import ChatOpenAI
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, replace with your frontend URL
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize components
text_splitter = CharacterTextSplitter()
vector_db = None
chat_openai = ChatOpenAI(model_name="gpt-3.5-turbo")
system_template = """\
Use the following context to answer a users question. If you cannot find the answer in the context, say you don't know the answer."""
system_role_prompt = SystemRolePrompt(system_template)
user_prompt_template = """\
Context:
{context}
Question:
{question}
"""
user_role_prompt = UserRolePrompt(user_prompt_template)
class QuestionRequest(BaseModel):
question: str
class QuestionResponse(BaseModel):
response: str
context: List[tuple]
def process_file(file_path: str, file_name: str):
print(f"Processing file: {file_name}")
# Create appropriate loader
if file_name.lower().endswith('.pdf'):
loader = PDFLoader(file_path)
else:
loader = TextFileLoader(file_path)
# Load and process the documents
documents = loader.load_documents()
texts = text_splitter.split_texts(documents)
return texts
class RetrievalAugmentedQAPipeline:
def __init__(self, llm: ChatOpenAI, vector_db_retriever: VectorDatabase) -> None:
self.llm = llm
self.vector_db_retriever = vector_db_retriever
async def run_pipeline(self, user_query: str):
context_list = self.vector_db_retriever.search_by_text(user_query, k=4)
context_prompt = ""
for context in context_list:
context_prompt += context[0] + "\n"
formatted_system_prompt = system_role_prompt.create_message()
formatted_user_prompt = user_role_prompt.create_message(question=user_query, context=context_prompt)
# Remove await since run is not an async method
response = await self.llm.run([formatted_system_prompt, formatted_user_prompt])
return {"response": response, "context": context_list}
@app.post("/upload")
async def upload_file(file: UploadFile = File(...)):
global vector_db
# Validate file type
if not file.filename.lower().endswith(('.txt', '.pdf')):
raise HTTPException(status_code=400, detail="Only .txt and .pdf files are allowed")
# Create a temporary file
suffix = f".{file.filename.split('.')[-1]}"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
# Copy the uploaded file content to the temporary file
content = await file.read()
print(f"File read complete. Size: {len(content)} bytes")
with open(temp_file.name, "wb") as f:
f.write(content)
try:
# Process the file
texts = process_file(temp_file.name, file.filename)
print(f"Processing {len(texts)} text chunks")
# Create a vector store
vector_db = VectorDatabase()
vector_db = await vector_db.abuild_from_list(texts)
print("Document processing complete")
return {"message": f"File {file.filename} processed successfully", "chunks": len(texts)}
finally:
# Clean up the temporary file
try:
os.unlink(temp_file.name)
except Exception as e:
print(f"Error cleaning up temporary file: {e}")
@app.post("/ask", response_model=QuestionResponse)
async def ask_question(request: QuestionRequest):
global vector_db
if vector_db is None:
raise HTTPException(status_code=400, detail="Please upload a file first")
# Create a chain
retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(
vector_db_retriever=vector_db,
llm=chat_openai
)
# Run the pipeline
result = await retrieval_augmented_qa_pipeline.run_pipeline(request.question)
return QuestionResponse(
response=result["response"],
context=result["context"]
)
@app.get("/health")
async def health_check():
return {"status": "healthy"}
# app.mount("/", StaticFiles(directory="../frontend/build", html=True), name="frontend")
|