Spaces:
Sleeping
Sleeping
File size: 13,916 Bytes
60ce079 d080972 f60c104 5146bd4 60ce079 5146bd4 60ce079 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 |
from typing import Dict
from langchain.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'models'))
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'config'))
from models import llm, retriever
from agent_state import AgentState
from config import config
def expand_query(state: AgentState) -> AgentState:
"""Expands the user query to improve retrieval of nutrition disorder-related information."""
print("---------Expanding Query---------")
system_message = '''You are an AI specializing in improving search queries to retrieve the most relevant nutrition disorder-related information.
Your task is to **refine** and **expand** the given query so that better search results are obtained, while **keeping the original intent** unchanged.
Guidelines:
- Add **specific details** where needed. Example: If a user asks about "anorexia," specify aspects like symptoms, causes, or treatment options.
- Include **related terms** to improve retrieval (e.g., "bulimia" → "bulimia nervosa vs binge eating disorder").
- If the user provides an unclear query, suggest necessary clarifications.
- **DO NOT** answer the question. Your job is only to enhance the query.
Examples:
1. User Query: "Tell me about eating disorders."
Expanded Query: "Provide details on eating disorders, including types (e.g., anorexia nervosa, bulimia nervosa), symptoms, causes, and treatment options."
2. User Query: "What is anorexia?"
Expanded Query: "Explain anorexia nervosa, including its symptoms, causes, risk factors, and treatment options."
3. User Query: "How to treat bulimia?"
Expanded Query: "Describe treatment options for bulimia nervosa, including psychotherapy, medications, and lifestyle changes."
4. User Query: "What are the effects of malnutrition?"
Expanded Query: "Explain the effects of malnutrition on physical and mental health, including specific nutrient deficiencies and their consequences."
Now, expand the following query:'''
expand_prompt = ChatPromptTemplate.from_messages([
("system", system_message),
("user", "Expand this query: {query} using the feedback: {query_feedback}")
])
chain = expand_prompt | llm | StrOutputParser()
expanded_query = chain.invoke({"query": state['query'], "query_feedback": state["query_feedback"]})
print("expanded_query", expanded_query)
state["expanded_query"] = expanded_query
return state
def retrieve_context(state: AgentState) -> AgentState:
"""Retrieves context from the vector store using the expanded or original query."""
print("---------retrieve_context---------")
query = state['expanded_query']
docs = retriever.invoke(query)
print("Retrieved documents:", docs)
context = [
{
"content": doc.page_content,
"metadata": doc.metadata
}
for doc in docs
]
state['context'] = context
print("Extracted context with metadata:", context)
return state
def craft_response(state: AgentState) -> AgentState:
"""Generates a response using the retrieved context, focusing on nutrition disorders."""
system_message = '''You are a professional AI nutrition disorder specialist generating responses based on retrieved documents.
Your task is to use the given **context** to generate a highly accurate, informative, and user-friendly response.
Guidelines:
- **Be direct and concise** while ensuring completeness.
- **DO NOT include information that is not present in the context.**
- If multiple sources exist, synthesize them into a coherent response.
- If the context does not fully answer the query, state what additional information is needed.
- Use bullet points when explaining complex concepts.
Example:
User Query: "What are the symptoms of anorexia nervosa?"
Context:
1. Anorexia nervosa is characterized by extreme weight loss and fear of gaining weight.
2. Common symptoms include restricted eating, distorted body image, and excessive exercise.
Response:
"Anorexia nervosa is an eating disorder characterized by extreme weight loss and an intense fear of gaining weight. Common symptoms include:
- Restricted eating
- Distorted body image
- Excessive exercise
If you or someone you know is experiencing these symptoms, it is important to seek professional help."'''
response_prompt = ChatPromptTemplate.from_messages([
("system", system_message),
("user", "Query: {query}\nContext: {context}\n\nResponse:")
])
chain = response_prompt | llm | StrOutputParser()
state['response'] = chain.invoke({
"query": state['query'],
"context": "\n".join([doc["content"] for doc in state['context']])
})
return state
def score_groundedness(state: AgentState) -> AgentState:
"""Checks whether the response is grounded in the retrieved context."""
print("---------check_groundedness---------")
system_message = '''You are an AI tasked with evaluating whether a response is grounded in the provided context and includes proper citations.
Guidelines:
1. **Groundedness Check**:
- Verify that the response accurately reflects the information in the context.
- Flag any unsupported claims or deviations from the context.
2. **Citation Check**:
- Ensure that the response includes citations to the source material (e.g., "According to [Source], ...").
- If citations are missing, suggest adding them.
3. **Scoring**:
- Assign a groundedness score between 0 and 1, where 1 means fully grounded and properly cited.
Examples:
1. Response: "Anorexia nervosa is caused by genetic factors (Source 1)."
Context: "Anorexia nervosa is influenced by genetic, environmental, and psychological factors (Source 1)."
Evaluation: "The response is grounded and properly cited. Groundedness score: 1.0."
2. Response: "Bulimia nervosa can be cured with diet alone."
Context: "Treatment for bulimia nervosa involves psychotherapy and medications (Source 2)."
Evaluation: "The response is ungrounded and lacks citations. Groundedness score: 0.2."
3. Response: "Anorexia nervosa has a high mortality rate."
Context: "Anorexia nervosa has one of the highest mortality rates among psychiatric disorders (Source 3)."
Evaluation: "The response is grounded but lacks a citation. Groundedness score: 0.7."
****Return only a float score (e.g., 0.9). Do not provide explanations.****
Now, evaluate the following response:'''
groundedness_prompt = ChatPromptTemplate.from_messages([
("system", system_message),
("user", "Context: {context}\nResponse: {response}\n\nGroundedness score:")
])
chain = groundedness_prompt | llm | StrOutputParser()
groundedness_score = float(chain.invoke({
"context": "\n".join([doc["content"] for doc in state['context']]),
"response": state['response']
}))
print("groundedness_score: ", groundedness_score)
state['groundedness_loop_count'] += 1
print("#########Groundedness Incremented###########")
state['groundedness_score'] = groundedness_score
return state
def check_precision(state: AgentState) -> AgentState:
"""Checks whether the response precisely addresses the user's query."""
print("---------check_precision---------")
system_message = '''You are an AI evaluator assessing the **precision** of the response.
Your task is to **score** how well the response addresses the user's original nutrition disorder-related query.
Scoring Criteria:
- 1.0 → The response is fully precise, directly answering the question.
- 0.7 → The response is mostly correct but contains some generalization.
- 0.5 → The response is somewhat relevant but lacks key details.
- 0.3 → The response is vague or only partially correct.
- 0.0 → The response is incorrect or misleading.
Examples:
1. Query: "What are the symptoms of anorexia nervosa?"
Response: "The symptoms of anorexia nervosa include extreme weight loss, fear of gaining weight, and a distorted body image."
Precision Score: 1.0
2. Query: "How is bulimia nervosa treated?"
Response: "Bulimia nervosa is treated with therapy and medications."
Precision Score: 0.7
3. Query: "What causes binge eating disorder?"
Response: "Binge eating disorder is caused by a combination of genetic, psychological, and environmental factors."
Precision Score: 0.5
4. Query: "What are the effects of malnutrition?"
Response: "Malnutrition can lead to health problems."
Precision Score: 0.3
5. Query: "What is the mortality rate of anorexia nervosa?"
Response: "Anorexia nervosa is a type of eating disorder."
Precision Score: 0.0
*****Return only a float score (e.g., 0.9). Do not provide explanations.*****
Now, evaluate the following query and response:'''
precision_prompt = ChatPromptTemplate.from_messages([
("system", system_message),
("user", "Query: {query}\nResponse: {response}\n\nPrecision score:")
])
chain = precision_prompt | llm | StrOutputParser()
precision_score = float(chain.invoke({
"query": state['query'],
"response": state['response']
}))
state['precision_score'] = precision_score
print("precision_score:", precision_score)
state['precision_loop_count'] += 1
print("#########Precision Incremented###########")
return state
def refine_response(state: AgentState) -> AgentState:
"""Suggests improvements for the generated response."""
print("---------refine_response---------")
system_message = '''You are an AI response refinement assistant. Your task is to suggest **improvements** for the given response.
### Guidelines:
- Identify **gaps in the explanation** (missing key details).
- Highlight **unclear or vague parts** that need elaboration.
- Suggest **additional details** that should be included for better accuracy.
- Ensure the refined response is **precise** and **grounded** in the retrieved context.
### Examples:
1. Query: "What are the symptoms of anorexia nervosa?"
Response: "The symptoms include weight loss and fear of gaining weight."
Suggestions: "The response is missing key details about behavioral and emotional symptoms. Add details like 'distorted body image' and 'restrictive eating patterns.'"
2. Query: "How is bulimia nervosa treated?"
Response: "Bulimia nervosa is treated with therapy."
Suggestions: "The response is too vague. Specify the types of therapy (e.g., cognitive-behavioral therapy) and mention other treatments like nutritional counseling and medications."
3. Query: "What causes binge eating disorder?"
Response: "Binge eating disorder is caused by psychological factors."
Suggestions: "The response is incomplete. Add details about genetic and environmental factors, and explain how they contribute to the disorder."
Now, suggest improvements for the following response:'''
refine_response_prompt = ChatPromptTemplate.from_messages([
("system", system_message),
("user", "Query: {query}\nResponse: {response}\n\n"
"What improvements can be made to enhance accuracy and completeness?")
])
chain = refine_response_prompt | llm | StrOutputParser()
feedback = f"Previous Response: {state['response']}\nSuggestions: {chain.invoke({'query': state['query'], 'response': state['response']})}"
print("feedback: ", feedback)
print(f"State: {state}")
state['feedback'] = feedback
return state
def refine_query(state: AgentState) -> AgentState:
"""Suggests improvements for the expanded query."""
print("---------refine_query---------")
system_message = '''You are an AI query refinement assistant. Your task is to suggest **improvements** for the expanded query.
### Guidelines:
- Add **specific keywords** to improve document retrieval.
- Identify **missing details** that should be included.
- Suggest **ways to narrow the scope** for better precision.
### Examples:
1. Original Query: "Tell me about eating disorders."
Expanded Query: "Provide details on eating disorders, including types, symptoms, causes, and treatment options."
Suggestions: "Add specific types of eating disorders like 'anorexia nervosa' and 'bulimia nervosa' to improve retrieval."
2. Original Query: "What is anorexia?"
Expanded Query: "Explain anorexia nervosa, including its symptoms and causes."
Suggestions: "Include details about treatment options and risk factors to make the query more comprehensive."
3. Original Query: "How to treat bulimia?"
Expanded Query: "Describe treatment options for bulimia nervosa."
Suggestions: "Specify types of treatments like 'cognitive-behavioral therapy' and 'medications' for better precision."
Now, suggest improvements for the following expanded query:'''
refine_query_prompt = ChatPromptTemplate.from_messages([
("system", system_message),
("user", "Original Query: {query}\nExpanded Query: {expanded_query}\n\n"
"What improvements can be made for a better search?")
])
chain = refine_query_prompt | llm | StrOutputParser()
query_feedback = f"Previous Expanded Query: {state['expanded_query']}\nSuggestions: {chain.invoke({'query': state['query'], 'expanded_query': state['expanded_query']})}"
print("query_feedback: ", query_feedback)
print(f"Groundedness loop count: {state['groundedness_loop_count']}")
state['query_feedback'] = query_feedback
return state
def max_iterations_reached(state: AgentState) -> AgentState:
"""Handles the case when the maximum number of iterations is reached."""
print("---------max_iterations_reached---------")
response = "I'm unable to refine the response further. Please provide more context or clarify your question."
state['response'] = response
return state
|