File size: 4,950 Bytes
ea1eada
 
 
 
 
 
d2a06a4
ea1eada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15a4149
ea1eada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f52b2e
ea1eada
 
 
 
 
 
 
 
 
 
 
48a6f56
ea1eada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a322815
 
cd0a241
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import tempfile
import shutil
from typing import List, Dict, Any
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles  # Add this import
from pydantic import BaseModel
from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader, PDFLoader
from aimakerspace.openai_utils.prompts import (
    UserRolePrompt,
    SystemRolePrompt,
    AssistantRolePrompt,
)
from aimakerspace.openai_utils.embedding import EmbeddingModel
from aimakerspace.vectordatabase import VectorDatabase
from aimakerspace.openai_utils.chatmodel import ChatOpenAI

app = FastAPI()

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # In production, replace with your frontend URL
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Initialize components
text_splitter = CharacterTextSplitter()
vector_db = None
chat_openai = ChatOpenAI(model_name="gpt-3.5-turbo")

system_template = """\
Use the following context to answer a users question. If you cannot find the answer in the context, say you don't know the answer."""
system_role_prompt = SystemRolePrompt(system_template)

user_prompt_template = """\
Context:
{context}

Question:
{question}
"""
user_role_prompt = UserRolePrompt(user_prompt_template)

class QuestionRequest(BaseModel):
    question: str

class QuestionResponse(BaseModel):
    response: str
    context: List[tuple]

def process_file(file_path: str, file_name: str):
    print(f"Processing file: {file_name}")
    
    # Create appropriate loader
    if file_name.lower().endswith('.pdf'):
        loader = PDFLoader(file_path)
    else:
        loader = TextFileLoader(file_path)
        
    # Load and process the documents
    documents = loader.load_documents()
    texts = text_splitter.split_texts(documents)
    return texts

class RetrievalAugmentedQAPipeline:
    def __init__(self, llm: ChatOpenAI, vector_db_retriever: VectorDatabase) -> None:
        self.llm = llm
        self.vector_db_retriever = vector_db_retriever

    async def run_pipeline(self, user_query: str):
        context_list = self.vector_db_retriever.search_by_text(user_query, k=4)

        context_prompt = ""
        for context in context_list:
            context_prompt += context[0] + "\n"

        formatted_system_prompt = system_role_prompt.create_message()
        formatted_user_prompt = user_role_prompt.create_message(question=user_query, context=context_prompt)

        # Remove await since run is not an async method
        response = await self.llm.run([formatted_system_prompt, formatted_user_prompt])
        
        return {"response": response, "context": context_list}

@app.post("/upload")
async def upload_file(file: UploadFile = File(...)):
    global vector_db
    
    # Validate file type
    if not file.filename.lower().endswith(('.txt', '.pdf')):
        raise HTTPException(status_code=400, detail="Only .txt and .pdf files are allowed")
    
    # Create a temporary file
    suffix = f".{file.filename.split('.')[-1]}"
    with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
        # Copy the uploaded file content to the temporary file
        content = await file.read()
        print(f"File read complete. Size: {len(content)} bytes")
        with open(temp_file.name, "wb") as f:
            f.write(content)
        
        try:
            # Process the file
            texts = process_file(temp_file.name, file.filename)
            print(f"Processing {len(texts)} text chunks")
            
            # Create a vector store
            vector_db = VectorDatabase()
            vector_db = await vector_db.abuild_from_list(texts)
            print("Document processing complete")
            
            return {"message": f"File {file.filename} processed successfully", "chunks": len(texts)}
        finally:
            # Clean up the temporary file
            try:
                os.unlink(temp_file.name)
            except Exception as e:
                print(f"Error cleaning up temporary file: {e}")

@app.post("/ask", response_model=QuestionResponse)
async def ask_question(request: QuestionRequest):
    global vector_db
    
    if vector_db is None:
        raise HTTPException(status_code=400, detail="Please upload a file first")
    
    # Create a chain
    retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(
        vector_db_retriever=vector_db,
        llm=chat_openai
    )
    
    # Run the pipeline
    result = await retrieval_augmented_qa_pipeline.run_pipeline(request.question)
    
    return QuestionResponse(
        response=result["response"],
        context=result["context"]
    )

@app.get("/health")
async def health_check():
    return {"status": "healthy"}

app.mount("/", StaticFiles(directory="../frontend/build", html=True), name="frontend")