gl-kp commited on
Commit
60ce079
·
verified ·
1 Parent(s): cd4eca5

Upload folder using huggingface_hub

Browse files
core/agent_state.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any, TypedDict
2
+
3
+ class AgentState(TypedDict):
4
+ query: str
5
+ expanded_query: str
6
+ context: List[Dict[str, Any]]
7
+ response: str
8
+ precision_score: float
9
+ groundedness_score: float
10
+ groundedness_loop_count: int
11
+ precision_loop_count: int
12
+ feedback: str
13
+ query_feedback: str
14
+ groundedness_check: bool
15
+ loop_max_iter: int
core/workflow.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langgraph.graph import StateGraph, END, START
2
+ from langchain_core.tools import tool
3
+ from agent_state import AgentState
4
+ from workflow_nodes import (
5
+ expand_query, retrieve_context, craft_response,
6
+ score_groundedness, refine_response, check_precision,
7
+ refine_query, max_iterations_reached
8
+ )
9
+ from workflow_conditions import should_continue_groundedness, should_continue_precision
10
+
11
+ def create_workflow() -> StateGraph:
12
+ """Creates the updated workflow for the AI nutrition agent."""
13
+ workflow = StateGraph(AgentState)
14
+
15
+ # Add processing nodes
16
+ workflow.add_node("expand_query", expand_query)
17
+ workflow.add_node("retrieve_context", retrieve_context)
18
+ workflow.add_node("craft_response", craft_response)
19
+ workflow.add_node("score_groundedness", score_groundedness)
20
+ workflow.add_node("refine_response", refine_response)
21
+ workflow.add_node("check_precision", check_precision)
22
+ workflow.add_node("refine_query", refine_query)
23
+ workflow.add_node("max_iterations_reached", max_iterations_reached)
24
+
25
+ # Main flow edges
26
+ workflow.add_edge(START, "expand_query")
27
+ workflow.add_edge("expand_query", "retrieve_context")
28
+ workflow.add_edge("retrieve_context", "craft_response")
29
+ workflow.add_edge("craft_response", "score_groundedness")
30
+
31
+ # Conditional edges based on groundedness check
32
+ workflow.add_conditional_edges(
33
+ "score_groundedness",
34
+ should_continue_groundedness,
35
+ {
36
+ "check_precision": "check_precision",
37
+ "refine_response": "refine_response",
38
+ "max_iterations_reached": "max_iterations_reached"
39
+ }
40
+ )
41
+ workflow.add_edge("refine_response", "craft_response")
42
+
43
+ # Conditional edges based on precision check
44
+ workflow.add_conditional_edges(
45
+ "check_precision",
46
+ should_continue_precision,
47
+ {
48
+ "pass": END,
49
+ "refine_query": "refine_query",
50
+ "max_iterations_reached": "max_iterations_reached"
51
+ }
52
+ )
53
+ workflow.add_edge("refine_query", "expand_query")
54
+ workflow.add_edge("max_iterations_reached", END)
55
+
56
+ return workflow
57
+
58
+ # Create workflow instance
59
+ WORKFLOW_APP = create_workflow().compile()
60
+
61
+ @tool
62
+ def agentic_rag(query: str):
63
+ """
64
+ Runs the RAG-based agent with conversation history for context-aware responses.
65
+
66
+ Args:
67
+ query (str): The current user query.
68
+
69
+ Returns:
70
+ Dict[str, Any]: The updated state with the generated response and conversation history.
71
+ """
72
+ inputs = {
73
+ "query": query,
74
+ "expanded_query": "",
75
+ "context": [],
76
+ "response": "",
77
+ "precision_score": 0.0,
78
+ "groundedness_score": 0.0,
79
+ "groundedness_loop_count": 0,
80
+ "precision_loop_count": 0,
81
+ "feedback": "",
82
+ "query_feedback": "",
83
+ "loop_max_iter": 2
84
+ }
85
+
86
+ output = WORKFLOW_APP.invoke(inputs)
87
+ return output
core/workflow_conditions.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ from config import config
3
+
4
+ def should_continue_groundedness(state: Dict) -> str:
5
+ """Decides if groundedness is sufficient or needs improvement."""
6
+ print("---------should_continue_groundedness---------")
7
+ print("groundedness loop count: ", state['groundedness_loop_count'])
8
+
9
+ if state['groundedness_score'] >= config.GROUNDEDNESS_THRESHOLD:
10
+ print("Moving to precision")
11
+ return "check_precision"
12
+ else:
13
+ if state["groundedness_loop_count"] > state['loop_max_iter']:
14
+ return "max_iterations_reached"
15
+ else:
16
+ print(f"---------Groundedness Score Threshold Not met. Refining Response-----------")
17
+ return "refine_response"
18
+
19
+ def should_continue_precision(state: Dict) -> str:
20
+ """Decides if precision is sufficient or needs improvement."""
21
+ print("---------should_continue_precision---------")
22
+ print("precision loop count: ", state['precision_loop_count'])
23
+
24
+ if state['precision_score'] >= config.PRECISION_THRESHOLD:
25
+ return "pass"
26
+ else:
27
+ if state['precision_loop_count'] > state['loop_max_iter']:
28
+ return "max_iterations_reached"
29
+ else:
30
+ print(f"---------Precision Score Threshold Not met. Refining Query-----------")
31
+ return "refine_query"
core/workflow_nodes.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ from langchain.prompts import ChatPromptTemplate
3
+ from langchain_core.output_parsers import StrOutputParser
4
+ from models import llm, retriever
5
+ from agent_state import AgentState
6
+ from config import config
7
+
8
+ def expand_query(state: AgentState) -> AgentState:
9
+ """Expands the user query to improve retrieval of nutrition disorder-related information."""
10
+ print("---------Expanding Query---------")
11
+
12
+ system_message = '''You are an AI specializing in improving search queries to retrieve the most relevant nutrition disorder-related information.
13
+ Your task is to **refine** and **expand** the given query so that better search results are obtained, while **keeping the original intent** unchanged.
14
+
15
+ Guidelines:
16
+ - Add **specific details** where needed. Example: If a user asks about "anorexia," specify aspects like symptoms, causes, or treatment options.
17
+ - Include **related terms** to improve retrieval (e.g., "bulimia" → "bulimia nervosa vs binge eating disorder").
18
+ - If the user provides an unclear query, suggest necessary clarifications.
19
+ - **DO NOT** answer the question. Your job is only to enhance the query.
20
+
21
+ Examples:
22
+ 1. User Query: "Tell me about eating disorders."
23
+ Expanded Query: "Provide details on eating disorders, including types (e.g., anorexia nervosa, bulimia nervosa), symptoms, causes, and treatment options."
24
+
25
+ 2. User Query: "What is anorexia?"
26
+ Expanded Query: "Explain anorexia nervosa, including its symptoms, causes, risk factors, and treatment options."
27
+
28
+ 3. User Query: "How to treat bulimia?"
29
+ Expanded Query: "Describe treatment options for bulimia nervosa, including psychotherapy, medications, and lifestyle changes."
30
+
31
+ 4. User Query: "What are the effects of malnutrition?"
32
+ Expanded Query: "Explain the effects of malnutrition on physical and mental health, including specific nutrient deficiencies and their consequences."
33
+
34
+ Now, expand the following query:'''
35
+
36
+ expand_prompt = ChatPromptTemplate.from_messages([
37
+ ("system", system_message),
38
+ ("user", "Expand this query: {query} using the feedback: {query_feedback}")
39
+ ])
40
+
41
+ chain = expand_prompt | llm | StrOutputParser()
42
+ expanded_query = chain.invoke({"query": state['query'], "query_feedback": state["query_feedback"]})
43
+ print("expanded_query", expanded_query)
44
+ state["expanded_query"] = expanded_query
45
+ return state
46
+
47
+ def retrieve_context(state: AgentState) -> AgentState:
48
+ """Retrieves context from the vector store using the expanded or original query."""
49
+ print("---------retrieve_context---------")
50
+ query = state['expanded_query']
51
+
52
+ docs = retriever.invoke(query)
53
+ print("Retrieved documents:", docs)
54
+
55
+ context = [
56
+ {
57
+ "content": doc.page_content,
58
+ "metadata": doc.metadata
59
+ }
60
+ for doc in docs
61
+ ]
62
+ state['context'] = context
63
+ print("Extracted context with metadata:", context)
64
+ return state
65
+
66
+ def craft_response(state: AgentState) -> AgentState:
67
+ """Generates a response using the retrieved context, focusing on nutrition disorders."""
68
+ system_message = '''You are a professional AI nutrition disorder specialist generating responses based on retrieved documents.
69
+ Your task is to use the given **context** to generate a highly accurate, informative, and user-friendly response.
70
+
71
+ Guidelines:
72
+ - **Be direct and concise** while ensuring completeness.
73
+ - **DO NOT include information that is not present in the context.**
74
+ - If multiple sources exist, synthesize them into a coherent response.
75
+ - If the context does not fully answer the query, state what additional information is needed.
76
+ - Use bullet points when explaining complex concepts.
77
+
78
+ Example:
79
+ User Query: "What are the symptoms of anorexia nervosa?"
80
+ Context:
81
+ 1. Anorexia nervosa is characterized by extreme weight loss and fear of gaining weight.
82
+ 2. Common symptoms include restricted eating, distorted body image, and excessive exercise.
83
+ Response:
84
+ "Anorexia nervosa is an eating disorder characterized by extreme weight loss and an intense fear of gaining weight. Common symptoms include:
85
+ - Restricted eating
86
+ - Distorted body image
87
+ - Excessive exercise
88
+ If you or someone you know is experiencing these symptoms, it is important to seek professional help."'''
89
+
90
+ response_prompt = ChatPromptTemplate.from_messages([
91
+ ("system", system_message),
92
+ ("user", "Query: {query}\nContext: {context}\n\nResponse:")
93
+ ])
94
+
95
+ chain = response_prompt | llm | StrOutputParser()
96
+ state['response'] = chain.invoke({
97
+ "query": state['query'],
98
+ "context": "\n".join([doc["content"] for doc in state['context']])
99
+ })
100
+ return state
101
+
102
+ def score_groundedness(state: AgentState) -> AgentState:
103
+ """Checks whether the response is grounded in the retrieved context."""
104
+ print("---------check_groundedness---------")
105
+
106
+ system_message = '''You are an AI tasked with evaluating whether a response is grounded in the provided context and includes proper citations.
107
+
108
+ Guidelines:
109
+ 1. **Groundedness Check**:
110
+ - Verify that the response accurately reflects the information in the context.
111
+ - Flag any unsupported claims or deviations from the context.
112
+
113
+ 2. **Citation Check**:
114
+ - Ensure that the response includes citations to the source material (e.g., "According to [Source], ...").
115
+ - If citations are missing, suggest adding them.
116
+
117
+ 3. **Scoring**:
118
+ - Assign a groundedness score between 0 and 1, where 1 means fully grounded and properly cited.
119
+
120
+ Examples:
121
+ 1. Response: "Anorexia nervosa is caused by genetic factors (Source 1)."
122
+ Context: "Anorexia nervosa is influenced by genetic, environmental, and psychological factors (Source 1)."
123
+ Evaluation: "The response is grounded and properly cited. Groundedness score: 1.0."
124
+
125
+ 2. Response: "Bulimia nervosa can be cured with diet alone."
126
+ Context: "Treatment for bulimia nervosa involves psychotherapy and medications (Source 2)."
127
+ Evaluation: "The response is ungrounded and lacks citations. Groundedness score: 0.2."
128
+
129
+ 3. Response: "Anorexia nervosa has a high mortality rate."
130
+ Context: "Anorexia nervosa has one of the highest mortality rates among psychiatric disorders (Source 3)."
131
+ Evaluation: "The response is grounded but lacks a citation. Groundedness score: 0.7."
132
+
133
+ ****Return only a float score (e.g., 0.9). Do not provide explanations.****
134
+
135
+ Now, evaluate the following response:'''
136
+
137
+ groundedness_prompt = ChatPromptTemplate.from_messages([
138
+ ("system", system_message),
139
+ ("user", "Context: {context}\nResponse: {response}\n\nGroundedness score:")
140
+ ])
141
+
142
+ chain = groundedness_prompt | llm | StrOutputParser()
143
+ groundedness_score = float(chain.invoke({
144
+ "context": "\n".join([doc["content"] for doc in state['context']]),
145
+ "response": state['response']
146
+ }))
147
+ print("groundedness_score: ", groundedness_score)
148
+ state['groundedness_loop_count'] += 1
149
+ print("#########Groundedness Incremented###########")
150
+ state['groundedness_score'] = groundedness_score
151
+ return state
152
+
153
+ def check_precision(state: AgentState) -> AgentState:
154
+ """Checks whether the response precisely addresses the user's query."""
155
+ print("---------check_precision---------")
156
+
157
+ system_message = '''You are an AI evaluator assessing the **precision** of the response.
158
+ Your task is to **score** how well the response addresses the user's original nutrition disorder-related query.
159
+
160
+ Scoring Criteria:
161
+ - 1.0 → The response is fully precise, directly answering the question.
162
+ - 0.7 → The response is mostly correct but contains some generalization.
163
+ - 0.5 → The response is somewhat relevant but lacks key details.
164
+ - 0.3 → The response is vague or only partially correct.
165
+ - 0.0 → The response is incorrect or misleading.
166
+
167
+ Examples:
168
+ 1. Query: "What are the symptoms of anorexia nervosa?"
169
+ Response: "The symptoms of anorexia nervosa include extreme weight loss, fear of gaining weight, and a distorted body image."
170
+ Precision Score: 1.0
171
+
172
+ 2. Query: "How is bulimia nervosa treated?"
173
+ Response: "Bulimia nervosa is treated with therapy and medications."
174
+ Precision Score: 0.7
175
+
176
+ 3. Query: "What causes binge eating disorder?"
177
+ Response: "Binge eating disorder is caused by a combination of genetic, psychological, and environmental factors."
178
+ Precision Score: 0.5
179
+
180
+ 4. Query: "What are the effects of malnutrition?"
181
+ Response: "Malnutrition can lead to health problems."
182
+ Precision Score: 0.3
183
+
184
+ 5. Query: "What is the mortality rate of anorexia nervosa?"
185
+ Response: "Anorexia nervosa is a type of eating disorder."
186
+ Precision Score: 0.0
187
+
188
+ *****Return only a float score (e.g., 0.9). Do not provide explanations.*****
189
+ Now, evaluate the following query and response:'''
190
+
191
+ precision_prompt = ChatPromptTemplate.from_messages([
192
+ ("system", system_message),
193
+ ("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
194
+ ])
195
+
196
+ chain = precision_prompt | llm | StrOutputParser()
197
+ precision_score = float(chain.invoke({
198
+ "query": state['query'],
199
+ "response": state['response']
200
+ }))
201
+ state['precision_score'] = precision_score
202
+ print("precision_score:", precision_score)
203
+ state['precision_loop_count'] += 1
204
+ print("#########Precision Incremented###########")
205
+ return state
206
+
207
+ def refine_response(state: AgentState) -> AgentState:
208
+ """Suggests improvements for the generated response."""
209
+ print("---------refine_response---------")
210
+
211
+ system_message = '''You are an AI response refinement assistant. Your task is to suggest **improvements** for the given response.
212
+
213
+ ### Guidelines:
214
+ - Identify **gaps in the explanation** (missing key details).
215
+ - Highlight **unclear or vague parts** that need elaboration.
216
+ - Suggest **additional details** that should be included for better accuracy.
217
+ - Ensure the refined response is **precise** and **grounded** in the retrieved context.
218
+
219
+ ### Examples:
220
+ 1. Query: "What are the symptoms of anorexia nervosa?"
221
+ Response: "The symptoms include weight loss and fear of gaining weight."
222
+ Suggestions: "The response is missing key details about behavioral and emotional symptoms. Add details like 'distorted body image' and 'restrictive eating patterns.'"
223
+
224
+ 2. Query: "How is bulimia nervosa treated?"
225
+ Response: "Bulimia nervosa is treated with therapy."
226
+ Suggestions: "The response is too vague. Specify the types of therapy (e.g., cognitive-behavioral therapy) and mention other treatments like nutritional counseling and medications."
227
+
228
+ 3. Query: "What causes binge eating disorder?"
229
+ Response: "Binge eating disorder is caused by psychological factors."
230
+ Suggestions: "The response is incomplete. Add details about genetic and environmental factors, and explain how they contribute to the disorder."
231
+
232
+ Now, suggest improvements for the following response:'''
233
+
234
+ refine_response_prompt = ChatPromptTemplate.from_messages([
235
+ ("system", system_message),
236
+ ("user", "Query: {query}\nResponse: {response}\n\n"
237
+ "What improvements can be made to enhance accuracy and completeness?")
238
+ ])
239
+
240
+ chain = refine_response_prompt | llm | StrOutputParser()
241
+ feedback = f"Previous Response: {state['response']}\nSuggestions: {chain.invoke({'query': state['query'], 'response': state['response']})}"
242
+ print("feedback: ", feedback)
243
+ print(f"State: {state}")
244
+ state['feedback'] = feedback
245
+ return state
246
+
247
+ def refine_query(state: AgentState) -> AgentState:
248
+ """Suggests improvements for the expanded query."""
249
+ print("---------refine_query---------")
250
+
251
+ system_message = '''You are an AI query refinement assistant. Your task is to suggest **improvements** for the expanded query.
252
+
253
+ ### Guidelines:
254
+ - Add **specific keywords** to improve document retrieval.
255
+ - Identify **missing details** that should be included.
256
+ - Suggest **ways to narrow the scope** for better precision.
257
+
258
+ ### Examples:
259
+ 1. Original Query: "Tell me about eating disorders."
260
+ Expanded Query: "Provide details on eating disorders, including types, symptoms, causes, and treatment options."
261
+ Suggestions: "Add specific types of eating disorders like 'anorexia nervosa' and 'bulimia nervosa' to improve retrieval."
262
+
263
+ 2. Original Query: "What is anorexia?"
264
+ Expanded Query: "Explain anorexia nervosa, including its symptoms and causes."
265
+ Suggestions: "Include details about treatment options and risk factors to make the query more comprehensive."
266
+
267
+ 3. Original Query: "How to treat bulimia?"
268
+ Expanded Query: "Describe treatment options for bulimia nervosa."
269
+ Suggestions: "Specify types of treatments like 'cognitive-behavioral therapy' and 'medications' for better precision."
270
+
271
+ Now, suggest improvements for the following expanded query:'''
272
+
273
+ refine_query_prompt = ChatPromptTemplate.from_messages([
274
+ ("system", system_message),
275
+ ("user", "Original Query: {query}\nExpanded Query: {expanded_query}\n\n"
276
+ "What improvements can be made for a better search?")
277
+ ])
278
+
279
+ chain = refine_query_prompt | llm | StrOutputParser()
280
+ query_feedback = f"Previous Expanded Query: {state['expanded_query']}\nSuggestions: {chain.invoke({'query': state['query'], 'expanded_query': state['expanded_query']})}"
281
+ print("query_feedback: ", query_feedback)
282
+ print(f"Groundedness loop count: {state['groundedness_loop_count']}")
283
+ state['query_feedback'] = query_feedback
284
+ return state
285
+
286
+ def max_iterations_reached(state: AgentState) -> AgentState:
287
+ """Handles the case when the maximum number of iterations is reached."""
288
+ print("---------max_iterations_reached---------")
289
+ response = "I'm unable to refine the response further. Please provide more context or clarify your question."
290
+ state['response'] = response
291
+ return state