File size: 13,916 Bytes
60ce079
 
 
d080972
f60c104
5146bd4
 
60ce079
 
 
 
5146bd4
60ce079
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
from typing import Dict
from langchain.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'models'))
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'config'))
from models import llm, retriever
from agent_state import AgentState
from config import config


def expand_query(state: AgentState) -> AgentState:
    """Expands the user query to improve retrieval of nutrition disorder-related information."""
    print("---------Expanding Query---------")
    
    system_message = '''You are an AI specializing in improving search queries to retrieve the most relevant nutrition disorder-related information.
Your task is to **refine** and **expand** the given query so that better search results are obtained, while **keeping the original intent** unchanged.

Guidelines:
- Add **specific details** where needed. Example: If a user asks about "anorexia," specify aspects like symptoms, causes, or treatment options.
- Include **related terms** to improve retrieval (e.g., "bulimia" → "bulimia nervosa vs binge eating disorder").
- If the user provides an unclear query, suggest necessary clarifications.
- **DO NOT** answer the question. Your job is only to enhance the query.

Examples:
1. User Query: "Tell me about eating disorders."
   Expanded Query: "Provide details on eating disorders, including types (e.g., anorexia nervosa, bulimia nervosa), symptoms, causes, and treatment options."

2. User Query: "What is anorexia?"
   Expanded Query: "Explain anorexia nervosa, including its symptoms, causes, risk factors, and treatment options."

3. User Query: "How to treat bulimia?"
   Expanded Query: "Describe treatment options for bulimia nervosa, including psychotherapy, medications, and lifestyle changes."

4. User Query: "What are the effects of malnutrition?"
   Expanded Query: "Explain the effects of malnutrition on physical and mental health, including specific nutrient deficiencies and their consequences."

Now, expand the following query:'''

    expand_prompt = ChatPromptTemplate.from_messages([
        ("system", system_message),
        ("user", "Expand this query: {query} using the feedback: {query_feedback}")
    ])

    chain = expand_prompt | llm | StrOutputParser()
    expanded_query = chain.invoke({"query": state['query'], "query_feedback": state["query_feedback"]})
    print("expanded_query", expanded_query)
    state["expanded_query"] = expanded_query
    return state

def retrieve_context(state: AgentState) -> AgentState:
    """Retrieves context from the vector store using the expanded or original query."""
    print("---------retrieve_context---------")
    query = state['expanded_query']
    
    docs = retriever.invoke(query)
    print("Retrieved documents:", docs)
    
    context = [
        {
            "content": doc.page_content,
            "metadata": doc.metadata
        }
        for doc in docs
    ]
    state['context'] = context
    print("Extracted context with metadata:", context)
    return state

def craft_response(state: AgentState) -> AgentState:
    """Generates a response using the retrieved context, focusing on nutrition disorders."""
    system_message = '''You are a professional AI nutrition disorder specialist generating responses based on retrieved documents.
Your task is to use the given **context** to generate a highly accurate, informative, and user-friendly response.

Guidelines:
- **Be direct and concise** while ensuring completeness.
- **DO NOT include information that is not present in the context.**
- If multiple sources exist, synthesize them into a coherent response.
- If the context does not fully answer the query, state what additional information is needed.
- Use bullet points when explaining complex concepts.

Example:
User Query: "What are the symptoms of anorexia nervosa?"
Context:
1. Anorexia nervosa is characterized by extreme weight loss and fear of gaining weight.
2. Common symptoms include restricted eating, distorted body image, and excessive exercise.
Response:
"Anorexia nervosa is an eating disorder characterized by extreme weight loss and an intense fear of gaining weight. Common symptoms include:
- Restricted eating
- Distorted body image
- Excessive exercise
If you or someone you know is experiencing these symptoms, it is important to seek professional help."'''

    response_prompt = ChatPromptTemplate.from_messages([
        ("system", system_message),
        ("user", "Query: {query}\nContext: {context}\n\nResponse:")
    ])

    chain = response_prompt | llm | StrOutputParser()
    state['response'] = chain.invoke({
        "query": state['query'],
        "context": "\n".join([doc["content"] for doc in state['context']])
    })
    return state

