madhan2211 commited on
Commit
a8c5c48
·
verified ·
1 Parent(s): 4708803

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +960 -0
app.py ADDED
@@ -0,0 +1,960 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Import necessary libraries
3
+ import os # Interacting with the operating system (reading/writing files)
4
+ import chromadb # High-performance vector database for storing/querying dense vectors
5
+ from dotenv import load_dotenv # Loading environment variables from a .env file
6
+ import json # Parsing and handling JSON data
7
+
8
+ # LangChain imports
9
+ from langchain_core.documents import Document # Document data structures
10
+ from langchain_core.runnables import (
11
+ RunnablePassthrough,
12
+ ) # LangChain core library for running pipelines
13
+ from langchain_core.output_parsers import StrOutputParser # String output parser
14
+ from langchain.prompts import ChatPromptTemplate # Template for chat prompts
15
+ from langchain.chains.query_constructor.base import (
16
+ AttributeInfo,
17
+ ) # Base classes for query construction
18
+ from langchain.retrievers.self_query.base import (
19
+ SelfQueryRetriever,
20
+ ) # Base classes for self-querying retrievers
21
+ from langchain.retrievers.document_compressors import (
22
+ LLMChainExtractor,
23
+ CrossEncoderReranker,
24
+ ) # Document compressors
25
+ from langchain.retrievers import (
26
+ ContextualCompressionRetriever,
27
+ ) # Contextual compression retrievers
28
+
29
+ # LangChain community & experimental imports
30
+ from langchain_chroma import Chroma # Implementations of vector stores like Chroma
31
+ from langchain_community.document_loaders import (
32
+ PyPDFDirectoryLoader,
33
+ PyPDFLoader,
34
+ ) # Document loaders for PDFs
35
+ from langchain_community.cross_encoders import (
36
+ HuggingFaceCrossEncoder,
37
+ ) # Cross-encoders from HuggingFace
38
+ from langchain_experimental.text_splitter import (
39
+ SemanticChunker,
40
+ ) # Experimental text splitting methods
41
+ from langchain.text_splitter import (
42
+ CharacterTextSplitter, # Splitting text by characters
43
+ RecursiveCharacterTextSplitter, # Recursive splitting of text by characters
44
+ )
45
+ from langchain_core.tools import tool
46
+ from langchain.agents import create_tool_calling_agent, AgentExecutor
47
+ from langchain_core.prompts import ChatPromptTemplate
48
+
49
+ # LangChain OpenAI imports
50
+ from langchain_openai import (
51
+ OpenAIEmbeddings,
52
+ ChatOpenAI,
53
+ ) # OpenAI embeddings and models
54
+
55
+
56
+ # LlamaParse & LlamaIndex imports
57
+ from llama_parse import LlamaParse # Document parsing library
58
+ from llama_index.core import (
59
+ Settings,
60
+ SimpleDirectoryReader,
61
+ ) # Core functionalities of the LlamaIndex
62
+
63
+ # LangGraph import
64
+ from langgraph.graph import (
65
+ StateGraph,
66
+ END,
67
+ START,
68
+ ) # State graph for managing states in LangChain
69
+
70
+ # Pydantic import
71
+ from pydantic import BaseModel # Pydantic for data validation
72
+
73
+ # Typing imports
74
+ from typing import (
75
+ Dict,
76
+ List,
77
+ Tuple,
78
+ Any,
79
+ TypedDict,
80
+ ) # Python typing for function annotations
81
+
82
+ # Other utilities
83
+ import numpy as np # Numpy for numerical operations
84
+ from groq import Groq
85
+ from mem0 import MemoryClient
86
+ import streamlit as st
87
+ from datetime import datetime
88
+
89
+ # ====================================SETUP=====================================#
90
+ # Fetch secrets from Hugging Face Spaces
91
+ api_key = os.getenv("OPENAI_API_KEY")
92
+ groq_api_key = os.getenv("GROQ_API_KEY")
93
+ mem0_api_key = os.getenv("MEM0_API_KEY")
94
+ os.environ["LANGSMITH_TRACING"] = "true"
95
+ os.environ["LANGSMITH_ENDPOINT"] = "https://api.smith.langchain.com"
96
+ os.environ["LANGSMITH_API_KEY"] = os.getenv("LANGSMITH_API_KEY")
97
+ os.environ["LANGSMITH_PROJECT"] = "nutrition_bot"
98
+
99
+ # Initialize the OpenAI embedding function for Chroma
100
+ embedding_function = chromadb.utils.embedding_functions.OpenAIEmbeddingFunction(
101
+ api_key=api_key, # Complete the code to define the API key
102
+ model_name="text-embedding-3-small", # This is a fixed value and does not need modification
103
+ )
104
+
105
+
106
+ # Initialize the OpenAI Embeddings
107
+ embedding_model = OpenAIEmbeddings(
108
+ openai_api_key=api_key, model="text-embedding-3-small"
109
+ )
110
+
111
+
112
+ # Initialize the Chat OpenAI model
113
+ llm = ChatOpenAI(openai_api_key=api_key, model="gpt-4o-mini", streaming=False)
114
+
115
+
116
+ # set the LLM and embedding model in the LlamaIndex settings.
117
+ Settings.llm = llm
118
+ Settings.embedding = embedding_model
119
+
120
+ # ================================Creating Langgraph agent======================#
121
+
122
+
123
+ class AgentState(TypedDict):
124
+ query: str # The current user query
125
+ expanded_query: str # The expanded version of the user query
126
+ context: List[Dict[str, Any]] # Retrieved documents (content and metadata)
127
+ response: str # The generated response to the user query
128
+ precision_score: float # The precision score of the response
129
+ groundedness_score: float # The groundedness score of the response
130
+ groundedness_loop_count: int # Counter for groundedness refinement loops
131
+ precision_loop_count: int # Counter for precision refinement loops
132
+ feedback: str
133
+ query_feedback: str
134
+ groundedness_check: bool
135
+ loop_max_iter: int
136
+
137
+
138
+ def expand_query(state):
139
+ """
140
+ Expands the user query to improve retrieval of nutrition disorder-related information.
141
+
142
+ Args:
143
+ state (Dict): The current state of the workflow, containing the user query.
144
+
145
+ Returns:
146
+ Dict: The updated state with the expanded query.
147
+ """
148
+ print("---------Expanding Query---------")
149
+ system_message = """You are an AI specializing in improving search queries to retrieve the most relevant nutrition disorder-related information.
150
+ Your task is to **refine** and **expand** the given query so that better search results are obtained, while **keeping the original intent** unchanged.
151
+
152
+ Guidelines:
153
+ - Add **specific details** where needed. Example: If a user asks about "anorexia," specify aspects like symptoms, causes, or treatment options.
154
+ - Include **related terms** to improve retrieval (e.g., “bulimia” → “bulimia nervosa vs binge eating disorder”).
155
+ - If the user provides an unclear query, suggest necessary clarifications.
156
+ - **DO NOT** answer the question. Your job is only to enhance the query.
157
+
158
+ Examples:
159
+ 1. User Query: "Tell me about eating disorders."
160
+ Expanded Query: "Provide details on eating disorders, including types (e.g., anorexia nervosa, bulimia nervosa), symptoms, causes, and treatment options."
161
+
162
+ 2. User Query: "What is anorexia?"
163
+ Expanded Query: "Explain anorexia nervosa, including its symptoms, causes, risk factors, and treatment options."
164
+
165
+ 3. User Query: "How to treat bulimia?"
166
+ Expanded Query: "Describe treatment options for bulimia nervosa, including psychotherapy, medications, and lifestyle changes."
167
+
168
+ 4. User Query: "What are the effects of malnutrition?"
169
+ Expanded Query: "Explain the effects of malnutrition on physical and mental health, including specific nutrient deficiencies and their consequences."
170
+
171
+ Now, expand the following query:"""
172
+
173
+ expand_prompt = ChatPromptTemplate.from_messages(
174
+ [
175
+ ("system", system_message),
176
+ ("user", "Expand this query: {query} using the feedback: {query_feedback}"),
177
+ ]
178
+ )
179
+
180
+ chain = expand_prompt | llm | StrOutputParser()
181
+ expanded_query = chain.invoke(
182
+ {"query": state["query"], "query_feedback": state["query_feedback"]}
183
+ )
184
+ print("expanded_query", expanded_query)
185
+ state["expanded_query"] = expanded_query
186
+ return state
187
+
188
+
189
+ # Initialize the Chroma vector store for retrieving documents
190
+ vector_store = Chroma(
191
+ collection_name="nutritional_hypotheticals",
192
+ persist_directory="./nutritional_db",
193
+ embedding_function=embedding_model,
194
+ )
195
+
196
+ # Create a retriever from the vector store
197
+ retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 3})
198
+
199
+
200
+ def retrieve_context(state):
201
+ """
202
+ Retrieves context from the vector store using the expanded or original query.
203
+
204
+ Args:
205
+ state (Dict): The current state of the workflow, containing the query and expanded query.
206
+
207
+ Returns:
208
+ Dict: The updated state with the retrieved context.
209
+ """
210
+ print("---------retrieve_context---------")
211
+ query = state["expanded_query"]
212
+ # print("Query used for retrieval:", query) # Debugging: Print the query
213
+
214
+ # Retrieve documents from the vector store
215
+ docs = retriever.invoke(query)
216
+ print("Retrieved documents:", docs) # Debugging: Print the raw docs object
217
+
218
+ # Extract both page_content and metadata from each document
219
+ context = [
220
+ {
221
+ "content": doc.page_content, # The actual content of the document
222
+ "metadata": doc.metadata, # The metadata (e.g., source, page number, etc.)
223
+ }
224
+ for doc in docs
225
+ ]
226
+ state["context"] = context
227
+ print(
228
+ "Extracted context with metadata:", context
229
+ ) # Debugging: Print the extracted context
230
+ # print(f"Groundedness loop count: {state['groundedness_loop_count']}")
231
+ return state
232
+
233
+
234
+ def craft_response(state: Dict) -> Dict:
235
+ """
236
+ Generates a response using the retrieved context, focusing on nutrition disorders.
237
+
238
+ Args:
239
+ state (Dict): The current state of the workflow, containing the query and retrieved context.
240
+
241
+ Returns:
242
+ Dict: The updated state with the generated response.
243
+ """
244
+ system_message = '''You are a professional AI nutrition disorder specialist generating responses based on retrieved documents.
245
+ Your task is to use the given **context** to generate a highly accurate, informative, and user-friendly response.
246
+
247
+ Guidelines:
248
+ - **Be direct and concise** while ensuring completeness.
249
+ - **DO NOT include information that is not present in the context.**
250
+ - If multiple sources exist, synthesize them into a coherent response.
251
+ - If the context does not fully answer the query, state what additional information is needed.
252
+ - Use bullet points when explaining complex concepts.
253
+
254
+ Example:
255
+ User Query: "What are the symptoms of anorexia nervosa?"
256
+ Context:
257
+ 1. Anorexia nervosa is characterized by extreme weight loss and fear of gaining weight.
258
+ 2. Common symptoms include restricted eating, distorted body image, and excessive exercise.
259
+ Response:
260
+ "Anorexia nervosa is an eating disorder characterized by extreme weight loss and an intense fear of gaining weight. Common symptoms include:
261
+ - Restricted eating
262
+ - Distorted body image
263
+ - Excessive exercise
264
+ If you or someone you know is experiencing these symptoms, it is important to seek professional help."'''
265
+
266
+ response_prompt = ChatPromptTemplate.from_messages(
267
+ [
268
+ ("system", system_message),
269
+ ("user", "Query: {query}\nContext: {context}\n\nResponse:"),
270
+ ]
271
+ )
272
+
273
+ chain = response_prompt | llm | StrOutputParser()
274
+ state["response"] = chain.invoke(
275
+ {
276
+ "query": state["query"],
277
+ "context": "\n".join(
278
+ [doc["content"] for doc in state["context"]]
279
+ ), # Extract content from each document
280
+ }
281
+ )
282
+ return state
283
+
284
+
285
+ def score_groundedness(state: Dict) -> Dict:
286
+ """
287
+ Checks whether the response is grounded in the retrieved context.
288
+
289
+ Args:
290
+ state (Dict): The current state of the workflow, containing the response and context.
291
+
292
+ Returns:
293
+ Dict: The updated state with the groundedness score.
294
+ """
295
+ print("---------check_groundedness---------")
296
+ system_message = """You are an AI tasked with evaluating whether a response is grounded in the provided context and includes proper citations.
297
+
298
+ Guidelines:
299
+ 1. **Groundedness Check**:
300
+ - Verify that the response accurately reflects the information in the context.
301
+ - Flag any unsupported claims or deviations from the context.
302
+
303
+ 2. **Citation Check**:
304
+ - Ensure that the response includes citations to the source material (e.g., "According to [Source], ...").
305
+ - If citations are missing, suggest adding them.
306
+
307
+ 3. **Scoring**:
308
+ - Assign a groundedness score between 0 and 1, where 1 means fully grounded and properly cited.
309
+
310
+ Examples:
311
+ 1. Response: "Anorexia nervosa is caused by genetic factors (Source 1)."
312
+ Context: "Anorexia nervosa is influenced by genetic, environmental, and psychological factors (Source 1)."
313
+ Evaluation: "The response is grounded and properly cited. Groundedness score: 1.0."
314
+
315
+ 2. Response: "Bulimia nervosa can be cured with diet alone."
316
+ Context: "Treatment for bulimia nervosa involves psychotherapy and medications (Source 2)."
317
+ Evaluation: "The response is ungrounded and lacks citations. Groundedness score: 0.2."
318
+
319
+ 3. Response: "Anorexia nervosa has a high mortality rate."
320
+ Context: "Anorexia nervosa has one of the highest mortality rates among psychiatric disorders (Source 3)."
321
+ Evaluation: "The response is grounded but lacks a citation. Groundedness score: 0.7. ."
322
+
323
+ ****Return only a float score (e.g., 0.9). Do not provide explanations.****
324
+
325
+ Now, evaluate the following response:
326
+ """
327
+
328
+ groundedness_prompt = ChatPromptTemplate.from_messages(
329
+ [
330
+ ("system", system_message),
331
+ ("user", "Context: {context}\nResponse: {response}\n\nGroundedness score:"),
332
+ ]
333
+ )
334
+
335
+ chain = groundedness_prompt | llm | StrOutputParser()
336
+ groundedness_score = float(
337
+ chain.invoke(
338
+ {
339
+ "context": "\n".join([doc["content"] for doc in state["context"]]),
340
+ "response": state["response"],
341
+ }
342
+ )
343
+ )
344
+ print("groundedness_score: ", groundedness_score)
345
+ state["groundedness_loop_count"] += 1
346
+ print("#########Groundedness Incremented###########")
347
+ state["groundedness_score"] = groundedness_score
348
+ return state
349
+
350
+
351
+ def check_precision(state: Dict) -> Dict:
352
+ """
353
+ Checks whether the response precisely addresses the user’s query.
354
+
355
+ Args:
356
+ state (Dict): The current state of the workflow, containing the query and response.
357
+
358
+ Returns:
359
+ Dict: The updated state with the precision score.
360
+ """
361
+ print("---------check_precision---------")
362
+ system_message = """You are an AI evaluator assessing the **precision** of the response.
363
+ Your task is to **score** how well the response addresses the user’s original nutrition disorder-related query.
364
+
365
+ Scoring Criteria:
366
+ - 1.0 → The response is fully precise, directly answering the question.
367
+ - 0.7 → The response is mostly correct but contains some generalization.
368
+ - 0.5 → The response is somewhat relevant but lacks key details.
369
+ - 0.3 → The response is vague or only partially correct.
370
+ - 0.0 → The response is incorrect or misleading.
371
+
372
+ Examples:
373
+ 1. Query: "What are the symptoms of anorexia nervosa?"
374
+ Response: "The symptoms of anorexia nervosa include extreme weight loss, fear of gaining weight, and a distorted body image."
375
+ Precision Score: 1.0
376
+
377
+ 2. Query: "How is bulimia nervosa treated?"
378
+ Response: "Bulimia nervosa is treated with therapy and medications."
379
+ Precision Score: 0.7
380
+
381
+ 3. Query: "What causes binge eating disorder?"
382
+ Response: "Binge eating disorder is caused by a combination of genetic, psychological, and environmental factors."
383
+ Precision Score: 0.5
384
+
385
+ 4. Query: "What are the effects of malnutrition?"
386
+ Response: "Malnutrition can lead to health problems."
387
+ Precision Score: 0.3
388
+
389
+ 5. Query: "What is the mortality rate of anorexia nervosa?"
390
+ Response: "Anorexia nervosa is a type of eating disorder."
391
+ Precision Score: 0.0
392
+
393
+ *****Return only a float score (e.g., 0.9). Do not provide explanations.*****
394
+ Now, evaluate the following query and response:
395
+ """
396
+ precision_prompt = ChatPromptTemplate.from_messages(
397
+ [
398
+ ("system", system_message),
399
+ ("user", "Query: {query}\nResponse: {response}\n\nPrecision score:"),
400
+ ]
401
+ )
402
+
403
+ chain = precision_prompt | llm | StrOutputParser()
404
+ precision_score = float(
405
+ chain.invoke({"query": state["query"], "response": state["response"]})
406
+ )
407
+ state["precision_score"] = precision_score
408
+ print("precision_score:", precision_score)
409
+ state["precision_loop_count"] += 1
410
+ print("#########Precision Incremented###########")
411
+ return state
412
+
413
+
414
+ def refine_response(state: Dict) -> Dict:
415
+ """
416
+ Suggests improvements for the generated response.
417
+
418
+ Args:
419
+ state (Dict): The current state of the workflow, containing the query and response.
420
+
421
+ Returns:
422
+ Dict: The updated state with response refinement suggestions.
423
+ """
424
+ print("---------refine_response---------")
425
+
426
+ system_message = """You are an AI response refinement assistant. Your task is to suggest **improvements** for the given response.
427
+
428
+ ### Guidelines:
429
+ - Identify **gaps in the explanation** (missing key details).
430
+ - Highlight **unclear or vague parts** that need elaboration.
431
+ - Suggest **additional details** that should be included for better accuracy.
432
+ - Ensure the refined response is **precise** and **grounded** in the retrieved context.
433
+
434
+ ### Examples:
435
+ 1. Query: "What are the symptoms of anorexia nervosa?"
436
+ Response: "The symptoms include weight loss and fear of gaining weight."
437
+ Suggestions: "The response is missing key details about behavioral and emotional symptoms. Add details like 'distorted body image' and 'restrictive eating patterns.'"
438
+
439
+ 2. Query: "How is bulimia nervosa treated?"
440
+ Response: "Bulimia nervosa is treated with therapy."
441
+ Suggestions: "The response is too vague. Specify the types of therapy (e.g., cognitive-behavioral therapy) and mention other treatments like nutritional counseling and medications."
442
+
443
+ 3. Query: "What causes binge eating disorder?"
444
+ Response: "Binge eating disorder is caused by psychological factors."
445
+ Suggestions: "The response is incomplete. Add details about genetic and environmental factors, and explain how they contribute to the disorder."
446
+
447
+ Now, suggest improvements for the following response:
448
+ """
449
+
450
+ refine_response_prompt = ChatPromptTemplate.from_messages(
451
+ [
452
+ ("system", system_message),
453
+ (
454
+ "user",
455
+ "Query: {query}\nResponse: {response}\n\n"
456
+ "What improvements can be made to enhance accuracy and completeness?",
457
+ ),
458
+ ]
459
+ )
460
+
461
+ chain = refine_response_prompt | llm | StrOutputParser()
462
+
463
+ # Store response suggestions in a structured format
464
+ feedback = f"Previous Response: {state['response']}\nSuggestions: {chain.invoke({'query': state['query'], 'response': state['response']})}"
465
+ print("feedback: ", feedback)
466
+ print(f"State: {state}")
467
+ state["feedback"] = feedback
468
+ return state
469
+
470
+
471
+ def refine_query(state: Dict) -> Dict:
472
+ """
473
+ Suggests improvements for the expanded query.
474
+
475
+ Args:
476
+ state (Dict): The current state of the workflow, containing the query and expanded query.
477
+
478
+ Returns:
479
+ Dict: The updated state with query refinement suggestions.
480
+ """
481
+ print("---------refine_query---------")
482
+ system_message = """You are an AI query refinement assistant. Your task is to suggest **improvements** for the expanded query.
483
+
484
+ ### Guidelines:
485
+ - Add **specific keywords** to improve document retrieval.
486
+ - Identify **missing details** that should be included.
487
+ - Suggest **ways to narrow the scope** for better precision.
488
+
489
+ ### Examples:
490
+ 1. Original Query: "Tell me about eating disorders."
491
+ Expanded Query: "Provide details on eating disorders, including types, symptoms, causes, and treatment options."
492
+ Suggestions: "Add specific types of eating disorders like 'anorexia nervosa' and 'bulimia nervosa' to improve retrieval."
493
+
494
+ 2. Original Query: "What is anorexia?"
495
+ Expanded Query: "Explain anorexia nervosa, including its symptoms and causes."
496
+ Suggestions: "Include details about treatment options and risk factors to make the query more comprehensive."
497
+
498
+ 3. Original Query: "How to treat bulimia?"
499
+ Expanded Query: "Describe treatment options for bulimia nervosa."
500
+ Suggestions: "Specify types of treatments like 'cognitive-behavioral therapy' and 'medications' for better precision."
501
+
502
+ Now, suggest improvements for the following expanded query:
503
+ """
504
+
505
+ refine_query_prompt = ChatPromptTemplate.from_messages(
506
+ [
507
+ ("system", system_message),
508
+ (
509
+ "user",
510
+ "Original Query: {query}\nExpanded Query: {expanded_query}\n\n"
511
+ "What improvements can be made for a better search?",
512
+ ),
513
+ ]
514
+ )
515
+
516
+ chain = refine_query_prompt | llm | StrOutputParser()
517
+
518
+ # Store refinement suggestions without modifying the original expanded query
519
+ query_feedback = f"Previous Expanded Query: {state['expanded_query']}\nSuggestions: {chain.invoke({'query': state['query'], 'expanded_query': state['expanded_query']})}"
520
+ print("query_feedback: ", query_feedback)
521
+ print(f"Groundedness loop count: {state['groundedness_loop_count']}")
522
+ state["query_feedback"] = query_feedback
523
+ return state
524
+
525
+
526
+ def should_continue_groundedness(state):
527
+ """Decides if groundedness is sufficient or needs improvement."""
528
+ print("---------should_continue_groundedness---------")
529
+ print("groundedness loop count: ", state["groundedness_loop_count"])
530
+ if state["groundedness_score"] >= 0.4: # Threshold for groundedness
531
+ print("Moving to precision")
532
+ return "check_precision"
533
+ else:
534
+ if state["groundedness_loop_count"] > state["loop_max_iter"]:
535
+ return "max_iterations_reached"
536
+ else:
537
+ print(
538
+ f"---------Groundedness Score Threshold Not met. Refining Response-----------"
539
+ )
540
+ return "refine_response"
541
+
542
+
543
+ def should_continue_precision(state: Dict) -> str:
544
+ """Decides if precision is sufficient or needs improvement."""
545
+ print("---------should_continue_precision---------")
546
+ print("precision loop count: ", state["precision_loop_count"])
547
+ if state["precision_score"] >= 0.7: # Threshold for precision
548
+ return "pass" # Complete the workflow
549
+ else:
550
+ if (
551
+ state["precision_loop_count"] > state["loop_max_iter"]
552
+ ): # Maximum allowed loops
553
+ return "max_iterations_reached"
554
+ else:
555
+ print(
556
+ f"---------Precision Score Threshold Not met. Refining Query-----------"
557
+ ) # Debugging
558
+ # Exit the loop
559
+ return "refine_query" # Refine the query
560
+
561
+
562
+ def max_iterations_reached(state: Dict) -> Dict:
563
+ """Handles the case when the maximum number of iterations is reached."""
564
+ print("---------max_iterations_reached---------")
565
+ """Handles the case when the maximum number of iterations is reached."""
566
+ response = "I'm unable to refine the response further. Please provide more context or clarify your question."
567
+ state["response"] = response
568
+ return state
569
+
570
+
571
+ def create_workflow() -> StateGraph:
572
+ """Creates the updated workflow for the AI nutrition agent."""
573
+ workflow = StateGraph(AgentState)
574
+
575
+ # Add processing nodes
576
+ workflow.add_node("expand_query", expand_query) # Step 1: Expand user query.
577
+ workflow.add_node(
578
+ "retrieve_context", retrieve_context
579
+ ) # Step 2: Retrieve relevant documents.
580
+ workflow.add_node(
581
+ "craft_response", craft_response
582
+ ) # Step 3: Generate a response based on retrieved data.
583
+ workflow.add_node(
584
+ "score_groundedness", score_groundedness
585
+ ) # Step 4: Evaluate response grounding.
586
+ workflow.add_node(
587
+ "refine_response", refine_response
588
+ ) # Step 5: Improve response if it's weakly grounded.
589
+ workflow.add_node(
590
+ "check_precision", check_precision
591
+ ) # Step 6: Evaluate response precision.
592
+ workflow.add_node(
593
+ "refine_query", refine_query
594
+ ) # Step 7: Improve query if response lacks precision.
595
+ workflow.add_node(
596
+ "max_iterations_reached", max_iterations_reached
597
+ ) # Step 8: Handle max iterations.
598
+ # workflow.add_node("groundedness_decider",groundedness_decider)
599
+ # Main flow edges
600
+ workflow.add_edge(START, "expand_query")
601
+ workflow.add_edge("expand_query", "retrieve_context")
602
+ workflow.add_edge("retrieve_context", "craft_response")
603
+ workflow.add_edge("craft_response", "score_groundedness")
604
+ # workflow.add_edge("score_groundedness","groundedness_decider")
605
+
606
+ # Conditional edges based on groundedness check
607
+ workflow.add_conditional_edges(
608
+ "score_groundedness",
609
+ should_continue_groundedness, # Use the conditional function
610
+ {
611
+ "check_precision": "check_precision", # If well-grounded, proceed to precision check.
612
+ "refine_response": "refine_response", # If not, refine the response.
613
+ "max_iterations_reached": "max_iterations_reached", # If max loops reached, exit.
614
+ },
615
+ )
616
+ workflow.add_edge(
617
+ "refine_response", "craft_response"
618
+ ) # Refined responses are reprocessed.
619
+
620
+ # Conditional edges based on precision check
621
+ workflow.add_conditional_edges(
622
+ "check_precision",
623
+ should_continue_precision, # Use the conditional function
624
+ {
625
+ "pass": END, # If precise, complete the workflow.
626
+ "refine_query": "refine_query", # If imprecise, refine the query.
627
+ "max_iterations_reached": "max_iterations_reached", # If max loops reached, exit.
628
+ },
629
+ )
630
+ workflow.add_edge(
631
+ "refine_query", "expand_query"
632
+ ) # Refined queries go through expansion again.
633
+
634
+ workflow.add_edge("max_iterations_reached", END)
635
+ # Set entry point
636
+ # workflow.set_entry_point("expand_query")
637
+
638
+ return workflow
639
+
640
+
641
+ # =========================== Defining the agentic rag tool ====================#
642
+ WORKFLOW_APP = create_workflow().compile()
643
+
644
+
645
+ @tool
646
+ def agentic_rag(query: str):
647
+ """
648
+ Runs the RAG-based agent with conversation history for context-aware responses.
649
+
650
+ Args:
651
+ query (str): The current user query.
652
+
653
+ Returns:
654
+ Dict[str, Any]: The updated state with the generated response and conversation history.
655
+ """
656
+ # Initialize state with necessary parameters
657
+ inputs = {
658
+ "query": query, # Current user query
659
+ "expanded_query": "", # Expanded version of the query
660
+ "context": [], # Retrieved documents (initially empty)
661
+ "response": "", # AI-generated response
662
+ "precision_score": 0.0, # Precision score of the response
663
+ "groundedness_score": 0.0, # Groundedness score of the response
664
+ "groundedness_loop_count": 0, # Counter for groundedness loops
665
+ "precision_loop_count": 0, # Counter for precision loops
666
+ "feedback": "",
667
+ "query_feedback": "",
668
+ "loop_max_iter": 2,
669
+ }
670
+
671
+ output = WORKFLOW_APP.invoke(inputs)
672
+
673
+ return output
674
+
675
+
676
+ # ================================ Guardrails ===========================#
677
+ llama_guard_client = Groq(api_key=groq_api_key)
678
+
679
+
680
+ # Function to filter user input with Llama Guard
681
+ def filter_input_with_llama_guard(user_input, model="meta-llama/llama-guard-4-12b"):
682
+ """
683
+ Filters user input using Llama Guard to ensure it is safe.
684
+
685
+ Parameters:
686
+ - user_input: The input provided by the user.
687
+ - model: The Llama Guard model to be used for filtering (default is "meta-llama/llama-guard-4-12b").
688
+
689
+ Returns:
690
+ - The filtered and safe input.
691
+ """
692
+ try:
693
+ # Create a request to Llama Guard to filter the user input
694
+ response = llama_guard_client.chat.completions.create(
695
+ messages=[{"role": "user", "content": user_input}],
696
+ model=model,
697
+ )
698
+ # Return the filtered input
699
+ return response.choices[0].message.content.strip()
700
+ except Exception as e:
701
+ print(f"Error with Llama Guard: {e}")
702
+ return None
703
+
704
+
705
+ # ============================= Adding Memory to the agent using mem0 ===============================#
706
+
707
+
708
+ class NutritionBot:
709
+ def __init__(self):
710
+ """
711
+ Initialize the NutritionBot class, setting up memory, the LLM client, tools, and the agent executor.
712
+ """
713
+
714
+ # Initialize a memory client to store and retrieve customer interactions
715
+ self.memory = MemoryClient(api_key=mem0_api_key)
716
+
717
+ self.client = ChatOpenAI(
718
+ model_name="gpt-4o-mini", # Specify the model to use (e.g., GPT-4 optimized version)
719
+ api_key=api_key, # API key for authentication
720
+ temperature=0, # Controls randomness in responses; 0 ensures deterministic results
721
+ )
722
+
723
+ # Define tools available to the chatbot, such as web search
724
+ tools = [agentic_rag]
725
+
726
+ # Define the system prompt to set the behavior of the chatbot
727
+ system_prompt = """You are a caring and knowledgeable Medical Support Agent, specializing in nutrition disorder-related guidance. Your goal is to provide accurate, empathetic, and tailored nutritional recommendations while ensuring a seamless customer experience.
728
+ Guidelines for Interaction:
729
+ Maintain a polite, professional, and reassuring tone.
730
+ Show genuine empathy for customer concerns and health challenges.
731
+ Reference past interactions to provide personalized and consistent advice.
732
+ Engage with the customer by asking about their food preferences, dietary restrictions, and lifestyle before offering recommendations.
733
+ Ensure consistent and accurate information across conversations.
734
+ If any detail is unclear or missing, proactively ask for clarification.
735
+ Always use the agentic_rag tool to retrieve up-to-date and evidence-based nutrition insights.
736
+ Keep track of ongoing issues and follow-ups to ensure continuity in support.
737
+ Your primary goal is to help customers make informed nutrition decisions that align with their health conditions and personal preferences.
738
+
739
+ """
740
+
741
+ # Build the prompt template for the agent
742
+ prompt = ChatPromptTemplate.from_messages(
743
+ [
744
+ ("system", system_prompt), # System instructions
745
+ ("human", "{input}"), # Placeholder for human input
746
+ (
747
+ "placeholder",
748
+ "{agent_scratchpad}",
749
+ ), # Placeholder for intermediate reasoning steps
750
+ ]
751
+ )
752
+
753
+ # Create an agent capable of interacting with tools and executing tasks
754
+ agent = create_tool_calling_agent(self.client, tools, prompt)
755
+
756
+ # Wrap the agent in an executor to manage tool interactions and execution flow
757
+ self.agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
758
+
759
+ def store_customer_interaction(
760
+ self, user_id: str, message: str, response: str, metadata: Dict = None
761
+ ):
762
+ """
763
+ Store customer interaction in memory for future reference.
764
+
765
+ Args:
766
+ user_id (str): Unique identifier for the customer.
767
+ message (str): Customer's query or message.
768
+ response (str): Chatbot's response.
769
+ metadata (Dict, optional): Additional metadata for the interaction.
770
+ """
771
+ if metadata is None:
772
+ metadata = {}
773
+
774
+ # Add a timestamp to the metadata for tracking purposes
775
+ metadata["timestamp"] = datetime.now().isoformat()
776
+
777
+ # Format the conversation for storage
778
+ conversation = [
779
+ {"role": "user", "content": message},
780
+ {"role": "assistant", "content": response},
781
+ ]
782
+
783
+ # Store the interaction in the memory client
784
+ self.memory.add(
785
+ conversation, user_id=user_id, output_format="v1.1", metadata=metadata
786
+ )
787
+
788
+ def get_relevant_history(self, user_id: str, query: str) -> List[Dict]:
789
+ """
790
+ Retrieve past interactions relevant to the current query.
791
+
792
+ Args:
793
+ user_id (str): Unique identifier for the customer.
794
+ query (str): The customer's current query.
795
+
796
+ Returns:
797
+ List[Dict]: A list of relevant past interactions.
798
+ """
799
+ return self.memory.search(
800
+ query=query, # Search for interactions related to the query
801
+ user_id=user_id, # Restrict search to the specific user
802
+ limit=5, # Retrieve up to 5 relevant interactions
803
+ )
804
+
805
+ def handle_customer_query(self, user_id: str, query: str) -> str:
806
+ """
807
+ Process a customer's query and provide a response, taking into account past interactions.
808
+
809
+ Args:
810
+ user_id (str): Unique identifier for the customer.
811
+ query (str): Customer's query.
812
+
813
+ Returns:
814
+ str: Chatbot's response.
815
+ """
816
+
817
+ # Retrieve relevant past interactions for context
818
+ relevant_history = self.get_relevant_history(user_id, query)
819
+
820
+ # Build a context string from the relevant history
821
+ context = "Previous relevant interactions:\n"
822
+ for memory in relevant_history:
823
+ context += f"Customer: {memory['memory']}\n" # Customer's past messages
824
+ context += f"Support: {memory['memory']}\n" # Chatbot's past responses
825
+ context += "---\n"
826
+
827
+ # Print context for debugging purposes
828
+ print("Context: ", context)
829
+
830
+ # Prepare a prompt combining past context and the current query
831
+ prompt = f"""
832
+ Context:
833
+ {context}
834
+
835
+ Current customer query: {query}
836
+
837
+ Provide a helpful response that takes into account any relevant past interactions.
838
+ """
839
+
840
+ # Generate a response using the agent
841
+ response = self.agent_executor.invoke({"input": prompt})
842
+
843
+ # Store the current interaction for future reference
844
+ self.store_customer_interaction(
845
+ user_id=user_id,
846
+ message=query,
847
+ response=response["output"],
848
+ metadata={"type": "support_query"},
849
+ )
850
+
851
+ # Return the chatbot's response
852
+ return response["output"]
853
+
854
+
855
+ # =====================User Interface using streamlit ===========================#
856
+ def nutrition_disorder_streamlit():
857
+ """
858
+ A Streamlit-based UI for the Nutrition Disorder Specialist Agent.
859
+ """
860
+ st.title("Nutrition Disorder Specialist")
861
+ st.write(
862
+ "Ask me anything about nutrition disorders, symptoms, causes, treatments, and more."
863
+ )
864
+ st.write("Type 'exit' to end the conversation.")
865
+
866
+ # Initialize session state for chat history and user_id if they don't exist
867
+ if "chat_history" not in st.session_state:
868
+ st.session_state.chat_history = []
869
+ if "user_id" not in st.session_state:
870
+ st.session_state.user_id = None
871
+
872
+ # Login form: Only if user is not logged in
873
+ if st.session_state.user_id is None:
874
+ with st.form("login_form", clear_on_submit=True):
875
+ user_id = st.text_input("Please enter your name to begin:")
876
+ submit_button = st.form_submit_button("Login")
877
+ if submit_button and user_id:
878
+ st.session_state.user_id = user_id
879
+ st.session_state.chat_history.append(
880
+ {
881
+ "role": "assistant",
882
+ "content": f"Welcome, {user_id}! How can I help you with nutrition disorders today?",
883
+ }
884
+ )
885
+ st.session_state.login_submitted = True # Set flag to trigger rerun
886
+
887
+ # Trigger rerun outside the form if login was successful
888
+ if st.session_state.get("login_submitted", False):
889
+ st.session_state.pop("login_submitted")
890
+ st.rerun()
891
+ else:
892
+ # Display chat history
893
+ for message in st.session_state.chat_history:
894
+ with st.chat_message(message["role"]):
895
+ st.write(message["content"])
896
+
897
+ # Chat input
898
+ user_query = st.chat_input("Type your question here (or 'exit' to end)...")
899
+
900
+ if user_query:
901
+ # Check if user wants to exit
902
+ if user_query.lower() == "exit":
903
+ st.session_state.chat_history.append(
904
+ {"role": "user", "content": "exit"}
905
+ )
906
+ with st.chat_message("user"):
907
+ st.write("exit")
908
+ goodbye_msg = "Goodbye! Feel free to return if you have more questions about nutrition disorders."
909
+ st.session_state.chat_history.append(
910
+ {"role": "assistant", "content": goodbye_msg}
911
+ )
912
+ with st.chat_message("assistant"):
913
+ st.write(goodbye_msg)
914
+ st.session_state.user_id = None
915
+ st.rerun()
916
+ return
917
+
918
+ # Add user message to chat history
919
+ st.session_state.chat_history.append(
920
+ {"role": "user", "content": user_query}
921
+ )
922
+ with st.chat_message("user"):
923
+ st.write(user_query)
924
+
925
+ # Filter input
926
+ filtered_result = filter_input_with_llama_guard(user_query)
927
+
928
+ # Process through the agent
929
+ with st.chat_message("assistant"):
930
+ if filtered_result in ["safe", "unsafe S7", "unsafe S6"]:
931
+ try:
932
+ # Initialize chatbot if not already done
933
+ if "chatbot" not in st.session_state:
934
+ st.session_state.chatbot = NutritionBot()
935
+
936
+ # Get response from the chatbot
937
+ response = st.session_state.chatbot.handle_customer_query(
938
+ st.session_state.user_id, user_query
939
+ )
940
+
941
+ st.write(response)
942
+ st.session_state.chat_history.append(
943
+ {"role": "assistant", "content": response}
944
+ )
945
+ except Exception as e:
946
+ error_msg = f"Sorry, I encountered an error while processing your query. Please try again. Error: {str(e)}"
947
+ st.write(error_msg)
948
+ st.session_state.chat_history.append(
949
+ {"role": "assistant", "content": error_msg}
950
+ )
951
+ else:
952
+ inappropriate_msg = "I apologize, but I cannot process that input as it may be inappropriate. Please try again."
953
+ st.write(inappropriate_msg)
954
+ st.session_state.chat_history.append(
955
+ {"role": "assistant", "content": inappropriate_msg}
956
+ )
957
+
958
+
959
+ if __name__ == "__main__":
960
+ nutrition_disorder_streamlit()