sepidnes commited on
Commit
ea1eada
·
verified ·
1 Parent(s): dfed1e5

Update backend/main.py

Browse files
Files changed (1) hide show
  1. backend/main.py +144 -1
backend/main.py CHANGED
@@ -1 +1,144 @@
1
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import shutil
4
+ from typing import List, Dict, Any
5
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from pydantic import BaseModel
8
+ from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader, PDFLoader
9
+ from aimakerspace.openai_utils.prompts import (
10
+ UserRolePrompt,
11
+ SystemRolePrompt,
12
+ AssistantRolePrompt,
13
+ )
14
+ from aimakerspace.openai_utils.embedding import EmbeddingModel
15
+ from aimakerspace.vectordatabase import VectorDatabase
16
+ from aimakerspace.openai_utils.chatmodel import ChatOpenAI
17
+
18
+ app = FastAPI()
19
+
20
+ # Add CORS middleware
21
+ app.add_middleware(
22
+ CORSMiddleware,
23
+ allow_origins=["*"], # In production, replace with your frontend URL
24
+ allow_credentials=True,
25
+ allow_methods=["*"],
26
+ allow_headers=["*"],
27
+ )
28
+
29
+ # Initialize components
30
+ text_splitter = CharacterTextSplitter()
31
+ vector_db = None
32
+ chat_openai = ChatOpenAI(model_name="gpt-3.5-turbo")
33
+
34
+ system_template = """\
35
+ Use the following context to answer a users question. If you cannot find the answer in the context, say you don't know the answer."""
36
+ system_role_prompt = SystemRolePrompt(system_template)
37
+
38
+ user_prompt_template = """\
39
+ Context:
40
+ {context}
41
+
42
+ Question:
43
+ {question}
44
+ """
45
+ user_role_prompt = UserRolePrompt(user_prompt_template)
46
+
47
+ class QuestionRequest(BaseModel):
48
+ question: str
49
+
50
+ class QuestionResponse(BaseModel):
51
+ response: str
52
+ context: List[tuple]
53
+
54
+ def process_file(file_path: str, file_name: str):
55
+ print(f"Processing file: {file_name}")
56
+
57
+ # Create appropriate loader
58
+ if file_name.lower().endswith('.pdf'):
59
+ loader = PDFLoader(file_path)
60
+ else:
61
+ loader = TextFileLoader(file_path)
62
+
63
+ # Load and process the documents
64
+ documents = loader.load_documents()
65
+ texts = text_splitter.split_texts(documents)
66
+ return texts
67
+
68
+ class RetrievalAugmentedQAPipeline:
69
+ def __init__(self, llm: ChatOpenAI, vector_db_retriever: VectorDatabase) -> None:
70
+ self.llm = llm
71
+ self.vector_db_retriever = vector_db_retriever
72
+
73
+ async def run_pipeline(self, user_query: str):
74
+ context_list = self.vector_db_retriever.search_by_text(user_query, k=4)
75
+
76
+ context_prompt = ""
77
+ for context in context_list:
78
+ context_prompt += context[0] + "\n"
79
+
80
+ formatted_system_prompt = system_role_prompt.create_message()
81
+ formatted_user_prompt = user_role_prompt.create_message(question=user_query, context=context_prompt)
82
+
83
+ # Remove await since run is not an async method
84
+ response = self.llm.run([formatted_system_prompt, formatted_user_prompt])
85
+
86
+ return {"response": response, "context": context_list}
87
+
88
+ @app.post("/upload")
89
+ async def upload_file(file: UploadFile = File(...)):
90
+ global vector_db
91
+
92
+ # Validate file type
93
+ if not file.filename.lower().endswith(('.txt', '.pdf')):
94
+ raise HTTPException(status_code=400, detail="Only .txt and .pdf files are allowed")
95
+
96
+ # Create a temporary file
97
+ suffix = f".{file.filename.split('.')[-1]}"
98
+ with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
99
+ # Copy the uploaded file content to the temporary file
100
+ content = await file.read()
101
+ with open(temp_file.name, "wb") as f:
102
+ f.write(content)
103
+
104
+ try:
105
+ # Process the file
106
+ texts = process_file(temp_file.name, file.filename)
107
+ print(f"Processing {len(texts)} text chunks")
108
+
109
+ # Create a vector store
110
+ vector_db = VectorDatabase()
111
+ vector_db = await vector_db.abuild_from_list(texts)
112
+
113
+ return {"message": f"File {file.filename} processed successfully", "chunks": len(texts)}
114
+ finally:
115
+ # Clean up the temporary file
116
+ try:
117
+ os.unlink(temp_file.name)
118
+ except Exception as e:
119
+ print(f"Error cleaning up temporary file: {e}")
120
+
121
+ @app.post("/ask", response_model=QuestionResponse)
122
+ async def ask_question(request: QuestionRequest):
123
+ global vector_db
124
+
125
+ if vector_db is None:
126
+ raise HTTPException(status_code=400, detail="Please upload a file first")
127
+
128
+ # Create a chain
129
+ retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(
130
+ vector_db_retriever=vector_db,
131
+ llm=chat_openai
132
+ )
133
+
134
+ # Run the pipeline
135
+ result = await retrieval_augmented_qa_pipeline.run_pipeline(request.question)
136
+
137
+ return QuestionResponse(
138
+ response=result["response"],
139
+ context=result["context"]
140
+ )
141
+
142
+ @app.get("/health")
143
+ async def health_check():
144
+ return {"status": "healthy"}