def score_groundedness(state: AgentState) -> AgentState:
    """Checks whether the response is grounded in the retrieved context."""
    print("---------check_groundedness---------")
    
    system_message = '''You are an AI tasked with evaluating whether a response is grounded in the provided context and includes proper citations.

Guidelines:
1. **Groundedness Check**:
   - Verify that the response accurately reflects the information in the context.
   - Flag any unsupported claims or deviations from the context.

2. **Citation Check**:
   - Ensure that the response includes citations to the source material (e.g., "According to [Source], ...").
   - If citations are missing, suggest adding them.

3. **Scoring**:
   - Assign a groundedness score between 0 and 1, where 1 means fully grounded and properly cited.

Examples:
1. Response: "Anorexia nervosa is caused by genetic factors (Source 1)."
   Context: "Anorexia nervosa is influenced by genetic, environmental, and psychological factors (Source 1)."
   Evaluation: "The response is grounded and properly cited. Groundedness score: 1.0."

2. Response: "Bulimia nervosa can be cured with diet alone."
   Context: "Treatment for bulimia nervosa involves psychotherapy and medications (Source 2)."
   Evaluation: "The response is ungrounded and lacks citations. Groundedness score: 0.2."

3. Response: "Anorexia nervosa has a high mortality rate."
   Context: "Anorexia nervosa has one of the highest mortality rates among psychiatric disorders (Source 3)."
   Evaluation: "The response is grounded but lacks a citation. Groundedness score: 0.7."

****Return only a float score (e.g., 0.9). Do not provide explanations.****

Now, evaluate the following response:'''

    groundedness_prompt = ChatPromptTemplate.from_messages([
        ("system", system_message),
        ("user", "Context: {context}\nResponse: {response}\n\nGroundedness score:")
    ])

    chain = groundedness_prompt | llm | StrOutputParser()
    groundedness_score = float(chain.invoke({
        "context": "\n".join([doc["content"] for doc in state['context']]),
        "response": state['response']
    }))
    print("groundedness_score: ", groundedness_score)
    state['groundedness_loop_count'] += 1
    print("#########Groundedness Incremented###########")
    state['groundedness_score'] = groundedness_score
    return state

