Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -59,6 +59,9 @@ uploads_dir = os.path.join(app.root_path,'static', 'uploads')
|
|
| 59 |
|
| 60 |
os.makedirs(uploads_dir, exist_ok=True)
|
| 61 |
|
|
|
|
|
|
|
|
|
|
| 62 |
defaultEmbeddingModelID = 3
|
| 63 |
defaultLLMID=0
|
| 64 |
|
|
@@ -201,6 +204,26 @@ def loadKB(fileprovided, urlProvided, uploads_dir, request):
|
|
| 201 |
|
| 202 |
|
| 203 |
def getRAGChain(customerName, customerDistrict, custDetailsPresent, vectordb,llmID):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
chain = RetrievalQA.from_chain_type(
|
| 205 |
llm=getLLMModel(llmID),
|
| 206 |
chain_type='stuff',
|
|
@@ -210,10 +233,7 @@ def getRAGChain(customerName, customerDistrict, custDetailsPresent, vectordb,llm
|
|
| 210 |
chain_type_kwargs={
|
| 211 |
"verbose": False,
|
| 212 |
"prompt": createPrompt(customerName, customerDistrict, custDetailsPresent),
|
| 213 |
-
"memory":
|
| 214 |
-
k=3,
|
| 215 |
-
memory_key="history",
|
| 216 |
-
input_key="question"),
|
| 217 |
}
|
| 218 |
)
|
| 219 |
return chain
|
|
@@ -307,6 +327,10 @@ def aisearch():
|
|
| 307 |
def process_json():
|
| 308 |
print(f"\n{'*' * 100}\n")
|
| 309 |
print("Request Received >>>>>>>>>>>>>>>>>>", datetime.now().strftime("%H:%M:%S"))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
content_type = request.headers.get('Content-Type')
|
| 311 |
if content_type == 'application/json':
|
| 312 |
requestQuery = request.get_json()
|
|
@@ -322,6 +346,14 @@ def process_json():
|
|
| 322 |
selectedLLMID=defaultLLMID
|
| 323 |
if "llmID" in requestQuery:
|
| 324 |
selectedLLMID=(int) (requestQuery['llmID'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
print("chain initiation")
|
| 326 |
chainRAG = getRAGChain(customerName, customerDistrict, custDetailsPresent, vectordb,selectedLLMID)
|
| 327 |
print("chain created")
|
|
@@ -332,6 +364,7 @@ def process_json():
|
|
| 332 |
# message = answering(query)
|
| 333 |
|
| 334 |
relevantDoc = vectordb.similarity_search_with_score(query, distance_metric="cos", k=3)
|
|
|
|
| 335 |
print("Printing Retriever Docs")
|
| 336 |
for doc in getRetriever(vectordb).get_relevant_documents(query):
|
| 337 |
searchResult = {}
|
|
|
|
| 59 |
|
| 60 |
os.makedirs(uploads_dir, exist_ok=True)
|
| 61 |
|
| 62 |
+
# Initialize global variables for conversation history
|
| 63 |
+
conversation_history = []
|
| 64 |
+
|
| 65 |
defaultEmbeddingModelID = 3
|
| 66 |
defaultLLMID=0
|
| 67 |
|
|
|
|
| 204 |
|
| 205 |
|
| 206 |
def getRAGChain(customerName, customerDistrict, custDetailsPresent, vectordb,llmID):
|
| 207 |
+
|
| 208 |
+
# Retrieve conversation history if available
|
| 209 |
+
memory = ConversationBufferWindowMemory(k=3, memory_key="history", input_key="question")
|
| 210 |
+
memory.load_history(conversation_history)
|
| 211 |
+
|
| 212 |
+
# chain = RetrievalQA.from_chain_type(
|
| 213 |
+
# llm=getLLMModel(llmID),
|
| 214 |
+
# chain_type='stuff',
|
| 215 |
+
# retriever=getRetriever(vectordb),
|
| 216 |
+
# #retriever=vectordb.as_retriever(),
|
| 217 |
+
# verbose=False,
|
| 218 |
+
# chain_type_kwargs={
|
| 219 |
+
# "verbose": False,
|
| 220 |
+
# "prompt": createPrompt(customerName, customerDistrict, custDetailsPresent),
|
| 221 |
+
# "memory": ConversationBufferWindowMemory(
|
| 222 |
+
# k=3,
|
| 223 |
+
# memory_key="history",
|
| 224 |
+
# input_key="question"),
|
| 225 |
+
# }
|
| 226 |
+
# )
|
| 227 |
chain = RetrievalQA.from_chain_type(
|
| 228 |
llm=getLLMModel(llmID),
|
| 229 |
chain_type='stuff',
|
|
|
|
| 233 |
chain_type_kwargs={
|
| 234 |
"verbose": False,
|
| 235 |
"prompt": createPrompt(customerName, customerDistrict, custDetailsPresent),
|
| 236 |
+
"memory": memory,
|
|
|
|
|
|
|
|
|
|
| 237 |
}
|
| 238 |
)
|
| 239 |
return chain
|
|
|
|
| 327 |
def process_json():
|
| 328 |
print(f"\n{'*' * 100}\n")
|
| 329 |
print("Request Received >>>>>>>>>>>>>>>>>>", datetime.now().strftime("%H:%M:%S"))
|
| 330 |
+
|
| 331 |
+
# Retrieve conversation ID from the request (use any suitable ID)
|
| 332 |
+
conversation_id = request.json.get('conversation_id', None)
|
| 333 |
+
|
| 334 |
content_type = request.headers.get('Content-Type')
|
| 335 |
if content_type == 'application/json':
|
| 336 |
requestQuery = request.get_json()
|
|
|
|
| 346 |
selectedLLMID=defaultLLMID
|
| 347 |
if "llmID" in requestQuery:
|
| 348 |
selectedLLMID=(int) (requestQuery['llmID'])
|
| 349 |
+
|
| 350 |
+
# Create a conversation ID-specific history list if not exists
|
| 351 |
+
conversation_history_id = f"{conversation_id}_history"
|
| 352 |
+
if conversation_history_id not in globals():
|
| 353 |
+
globals()[conversation_history_id] = []
|
| 354 |
+
conversation_history = globals()[conversation_history_id]
|
| 355 |
+
|
| 356 |
+
|
| 357 |
print("chain initiation")
|
| 358 |
chainRAG = getRAGChain(customerName, customerDistrict, custDetailsPresent, vectordb,selectedLLMID)
|
| 359 |
print("chain created")
|
|
|
|
| 364 |
# message = answering(query)
|
| 365 |
|
| 366 |
relevantDoc = vectordb.similarity_search_with_score(query, distance_metric="cos", k=3)
|
| 367 |
+
conversation_history.append(query)
|
| 368 |
print("Printing Retriever Docs")
|
| 369 |
for doc in getRetriever(vectordb).get_relevant_documents(query):
|
| 370 |
searchResult = {}
|