Akashpb13 commited on
Commit
7d7e8f7
·
verified ·
1 Parent(s): 2db5655

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +737 -0
app.py ADDED
@@ -0,0 +1,737 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Import necessary libraries
3
+ import os # Interacting with the operating system (reading/writing files)
4
+ import chromadb # High-performance vector database for storing/querying dense vectors
5
+ from dotenv import load_dotenv # Loading environment variables from a .env file
6
+ import json # Parsing and handling JSON data
7
+
8
+ # LangChain imports
9
+ from langchain_core.documents import Document # Document data structures
10
+ from langchain_core.runnables import RunnablePassthrough # LangChain core library for running pipelines
11
+ from langchain_core.output_parsers import StrOutputParser # String output parser
12
+ from langchain.prompts import ChatPromptTemplate # Template for chat prompts
13
+ from langchain.chains.query_constructor.base import AttributeInfo # Base classes for query construction
14
+ from langchain.retrievers.self_query.base import SelfQueryRetriever # Base classes for self-querying retrievers
15
+ from langchain.retrievers.document_compressors import LLMChainExtractor, CrossEncoderReranker # Document compressors
16
+ from langchain.retrievers import ContextualCompressionRetriever # Contextual compression retrievers
17
+
18
+ # LangChain community & experimental imports
19
+ from langchain_community.vectorstores import Chroma # Implementations of vector stores like Chroma
20
+ from langchain_community.document_loaders import PyPDFDirectoryLoader, PyPDFLoader # Document loaders for PDFs
21
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder # Cross-encoders from HuggingFace
22
+ from langchain_experimental.text_splitter import SemanticChunker # Experimental text splitting methods
23
+ from langchain.text_splitter import (
24
+ CharacterTextSplitter, # Splitting text by characters
25
+ RecursiveCharacterTextSplitter # Recursive splitting of text by characters
26
+ )
27
+ from langchain_core.tools import tool
28
+ from langchain.agents import create_tool_calling_agent, AgentExecutor
29
+ from langchain_core.prompts import ChatPromptTemplate
30
+ from langchain_openai import ChatOpenAI
31
+
32
+ # LangChain OpenAI imports
33
+ from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI # OpenAI embeddings and models
34
+ from langchain.embeddings.openai import OpenAIEmbeddings # OpenAI embeddings for text vectors
35
+
36
+ # LlamaParse & LlamaIndex imports
37
+ from llama_parse import LlamaParse # Document parsing library
38
+ from llama_index.core import Settings, SimpleDirectoryReader # Core functionalities of the LlamaIndex
39
+
40
+ # LangGraph import
41
+ from langgraph.graph import StateGraph, END, START # State graph for managing states in LangChain
42
+
43
+ # Pydantic import
44
+ from pydantic import BaseModel # Pydantic for data validation
45
+
46
+ # Typing imports
47
+ from typing import Dict, List, Tuple, Any, TypedDict # Python typing for function annotations
48
+
49
+ # Other utilities
50
+ import numpy as np # Numpy for numerical operations
51
+ from groq import Groq
52
+ from mem0 import MemoryClient
53
+ import streamlit as st
54
+ from datetime import datetime
55
+
56
+ #====================================SETUP=====================================#
57
+ # Fetch secrets from Hugging Face Spaces
58
+ api_key = config.get("API_KEY")
59
+ endpoint = config.get("OPENAI_API_BASE")
60
+ llama_api_key = os.environ['GROQ_API_KEY']
61
+ MEM0_api_key = os.environ['mem0']
62
+
63
+ # Initialize the OpenAI embedding function for Chroma
64
+ embedding_function = chromadb.utils.embedding_functions.OpenAIEmbeddingFunction(
65
+ api_base=endpoint, # Complete the code to define the API base endpoint
66
+ api_key=api_key, # Complete the code to define the API key
67
+ model_name='text-embedding-ada-002' # This is a fixed value and does not need modification
68
+ )
69
+
70
+ # This initializes the OpenAI embedding function for the Chroma vectorstore, using the provided endpoint and API key.
71
+
72
+ # Initialize the OpenAI Embeddings
73
+ embedding_model = OpenAIEmbeddings(
74
+ openai_api_base=endpoint,
75
+ openai_api_key=api_key,
76
+ model='text-embedding-ada-002'
77
+ )
78
+
79
+
80
+ # Initialize the Chat OpenAI model
81
+ llm = ChatOpenAI(
82
+ base_url=endpoint,
83
+ openai_api_key=api_key,
84
+ model="gpt-4o-mini",
85
+ streaming=False
86
+ )
87
+ # This initializes the Chat OpenAI model with the provided endpoint, API key, deployment name, and a temperature setting of 0 (to control response variability).
88
+
89
+ # set the LLM and embedding model in the LlamaIndex settings.
90
+ Settings.llm = llm # Complete the code to define the LLM model
91
+ Settings.embedding = embedding_model # Complete the code to define the embedding model
92
+
93
+ #================================Creating Langgraph agent======================#
94
+
95
+ class AgentState(TypedDict):
96
+ query: str # The current user query
97
+ expanded_query: str # The expanded version of the user query
98
+ context: List[Dict[str, Any]] # Retrieved documents (content and metadata)
99
+ response: str # The generated response to the user query
100
+ precision_score: float # The precision score of the response
101
+ groundedness_score: float # The groundedness score of the response
102
+ groundedness_loop_count: int # Counter for groundedness refinement loops
103
+ precision_loop_count: int # Counter for precision refinement loops
104
+ feedback: str
105
+ query_feedback: str
106
+ groundedness_check: bool
107
+ loop_max_iter: int
108
+
109
+ def expand_query(state):
110
+ """
111
+ Expands the user query to improve retrieval of nutrition disorder-related information.
112
+
113
+ Args:
114
+ state (Dict): The current state of the workflow, containing the user query.
115
+
116
+ Returns:
117
+ Dict: The updated state with the expanded query.
118
+ """
119
+ print("---------Expanding Query---------")
120
+ system_message = '''You are an expert in medical and nutritional information retrieval.
121
+ Your task is to expand the user query using the provided feedback so that it retrieves
122
+ the most relevant documents about nutrition disorders, vitamin deficiencies, metabolic disorders, and related health conditions.
123
+ Generate a clear, precise, and comprehensive expanded query.'''
124
+
125
+
126
+ expand_prompt = ChatPromptTemplate.from_messages([
127
+ ("system", system_message),
128
+ ("user", "Expand this query: {query} using the feedback: {query_feedback}")
129
+
130
+ ])
131
+
132
+ chain = expand_prompt | llm | StrOutputParser()
133
+ expanded_query = chain.invoke({"query": state['query'], "query_feedback":state["query_feedback"]})
134
+ print("expanded_query", expanded_query)
135
+ state["expanded_query"] = expanded_query
136
+ return state
137
+
138
+
139
+ # Initialize the Chroma vector store for retrieving documents
140
+ vector_store = Chroma(
141
+ collection_name="nutritional_hypotheticals",
142
+ persist_directory="./nutritional_db",
143
+ embedding_function=embedding_model
144
+
145
+ )
146
+
147
+ # Create a retriever from the vector store
148
+ retriever = vector_store.as_retriever(
149
+ search_type='similarity',
150
+ search_kwargs={'k': 3}
151
+ )
152
+
153
+ def retrieve_context(state):
154
+ """
155
+ Retrieves context from the vector store using the expanded or original query.
156
+
157
+ Args:
158
+ state (Dict): The current state of the workflow, containing the query and expanded query.
159
+
160
+ Returns:
161
+ Dict: The updated state with the retrieved context.
162
+ """
163
+ print("---------retrieve_context---------")
164
+ query = state['expanded_query'] # Complete the code to define the key for the expanded query
165
+ #print("Query used for retrieval:", query) # Debugging: Print the query
166
+
167
+ # Retrieve documents from the vector store
168
+ docs = retriever.invoke(query)
169
+ print("Retrieved documents:", docs) # Debugging: Print the raw docs object
170
+
171
+ # Extract both page_content and metadata from each document
172
+ context= [
173
+ {
174
+ "content": doc.page_content, # The actual content of the document
175
+ "metadata": doc.metadata # The metadata (e.g., source, page number, etc.)
176
+ }
177
+ for doc in docs
178
+ ]
179
+ state['retrieved_context'] = context # Complete the code to define the key for storing the context
180
+ print("Extracted context with metadata:", context) # Debugging: Print the extracted context
181
+ #print(f"Groundedness loop count: {state['groundedness_loop_count']}")
182
+ return state
183
+
184
+
185
+
186
+ def craft_response(state: Dict) -> Dict:
187
+ """
188
+ Generates a response using the retrieved context, focusing on nutrition disorders.
189
+
190
+ Args:
191
+ state (Dict): The current state of the workflow, containing the query and retrieved context.
192
+
193
+ Returns:
194
+ Dict: The updated state with the generated response.
195
+ """
196
+ print("---------craft_response---------")
197
+ system_message = '''You are a medical assistant specialized in nutritional disorders.
198
+ Provide clear, accurate, and evidence-based answers based on the context provided.
199
+ Focus on causes, symptoms, treatments, and prevention of nutritional disorders.'''
200
+
201
+ response_prompt = ChatPromptTemplate.from_messages([
202
+ ("system", system_message),
203
+ ("user", "Query: {query}\nContext: {context}\n\nfeedback: {feedback}")
204
+ ])
205
+
206
+ chain = response_prompt | llm
207
+ response = chain.invoke({
208
+ "query": state['query'],
209
+ "context": "\n".join([doc["content"] for doc in state['context']]),
210
+ "feedback": state.get("feedback", "None") # add feedback to the prompt
211
+ })
212
+ state['response'] = response
213
+ print("intermediate response: ", response)
214
+
215
+ return state
216
+
217
+
218
+
219
+ def score_groundedness(state: Dict) -> Dict:
220
+ """
221
+ Checks whether the response is grounded in the retrieved context.
222
+
223
+ Args:
224
+ state (Dict): The current state of the workflow, containing the response and context.
225
+
226
+ Returns:
227
+ Dict: The updated state with the groundedness score.
228
+ """
229
+ print("---------check_groundedness---------")
230
+ system_message = '''You are a helpful assistant that evaluates whether a response is grounded in the provided context.
231
+ Score the response from 0.0 to 1.0, where 1.0 means fully grounded in the context,
232
+ and 0.0 means not grounded at all. Respond with only the numeric score.'''
233
+
234
+ groundedness_prompt = ChatPromptTemplate.from_messages([
235
+ ("system", system_message),
236
+ ("user", "Context: {context}\nResponse: {response}\n\nGroundedness score:")
237
+ ])
238
+
239
+ chain = groundedness_prompt | llm | StrOutputParser()
240
+ groundedness_score = float(chain.invoke({
241
+ "context": "\n".join([doc["content"] for doc in state['context']]),
242
+ "response": state['response'] # Complete the code to define the response
243
+ }))
244
+ print("groundedness_score: ", groundedness_score)
245
+ state['groundedness_loop_count'] += 1
246
+ print("#########Groundedness Incremented###########")
247
+ state['groundedness_score'] = groundedness_score
248
+
249
+ return state
250
+
251
+
252
+
253
+ def check_precision(state: Dict) -> Dict:
254
+ """
255
+ Checks whether the response precisely addresses the user’s query.
256
+
257
+ Args:
258
+ state (Dict): The current state of the workflow, containing the query and response.
259
+
260
+ Returns:
261
+ Dict: The updated state with the precision score.
262
+ """
263
+ print("---------check_precision---------")
264
+ system_message = '''You are an assistant that evaluates whether the response precisely answers the user's query.
265
+ Score the response from 0.0 to 1.0, where 1.0 means the response fully and directly addresses the query,
266
+ and 0.0 means it does not address the query at all. Respond with only the numeric score.'''
267
+
268
+ precision_prompt = ChatPromptTemplate.from_messages([
269
+ ("system", system_message),
270
+ ("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
271
+ ])
272
+
273
+ chain = precision_prompt | llm | StrOutputParser() # Complete the code to define the chain of processing
274
+ precision_score = float(chain.invoke({
275
+ "query": state['query'],
276
+ "response":state['response'] # Complete the code to access the response from the state
277
+ }))
278
+ state['precision_score'] = precision_score
279
+ print("precision_score:", precision_score)
280
+ state['precision_loop_count'] +=1
281
+ print("#########Precision Incremented###########")
282
+ return state
283
+
284
+
285
+
286
+ def refine_response(state: Dict) -> Dict:
287
+ """
288
+ Suggests improvements for the generated response.
289
+
290
+ Args:
291
+ state (Dict): The current state of the workflow, containing the query and response.
292
+
293
+ Returns:
294
+ Dict: The updated state with response refinement suggestions.
295
+ """
296
+ print("---------refine_response---------")
297
+
298
+ system_message = '''You are an expert medical assistant specializing in nutritional disorders.
299
+ Your task is to review a generated response and suggest concrete improvements
300
+ to make it more accurate, precise, and complete based on the query.
301
+ Provide actionable suggestions only.'''
302
+
303
+ refine_response_prompt = ChatPromptTemplate.from_messages([
304
+ ("system", system_message),
305
+ ("user", "Query: {query}\nResponse: {response}\n\n"
306
+ "What improvements can be made to enhance accuracy and completeness?")
307
+ ])
308
+
309
+ chain = refine_response_prompt | llm| StrOutputParser()
310
+
311
+ # Store response suggestions in a structured format
312
+ feedback = f"Previous Response: {state['response']}\nSuggestions: {chain.invoke({'query': state['query'], 'response': state['response']})}"
313
+ print("feedback: ", feedback)
314
+ print(f"State: {state}")
315
+ state['feedback'] = feedback
316
+ return state
317
+
318
+
319
+
320
+ def refine_query(state: Dict) -> Dict:
321
+ """
322
+ Suggests improvements for the expanded query.
323
+
324
+ Args:
325
+ state (Dict): The current state of the workflow, containing the query and expanded query.
326
+
327
+ Returns:
328
+ Dict: The updated state with query refinement suggestions.
329
+ """
330
+ print("---------refine_query---------")
331
+ system_message = '''You are an expert in optimizing search queries for retrieving medical and nutritional information.
332
+ Your task is to suggest improvements to an expanded query to make it more precise, comprehensive,
333
+ and likely to retrieve relevant documents. Provide actionable suggestions only.'''
334
+
335
+ refine_query_prompt = ChatPromptTemplate.from_messages([
336
+ ("system", system_message),
337
+ ("user", "Original Query: {query}\nExpanded Query: {expanded_query}\n\n"
338
+ "What improvements can be made for a better search?")
339
+ ])
340
+
341
+ chain = refine_query_prompt | llm | StrOutputParser()
342
+
343
+ # Store refinement suggestions without modifying the original expanded query
344
+ query_feedback = f"Previous Expanded Query: {state['expanded_query']}\nSuggestions: {chain.invoke({'query': state['query'], 'expanded_query': state['expanded_query']})}"
345
+ print("query_feedback: ", query_feedback)
346
+ print(f"Groundedness loop count: {state['groundedness_loop_count']}")
347
+ state['query_feedback'] = query_feedback
348
+ return state
349
+
350
+
351
+
352
+ def should_continue_groundedness(state):
353
+ """Decides if groundedness is sufficient or needs improvement."""
354
+ print("---------should_continue_groundedness---------")
355
+ print("groundedness loop count: ", state['groundedness_loop_count'])
356
+ if state['groundedness_score'] >= 0.7: # Complete the code to define the threshold for groundedness
357
+ print("Moving to precision")
358
+ return "check_precision"
359
+ else:
360
+ if state["groundedness_loop_count"] > state['loop_max_iter']:
361
+ return "max_iterations_reached"
362
+ else:
363
+ print(f"---------Groundedness Score Threshold Not met. Refining Response-----------")
364
+ return "refine_response"
365
+
366
+
367
+ def should_continue_precision(state: Dict) -> str:
368
+ """Decides if precision is sufficient or needs improvement."""
369
+ print("---------should_continue_precision---------")
370
+ print("precision loop count: ", state['precision_loop_count'])
371
+ if state['precision_score'] >= 0.7: # Threshold for precision
372
+ return "pass" # Complete the workflow
373
+ else:
374
+ if state['precision_loop_count'] > state['loop_max_iter']: # Maximum allowed loops
375
+ return "max_iterations_reached"
376
+ else:
377
+ print(f"---------Precision Score Threshold Not met. Refining Query-----------") # Debugging
378
+ return "refine_query" # Refine the query
379
+
380
+
381
+
382
+
383
+ def max_iterations_reached(state: Dict) -> Dict:
384
+ """Handles the case when the maximum number of iterations is reached."""
385
+ print("---------max_iterations_reached---------")
386
+ """Handles the case when the maximum number of iterations is reached."""
387
+ response = "I'm unable to refine the response further. Please provide more context or clarify your question."
388
+ state['response'] = response
389
+ return state
390
+
391
+
392
+
393
+ from langgraph.graph import END, StateGraph, START
394
+
395
+ def create_workflow() -> StateGraph:
396
+ """Creates the updated workflow for the AI nutrition agent."""
397
+ workflow = StateGraph(AgentState ) # Complete the code to define the initial state of the agent
398
+
399
+ # Add processing nodes
400
+ workflow.add_node("expand_query", expand_query ) # Step 1: Expand user query. Complete with the function to expand the query
401
+ workflow.add_node("retrieve_context", retrieve_context ) # Step 2: Retrieve relevant documents. Complete with the function to retrieve context
402
+ workflow.add_node("craft_response", craft_response ) # Step 3: Generate a response based on retrieved data. Complete with the function to craft a response
403
+ workflow.add_node("score_groundedness", score_groundedness ) # Step 4: Evaluate response grounding. Complete with the function to score groundedness
404
+ workflow.add_node("refine_response", refine_response ) # Step 5: Improve response if it's weakly grounded. Complete with the function to refine the response
405
+ workflow.add_node("check_precision", check_precision ) # Step 6: Evaluate response precision. Complete with the function to check precision
406
+ workflow.add_node("refine_query", refine_query ) # Step 7: Improve query if response lacks precision. Complete with the function to refine the query
407
+ workflow.add_node("max_iterations_reached", max_iterations_reached ) # Step 8: Handle max iterations. Complete with the function to handle max iterations
408
+
409
+ # Main flow edges
410
+ workflow.add_edge(START, "expand_query")
411
+ workflow.add_edge("expand_query", "retrieve_context")
412
+ workflow.add_edge("retrieve_context", "craft_response")
413
+ workflow.add_edge("craft_response", "score_groundedness")
414
+
415
+ # Conditional edges based on groundedness check
416
+ workflow.add_conditional_edges(
417
+ "score_groundedness",
418
+ should_continue_groundedness, # Use the conditional function
419
+ {
420
+ "check_precision": "check_precision", # If well-grounded, proceed to precision check.
421
+ "refine_response": "refine_response", # If not, refine the response.
422
+ "max_iterations_reached":"max_iterations_reached" # If max loops reached, exit.
423
+ }
424
+ )
425
+
426
+ workflow.add_edge("refine_response", "craft_response") # Refined responses are reprocessed.
427
+
428
+ # Conditional edges based on precision check
429
+ workflow.add_conditional_edges(
430
+ "check_precision",
431
+ should_continue_precision, # Use the conditional function
432
+ {
433
+ "pass": END, # If precise, complete the workflow.
434
+ "refine_query": "refine_query", # If imprecise, refine the query.
435
+ "max_iterations_reached": "max_iterations_reached" # If max loops reached, exit.
436
+ }
437
+ )
438
+
439
+ workflow.add_edge("refine_query", "expand_query") # Refined queries go through expansion again.
440
+
441
+ workflow.add_edge("max_iterations_reached", END)
442
+
443
+ return workflow
444
+
445
+
446
+
447
+
448
+ #=========================== Defining the agentic rag tool ====================#
449
+ WORKFLOW_APP = create_workflow().compile()
450
+ @tool
451
+ def agentic_rag(query: str):
452
+ """
453
+ Runs the RAG-based agent with conversation history for context-aware responses.
454
+
455
+ Args:
456
+ query (str): The current user query.
457
+
458
+ Returns:
459
+ Dict[str, Any]: The updated state with the generated response and conversation history.
460
+ """
461
+ # Initialize state with necessary parameters
462
+ inputs = {
463
+ "query": query, # Current user query
464
+ "expanded_query": query, # Complete the code to define the expanded version of the query
465
+ "context": [], # Retrieved documents (initially empty)
466
+ "response": "", # Complete the code to define the AI-generated response
467
+ "precision_score": 0.0, # Complete the code to define the precision score of the response
468
+ "groundedness_score": 0.0, # Complete the code to define the groundedness score of the response
469
+ "groundedness_loop_count": 0, # Complete the code to define the counter for groundedness loops
470
+ "precision_loop_count": 0, # Complete the code to define the counter for precision loops
471
+ "feedback": "", # Complete the code to define the feedback
472
+ "query_feedback": "", # Complete the code to define the query feedback
473
+ "loop_max_iter": 3 # Complete the code to define the maximum number of iterations for loops
474
+ }
475
+
476
+ output = WORKFLOW_APP.invoke(inputs)
477
+
478
+ return output
479
+
480
+
481
+ #================================ Guardrails ===========================#
482
+ llama_guard_client = Groq(api_key=llama_api_key)
483
+ # Function to filter user input with Llama Guard
484
+ def filter_input_with_llama_guard(user_input, model="meta-llama/llama-guard-4-12b"):
485
+ """
486
+ Filters user input using Llama Guard to ensure it is safe.
487
+
488
+ Parameters:
489
+ - user_input: The input provided by the user.
490
+ - model: The Llama Guard model to be used for filtering (default is "llama-guard-3-8b").
491
+
492
+ Returns:
493
+ - The filtered and safe input.
494
+ """
495
+ try:
496
+ # Create a request to Llama Guard to filter the user input
497
+ response = llama_guard_client.chat.completions.create(
498
+ messages=[{"role": "user", "content": user_input}],
499
+ model=model,
500
+ )
501
+ # Return the filtered input
502
+ return response.choices[0].message.content.strip()
503
+ except Exception as e:
504
+ print(f"Error with Llama Guard: {e}")
505
+ return None
506
+
507
+
508
+ #============================= Adding Memory to the agent using mem0 ===============================#
509
+
510
+ class NutritionBot:
511
+ def __init__(self):
512
+ """
513
+ Initialize the NutritionBot class, setting up memory, the LLM client, tools, and the agent executor.
514
+ """
515
+
516
+ # Initialize a memory client to store and retrieve customer interactions
517
+ self.memory = MemoryClient(api_key= MEM0_api_key ) # Complete the code to define the memory client API key
518
+
519
+ # Initialize the OpenAI client using the provided credentials
520
+ self.client = ChatOpenAI(
521
+ model_name="gpt-4o-mini", # Specify the model to use (e.g., GPT-4 optimized version)
522
+ api_key=config.get("API_KEY"), # API key for authentication
523
+ base_url = config.get("OPENAI_API_BASE"),
524
+ temperature=0 # Controls randomness in responses; 0 ensures deterministic results
525
+ )
526
+
527
+ # Define tools available to the chatbot, such as web search
528
+ tools = [agentic_rag]
529
+
530
+ # Define the system prompt to set the behavior of the chatbot
531
+ system_prompt = """You are a caring and knowledgeable Medical Support Agent, specializing in nutrition disorder-related guidance. Your goal is to provide accurate, empathetic, and tailored nutritional recommendations while ensuring a seamless customer experience.
532
+ Guidelines for Interaction:
533
+ Maintain a polite, professional, and reassuring tone.
534
+ Show genuine empathy for customer concerns and health challenges.
535
+ Reference past interactions to provide personalized and consistent advice.
536
+ Engage with the customer by asking about their food preferences, dietary restrictions, and lifestyle before offering recommendations.
537
+ Ensure consistent and accurate information across conversations.
538
+ If any detail is unclear or missing, proactively ask for clarification.
539
+ Always use the agentic_rag tool to retrieve up-to-date and evidence-based nutrition insights.
540
+ Keep track of ongoing issues and follow-ups to ensure continuity in support.
541
+ Your primary goal is to help customers make informed nutrition decisions that align with their health conditions and personal preferences.
542
+
543
+ """
544
+
545
+ # Build the prompt template for the agent
546
+ prompt = ChatPromptTemplate.from_messages([
547
+ ("system", system_prompt), # System instructions
548
+ ("human", "{input}"), # Placeholder for human input
549
+ ("placeholder", "{agent_scratchpad}") # Placeholder for intermediate reasoning steps
550
+ ])
551
+
552
+ # Create an agent capable of interacting with tools and executing tasks
553
+ agent = create_tool_calling_agent(self.client, tools, prompt)
554
+
555
+ # Wrap the agent in an executor to manage tool interactions and execution flow
556
+ self.agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
557
+
558
+
559
+ def store_customer_interaction(self, user_id: str, message: str, response: str, metadata: Dict = None):
560
+ """
561
+ Store customer interaction in memory for future reference.
562
+
563
+ Args:
564
+ user_id (str): Unique identifier for the customer.
565
+ message (str): Customer's query or message.
566
+ response (str): Chatbot's response.
567
+ metadata (Dict, optional): Additional metadata for the interaction.
568
+ """
569
+ if metadata is None:
570
+ metadata = {}
571
+
572
+ # Add a timestamp to the metadata for tracking purposes
573
+ metadata["timestamp"] = datetime.now().isoformat()
574
+
575
+ # Format the conversation for storage
576
+ conversation = [
577
+ {"role": "user", "content": message},
578
+ {"role": "assistant", "content": response}
579
+ ]
580
+
581
+ # Store the interaction in the memory client
582
+ self.memory.add(
583
+ conversation,
584
+ user_id=user_id,
585
+ output_format="v1.1",
586
+ metadata=metadata
587
+ )
588
+
589
+
590
+ def get_relevant_history(self, user_id: str, query: str) -> List[Dict]:
591
+ """
592
+ Retrieve past interactions relevant to the current query.
593
+
594
+ Args:
595
+ user_id (str): Unique identifier for the customer.
596
+ query (str): The customer's current query.
597
+
598
+ Returns:
599
+ List[Dict]: A list of relevant past interactions.
600
+ """
601
+ return self.memory.search(
602
+ query=query, # Search for interactions related to the query
603
+ user_id=user_id, # Restrict search to the specific user
604
+ limit=3 # Complete the code to define the limit for retrieved interactions
605
+ )
606
+
607
+
608
+ def handle_customer_query(self, user_id: str, query: str) -> str:
609
+ """
610
+ Process a customer's query and provide a response, taking into account past interactions.
611
+
612
+ Args:
613
+ user_id (str): Unique identifier for the customer.
614
+ query (str): Customer's query.
615
+
616
+ Returns:
617
+ str: Chatbot's response.
618
+ """
619
+
620
+ # Retrieve relevant past interactions for context
621
+ relevant_history = self.get_relevant_history(user_id, query)
622
+
623
+ # Build a context string from the relevant history
624
+ context = "Previous relevant interactions:\n"
625
+ for memory in relevant_history:
626
+ context += f"Customer: {memory['memory']}\n" # Customer's past messages
627
+ context += f"Support: {memory['memory']}\n" # Chatbot's past responses
628
+ context += "---\n"
629
+
630
+ # Print context for debugging purposes
631
+ print("Context: ", context)
632
+
633
+ # Prepare a prompt combining past context and the current query
634
+ prompt = f"""
635
+ Context:
636
+ {context}
637
+
638
+ Current customer query: {query}
639
+
640
+ Provide a helpful response that takes into account any relevant past interactions.
641
+ """
642
+
643
+ # Generate a response using the agent
644
+ response = self.agent_executor.invoke({"input": prompt})
645
+
646
+ # Store the current interaction for future reference
647
+ self.store_customer_interaction(
648
+ user_id=user_id,
649
+ message=query,
650
+ response=response["output"],
651
+ metadata={"type": "support_query"}
652
+ )
653
+
654
+ # Return the chatbot's response
655
+ return response['output']
656
+
657
+
658
+ #=====================User Interface using streamlit ===========================#
659
+ def nutrition_disorder_streamlit():
660
+ """
661
+ A Streamlit-based UI for the Nutrition Disorder Specialist Agent.
662
+ """
663
+ st.title("Nutrition Disorder Specialist")
664
+ st.write("Ask me anything about nutrition disorders, symptoms, causes, treatments, and more.")
665
+ st.write("Type 'exit' to end the conversation.")
666
+
667
+ # Initialize session state for chat history and user_id if they don't exist
668
+ if 'chat_history' not in st.session_state:
669
+ st.session_state.chat_history = []
670
+ if 'user_id' not in st.session_state:
671
+ st.session_state.user_id = None
672
+
673
+ # Login form: Only if user is not logged in
674
+ if st.session_state.user_id is None:
675
+ with st.form("login_form", clear_on_submit=True):
676
+ user_id = st.text_input("Please enter your name to begin:")
677
+ submit_button = st.form_submit_button("Login")
678
+ if submit_button and user_id:
679
+ st.session_state.user_id = user_id
680
+ st.session_state.chat_history.append({
681
+ "role": "assistant",
682
+ "content": f"Welcome, {user_id}! How can I help you with nutrition disorders today?"
683
+ })
684
+ st.session_state.login_submitted = True # Set flag to trigger rerun
685
+ if st.session_state.get("login_submitted", False):
686
+ st.session_state.pop("login_submitted")
687
+ st.rerun()
688
+ else:
689
+ # Display chat history
690
+ for message in st.session_state.chat_history:
691
+ with st.chat_message(message["role"]):
692
+ st.write(message["content"])
693
+
694
+ # Chat input with custom placeholder text
695
+ user_query = st.chat_input("Type your query here ") # Blank #1: Fill in the chat input prompt (e.g., "Type your question here (or 'exit' to end)...")
696
+ if user_query:
697
+ if user_query.lower() == "exit":
698
+ st.session_state.chat_history.append({"role": "user", "content": "exit"})
699
+ with st.chat_message("user"):
700
+ st.write("exit")
701
+ goodbye_msg = "Goodbye! Feel free to return if you have more questions about nutrition disorders."
702
+ st.session_state.chat_history.append({"role": "assistant", "content": goodbye_msg})
703
+ with st.chat_message("assistant"):
704
+ st.write(goodbye_msg)
705
+ st.session_state.user_id = None
706
+ st.rerun()
707
+ return
708
+
709
+ st.session_state.chat_history.append({"role": "user", "content": user_query})
710
+ with st.chat_message("user"):
711
+ st.write(user_query)
712
+
713
+ # Filter input using Llama Guard
714
+ filtered_result = filter_input_with_llama_guard(user_query) # Blank #2: Fill in with the function name for filtering input (e.g., filter_input_with_llama_guard)
715
+ filtered_result = filtered_result.replace("\n", " ") # Normalize the result
716
+
717
+ # Check if input is safe based on allowed statuses
718
+ if filtered_result in ["safe", "S6", "S7"]: # Blanks #3, #4, #5: Fill in with allowed safe statuses (e.g., "safe", "unsafe S7", "unsafe S6")
719
+ try:
720
+ if 'chatbot' not in st.session_state:
721
+ st.session_state.chatbot = NutritionBot() # Blank #6: Fill in with the chatbot class initialization (e.g., NutritionBot)
722
+ response = st.session_state.chatbot.handle_customer_query(st.session_state.user_id, user_query)
723
+
724
+ # Blank #7: Fill in with the method to handle queries (e.g., handle_customer_query)
725
+ st.write(response)
726
+ st.session_state.chat_history.append({"role": "assistant", "content": response})
727
+ except Exception as e:
728
+ error_msg = f"Sorry, I encountered an error while processing your query. Please try again. Error: {str(e)}"
729
+ st.write(error_msg)
730
+ st.session_state.chat_history.append({"role": "assistant", "content": error_msg})
731
+ else:
732
+ inappropriate_msg = "I apologize, but I cannot process that input as it may be inappropriate. Please try again."
733
+ st.write(inappropriate_msg)
734
+ st.session_state.chat_history.append({"role": "assistant", "content": inappropriate_msg})
735
+
736
+ if __name__ == "__main__":
737
+ nutrition_disorder_streamlit()