def check_precision(state: AgentState) -> AgentState:
    """Checks whether the response precisely addresses the user's query."""
    print("---------check_precision---------")
    
    system_message = '''You are an AI evaluator assessing the **precision** of the response.
Your task is to **score** how well the response addresses the user's original nutrition disorder-related query.

Scoring Criteria:
- 1.0 → The response is fully precise, directly answering the question.
- 0.7 → The response is mostly correct but contains some generalization.
- 0.5 → The response is somewhat relevant but lacks key details.
- 0.3 → The response is vague or only partially correct.
- 0.0 → The response is incorrect or misleading.

Examples:
1. Query: "What are the symptoms of anorexia nervosa?"
   Response: "The symptoms of anorexia nervosa include extreme weight loss, fear of gaining weight, and a distorted body image."
   Precision Score: 1.0

2. Query: "How is bulimia nervosa treated?"
   Response: "Bulimia nervosa is treated with therapy and medications."
   Precision Score: 0.7

3. Query: "What causes binge eating disorder?"
   Response: "Binge eating disorder is caused by a combination of genetic, psychological, and environmental factors."
   Precision Score: 0.5

4. Query: "What are the effects of malnutrition?"
   Response: "Malnutrition can lead to health problems."
   Precision Score: 0.3

5. Query: "What is the mortality rate of anorexia nervosa?"
   Response: "Anorexia nervosa is a type of eating disorder."
   Precision Score: 0.0

*****Return only a float score (e.g., 0.9). Do not provide explanations.*****
Now, evaluate the following query and response:'''

    precision_prompt = ChatPromptTemplate.from_messages([
        ("system", system_message),
        ("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
    ])

    chain = precision_prompt | llm | StrOutputParser()
    precision_score = float(chain.invoke({
        "query": state['query'],
        "response": state['response']
    }))
    state['precision_score'] = precision_score
    print("precision_score:", precision_score)
    state['precision_loop_count'] += 1
    print("#########Precision Incremented###########")
    return state

def refine_response(state: AgentState) -> AgentState:
    """Suggests improvements for the generated response."""
    print("---------refine_response---------")

    system_message = '''You are an AI response refinement assistant. Your task is to suggest **improvements** for the given response.

### Guidelines:
- Identify **gaps in the explanation** (missing key details).
- Highlight **unclear or vague parts** that need elaboration.
- Suggest **additional details** that should be included for better accuracy.
- Ensure the refined response is **precise** and **grounded** in the retrieved context.

### Examples:
1. Query: "What are the symptoms of anorexia nervosa?"
   Response: "The symptoms include weight loss and fear of gaining weight."
   Suggestions: "The response is missing key details about behavioral and emotional symptoms. Add details like 'distorted body image' and 'restrictive eating patterns.'"

2. Query: "How is bulimia nervosa treated?"
   Response: "Bulimia nervosa is treated with therapy."
   Suggestions: "The response is too vague. Specify the types of therapy (e.g., cognitive-behavioral therapy) and mention other treatments like nutritional counseling and medications."

3. Query: "What causes binge eating disorder?"
   Response: "Binge eating disorder is caused by psychological factors."
   Suggestions: "The response is incomplete. Add details about genetic and environmental factors, and explain how they contribute to the disorder."

Now, suggest improvements for the following response:'''

    refine_response_prompt = ChatPromptTemplate.from_messages([
        ("system", system_message),
        ("user", "Query: {query}\nResponse: {response}\n\n"
                 "What improvements can be made to enhance accuracy and completeness?")
    ])

    chain = refine_response_prompt | llm | StrOutputParser()
    feedback = f"Previous Response: {state['response']}\nSuggestions: {chain.invoke({'query': state['query'], 'response': state['response']})}"
    print("feedback: ", feedback)
    print(f"State: {state}")
    state['feedback'] = feedback
    return state

def refine_query(state: AgentState) -> AgentState:
    """Suggests improvements for the expanded query."""
    print("---------refine_query---------")
    
    system_message = '''You are an AI query refinement assistant. Your task is to suggest **improvements** for the expanded query.

### Guidelines:
- Add **specific keywords** to improve document retrieval.
- Identify **missing details** that should be included.
- Suggest **ways to narrow the scope** for better precision.

### Examples:
1. Original Query: "Tell me about eating disorders."
   Expanded Query: "Provide details on eating disorders, including types, symptoms, causes, and treatment options."
   Suggestions: "Add specific types of eating disorders like 'anorexia nervosa' and 'bulimia nervosa' to improve retrieval."

2. Original Query: "What is anorexia?"
   Expanded Query: "Explain anorexia nervosa, including its symptoms and causes."
   Suggestions: "Include details about treatment options and risk factors to make the query more comprehensive."

3. Original Query: "How to treat bulimia?"
   Expanded Query: "Describe treatment options for bulimia nervosa."
   Suggestions: "Specify types of treatments like 'cognitive-behavioral therapy' and 'medications' for better precision."

Now, suggest improvements for the following expanded query:'''

    refine_query_prompt = ChatPromptTemplate.from_messages([
        ("system", system_message),
        ("user", "Original Query: {query}\nExpanded Query: {expanded_query}\n\n"
                 "What improvements can be made for a better search?")
    ])

    chain = refine_query_prompt | llm | StrOutputParser()
    query_feedback = f"Previous Expanded Query: {state['expanded_query']}\nSuggestions: {chain.invoke({'query': state['query'], 'expanded_query': state['expanded_query']})}"
    print("query_feedback: ", query_feedback)
    print(f"Groundedness loop count: {state['groundedness_loop_count']}")
    state['query_feedback'] = query_feedback
    return state

def max_iterations_reached(state: AgentState) -> AgentState:
    """Handles the case when the maximum number of iterations is reached."""
    print("---------max_iterations_reached---------")
    response = "I'm unable to refine the response further. Please provide more context or clarify your question."
    state['response'] = response
    return state