gl-kp commited on
Commit
4d0454b
·
verified ·
1 Parent(s): f4f8a36

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +694 -0
app.py ADDED
@@ -0,0 +1,694 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import necessary libraries
2
+ import os # Interacting with the operating system (reading/writing files)
3
+ import chromadb # High-performance vector database for storing/querying dense vectors
4
+ from dotenv import load_dotenv # Loading environment variables from a .env file
5
+ import json # Parsing and handling JSON data
6
+
7
+ # LangChain imports
8
+ from langchain_core.documents import Document # Document data structures
9
+ from langchain_core.runnables import RunnablePassthrough # LangChain core library for running pipelines
10
+ from langchain_core.output_parsers import StrOutputParser # String output parser
11
+ from langchain.prompts import ChatPromptTemplate # Template for chat prompts
12
+ from langchain.chains.query_constructor.base import AttributeInfo # Base classes for query construction
13
+ from langchain.retrievers.self_query.base import SelfQueryRetriever # Base classes for self-querying retrievers
14
+ from langchain.retrievers.document_compressors import LLMChainExtractor, CrossEncoderReranker # Document compressors
15
+ from langchain.retrievers import ContextualCompressionRetriever # Contextual compression retrievers
16
+
17
+ # LangChain community & experimental imports
18
+ from langchain_community.vectorstores import Chroma # Implementations of vector stores like Chroma
19
+ from langchain_community.document_loaders import PyPDFDirectoryLoader, PyPDFLoader # Document loaders for PDFs
20
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder # Cross-encoders from HuggingFace
21
+ from langchain_experimental.text_splitter import SemanticChunker # Experimental text splitting methods
22
+ from langchain.text_splitter import (
23
+ CharacterTextSplitter, # Splitting text by characters
24
+ RecursiveCharacterTextSplitter # Recursive splitting of text by characters
25
+ )
26
+ from langchain_core.tools import tool
27
+ from langchain.agents import create_tool_calling_agent, AgentExecutor
28
+ from langchain_core.prompts import ChatPromptTemplate
29
+
30
+ # LangChain OpenAI imports
31
+ from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI # OpenAI embeddings and models
32
+ from langchain.embeddings.openai import OpenAIEmbeddings # OpenAI embeddings for text vectors
33
+
34
+ # LlamaParse & LlamaIndex imports
35
+ from llama_parse import LlamaParse # Document parsing library
36
+ from llama_index.core import Settings, SimpleDirectoryReader # Core functionalities of the LlamaIndex
37
+
38
+ # LangGraph import
39
+ from langgraph.graph import StateGraph, END, START # State graph for managing states in LangChain
40
+
41
+ # Pydantic import
42
+ from pydantic import BaseModel # Pydantic for data validation
43
+
44
+ # Typing imports
45
+ from typing import Dict, List, Tuple, Any, TypedDict # Python typing for function annotations
46
+
47
+ # Other utilities
48
+ import numpy as np # Numpy for numerical operations
49
+ from groq import Groq
50
+ from mem0 import MemoryClient
51
+ import streamlit as st
52
+ from datetime import datetime
53
+
54
+ #====================================SETUP=====================================#
55
+ # Fetch secrets from Hugging Face Spaces
56
+ api_key = os.environ("API_KEY")
57
+ endpoint = os.environ("OPENAI_API_BASE")
58
+ llama_api_key = os.environ['GROQ_API_KEY']
59
+ MEM0_api_key = os.environ['mem0']
60
+
61
+ # Initialize the OpenAI embedding function for Chroma
62
+ embedding_function = chromadb.utils.embedding_functions.OpenAIEmbeddingFunction(
63
+ api_base=endpoint, # Complete the code to define the API base endpoint
64
+ api_key=api_key, # Complete the code to define the API key
65
+ model_name='text-embedding-ada-002' # This is a fixed value and does not need modification
66
+ )
67
+
68
+ # This initializes the OpenAI embedding function for the Chroma vectorstore, using the provided endpoint and API key.
69
+
70
+ # Initialize the OpenAI Embeddings
71
+ embedding_model = OpenAIEmbeddings(
72
+ openai_api_base=endpoint,
73
+ openai_api_key=api_key,
74
+ model='text-embedding-ada-002'
75
+ )
76
+
77
+
78
+ # Initialize the Chat OpenAI model
79
+ llm = ChatOpenAI(
80
+ openai_api_base=endpoint,
81
+ openai_api_key=api_key,
82
+ model="gpt-4o-mini",
83
+ streaming=False
84
+ )
85
+ # This initializes the Chat OpenAI model with the provided endpoint, API key, deployment name, and a temperature setting of 0 (to control response variability).
86
+
87
+ # set the LLM and embedding model in the LlamaIndex settings.
88
+ Settings.llm = llm # Complete the code to define the LLM model
89
+ Settings.embedding = embedding_model # Complete the code to define the embedding model
90
+
91
+ #================================Creating Langgraph agent======================#
92
+
93
+ class AgentState(TypedDict):
94
+ query: str # The current user query
95
+ expanded_query: str # The expanded version of the user query
96
+ context: List[Dict[str, Any]] # Retrieved documents (content and metadata)
97
+ response: str # The generated response to the user query
98
+ precision_score: float # The precision score of the response
99
+ groundedness_score: float # The groundedness score of the response
100
+ groundedness_loop_count: int # Counter for groundedness refinement loops
101
+ precision_loop_count: int # Counter for precision refinement loops
102
+ feedback: str
103
+ query_feedback: str
104
+ groundedness_check: bool
105
+ loop_max_iter: int
106
+
107
+ def expand_query(state):
108
+ """
109
+ Expands the user query to improve retrieval of nutrition disorder-related information.
110
+ Args:
111
+ state (Dict): The current state of the workflow, containing the user query.
112
+ Returns:
113
+ Dict: The updated state with the expanded query.
114
+ """
115
+ print("---------Expanding Query---------")
116
+ system_message = '''You are a medical assistant. Use the provided context to generate a clear, accurate, and grounded response to the user's query.'''
117
+
118
+
119
+ expand_prompt = ChatPromptTemplate.from_messages([
120
+ ("system", system_message),
121
+ ("user", "Expand this query: {query} using the feedback: {query_feedback}")
122
+
123
+ ])
124
+
125
+ chain = expand_prompt | llm | StrOutputParser()
126
+ expanded_query = chain.invoke({"query": state['query'], "query_feedback":state["query_feedback"]})
127
+ print("expanded_query", expanded_query)
128
+ state["expanded_query"] = expanded_query
129
+ return state
130
+
131
+
132
+ # Initialize the Chroma vector store for retrieving documents
133
+ vector_store = Chroma(
134
+ collection_name="nutritional_hypotheticals",
135
+ persist_directory="./nutritional_db",
136
+ embedding_function=embedding_model
137
+
138
+ )
139
+
140
+ # Create a retriever from the vector store
141
+ retriever = vector_store.as_retriever(
142
+ search_type='similarity',
143
+ search_kwargs={'k': 3}
144
+ )
145
+
146
+ def retrieve_context(state):
147
+ """
148
+ Retrieves context from the vector store using the expanded or original query.
149
+ Args:
150
+ state (Dict): The current state of the workflow, containing the query and expanded query.
151
+ Returns:
152
+ Dict: The updated state with the retrieved context.
153
+ """
154
+ print("---------retrieve_context---------")
155
+ query = state['expanded_query'] # Complete the code to define the key for the expanded query
156
+ #print("Query used for retrieval:", query) # Debugging: Print the query
157
+
158
+ # Retrieve documents from the vector store
159
+ docs = retriever.invoke(query)
160
+ print("Retrieved documents:", docs) # Debugging: Print the raw docs object
161
+
162
+ # Extract both page_content and metadata from each document
163
+ context= [
164
+ {
165
+ "content": doc.page_content, # The actual content of the document
166
+ "metadata": doc.metadata # The metadata (e.g., source, page number, etc.)
167
+ }
168
+ for doc in docs
169
+ ]
170
+ state['context'] = context # Complete the code to define the key for storing the context
171
+ print("Extracted context with metadata:", context) # Debugging: Print the extracted context
172
+ #print(f"Groundedness loop count: {state['groundedness_loop_count']}")
173
+ return state
174
+
175
+
176
+
177
+ def craft_response(state: Dict) -> Dict:
178
+ """
179
+ Generates a response using the retrieved context, focusing on nutrition disorders.
180
+ Args:
181
+ state (Dict): The current state of the workflow, containing the query and retrieved context.
182
+ Returns:
183
+ Dict: The updated state with the generated response.
184
+ """
185
+ print("---------craft_response---------")
186
+ system_message = '''You are a medical assistant. Use the provided context to generate a clear, accurate, and grounded response to the user's query.'''
187
+
188
+ response_prompt = ChatPromptTemplate.from_messages([
189
+ ("system", system_message),
190
+ ("user", "Query: {query}\nContext: {context}\n\nfeedback: {feedback}")
191
+ ])
192
+
193
+ chain = response_prompt | llm
194
+ response = chain.invoke({
195
+ "query": state['query'],
196
+ "context": "\n".join([doc["content"] for doc in state['context']]),
197
+ "feedback": state['feedback'] # add feedback to the prompt
198
+ })
199
+ state['response'] = response
200
+ print("intermediate response: ", response)
201
+
202
+ return state
203
+
204
+
205
+
206
+ def score_groundedness(state: Dict) -> Dict:
207
+ """
208
+ Checks whether the response is grounded in the retrieved context.
209
+ Args:
210
+ state (Dict): The current state of the workflow, containing the response and context.
211
+ Returns:
212
+ Dict: The updated state with the groundedness score.
213
+ """
214
+ print("---------check_groundedness---------")
215
+ system_message = '''Evaluate how well the response is grounded in the provided context. Return a score between 0 and 1.'''
216
+
217
+ groundedness_prompt = ChatPromptTemplate.from_messages([
218
+ ("system", system_message),
219
+ ("user", "Context: {context}\nResponse: {response}\n\nGroundedness score:")
220
+ ])
221
+
222
+ chain = groundedness_prompt | llm | StrOutputParser()
223
+ groundedness_score = float(chain.invoke({
224
+ "context": "\n".join([doc["content"] for doc in state['context']]),
225
+ "response": state['response'] # Complete the code to define the response
226
+ }))
227
+ print("groundedness_score: ", groundedness_score)
228
+ state['groundedness_loop_count'] += 1
229
+ print("#########Groundedness Incremented###########")
230
+ state['groundedness_score'] = groundedness_score
231
+
232
+ return state
233
+
234
+
235
+
236
+ def check_precision(state: Dict) -> Dict:
237
+ """
238
+ Checks whether the response precisely addresses the user’s query.
239
+ Args:
240
+ state (Dict): The current state of the workflow, containing the query and response.
241
+ Returns:
242
+ Dict: The updated state with the precision score.
243
+ """
244
+ print("---------check_precision---------")
245
+ system_message = '''Evaluate how precisely the response addresses the user's query. Return a score between 0 and 1.'''
246
+
247
+ precision_prompt = ChatPromptTemplate.from_messages([
248
+ ("system", system_message),
249
+ ("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
250
+ ])
251
+
252
+ chain = precision_prompt| llm | StrOutputParser() # Complete the code to define the chain of processing
253
+ precision_score = float(chain.invoke({
254
+ "query": state['query'],
255
+ "response":state['response'] # Complete the code to access the response from the state
256
+ }))
257
+ state['precision_score'] = precision_score
258
+ print("precision_score:", precision_score)
259
+ state['precision_loop_count'] +=1
260
+ print("#########Precision Incremented###########")
261
+ return state
262
+
263
+
264
+
265
+ def refine_response(state: Dict) -> Dict:
266
+ """
267
+ Suggests improvements for the generated response.
268
+ Args:
269
+ state (Dict): The current state of the workflow, containing the query and response.
270
+ Returns:
271
+ Dict: The updated state with response refinement suggestions.
272
+ """
273
+ print("---------refine_response---------")
274
+
275
+ system_message = '''You are a helpful assistant. Suggest improvements to the response to enhance accuracy and completeness.'''
276
+
277
+ refine_response_prompt = ChatPromptTemplate.from_messages([
278
+ ("system", system_message),
279
+ ("user", "Query: {query}\nResponse: {response}\n\n"
280
+ "What improvements can be made to enhance accuracy and completeness?")
281
+ ])
282
+
283
+ chain = refine_response_prompt | llm| StrOutputParser()
284
+
285
+ # Store response suggestions in a structured format
286
+ feedback = f"Previous Response: {state['response']}\nSuggestions: {chain.invoke({'query': state['query'], 'response': state['response']})}"
287
+ print("feedback: ", feedback)
288
+ print(f"State: {state}")
289
+ state['feedback'] = feedback
290
+ return state
291
+
292
+
293
+
294
+ def refine_query(state: Dict) -> Dict:
295
+ """
296
+ Suggests improvements for the expanded query.
297
+ Args:
298
+ state (Dict): The current state of the workflow, containing the query and expanded query.
299
+ Returns:
300
+ Dict: The updated state with query refinement suggestions.
301
+ """
302
+ print("---------refine_query---------")
303
+ system_message = '''You are a helpful assistant. Suggest improvements to the expanded query to improve search relevance.'''
304
+
305
+ refine_query_prompt = ChatPromptTemplate.from_messages([
306
+ ("system", system_message),
307
+ ("user", "Original Query: {query}\nExpanded Query: {expanded_query}\n\n"
308
+ "What improvements can be made for a better search?")
309
+ ])
310
+
311
+ chain = refine_query_prompt | llm | StrOutputParser()
312
+
313
+ # Store refinement suggestions without modifying the original expanded query
314
+ query_feedback = f"Previous Expanded Query: {state['expanded_query']}\nSuggestions: {chain.invoke({'query': state['query'], 'expanded_query': state['expanded_query']})}"
315
+ print("query_feedback: ", query_feedback)
316
+ print(f"Groundedness loop count: {state['groundedness_loop_count']}")
317
+ state['query_feedback'] = query_feedback
318
+ return state
319
+
320
+
321
+
322
+ def should_continue_groundedness(state):
323
+ """Decides if groundedness is sufficient or needs improvement."""
324
+ print("---------should_continue_groundedness---------")
325
+ print("groundedness loop count: ", state['groundedness_loop_count'])
326
+ if state['groundedness_score'] >= 0.8: # Complete the code to define the threshold for groundedness
327
+ print("Moving to precision")
328
+ return "check_precision"
329
+ else:
330
+ if state["groundedness_loop_count"] > state['loop_max_iter']:
331
+ return "max_iterations_reached"
332
+ else:
333
+ print(f"---------Groundedness Score Threshold Not met. Refining Response-----------")
334
+ return "refine_response"
335
+
336
+
337
+ def should_continue_precision(state: Dict) -> str:
338
+ """Decides if precision is sufficient or needs improvement."""
339
+ print("---------should_continue_precision---------")
340
+ print("precision loop count: ",state['precision_loop_count'])
341
+ if state['precision_score'] >= 0.8: # Threshold for precision
342
+ return "pass" # Complete the workflow
343
+ else:
344
+ if state['precision_loop_count'] > state['loop_max_iter']: # Maximum allowed loops
345
+ return "max_iterations_reached"
346
+ else:
347
+ print(f"---------Precision Score Threshold Not met. Refining Query-----------") # Debugging
348
+ return "refine_query" # Refine the query
349
+
350
+
351
+
352
+
353
+ def max_iterations_reached(state: Dict) -> Dict:
354
+ """Handles the case when the maximum number of iterations is reached."""
355
+ print("---------max_iterations_reached---------")
356
+ """Handles the case when the maximum number of iterations is reached."""
357
+ response = "I'm unable to refine the response further. Please provide more context or clarify your question."
358
+ state['response'] = response
359
+ return state
360
+
361
+
362
+
363
+ from langgraph.graph import END, StateGraph, START
364
+
365
+ def create_workflow() -> StateGraph:
366
+ """Creates the updated workflow for the AI nutrition agent."""
367
+ workflow = StateGraph(AgentState) # Complete the code to define the initial state of the agent
368
+
369
+ # Add processing nodes
370
+ workflow.add_node("expand_query", expand_query) # Step 1: Expand user query. Complete with the function to expand the query
371
+ workflow.add_node("retrieve_context", retrieve_context) # Step 2: Retrieve relevant documents. Complete with the function to retrieve context
372
+ workflow.add_node("craft_response", craft_response) # Step 3: Generate a response based on retrieved data. Complete with the function to craft a response
373
+ workflow.add_node("score_groundedness", score_groundedness ) # Step 4: Evaluate response grounding. Complete with the function to score groundedness
374
+ workflow.add_node("refine_response", refine_response) # Step 5: Improve response if it's weakly grounded. Complete with the function to refine the response
375
+ workflow.add_node("check_precision", check_precision) # Step 6: Evaluate response precision. Complete with the function to check precision
376
+ workflow.add_node("refine_query", refine_query) # Step 7: Improve query if response lacks precision. Complete with the function to refine the query
377
+ workflow.add_node("max_iterations_reached",max_iterations_reached) # Step 8: Handle max iterations. Complete with the function to handle max iterations
378
+
379
+ # Main flow edges
380
+ workflow.add_edge(START, "expand_query")
381
+ workflow.add_edge("expand_query", "retrieve_context")
382
+ workflow.add_edge("retrieve_context", "craft_response")
383
+ workflow.add_edge("craft_response", "score_groundedness")
384
+
385
+ # Conditional edges based on groundedness check
386
+ workflow.add_conditional_edges(
387
+ "score_groundedness",
388
+ should_continue_groundedness, # Use the conditional function
389
+ {
390
+ "check_precision": "check_precision", # If well-grounded, proceed to precision check.
391
+ "refine_response": "refine_response", # If not, refine the response.
392
+ "max_iterations_reached": "max_iterations_reached" # If max loops reached, exit.
393
+ }
394
+ )
395
+
396
+ workflow.add_edge("refine_response","craft_response") # Refined responses are reprocessed.
397
+
398
+ # Conditional edges based on precision check
399
+ workflow.add_conditional_edges(
400
+ "check_precision",
401
+ should_continue_precision, # Use the conditional function
402
+ {
403
+ "pass": END, # If precise, complete the workflow.
404
+ "refine_query":"refine_query", # If imprecise, refine the query.
405
+ "max_iterations_reached": "max_iterations_reached" # If max loops reached, exit.
406
+ }
407
+ )
408
+
409
+ workflow.add_edge("refine_query", "expand_query") # Refined queries go through expansion again.
410
+
411
+ workflow.add_edge("max_iterations_reached", END)
412
+
413
+ return workflow
414
+
415
+
416
+
417
+
418
+ #=========================== Defining the agentic rag tool ====================#
419
+ WORKFLOW_APP = create_workflow().compile()
420
+ @tool
421
+ def agentic_rag(query: str):
422
+ """
423
+ Runs the RAG-based agent with conversation history for context-aware responses.
424
+ Args:
425
+ query (str): The current user query.
426
+ Returns:
427
+ Dict[str, Any]: The updated state with the generated response and conversation history.
428
+ """
429
+ # Initialize state with necessary parameters
430
+ inputs = {
431
+ "query": query, # Current user query
432
+ "expanded_query": "", # Complete the code to define the expanded version of the query
433
+ "context": [], # Retrieved documents (initially empty)
434
+ "response": "", # Complete the code to define the AI-generated response
435
+ "precision_score": 0.0, # Complete the code to define the precision score of the response
436
+ "groundedness_score": 0.0, # Complete the code to define the groundedness score of the response
437
+ "groundedness_loop_count": 0, # Complete the code to define the counter for groundedness loops
438
+ "precision_loop_count": 0, # Complete the code to define the counter for precision loops
439
+ "feedback": "", # Complete the code to define the feedback
440
+ "query_feedback": "", # Complete the code to define the query feedback
441
+ "loop_max_iter": 3 # Complete the code to define the maximum number of iterations for loops
442
+ }
443
+
444
+ output = WORKFLOW_APP.invoke(inputs)
445
+
446
+ return output
447
+
448
+
449
+ #================================ Guardrails ===========================#
450
+ llama_guard_client = Groq(api_key=llama_api_key)
451
+ # Function to filter user input with Llama Guard
452
+ def filter_input_with_llama_guard(user_input, model="llama-guard-3-8b"):
453
+ """
454
+ Filters user input using Llama Guard to ensure it is safe.
455
+ Parameters:
456
+ - user_input: The input provided by the user.
457
+ - model: The Llama Guard model to be used for filtering (default is "llama-guard-3-8b").
458
+ Returns:
459
+ - The filtered and safe input.
460
+ """
461
+ try:
462
+ # Create a request to Llama Guard to filter the user input
463
+ response = llama_guard_client.chat.completions.create(
464
+ messages=[{"role": "user", "content": user_input}],
465
+ model=model,
466
+ )
467
+ # Return the filtered input
468
+ return response.choices[0].message.content.strip()
469
+ except Exception as e:
470
+ print(f"Error with Llama Guard: {e}")
471
+ return None
472
+
473
+
474
+ #============================= Adding Memory to the agent using mem0 ===============================#
475
+
476
+ class NutritionBot:
477
+ def __init__(self):
478
+ """
479
+ Initialize the NutritionBot class, setting up memory, the LLM client, tools, and the agent executor.
480
+ """
481
+
482
+ # Initialize a memory client to store and retrieve customer interactions
483
+ self.memory = MemoryClient(api_key=userdata.get("mem0")) # Complete the code to define the memory client API key
484
+
485
+ # Initialize the OpenAI client using the provided credentials
486
+ self.client = ChatOpenAI(
487
+ model_name="gpt-4o-mini", # Specify the model to use (e.g., GPT-4 optimized version)
488
+ api_key=config.get("API_KEY"), # API key for authentication
489
+ endpoint = config.get("OPENAI_API_BASE"),
490
+ temperature=0 # Controls randomness in responses; 0 ensures deterministic results
491
+ )
492
+
493
+ # Define tools available to the chatbot, such as web search
494
+ tools = [agentic_rag]
495
+
496
+ # Define the system prompt to set the behavior of the chatbot
497
+ system_prompt = """You are a caring and knowledgeable Medical Support Agent, specializing in nutrition disorder-related guidance. Your goal is to provide accurate, empathetic, and tailored nutritional recommendations while ensuring a seamless customer experience.
498
+ Guidelines for Interaction:
499
+ Maintain a polite, professional, and reassuring tone.
500
+ Show genuine empathy for customer concerns and health challenges.
501
+ Reference past interactions to provide personalized and consistent advice.
502
+ Engage with the customer by asking about their food preferences, dietary restrictions, and lifestyle before offering recommendations.
503
+ Ensure consistent and accurate information across conversations.
504
+ If any detail is unclear or missing, proactively ask for clarification.
505
+ Always use the agentic_rag tool to retrieve up-to-date and evidence-based nutrition insights.
506
+ Keep track of ongoing issues and follow-ups to ensure continuity in support.
507
+ Your primary goal is to help customers make informed nutrition decisions that align with their health conditions and personal preferences.
508
+ """
509
+
510
+ # Build the prompt template for the agent
511
+ prompt = ChatPromptTemplate.from_messages([
512
+ ("system", system_prompt), # System instructions
513
+ ("human", "{input}"), # Placeholder for human input
514
+ ("placeholder", "{agent_scratchpad}") # Placeholder for intermediate reasoning steps
515
+ ])
516
+
517
+ # Create an agent capable of interacting with tools and executing tasks
518
+ agent = create_tool_calling_agent(self.client, tools, prompt)
519
+
520
+ # Wrap the agent in an executor to manage tool interactions and execution flow
521
+ self.agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
522
+
523
+
524
+ def store_customer_interaction(self, user_id: str, message: str, response: str, metadata: Dict = None):
525
+ """
526
+ Store customer interaction in memory for future reference.
527
+ Args:
528
+ user_id (str): Unique identifier for the customer.
529
+ message (str): Customer's query or message.
530
+ response (str): Chatbot's response.
531
+ metadata (Dict, optional): Additional metadata for the interaction.
532
+ """
533
+ if metadata is None:
534
+ metadata = {}
535
+
536
+ # Add a timestamp to the metadata for tracking purposes
537
+ metadata["timestamp"] = datetime.now().isoformat()
538
+
539
+ # Format the conversation for storage
540
+ conversation = [
541
+ {"role": "user", "content": message},
542
+ {"role": "assistant", "content": response}
543
+ ]
544
+
545
+ # Store the interaction in the memory client
546
+ self.memory.add(
547
+ conversation,
548
+ user_id=user_id,
549
+ output_format="v1.1",
550
+ metadata=metadata
551
+ )
552
+
553
+
554
+ def get_relevant_history(self, user_id: str, query: str) -> List[Dict]:
555
+ """
556
+ Retrieve past interactions relevant to the current query.
557
+ Args:
558
+ user_id (str): Unique identifier for the customer.
559
+ query (str): The customer's current query.
560
+ Returns:
561
+ List[Dict]: A list of relevant past interactions.
562
+ """
563
+ return self.memory.search(
564
+ query=query, # Search for interactions related to the query
565
+ user_id=user_id, # Restrict search to the specific user
566
+ limit=5 # Complete the code to define the limit for retrieved interactions
567
+ )
568
+
569
+
570
+ def handle_customer_query(self, user_id: str, query: str) -> str:
571
+ """
572
+ Process a customer's query and provide a response, taking into account past interactions.
573
+ Args:
574
+ user_id (str): Unique identifier for the customer.
575
+ query (str): Customer's query.
576
+ Returns:
577
+ str: Chatbot's response.
578
+ """
579
+
580
+ # Retrieve relevant past interactions for context
581
+ relevant_history = self.get_relevant_history(user_id, query)
582
+
583
+ # Build a context string from the relevant history
584
+ context = "Previous relevant interactions:\n"
585
+ for memory in relevant_history:
586
+ context += f"Customer: {memory['memory']}\n" # Customer's past messages
587
+ context += f"Support: {memory['memory']}\n" # Chatbot's past responses
588
+ context += "---\n"
589
+
590
+ # Print context for debugging purposes
591
+ print("Context: ", context)
592
+
593
+ # Prepare a prompt combining past context and the current query
594
+ prompt = f"""
595
+ Context:
596
+ {context}
597
+ Current customer query: {query}
598
+ Provide a helpful response that takes into account any relevant past interactions.
599
+ """
600
+
601
+ # Generate a response using the agent
602
+ response = self.agent_executor.invoke({"input": prompt})
603
+
604
+ # Store the current interaction for future reference
605
+ self.store_customer_interaction(
606
+ user_id=user_id,
607
+ message=query,
608
+ response=response["output"],
609
+ metadata={"type": "support_query"}
610
+ )
611
+
612
+ # Return the chatbot's response
613
+ return response['output']
614
+
615
+
616
+ #=====================User Interface using streamlit ===========================#
617
+ def nutrition_disorder_streamlit():
618
+ """
619
+ A Streamlit-based UI for the Nutrition Disorder Specialist Agent.
620
+ """
621
+ st.title("Nutrition Disorder Specialist")
622
+ st.write("Ask me anything about nutrition disorders, symptoms, causes, treatments, and more.")
623
+ st.write("Type 'exit' to end the conversation.")
624
+
625
+ # Initialize session state for chat history and user_id if they don't exist
626
+ if 'chat_history' not in st.session_state:
627
+ st.session_state.chat_history = []
628
+ if 'user_id' not in st.session_state:
629
+ st.session_state.user_id = None
630
+
631
+ # Login form: Only if user is not logged in
632
+ if st.session_state.user_id is None:
633
+ with st.form("login_form", clear_on_submit=True):
634
+ user_id = st.text_input("Please enter your name to begin:")
635
+ submit_button = st.form_submit_button("Login")
636
+ if submit_button and user_id:
637
+ st.session_state.user_id = user_id
638
+ st.session_state.chat_history.append({
639
+ "role": "assistant",
640
+ "content": f"Welcome, {user_id}! How can I help you with nutrition disorders today?"
641
+ })
642
+ st.session_state.login_submitted = True # Set flag to trigger rerun
643
+ if st.session_state.get("login_submitted", False):
644
+ st.session_state.pop("login_submitted")
645
+ st.rerun()
646
+ else:
647
+ # Display chat history
648
+ for message in st.session_state.chat_history:
649
+ with st.chat_message(message["role"]):
650
+ st.write(message["content"])
651
+
652
+ # Chat input with custom placeholder text
653
+ user_query = st.chat_input('Type your question here or exit to end the conversation') # Blank #1: Fill in the chat input prompt (e.g., "Type your question here (or 'exit' to end)...")
654
+ if user_query:
655
+ if user_query.lower() == "exit":
656
+ st.session_state.chat_history.append({"role": "user", "content": "exit"})
657
+ with st.chat_message("user"):
658
+ st.write("exit")
659
+ goodbye_msg = "Goodbye! Feel free to return if you have more questions about nutrition disorders."
660
+ st.session_state.chat_history.append({"role": "assistant", "content": goodbye_msg})
661
+ with st.chat_message("assistant"):
662
+ st.write(goodbye_msg)
663
+ st.session_state.user_id = None
664
+ st.rerun()
665
+ return
666
+
667
+ st.session_state.chat_history.append({"role": "user", "content": user_query})
668
+ with st.chat_message("user"):
669
+ st.write(user_query)
670
+
671
+ # Filter input using Llama Guard
672
+ filtered_result = filter_input_with_llama_guard(user_query) # Blank #2: Fill in with the function name for filtering input (e.g., filter_input_with_llama_guard)
673
+ filtered_result = filtered_result.replace("\n", " ") # Normalize the result
674
+
675
+ # Check if input is safe based on allowed statuses
676
+ if filtered_result in ["safe", "unsafe S6", "unsafe S7"]: # Blanks #3, #4, #5: Fill in with allowed safe statuses (e.g., "safe", "unsafe S7", "unsafe S6")
677
+ try:
678
+ if 'chatbot' not in st.session_state:
679
+ st.session_state.chatbot = NutritionBot() # Blank #6: Fill in with the chatbot class initialization (e.g., NutritionBot)
680
+ response = st.session_state.chatbot.handle_customer_query(st.session_state.user_id, user_query)
681
+ # Blank #7: Fill in with the method to handle queries (e.g., handle_customer_query)
682
+ st.write(response)
683
+ st.session_state.chat_history.append({"role": "assistant", "content": response})
684
+ except Exception as e:
685
+ error_msg = f"Sorry, I encountered an error while processing your query. Please try again. Error: {str(e)}"
686
+ st.write(error_msg)
687
+ st.session_state.chat_history.append({"role": "assistant", "content": error_msg})
688
+ else:
689
+ inappropriate_msg = "I apologize, but I cannot process that input as it may be inappropriate. Please try again."
690
+ st.write(inappropriate_msg)
691
+ st.session_state.chat_history.append({"role": "assistant", "content": inappropriate_msg})
692
+
693
+ if __name__ == "__main__":
694
+ nutrition_disorder_streamlit()