eaedk commited on
Commit
fab7c40
·
1 Parent(s): 5f790a2
Files changed (1) hide show
  1. _app.py +206 -0
_app.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import chromadb
3
+ from dotenv import load_dotenv
4
+ from uuid import uuid4
5
+
6
+ from fastapi import FastAPI, File, UploadFile, HTTPException
7
+ from fastapi.responses import JSONResponse, StreamingResponse
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+
10
+ from langchain_community.document_loaders import PyPDFLoader
11
+ from langchain_openai import OpenAIEmbeddings
12
+ from langchain.chat_models import init_chat_model
13
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
14
+ from langchain_core.prompts import PromptTemplate
15
+ from langchain_core.output_parsers import StrOutputParser
16
+ from langchain_core.runnables import RunnablePassthrough
17
+ from langchain_chroma import Chroma
18
+ import uvicorn
19
+
20
+ # ----------------------
21
+ # Configuration and Setup
22
+ # ----------------------
23
+
24
+ # Load environment variables from .env file
25
+ load_dotenv()
26
+
27
+ # Directories for file upload and persistent storage of Chroma vector database
28
+ UPLOAD_DIR = "uploads"
29
+ CHROMA_DIR = "chroma_db"
30
+
31
+ # Set model versions for LLM and embeddings
32
+ LLM = "gpt-4o-mini-2024-07-18"
33
+ EMBEDDING_MODEL = "text-embedding-3-small"
34
+
35
+ # Ensure necessary directories exist
36
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
37
+ os.makedirs(CHROMA_DIR, exist_ok=True)
38
+
39
+ # Set OpenAI API key from environment variables
40
+ os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
41
+
42
+ # Initialize a persistent client for Chroma, specifying where the data is stored
43
+ client = chromadb.PersistentClient(path=CHROMA_DIR)
44
+
45
+ # FastAPI application setup
46
+ app = FastAPI()
47
+
48
+ # Enable CORS (Cross-Origin Resource Sharing) for all origins, methods, and headers
49
+ app.add_middleware(
50
+ CORSMiddleware,
51
+ allow_origins=["*"],
52
+ allow_credentials=True,
53
+ allow_methods=["*"],
54
+ allow_headers=["*"],
55
+ )
56
+
57
+ # ----------------------
58
+ # LangChain Initialization
59
+ # ----------------------
60
+
61
+ # Initialize the embedding model using OpenAI's API
62
+ embedding = OpenAIEmbeddings(model=EMBEDDING_MODEL)
63
+
64
+ # Initialize the language model (LLM) using OpenAI's API (with temperature for creativity)
65
+ llm = init_chat_model(model=LLM, model_provider="openai", temperature=0)
66
+
67
+ # Text splitter to split documents into manageable chunks (for efficient processing)
68
+ text_splitter = RecursiveCharacterTextSplitter(
69
+ chunk_size=1200,
70
+ chunk_overlap=50,
71
+ separators=["\n\n", "\n", ".", " ", ""]
72
+ )
73
+
74
+ # Set up Chroma vector store to store document embeddings and their metadata
75
+ vectorstore = Chroma(
76
+ client=client,
77
+ persist_directory=CHROMA_DIR,
78
+ embedding_function=embedding,
79
+ collection_name="legal_docs"
80
+ )
81
+
82
+ # Define the prompt template that will be used in the LLM for querying with context
83
+ prompt_template = """
84
+ Tu es un assistant utile qui réponds en français de manière claire et concise.
85
+ Réponds uniquement en utilisant le contexte fourni.
86
+ Si tu ne sais pas, dis "Je ne sais pas".
87
+
88
+ contexte : {context}
89
+
90
+ question : {question}
91
+
92
+ answer :
93
+ """
94
+
95
+ # Initialize the prompt template with variables
96
+ prompt = PromptTemplate(
97
+ input_variables=["question", "context"],
98
+ template=prompt_template,
99
+ )
100
+
101
+ # Function to format documents for easier reading (used for retriever output)
102
+ def format_docs(docs):
103
+ return "\n\n".join([f"(Page {d.metadata.get('page','?')}) {d.page_content}" for d in docs])
104
+
105
+ # Set up the retriever to pull relevant documents from the vector store based on a query
106
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 10})
107
+
108
+ # Define the QA chain that links together the retriever, document formatting, and LLM for querying
109
+ qa_chain = (
110
+ {
111
+ "context": retriever | format_docs,
112
+ "question": RunnablePassthrough(),
113
+ }
114
+ | prompt
115
+ | llm
116
+ | StrOutputParser()
117
+ )
118
+
119
+ # ----------------------
120
+ # Document Management Functions
121
+ # ----------------------
122
+
123
+ # Function to add a PDF document to the vector store (embedding and splitting into chunks)
124
+ def add_pdf_to_vectorstore(file_path):
125
+ # Load the PDF file
126
+ loader = PyPDFLoader(file_path)
127
+ documents = loader.load()
128
+
129
+ # Split the document into smaller chunks
130
+ docs = text_splitter.split_documents(documents)
131
+
132
+ # Generate a unique ID for each chunk
133
+ uuids = [str(uuid4()) for _ in range(len(docs))]
134
+ print(f"Number of documents split: {len(docs)}")
135
+
136
+ # Add documents to the vector store (Chroma)
137
+ vectorstore.add_documents(documents=docs, ids=uuids)
138
+
139
+ # ----------------------
140
+ # FastAPI Routes
141
+ # ----------------------
142
+
143
+ # Route to upload a PDF file and add its content to the vector store
144
+ @app.post("/upload/")
145
+ async def upload_pdf(file: UploadFile = File(...)):
146
+ # Check if the uploaded file is a PDF
147
+ if not file.filename.endswith(".pdf"):
148
+ raise HTTPException(status_code=400, detail="Seuls les fichiers PDF sont acceptés.")
149
+
150
+ # Save the uploaded file to disk
151
+ file_path = os.path.join(UPLOAD_DIR, file.filename)
152
+ with open(file_path, "wb") as buffer:
153
+ buffer.write(await file.read())
154
+
155
+ # Add the PDF document to the vector store
156
+ add_pdf_to_vectorstore(file_path)
157
+
158
+ # Return a success message
159
+ content = {"message": f"Fichier {file.filename} ajouté à la base de connaissances."}
160
+ print(f"{content=}")
161
+ return JSONResponse(content=content)
162
+
163
+ # Route to interact with the assistant via a chat-like interface
164
+ @app.get("/chat/")
165
+ async def chat(message: str):
166
+ # Use the QA chain to get a response from the assistant
167
+ response = qa_chain.invoke(message)
168
+
169
+ # Return the response from the assistant
170
+ print(f"{response=}")
171
+ return {"answer": response}
172
+
173
+ # ----------------------
174
+ # Streaming Response for Chat
175
+ # ----------------------
176
+
177
+ # This function will simulate the streaming of the response.
178
+ async def stream_chat_response(message: str):
179
+ # Initialize the chat model (this could be done outside the function if it's expensive)
180
+ # response_parts = []
181
+ # print("Streaming API response:\n")
182
+ async for part in qa_chain.astream(message):
183
+ # response_parts.append(part) # Collect all parts of the response
184
+ # Yield each part as a chunk for streaming to the client
185
+ print(part, end="", flush=True)
186
+ yield part
187
+
188
+ # # Final join to return the complete response after streaming
189
+ # full_response = "".join(response_parts)
190
+ # yield full_response
191
+
192
+ # FastAPI endpoint for streaming chat responses
193
+ @app.get("/chat_stream/")
194
+ async def chat_stream(message: str):
195
+ """
196
+ Endpoint to stream chat responses progressively.
197
+ """
198
+ # Return a StreamingResponse that will stream the response from the generator
199
+ return StreamingResponse(stream_chat_response(message), media_type="text/plain")
200
+
201
+ # ----------------------
202
+ # Start the FastAPI app using Uvicorn
203
+ # ----------------------
204
+ if __name__ == "__main__":
205
+ # Run the FastAPI application with auto-reloading enabled
206
+ uvicorn.run(app, host="0.0.0.0", port=8000)