Vidhi00 commited on
Commit
b21affd
·
verified ·
1 Parent(s): 0ecdaec

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +736 -0
app.py ADDED
@@ -0,0 +1,736 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import streamlit as st
4
+ from datetime import datetime
5
+ from typing import Any, Dict, List, Tuple, TypedDict
6
+
7
+ # LangChain imports
8
+ from langchain_core.output_parsers import StrOutputParser
9
+ from langchain_core.prompts import ChatPromptTemplate
10
+ from langchain_core.runnables import RunnablePassthrough
11
+ from langchain_core.tools import tool
12
+ from langchain_core.documents import Document
13
+
14
+ from langchain.agents import (
15
+ AgentExecutor,
16
+ create_tool_calling_agent,
17
+ create_openai_tools_agent,
18
+ initialize_agent,
19
+ AgentType
20
+ )
21
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
22
+ from langchain.retrievers import ContextualCompressionRetriever
23
+ from langchain.retrievers.self_query.base import SelfQueryRetriever
24
+ from langchain.retrievers.document_compressors import LLMChainExtractor, CrossEncoderReranker
25
+ from langchain.chains.query_constructor.base import AttributeInfo
26
+
27
+ # LangChain Community imports
28
+ from langchain_community.document_loaders import PyPDFDirectoryLoader, PyPDFLoader
29
+ from langchain_community.vectorstores import Chroma
30
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder
31
+
32
+ # LangChain OpenAI specific imports
33
+ from langchain_openai import ChatOpenAI, OpenAIEmbeddings
34
+
35
+ # Misc imports
36
+ from langgraph.graph import END, StateGraph, START
37
+ from pydantic import BaseModel, Field, validator
38
+ import chromadb
39
+ from tqdm import tqdm
40
+ from llama_index.core.settings import Settings
41
+ from groq import Groq # Llama Guard client for filtering user input
42
+ from mem0 import MemoryClient
43
+
44
+ # Fix numpy float type for compatibility
45
+ np.float_ = np.float64
46
+ #====================================
47
+ # Environment setup
48
+ api_key = os.getenv("OPENAI_API_KEY")
49
+ endpoint = os.getenv("OPENAI_API_BASE")
50
+ memo_api_key = os.getenv('mem0')
51
+
52
+ # Initialize the OpenAI Embeddings
53
+ embedding_model = OpenAIEmbeddings(
54
+ openai_api_base=endpoint, # Fill in the endpoint
55
+ openai_api_key=api_key, # Fill in the API key
56
+ model='text-embedding-ada-002' # Fill in the model name
57
+ )
58
+
59
+ # Initialize LLM
60
+ llm = ChatOpenAI(
61
+ openai_api_base=endpoint, # Fill in the endpoint
62
+ openai_api_key=api_key, # Fill in the API key
63
+ model="gpt-4o-mini", # Fill in the deployment name (e.g., gpt-4o-mini)
64
+ streaming=False)
65
+
66
+ # Configure settings
67
+ Settings.llm = llm
68
+ Settings.embedding = embedding_model
69
+ #====================================
70
+ class AgentState(TypedDict):
71
+ query: str # The current user query
72
+ expanded_query: str # The expanded version of the user query
73
+ context: List[Dict[str, Any]] # Retrieved documents (content and metadata)
74
+ response: str # The generated response to the user query
75
+ precision_score: float # The precision score of the response
76
+ groundedness_score: float # The groundedness score of the response
77
+ groundedness_loop_count: int # Counter for groundedness refinement loops
78
+ precision_loop_count: int # Counter for precision refinement loops
79
+ feedback: str # Feedback from the user
80
+ query_feedback: str # Feedback specifically related to the query
81
+ groundedness_check: bool # Indicator for groundedness check
82
+ loop_max_iter: int # Maximum iterations for loops
83
+ #====================================
84
+ def expand_query(state):
85
+ print("---------Expanding Query---------")
86
+
87
+ system_message = """
88
+ You are a domain expert assisting in answering questions related to medical reference documentation.
89
+ Convert the user query into more specific and domain-related phrasing that a Nutrition Disorder Specialist would understand.
90
+ Expand the query by considering the use of appropriate medical terminology, synonyms, and various common ways to phrase the query.
91
+
92
+ Guidelines:
93
+ If the query has multiple distinct parts, break them into separate, simpler queries.
94
+ If there are common synonyms or alternative phrasing for key terms, provide multiple versions of the query.
95
+ Do not generate more than three queries, except when the query involves multiple separate parts (in which case, you can generate more than three).
96
+ Do not attempt to rephrase unfamiliar acronyms or terms. Leave them as is.
97
+ Return only a list of up to three queries. Do not include anything before or after the list.
98
+ """
99
+
100
+ expand_prompt = ChatPromptTemplate.from_messages([
101
+ ("system", system_message),
102
+ ("user", "Expand this query: {query} using the feedback: {query_feedback}")
103
+
104
+ ])
105
+
106
+ chain = expand_prompt | llm | StrOutputParser()
107
+ expanded_query = chain.invoke({"query": state['query'], "query_feedback":state["query_feedback"]})
108
+ print("expanded_query", expanded_query)
109
+ state["expanded_query"] = expanded_query
110
+ return state
111
+ #====================================
112
+ # Initialize the Chroma vector store for retrieving documents
113
+
114
+ vector_store = Chroma(
115
+ collection_name='Nutrition', # Complete the code to define the collection name
116
+ persist_directory='./nutritional_db', # Complete the code to define the directory for persistence
117
+ embedding_function=embedding_model # Complete the code to define the embedding function
118
+ )
119
+
120
+ # Create a retriever from the vector store
121
+ retriever = vector_store.as_retriever(
122
+ search_type='similarity', # Complete the code to define the search type
123
+ search_kwargs={'k': 5} # Complete the code to define the number of results to retrieve
124
+ )
125
+ #====================================
126
+ def retrieve_context(state):
127
+ """
128
+ Retrieves context from the vector store using the expanded or original query.
129
+
130
+ Args:
131
+ state (Dict): The current state of the workflow, containing the query and expanded query.
132
+
133
+ Returns:
134
+ Dict: The updated state with the retrieved context.
135
+ """
136
+ print("---------Retrieve_Context---------")
137
+ query = state['expanded_query'] # Complete the code to define the key for the expanded query
138
+ print("Query used for retrieval:", query) # Debugging: Print the query
139
+
140
+ # Retrieve documents from the vector store
141
+ docs = retriever.invoke(query)
142
+ print("Retrieved documents:", docs) # Debugging: Print the raw docs object
143
+
144
+ # Extract both page_content and metadata from each document
145
+ context= [
146
+ {
147
+ "content": doc.page_content, # The actual content of the document
148
+ "metadata": doc.metadata # The metadata (e.g., source, page number, etc.)
149
+ }
150
+ for doc in docs
151
+ ]
152
+ state['context'] = context # Complete the code to define the key for storing the context
153
+ print("Extracted context with metadata:", context) # Debugging: Print the extracted context
154
+ print(f"Groundedness loop count: {state['groundedness_loop_count']}")
155
+ return state
156
+ #====================================
157
+ def craft_response(state: Dict) -> Dict:
158
+ """
159
+ Generates a response using the retrieved context, focusing on nutrition disorders.
160
+
161
+ Args:
162
+ state (Dict): The current state of the workflow, containing the query and retrieved context.
163
+
164
+ Returns:
165
+ Dict: The updated state with the generated response.
166
+ """
167
+ print("---------Craft_Response---------")
168
+ system_message = """
169
+ Ensure the information is grounded in the context, avoid speculation, and prioritize clarity.
170
+ You are a knowledgeable and empathetic medical assistant specializing in nutritional disorders.
171
+ Given the retrieved context, generate a precise, informative, and concise response that directly addresses the user's query.
172
+ Ensure the information is fully grounded in the provided context, and avoid introducing speculative or unsupported content.
173
+ Focus on clarity and accuracy, ensuring the user receives helpful and relevant advice regarding nutritional disorders.
174
+ """
175
+
176
+ response_prompt = ChatPromptTemplate.from_messages([
177
+ ("system", system_message),
178
+ ("user", "Query: {query}\nContext: {context}\n\nfeedback: {feedback}")
179
+ ])
180
+
181
+ chain = response_prompt | llm
182
+ response = chain.invoke({
183
+ "query": state['query'],
184
+ "context": "\n".join([doc["content"] for doc in state['context']]),
185
+ "feedback": state['feedback'] # add feedback to the prompt
186
+ })
187
+ state['response'] = response
188
+ print("intermediate response: ", response)
189
+
190
+ return state
191
+ #====================================
192
+ def score_groundedness(state: Dict) -> Dict:
193
+ """
194
+ Checks whether the response is grounded in the retrieved context.
195
+
196
+ Args:
197
+ state (Dict): The current state of the workflow, containing the response and context.
198
+
199
+ Returns:
200
+ Dict: The updated state with the groundedness score.
201
+ """
202
+ print("---------Check_Groundedness---------")
203
+ system_message = """
204
+ You are an expert evaluator for Retrieval-Augmented Generation (RAG) systems. Your task is to assess the GROUNDEDNESS of a response.
205
+ Given an answer and its retrieved context, determine whether the response is based on or supported by the provided context.
206
+ Respond with:
207
+ 1.0 if the answer is grounded in the context
208
+ 0.0 if the answer is not grounded in the context
209
+ The output must be a float: either 1.0 or 0.0. Do not include any explanation.
210
+ """
211
+
212
+ groundedness_prompt = ChatPromptTemplate.from_messages([
213
+ ("system", system_message),
214
+ ("user", "Context: {context}\nResponse: {response}\n\nGroundedness score:")
215
+ ])
216
+
217
+ chain = groundedness_prompt | llm | StrOutputParser()
218
+ groundedness_score = float(chain.invoke({
219
+ "context": "\n".join([doc["content"] for doc in state['context']]),
220
+ "response": state['response'] # Complete the code to define the response
221
+ }))
222
+ print("groundedness_score: ", groundedness_score)
223
+ state['groundedness_loop_count'] += 1
224
+ print("#########Groundedness Incremented###########")
225
+ state['groundedness_score'] = groundedness_score
226
+
227
+ return state
228
+ #====================================
229
+ def check_precision(state: Dict) -> Dict:
230
+ """
231
+ Checks whether the response precisely addresses the user’s query.
232
+
233
+ Args:
234
+ state (Dict): The current state of the workflow, containing the query and response.
235
+
236
+ Returns:
237
+ Dict: The updated state with the precision score.
238
+ """
239
+ print("---------Check_Precision---------")
240
+ system_message = """
241
+ You are an expert evaluator for Retrieval-Augmented Generation (RAG) systems. Your task is to assess the PRECISION of a response.
242
+ Given a query and a response, determine whether the response precisely addresses the user's query without including unrelated or irrelevant information.
243
+ Respond with:
244
+ 1.0 if the response is precise
245
+ 0.0 if the response is not precise
246
+ Your answer must be a float: either 1.0 or 0.0. Do not include any explanation.
247
+ """
248
+
249
+ precision_prompt = ChatPromptTemplate.from_messages([
250
+ ("system", system_message),
251
+ ("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
252
+ ])
253
+
254
+ chain = precision_prompt | llm | StrOutputParser() # Complete the code to define the chain of processing
255
+ precision_score = float(chain.invoke({
256
+ "query": state['query'],
257
+ "response":state['response'] # Complete the code to access the response from the state
258
+ }))
259
+ state['precision_score'] = precision_score
260
+ print("precision_score:", precision_score)
261
+ state['precision_loop_count'] +=1
262
+ print("#########Precision Incremented###########")
263
+ return state
264
+ #====================================
265
+ def refine_response(state: Dict) -> Dict:
266
+ """
267
+ Suggests improvements for the generated response.
268
+
269
+ Args:
270
+ state (Dict): The current state of the workflow, containing the query and response.
271
+
272
+ Returns:
273
+ Dict: The updated state with response refinement suggestions.
274
+ """
275
+ print("---------Refine_Response---------")
276
+
277
+ system_message = """
278
+ You are an expert evaluator for medical responses. Given the generated response, your task is to identify areas for improvement.
279
+ Provide feedback on any gaps, ambiguities, or missing details in the response.
280
+ Ensure that the suggestions focus on:
281
+ Factual grounding: Ensure the information is scientifically accurate and well-supported by evidence.
282
+ Logic and coherence: Assess the clarity and flow of the response. Ensure the response makes sense and is easy to follow.
283
+ Completeness: Identify any missing details or important context that should be included.
284
+ Empathy and tone: If the response is intended for a medical or patient-facing audience, ensure the tone is supportive, clear, and empathetic.
285
+
286
+ After reviewing the response, provide detailed suggestions for improvement in the following format:
287
+ 1. [Description of Issue] - [Suggested Improvement].
288
+ 2. [Description of Issue] - [Suggested Improvement].
289
+ (Continue providing numbered suggestions as needed.)
290
+
291
+ Make sure to provide constructive, actionable suggestions that can enhance the response.
292
+ """
293
+
294
+ refine_response_prompt = ChatPromptTemplate.from_messages([
295
+ ("system", system_message),
296
+ ("user", "Query: {query}\nResponse: {response}\n\n"
297
+ "What improvements can be made to enhance accuracy and completeness?")
298
+ ])
299
+
300
+ chain = refine_response_prompt | llm| StrOutputParser()
301
+
302
+ # Store response suggestions in a structured format
303
+ feedback = f"Previous Response: {state['response']}\nSuggestions: {chain.invoke({'query': state['query'], 'response': state['response']})}"
304
+ print("feedback: ", feedback)
305
+ print(f"State: {state}")
306
+ state['feedback'] = feedback
307
+ return state
308
+ #====================================
309
+ def refine_query(state: Dict) -> Dict:
310
+ """
311
+ Suggests improvements for the expanded query.
312
+
313
+ Args:
314
+ state (Dict): The current state of the workflow, containing the query and expanded query.
315
+
316
+ Returns:
317
+ Dict: The updated state with query refinement suggestions.
318
+ """
319
+ print("---------Refine_Query---------")
320
+ system_message = """
321
+ You are an expert in query refinement, helping to improve search precision for medical topics.
322
+ Given the original and expanded queries, your task is to suggest improvements that can enhance search effectiveness.
323
+
324
+ Areas to focus on:
325
+ 1. **Missing Details**: Identify any essential details or context that could improve the specificity of the query.
326
+ 2. **Keywords**: Recommend more relevant or precise keywords that are specific to the topic (e.g., nutrition disorders, medical terminology).
327
+ 3. **Scope Refinement**: Suggest ways to narrow or broaden the query to improve the relevance of the results (e.g., focus on specific conditions, age groups, etc.).
328
+ 4. **Clarity**: Ensure the query is clear, concise, and free from ambiguity to help improve search engine interpretation.
329
+ 5. **Synonym Usage**: Suggest any common synonyms or alternate phrases for key terms that could improve search recall.
330
+
331
+ After reviewing the original and expanded queries, provide your suggestions in a clear and actionable format, such as:
332
+ 1. [Description of Issue] - [Suggested Improvement].
333
+ 2. [Description of Issue] - [Suggested Improvement].
334
+ (Continue providing numbered suggestions as needed.)
335
+
336
+ Your goal is to make the query more precise, comprehensive, and search-friendly to enhance the retrieval of relevant information.
337
+ """
338
+
339
+ refine_query_prompt = ChatPromptTemplate.from_messages([
340
+ ("system", system_message),
341
+ ("user", "Original Query: {query}\nExpanded Query: {expanded_query}\n\n"
342
+ "What improvements can be made for a better search?")
343
+ ])
344
+
345
+ chain = refine_query_prompt | llm | StrOutputParser()
346
+
347
+ # Store refinement suggestions without modifying the original expanded query
348
+ query_feedback = f"Previous Expanded Query: {state['expanded_query']}\nSuggestions: {chain.invoke({'query': state['query'], 'expanded_query': state['expanded_query']})}"
349
+ print("query_feedback: ", query_feedback)
350
+ print(f"Groundedness loop count: {state['groundedness_loop_count']}")
351
+ state['query_feedback'] = query_feedback
352
+ return state
353
+ #====================================
354
+ def should_continue_groundedness(state):
355
+ """Decides if groundedness is sufficient or needs improvement."""
356
+ print("---------Should_Continue_Groundedness---------")
357
+ print("Groundedness Loop Count: ", state['groundedness_loop_count'])
358
+ if state['groundedness_score'] > 0: # Complete the code to define the threshold for groundedness
359
+ print("Moving to Precision")
360
+ return "check_precision"
361
+ else:
362
+ if state["groundedness_loop_count"] > state['loop_max_iter']:
363
+ return "max_iterations_reached"
364
+ else:
365
+ print(f"---------Groundedness Score Threshold Not Met. Refining Response-----------")
366
+ return "refine_response"
367
+ #====================================
368
+ def should_continue_precision(state: Dict) -> str:
369
+ """Decides if precision is sufficient or needs improvement."""
370
+ print("---------Should_Continue_Precision---------")
371
+ print("precision loop count: ", state['precision_loop_count'])
372
+ if state['precision_score'] > 0: # Threshold for precision
373
+ return "pass" # Complete the workflow
374
+ else:
375
+ if state['precision_loop_count'] > state['loop_max_iter']: # Maximum allowed loops
376
+ return "max_iterations_reached"
377
+ else:
378
+ print(f"---------Precision Score Threshold Not Met. Refining Query-----------") # Debugging
379
+ return "refine_query" # Refine the query
380
+ #====================================
381
+ def max_iterations_reached(state: Dict) -> Dict:
382
+ """Handles the case when the maximum number of iterations is reached."""
383
+ print("---------Max_Iterations_Reached---------")
384
+ """Handles the case when the maximum number of iterations is reached."""
385
+ response = "I'm unable to refine the response further. Please provide more context or clarify your question."
386
+ state['response'] = response
387
+ return state
388
+ #====================================
389
+ def create_workflow() -> StateGraph:
390
+ """Creates the updated workflow for the AI nutrition agent."""
391
+ workflow = StateGraph(AgentState) # Complete the code to define the initial state of the agent
392
+
393
+ # Add processing nodes
394
+ workflow.add_node("expand_query", expand_query) # Step 1: Expand user query. Complete with the function to expand the query
395
+ workflow.add_node("retrieve_context", retrieve_context) # Step 2: Retrieve relevant documents. Complete with the function to retrieve context
396
+ workflow.add_node("craft_response", craft_response) # Step 3: Generate a response based on retrieved data. Complete with the function to craft a response
397
+ workflow.add_node("score_groundedness", score_groundedness) # Step 4: Evaluate response grounding. Complete with the function to score groundedness
398
+ workflow.add_node("refine_response", refine_response) # Step 5: Improve response if it's weakly grounded. Complete with the function to refine the response
399
+ workflow.add_node("check_precision", check_precision) # Step 6: Evaluate response precision. Complete with the function to check precision
400
+ workflow.add_node("refine_query", refine_query) # Step 7: Improve query if response lacks precision. Complete with the function to refine the query
401
+ workflow.add_node("max_iterations_reached", max_iterations_reached) # Step 8: Handle max iterations. Complete with the function to handle max iterations
402
+
403
+ # Main flow edges
404
+ workflow.add_edge(START, "expand_query")
405
+ workflow.add_edge("expand_query", "retrieve_context")
406
+ workflow.add_edge("retrieve_context", "craft_response")
407
+ workflow.add_edge("craft_response", "score_groundedness")
408
+
409
+ # Conditional edges based on groundedness check
410
+ workflow.add_conditional_edges(
411
+ "score_groundedness",
412
+ should_continue_groundedness, # Use the conditional function
413
+ {
414
+ "check_precision": 'check_precision', # If well-grounded, proceed to precision check.
415
+ "refine_response": 'refine_response', # If not, refine the response.
416
+ "max_iterations_reached": 'max_iterations_reached' # If max loops reached, exit.
417
+ }
418
+ )
419
+
420
+ workflow.add_edge('refine_response', 'craft_response') # Refined responses are reprocessed.
421
+
422
+ # Conditional edges based on precision check
423
+ workflow.add_conditional_edges(
424
+ "check_precision",
425
+ should_continue_precision, # Use the conditional function
426
+ {
427
+ "pass": END, # If precise, complete the workflow.
428
+ "refine_query": 'refine_query', # If imprecise, refine the query.
429
+ "max_iterations_reached": 'max_iterations_reached' # If max loops reached, exit.
430
+ }
431
+ )
432
+
433
+ workflow.add_edge('refine_query', 'expand_query') # Refined queries go through expansion again.
434
+
435
+ workflow.add_edge("max_iterations_reached", END)
436
+
437
+ return workflow
438
+ #====================================
439
+ WORKFLOW_APP = create_workflow().compile()
440
+ #====================================
441
+ @tool
442
+ def agentic_rag(query: str):
443
+ """
444
+ Runs the RAG-based agent with conversation history for context-aware responses.
445
+
446
+ Args:
447
+ query (str): The current user query.
448
+
449
+ Returns:
450
+ Dict[str, Any]: The updated state with the generated response and conversation history.
451
+ """
452
+ # Initialize state with necessary parameters
453
+ inputs = {
454
+ "query": query, # Current user query
455
+ "expanded_query": "", # Complete the code to define the expanded version of the query
456
+ "context": [], # Retrieved documents (initially empty)
457
+ "response": "", # Complete the code to define the AI-generated response
458
+ "precision_score": 0, # Complete the code to define the precision score of the response
459
+ "groundedness_score": 0, # Complete the code to define the groundedness score of the response
460
+ "groundedness_loop_count": 0, # Complete the code to define the counter for groundedness loops
461
+ "precision_loop_count": 0, # Complete the code to define the counter for precision loops
462
+ "feedback": "", # Complete the code to define the feedback
463
+ "query_feedback": "", # Complete the code to define the query feedback
464
+ "loop_max_iter": 3 # Complete the code to define the maximum number of iterations for loops
465
+ }
466
+
467
+ output = WORKFLOW_APP.invoke(inputs)
468
+
469
+ # Extract the AI response from the output
470
+ response = output.get("response", "Error: No response generated.")
471
+
472
+ # Check if the response is an AIMessage, then extract its content
473
+ if isinstance(response, str):
474
+ return response.strip()
475
+ elif hasattr(response, "content"):
476
+ return response.content.strip()
477
+ else:
478
+ # Handle unexpected response types if necessary
479
+ return str(response).strip()
480
+ #====================================
481
+ # Retrieve the Llama API key from user data
482
+ groq_api_key = os.getenv('Groq') # Complete the code to define the key name for retrieving the API key
483
+
484
+ # Initialize the Llama Guard client with the API key
485
+ llama_guard_client = Groq(api_key=groq_api_key) # Complete the code to provide the API key for the Llama Guard client
486
+ #====================================
487
+ # Function to filter user input with Llama Guard
488
+ def filter_input_with_llama_guard(user_input, model="llama-guard-3-8b"):
489
+ """
490
+ Filters user input using Llama Guard to ensure it is safe.
491
+
492
+ Parameters:
493
+ - user_input: The input provided by the user.
494
+ - model: The Llama Guard model to be used for filtering (default is "llama-guard-3-8b").
495
+
496
+ Returns:
497
+ - The filtered and safe input.
498
+ """
499
+ try:
500
+ # Create a request to Llama Guard to filter the user input
501
+ llama_response = llama_guard_client.chat.completions.create(
502
+ messages=[{"role": "user", "content": user_input}],
503
+ model=model,
504
+ )
505
+ # Return the filtered input
506
+ return llama_response.choices[0].message.content.strip()
507
+ except Exception as e:
508
+ print(f"Error with Llama Guard: {e}")
509
+ return None
510
+ #====================================
511
+ class NutritionBot:
512
+ def __init__(self):
513
+ """
514
+ Initialize the NutritionBot class with memory, LLM client, tools, and the agent executor.
515
+ """
516
+
517
+ # Memory to store/retrieve customer interactions
518
+ self.memory = MemoryClient(api_key=os.getenv("Mem0"))
519
+
520
+ # LLM setup
521
+ self.llm = ChatOpenAI(
522
+ model="gpt-4o-mini",
523
+ openai_api_key=os.getenv("API_KEY"),
524
+ openai_api_base=os.getenv("OPENAI_API_BASE"),
525
+ temperature=0,
526
+ streaming=False
527
+ )
528
+
529
+ # Define the system prompt to set the behavior of the chatbot
530
+ system_prompt = """You are a caring and knowledgeable Medical Support Agent, specializing in nutrition disorder-related guidance. Your goal is to provide accurate, empathetic, and tailored nutritional recommendations while ensuring a seamless customer experience.
531
+ Guidelines for Interaction:
532
+ Maintain a polite, professional, and reassuring tone.
533
+ Show genuine empathy for customer concerns and health challenges.
534
+ Reference past interactions to provide personalized and consistent advice.
535
+ Engage with the customer by asking about their food preferences, dietary restrictions, and lifestyle before offering recommendations.
536
+ Ensure consistent and accurate information across conversations.
537
+ If any detail is unclear or missing, proactively ask for clarification.
538
+ Always use the agentic_rag tool to retrieve up-to-date and evidence-based nutrition insights.
539
+ Keep track of ongoing issues and follow-ups to ensure continuity in support.
540
+ Your primary goal is to help customers make informed nutrition decisions that align with their health conditions and personal preferences."""
541
+
542
+ # Build the prompt template for the agent
543
+ prompt = ChatPromptTemplate.from_messages([
544
+ ("system", system_prompt), # System instructions
545
+ ("human", "{input}"), # Placeholder for human input
546
+ ("placeholder", "{agent_scratchpad}") # Placeholder for intermediate reasoning steps
547
+ ])
548
+
549
+ # Tool setup
550
+ self.tools = [agentic_rag]
551
+
552
+ # Agent initialization
553
+ self.agent_executor = initialize_agent(
554
+ tools=self.tools,
555
+ llm=self.llm,
556
+ prompt=prompt,
557
+ agent=AgentType.OPENAI_FUNCTIONS,
558
+ verbose=True,
559
+ handle_parsing_errors=True,
560
+ return_intermediate_steps=True
561
+ )
562
+
563
+ def store_customer_interaction(self, user_id: str, message: str, response: str, metadata: Dict = None):
564
+ """
565
+ Store customer interaction in memory for future reference.
566
+
567
+ Args:
568
+ user_id (str): Unique identifier for the customer.
569
+ message (str): Customer's query or message.
570
+ response (str): Chatbot's response.
571
+ metadata (Dict, optional): Additional metadata for the interaction.
572
+ """
573
+ if metadata is None:
574
+ metadata = {}
575
+
576
+ # Add a timestamp to the metadata for tracking purposes
577
+ metadata["timestamp"] = datetime.now().isoformat()
578
+
579
+ # Format the conversation for storage
580
+ conversation = [
581
+ {"role": "user", "content": message},
582
+ {"role": "assistant", "content": response}
583
+ ]
584
+
585
+ # Store the interaction in the memory client
586
+ self.memory.add(
587
+ conversation,
588
+ user_id=user_id,
589
+ output_format="v1.1",
590
+ metadata=metadata
591
+ )
592
+
593
+ def get_relevant_history(self, user_id: str, query: str) -> List[Dict]:
594
+ """
595
+ Retrieve past interactions relevant to the current query.
596
+
597
+ Args:
598
+ user_id (str): Unique identifier for the customer.
599
+ query (str): The customer's current query.
600
+
601
+ Returns:
602
+ List[Dict]: A list of relevant past interactions.
603
+ """
604
+ return self.memory.search(
605
+ query=query, # Search for interactions related to the query
606
+ user_id=user_id, # Restrict search to the specific user
607
+ limit=5 # Complete the code to define the limit for retrieved interactions
608
+ )
609
+
610
+ def handle_customer_query(self, user_id: str, query: str) -> str:
611
+ """
612
+ Process a customer's query and provide a response, taking into account past interactions.
613
+ Args:
614
+ user_id (str): Unique identifier for the customer.
615
+ query (str): Customer's query.
616
+
617
+ Returns:
618
+ str: Chatbot's response.
619
+ """
620
+ # Retrieve relevant past interactions for context
621
+ relevant_history = self.get_relevant_history(user_id, query)
622
+
623
+ # Build a context string from the relevant history
624
+ context = "Previous relevant interactions:\n"
625
+ for memory in relevant_history:
626
+ context += f"{memory}\n---\n"
627
+ print(f"Context:{context}")
628
+
629
+ # Prepare a prompt combining past context and the current query
630
+ prompt = f"""
631
+ Context: {context}
632
+ Current customer query: {query}
633
+ Provide a helpful response that takes into account any relevant past interactions.
634
+ """
635
+ try:
636
+ response = self.agent_executor.invoke({"input": prompt})
637
+ except Exception as e:
638
+ print(f"An error occurred while invoking the agent executor: {e}")
639
+ return "I'm sorry, something went wrong while processing your request."
640
+
641
+ self.store_customer_interaction(
642
+ user_id=user_id,
643
+ message=query,
644
+ response=response["output"],
645
+ metadata={"type": "support_query"}
646
+ )
647
+
648
+ return response["output"]
649
+
650
+ #=====================User Interface using streamlit ===========================#
651
+ def nutrition_disorder_streamlit():
652
+ """
653
+ A Streamlit-based UI for the Nutrition Disorder Specialist Agent.
654
+ """
655
+ st.title("Nutrition Disorder Specialist")
656
+ st.write("Ask me anything about nutrition disorders, symptoms, causes, treatments, and more.")
657
+ st.write("Type 'exit' to end the conversation.")
658
+
659
+ # Initialize user_id
660
+ if 'user_id' not in st.session_state:
661
+ st.session_state.user_id = None
662
+
663
+ # Define per-user chat history key
664
+ user_chat_key = f"chat_history_{st.session_state.user_id}" if st.session_state.user_id else "chat_history_temp"
665
+
666
+ if user_chat_key not in st.session_state:
667
+ st.session_state[user_chat_key] = []
668
+
669
+ # Login form
670
+ if st.session_state.user_id is None:
671
+ with st.form("login_form", clear_on_submit=True):
672
+ user_id = st.text_input("Please enter your name to begin:")
673
+ submit_button = st.form_submit_button("Login")
674
+ if submit_button and user_id:
675
+ st.session_state.user_id = user_id
676
+ user_chat_key = f"chat_history_{user_id}"
677
+ st.session_state[user_chat_key] = [{
678
+ "role": "assistant",
679
+ "content": f"Welcome, {user_id}! How can I help you with nutrition disorders today?"
680
+ }]
681
+ st.session_state.login_submitted = True
682
+ if st.session_state.get("login_submitted", False):
683
+ st.session_state.pop("login_submitted")
684
+ st.rerun()
685
+ else:
686
+ # Display chat history
687
+ for message in st.session_state[user_chat_key]:
688
+ with st.chat_message(message["role"]):
689
+ st.write(message["content"])
690
+
691
+ # Chat input
692
+ user_query = st.chat_input("Type your question here or 'exit' to end")
693
+ if user_query:
694
+ if user_query.lower() == "exit":
695
+ st.session_state[user_chat_key].append({"role": "user", "content": "exit"})
696
+ with st.chat_message("user"):
697
+ st.write("exit")
698
+ goodbye_msg = "Goodbye! Feel free to return if you have more questions about nutrition disorders."
699
+ st.session_state[user_chat_key].append({"role": "assistant", "content": goodbye_msg})
700
+ with st.chat_message("assistant"):
701
+ st.write(goodbye_msg)
702
+ st.session_state.pop(user_chat_key, None)
703
+ st.session_state.user_id = None
704
+ st.rerun()
705
+ return
706
+
707
+ st.session_state[user_chat_key].append({"role": "user", "content": user_query})
708
+ with st.chat_message("user"):
709
+ st.write(user_query)
710
+
711
+ # Filter input using Llama Guard
712
+ filtered_result = filter_input_with_llama_guard(user_query)
713
+ filtered_result = filtered_result.replace("\n", " ")
714
+
715
+ # Validate safe input
716
+ if filtered_result in ['safe', 'unsafe S6', 'unsafe S7']:
717
+ try:
718
+ if 'chatbot' not in st.session_state:
719
+ st.session_state.chatbot = NutritionBot()
720
+ response = st.session_state.chatbot.handle_customer_query(
721
+ st.session_state.user_id,
722
+ user_query
723
+ )
724
+ st.write(response)
725
+ st.session_state[user_chat_key].append({"role": "assistant", "content": response})
726
+ except Exception as e:
727
+ error_msg = f"Sorry, I encountered an error while processing your query. Please try again. Error: {str(e)}"
728
+ st.write(error_msg)
729
+ st.session_state[user_chat_key].append({"role": "assistant", "content": error_msg})
730
+ else:
731
+ inappropriate_msg = "I apologize, but I cannot process that input as it may be inappropriate. Please try again."
732
+ st.write(inappropriate_msg)
733
+ st.session_state[user_chat_key].append({"role": "assistant", "content": inappropriate_msg})
734
+
735
+ if __name__ == "__main__":
736
+ nutrition_disorder_streamlit()