moazeldegwy commited on
Commit
c504fac
·
2 Parent(s): bcd961e34b037e

Merge Phase 2 into Phase 3 base

Browse files
.env.example ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Comma-separated list of Gemini API keys. The system rotates through them
2
+ # and respects per-key RPM/RPD limits when NUTRITION_MAS_ENABLE_RATE_LIMITING=true.
3
+ NUTRITION_MAS_GEMINI_API_KEYS=key_one,key_two
4
+
5
+ # Optional infra paths.
6
+ # When set, every agent/tool I/O is dumped to LOG_DIR/<subdir>/<timestamp>.json
7
+ NUTRITION_MAS_LOG_DIR=
8
+ # When set, LangGraph checkpoints are persisted to disk (instead of memory).
9
+ NUTRITION_MAS_PERSISTENCE_DIR=
10
+
11
+ # Debug switches (default: off)
12
+ NUTRITION_MAS_DEBUG_MODE=false
13
+ NUTRITION_MAS_DEBUG_LEVEL=full # 'full' | 'output'
14
+ # JSON-encoded dict, see config.Settings for shape. Defaults to all/all.
15
+ # NUTRITION_MAS_DEBUG_SCOPES={"agents": ["CoachAgent"], "tools": ["all"]}
16
+
17
+ # Rate limiting (default: on)
18
+ NUTRITION_MAS_ENABLE_RATE_LIMITING=true
19
+
20
+ # Optional LangSmith tracing (Phase 6 will wire this up properly)
21
+ # LANGCHAIN_TRACING_V2=true
22
+ # LANGCHAIN_API_KEY=
23
+ # LANGCHAIN_PROJECT=Nutrition-MAS
.gitignore ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+ MANIFEST
23
+
24
+ # Virtual environments
25
+ .venv/
26
+ venv/
27
+ env/
28
+ ENV/
29
+ .env
30
+
31
+ # IDE
32
+ .vscode/
33
+ .idea/
34
+ .claude/
35
+ *.swp
36
+ *.swo
37
+ *~
38
+ .DS_Store
39
+
40
+ # Project
41
+ logs/
42
+ checkpoints/
43
+ data/cache/
44
+ *.sqlite
45
+ *.sqlite3
46
+ *.db
47
+ .cache/
48
+
49
+ # Pytest / coverage
50
+ .pytest_cache/
51
+ .coverage
52
+ htmlcov/
53
+ .mypy_cache/
54
+ .ruff_cache/
55
+
56
+ # Notebook
57
+ .ipynb_checkpoints/
58
+
59
+ # LangSmith / tracing
60
+ .langsmith/
Makefile ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .PHONY: install dev test lint format clean run
2
+
3
+ install:
4
+ pip install -r requirements.txt
5
+
6
+ dev:
7
+ pip install -e ".[dev]"
8
+
9
+ test:
10
+ pytest -ra -q
11
+
12
+ test-cov:
13
+ pytest -ra --cov=. --cov-report=term-missing --cov-report=html
14
+
15
+ lint:
16
+ ruff check .
17
+
18
+ format:
19
+ ruff format .
20
+
21
+ clean:
22
+ rm -rf .pytest_cache .ruff_cache .mypy_cache htmlcov .coverage
23
+ find . -type d -name __pycache__ -exec rm -rf {} +
24
+
25
+ run:
26
+ python -c "import nutritionmas; print('Module imports OK')"
agents.py CHANGED
@@ -1,489 +1,622 @@
1
- from typing import Dict, Any
2
- from utils import extract_and_parse_json, set_nested, update_memory_partition, save_to_json, should_debug
3
- from tools import ComputationTool, WebSearchTool, QuantitiesFinder
4
- from datetime import datetime
 
 
 
 
 
 
 
 
 
 
5
  import json
6
- import config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  class CoachAgent:
9
  def __init__(self, llm_instance):
10
  self.llm = llm_instance
11
 
12
  def handle_task(self, state: Dict[str, Any]) -> Dict[str, Any]:
13
- memory_str = json.dumps(state["memory"], indent=2)
 
14
  response_steps = state.get("response_steps", [])
15
- response_steps_str = json.dumps(response_steps, indent=2) if response_steps else "None"
16
- truncated_history = []
 
 
 
17
  for msg in state["conversation_history"]:
18
  if msg["role"] == "assistant" and len(msg["content"]) > 200:
19
- truncated_content = msg["content"][:200] + "... (full response in memory)"
20
- truncated_history.append({"role": "assistant", "content": truncated_content})
 
21
  else:
22
  truncated_history.append(msg)
23
- history_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in truncated_history])
24
-
25
- observation = f"""User query: {state['user_question']}
26
- Memory State: {memory_str}
27
- Current Response Steps: {response_steps_str}
28
- Previous Tool Result: {state.get('agent_result', 'None')}
29
- Conversation history: {history_str}"""
30
-
31
- prompt = f"""
32
- You are the Coach Agent (central orchestrator) of a nutrition MAS.
33
-
34
- Current State: {observation}
35
-
36
- Primary responsibilities:
37
- - Translate user intent to a concrete workflow of response_steps (use the shared response_step schema).
38
- - Enforce system rules (MedicalAssessment must be completed before Planner.
39
- - Decide and perform actions: call_agent, call_tool, ask_user, write_memory, compose_response.
40
-
41
- Inputs:
42
- - observation (string)
43
- - memory partitions: user_profile, medical_history, flags_and_assessments, plans
44
- - response_steps (may be None or list)
45
-
46
- Behavior rules (mandatory):
47
- 1. If response_steps is None or empty, generate a response_steps list with explicit ordered steps (max 6 steps). Each step must include id, actor, prerequisites, and status "pending".
48
- - Typical personal-workflow (if user asks for personalized plan):
49
- 1) Validate required user data (height, weight, age, sex, activity_level, allergies, goal). If missing -> ask_user.
50
- 2) Update memory (if user provided new data). [action: write_memory]
51
- 3) Call MedicalAssessmentAgent with task to assess user.
52
- 4) Wait for assessment to be completed and stored into memory.
53
- 5) Call PlannerAgent with relevent task.
54
- 2. When calling any agent, set the called step status to "in_progress" and include `prerequisites` satisfied by your observation.
55
- 3. Only call PlannerAgent if memory.flags_and_assessments exists and contains "assessment_status":"assessment_complete". If not, call MedicalAssessmentAgent.
56
- 4. When new user personal data is detected in user input, add steps to:
57
- - propose memory update (write_memory)
58
- - call MedicalAssessmentAgent if needed
59
- - re-plan if needed
60
- 5. For any "write_memory" action, provide the full partition contents in params.data (not diffs). The Coach is responsible to merge and store.
61
-
62
- Action outputs: respond with a JSON object:
63
- {{
64
- "observation": "...",
65
- "thought": "...",
66
- "response_steps": [ ... ],
67
- "action": "call_agent | call_tool | ask_user | write_memory | compose_response",
68
- "params": {{ ... }}
69
- }}
70
-
71
- Examples:
72
- - call_agent params: {{"agent_name":"MedicalAssessmentAgent", "task":"task description"}}
73
- - compose_response params:{{"text":"Complete response in markdown"}}
74
-
75
-
76
- Rules:
77
- - When composing the response, extract and include relevant information from the memory state (e.g., calorie target, plan details, dietary restrictions) in markdown format for readability.
78
- - Always include a "trace" field in composed responses summarizing which agents/tools were called for and which sources were used.
79
- - For high-risk profiles (e.g., requires_professional_consultation: true); in such cases append a bold warning at the end of the diet plan response advising professional consultation before implementation.
80
- """
81
-
82
- if should_debug('agents', 'CoachAgent'):
83
- print(f"\n--- Coach Agent Turn {state['num_turns'] + 1} ---")
84
- if should_debug('agents', 'CoachAgent') and config.DEBUG_LEVEL == 'full':
85
- print(f"Raw LLM input:\n{prompt}")
86
- response = self.llm(prompt)[0]
87
- if should_debug('agents', 'CoachAgent'):
88
- print(f"Coach Raw Response:\n{response}")
89
-
90
- parsed = extract_and_parse_json(response)
91
-
92
- # Add high-level print for user mode
93
- if not config.DEBUG_MODE:
94
- action = parsed.get("action")
95
- params = parsed.get("params", {})
96
- print_str = "\n🏋️‍♂️Coach Agent: "
97
- if action == "call_agent":
98
- print_str += f"Calling {params.get('agent_name')} with task '{params.get('task')}'"
99
- elif action == "call_tool":
100
- print_str += f"Using {params.get('tool_name')} with task '{params.get('task')}'"
101
- elif action == "ask_user":
102
- print_str += f"Asking user: {params.get('prompt')}"
103
- elif action == "write_memory":
104
- print_str += f"Writing to memory partition '{params.get('partition')}'"
105
- elif action == "compose_response":
106
- print_str += "Composing final response"
107
- print(print_str)
108
-
109
- current_action = {
110
- "action": parsed.get("action"),
111
- "params": parsed.get("params", {})
112
- }
113
 
114
- response_steps = parsed.get("response_steps", state.get("response_steps", []))
115
-
116
- log_data = {
117
- "prompt": prompt,
118
- "output":response,
119
- "parsed": parsed,
120
- "timestamp": datetime.now().isoformat()
121
- }
122
- save_to_json(log_data, f'coach_agent_{datetime.now().isoformat()}.json', subdirectory='CoachAgent')
123
-
124
  return {
125
  **state,
126
  "current_action": current_action,
127
- "response_steps": response_steps,
128
  "num_turns": state["num_turns"] + 1,
129
- "agent_result": None
130
  }
131
 
132
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  class MedicalAssessmentAgent:
134
- def __init__(self, llm_instance, computation_tool: ComputationTool, web_search_tool: WebSearchTool):
 
 
 
 
 
 
 
135
  self.llm = llm_instance
136
  self.computation_tool = computation_tool
137
  self.web_search_tool = web_search_tool
138
 
139
  def handle_task(self, task: str, memory: Dict[str, Any]) -> str:
140
- print(f"\n👨🏻‍⚕️ MEDICAL ASSESSMENT AGENT STARTED")
 
141
 
142
- # Build relevant memory context
143
  relevant_memory = {
144
  "user_profile": memory.get("user_profile", {}),
145
  "medical_history": memory.get("medical_history", {}),
146
  }
147
- memory_str = json.dumps(relevant_memory, indent=2)
148
- tool_results = []
149
- assessment_plan = []
150
- max_iterations = 15
151
- iteration = 0
152
-
153
- while iteration < max_iterations:
154
-
155
- tool_results_str = "\n".join([f"Tool Result {i+1}: {result}" for i, result in enumerate(tool_results)])
156
-
157
- prompt = f"""
158
- You are the Medical Assessment Agent. Your job: produce an evidence-based assessment and the set of clinical flags and calculations needed by the Planner and Validation agents.
159
- Task: {task}
160
- Current Memory: {memory_str}
161
- Current Assessment Plan: {assessment_plan}
162
- Previous Tool Results: {tool_results_str}
163
- Available tools: ComputationTool, WebSearchTool
164
- Mandatory behavior (do not skip):
165
- 1. Critical data check: confirm presence of age, sex, height, weight, activity_level, allergies, medications. If any critical field is missing -> action: ask_user (return which fields).
166
- 2. Use ComputationTool for all numeric calculations (BMI, BMR, TDEE, calorie targets, macro targets). Provide computation inputs with inside the task description.
167
- 3. Use WebSearchTool to fetch authoritative guidelines where relevant (WHO, USDA, clinical guidelines). Always capture the source(s) used with timestamped citations.
168
- 4. Produce a compact assessment_plan (3-6 steps max) that lists each computational/search step, its status, and result.
169
- - When generating the assessment_plan (if empty or None), follow this exact sequence (assuming critical data is present; if not, prepend a step for ask_user):
170
- 1. Call ComputationTool to calculate BMI, BMR, TDEE, and a single daily_target_calories (integer) based on the user's goal, all in one tool call.
171
- 2. Call ComputationTool to calculate macro_targets (protein_g, fat_g, carbohydrates_g as single integers) optimized for the user's goal given the daily_target_calories.
172
- 3. Call WebSearchTool to find dietary guidelines related to the user based on their profile and medical history to manage conditions.
173
- 4-6. Additional steps if needed (e.g., synthesis, further searches/computations for specific risks).
174
- 5. Return a `assessment_complete` containing:
175
- - assessment_summary
176
- - calculations: {{BMI, BMR, TDEE, daily_target_calories, macro_targets}}
177
- - daily_target_calories: a single integer value (e.g., 2750)
178
- - macro_targets: {{"protein_g": int, "fat_g": int, "carbohydrates_g": int}} (single integer values for each, no ranges)
179
- - flags_to_set: [e.g., "high_ldl", "diabetes_risk"]
180
- - recommendations: clinical dietary constraints or urgent issues (e.g., "refer to PCP for suspected iron deficiency")
181
- - requires_professional_consultation: True/False (True if the case is medically sensitive)
182
- - trace: a single paragraph summarizing which agents/tools were called and key steps.
183
- 6. If any calculation or guideline retrieval fails due to tool error:
184
- - fallback to best-known guideline values only if necessary (mark "data_confidence": 0.xx).
185
- - set "requires_tool_retry": true in the response.
186
- Response JSON must contain:
187
- - medical_reasoning: detailed rationale
188
- - observation: missing/available info
189
- - risk_assessment_priorities: ordered list of 1-4 priorities
190
- - assessment_plan: list of response_step objects (schema above)
191
- - action: either {{"type":"call_tool","tool_name":"ComputationTool" or "WebSearchTool","tool_task": "<task string>"}} or {{"type":"assessment_complete",...}}
192
- """
193
-
194
- if should_debug('agents', 'MedicalAssessmentAgent'):
195
- print(f"\n--- Medical Assessment Agent Iteration {iteration + 1} ---")
196
- if should_debug('agents', 'MedicalAssessmentAgent') and config.DEBUG_LEVEL == 'full':
197
- print(f"Raw LLM input:\n{prompt}")
198
- response = self.llm(prompt)[0]
199
- if should_debug('agents', 'MedicalAssessmentAgent'):
200
- print(f"Medical Assessment Raw Response:\n{response}")
201
-
202
- parsed = extract_and_parse_json(response)
203
-
204
- # Add high-level print for user mode
205
- if not config.DEBUG_MODE:
206
- action_type = parsed.get("action", {}).get("type")
207
- if action_type == "call_tool":
208
- tool_name = parsed["action"].get("tool_name")
209
- tool_task = parsed["action"].get("tool_task")
210
- print(f"👨🏻‍⚕️ Medical Assessment Agent: Using {tool_name} for '{tool_task}'")
211
- elif action_type == "ask_user":
212
- fields = parsed["action"].get("fields", [])
213
- print(f"👨🏻‍⚕️ Medical Assessment Agent: Asking user for missing fields: {', '.join(fields)}")
214
- elif action_type == "assessment_complete":
215
- print("👨🏻‍⚕️ Medical Assessment Agent: Completing assessment")
216
-
217
- if "assessment_plan" in parsed:
218
- assessment_plan = parsed["assessment_plan"]
219
-
220
- action = parsed.get("action", {})
221
- action_type = action.get("type")
222
-
223
- if action_type == "call_tool":
224
- tool_name = action.get("tool_name")
225
- tool_task = action.get("tool_task")
226
-
227
- if tool_name == "ComputationTool":
228
- if tool_task:
229
- result = self.computation_tool.handle_task(tool_task)
230
- else:
231
- result = "Missing 'tool_task' for ComputationTool"
232
- elif tool_name == "WebSearchTool":
233
- if tool_task:
234
- result = self.web_search_tool.handle_task(tool_task)
235
- else:
236
- result = "Missing 'tool_task' for WebSearchTool"
237
- else:
238
- result = f"Unknown tool: {tool_name}"
239
-
240
- tool_results.append(f"{tool_name}: {result}")
241
-
242
- elif action_type == "ask_user":
243
- fields = action.get("fields", [])
244
- result = f"Missing critical fields: {', '.join(fields)}. Please provide the following information to continue the assessment."
245
- print(f"👨🏻‍⚕️ MEDICAL ASSESSMENT AGENT: User query needed - {result}")
246
- return result
247
-
248
- elif action_type == "assessment_complete":
249
- assessment_summary = action.get("assessment_summary")
250
- flags_to_set = action.get("flags_to_set", [])
251
- recommendations = action.get("recommendations", [])
252
- requires_professional_consultation = action.get("requires_professional_consultation", False)
253
- calculations = action.get("calculations", {}) # Now a dict as per new prompt
254
- evidence_sources = action.get("evidence_sources", [])
255
- trace = action.get("trace", "")
256
-
257
- if action.get("requires_tool_retry", False):
258
- result = "Assessment requires tool retry due to failures. Please re-run with fixed tools."
259
- print(f"👨🏻‍⚕️ MEDICAL ASSESSMENT AGENT: Tool retry needed - {result}")
260
- return result # Return early without updating memory
261
-
262
- # Update memory using update_memory_partition
263
- update_memory_partition(memory, "flags_and_assessments", {
264
- "assessment_summary": assessment_summary,
265
- "flags": flags_to_set,
266
- "recommendations": recommendations,
267
- "requires_professional_consultation": requires_professional_consultation,
268
- "calculations": calculations,
269
- "evidence_sources": evidence_sources,
270
- "trace": trace,
271
- "assessment_timestamp": datetime.now().isoformat() # Retained timestamp
272
- })
273
-
274
- # Log the assessment (updated to include new fields)
275
- log_data = {
276
- "task": task,
277
- "memory_input": relevant_memory,
278
- "tool_results": tool_results,
279
- "assessment_summary": assessment_summary,
280
- "flags_set": flags_to_set,
281
- "recommendations": recommendations,
282
- "requires_professional_consultation": requires_professional_consultation,
283
- "evidence_sources": evidence_sources,
284
- "trace": trace,
285
- "timestamp": datetime.now().isoformat()
286
- }
287
- save_to_json(log_data, f'medical_assessment_{datetime.now().isoformat()}.json', subdirectory='MedicalAssessment')
288
-
289
- result = assessment_summary
290
- print(f"👨🏻‍⚕️ MEDICAL ASSESSMENT AGENT COMPLETED: {result}")
291
- return result
292
 
293
  else:
294
- print(f"Unknown action type: {parsed}")
295
  break
296
 
297
- iteration += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
 
299
- # Fallback if max iterations reached
300
- result = f"Medical assessment stopped after {max_iterations} iterations"
301
- print(f"👨🏻‍⚕️ MEDICAL ASSESSMENT AGENT Stopped (MAX ITERATIONS)")
302
- return result
303
 
304
  class PlannerAgent:
305
- def __init__(self, llm_instance, computation_tool: ComputationTool, web_search_tool: WebSearchTool, quantities_finder: QuantitiesFinder):
 
 
 
 
 
 
 
 
306
  self.llm = llm_instance
307
  self.computation_tool = computation_tool
308
  self.web_search_tool = web_search_tool
309
  self.quantities_finder = quantities_finder
310
 
311
  def handle_task(self, task: str, memory: Dict[str, Any]) -> str:
312
- print(f"\n📋 PLANNER AGENT STARTED")
 
313
 
314
  relevant_memory = {
315
  "user_profile": memory.get("user_profile", {}),
316
  "flags_and_assessments": memory.get("flags_and_assessments", {}),
317
  }
318
- memory_str = json.dumps(relevant_memory, indent=2)
319
- tool_results = []
320
- planning_steps = []
321
- max_iterations = 15
322
- iteration = 0
323
-
324
- while iteration < max_iterations:
325
- tool_results_str = "\n".join([f"Tool Result {i+1}: {res}" for i, res in enumerate(tool_results)]) if tool_results else "None"
326
- planning_steps_str = json.dumps(planning_steps, indent=2) if planning_steps else "None"
327
- plan_status = relevant_memory.get("flags_and_assessments", {}).get("assessment_status", "none")
328
- prompt = f"""
329
- You are the Planner Agent. Create personalized meal plans constrained by the medical assessment.
330
-
331
- Task: {task}
332
- Current Memory: {memory_str}
333
- Current Planning Steps: {planning_steps_str}
334
- Previous Tool Results: {tool_results_str}
335
-
336
- Available Tools: WebSearchTool, QuantitiesFinder
337
-
338
- Mandatory behavior & rules:
339
- 1. Precondition: Do NOT start planing unless user medical assessment exists in memory (flags_and_assessments is not empty). If missing, return action: {{"type":"provide_plan", "final_plan":{{"Can't draft plan as flags_and_assessments is empty, please use MedicalAssessmentAgent"}}}}
340
-
341
- 2. Batch behavior:
342
- - Always group related items when using tools. Example: fetch nutrition facts for all foods in one WebSearchTool call instead of multiple calls.
343
-
344
- 3. For each food in the draft:
345
- - Use WebSearchTool to fetch nutrition facts for a standard serving size (or 100g cooked) (e.g., "Find nutrition facts (calories, protein, fat, carbohydrates) for the following items,...").
346
- - If WebSearchTool fails for >2 items, stop retrying and use your internal knowledge.
347
-
348
- 4. Acceptable tolerances:
349
- - Calories: within ±3% of daily_target_calories
350
- - Macronutrients: within ±5% of each macro target
351
-
352
- 5. Exclude all items listed in allergies and avoid disliked foods unless necessary for balance, in which case propose alternatives.
353
-
354
- 6. Flexible Planning: If task requests a multi-day plan (e.g., 7 days), fall back to a shorter balanced plan (1–2 unique days) and instruct user to repeat/rotate.
355
-
356
- 7. QuantitiesFinder Format: When calling 'QuantitiesFinder', the 'tool_task' MUST be a JSON STRING. This string is the serialized version of an object containing "foods" and "targets".
357
-     - "foods": A list of dictionaries. Each dictionary must have:
358
-       - name, calories, protein, fat, carbohydrates (per 100g)
359
-       - estimated_g: Your "best guess" for a realistic quantity (e.g., 150g). The solver will be penalized for deviating from this, so it will try to stay close.
360
-     - "targets": A dictionary containing: calories, protein, fat, carbohydrates.
361
-     - Example: "tool_task": "{{\"foods\": [...], \"targets\": {{...}}}}"
362
-
363
- Planning Steps Handling:
364
- - If Current Planning Steps is empty or 'None', you MUST adopt the following fixed 6-step plan as your primary workflow.
365
- [
366
- {{"id": 1, "description": "Analyze requirements, "Draft a realistic diet plan. For each food, assign a realistic 'estimated_g' (e.g., 150g chicken)."", "status": "pending"}},
367
- {{"id": 1, "description": "Analyze drafted plan, determine a list of all ingredients in the darafted plan, and batch-gather their nutritional facts (calories, protein, fat, carbohydrates) using WebSearchTool.", "status": "pending"}},
368
- {{"id": 3, "description": "Call 'QuantitiesFinder' (PuLP solver) with all nutritional data, targets, and bounds to calculate precise quantities.", "status": "pending"}},
369
- {{"id": 4, "description": "Update the drafted plan with the precise quantities returned by the QuantitiesFinder.", "status": "pending"}},
370
- {{"id": 4, "description": "Provide the final plan 'provide_plan'", "status": "pending"}}
371
- ]
372
-
373
- - If Current Planning Steps is provided... You may remain in a step for multiple iterations if necessary to meet all targets, as outlined in the Iterative Correction Loop rule.
374
-
375
- Return JSON:
376
- - observation, thought
377
- - planning_steps (full list of response_step objects)
378
- - action: one of {{
379
- "type":"call_tool","tool_name":...,"tool_task":...,
380
- "type":"draft_plan","drafted_plan":{{...}},
381
- "type":"provide_plan","final_plan":{{...}}
382
- }}
383
-
384
- Notes:
385
- - Keep each plan realistic and culturally appropriate (regional foods if provided).
386
- - Trace: at the end of the plan, summarize which agents/tools were called.
387
- - Always include the full updated planning_steps in your response JSON to persist across iterations.
388
- """
389
-
390
- if should_debug('agents', 'PlannerAgent'):
391
- print(f"\n--- Planner Agent Iteration {iteration + 1} ---")
392
- if should_debug('agents', 'PlannerAgent') and config.DEBUG_LEVEL == 'full':
393
- print(f"Raw LLM input:\n{prompt}")
394
- response = self.llm(prompt)[0]
395
- if should_debug('agents', 'PlannerAgent'):
396
- print(f"Planner Raw Response:\n{response}")
397
-
398
- parsed = extract_and_parse_json(response)
399
-
400
- # Add high-level print for user mode
401
- if not config.DEBUG_MODE:
402
- action_type = parsed.get("action", {}).get("type")
403
- print_str = "📋 Planner Agent: "
404
- if action_type == "call_tool":
405
- tool_name = parsed["action"].get("tool_name")
406
- tool_task = parsed["action"].get("tool_task")
407
- print_str += f"Using {tool_name} for '{tool_task}'"
408
- elif action_type == "draft_plan":
409
- print_str += "Drafting plan"
410
- elif action_type == "provide_plan":
411
- print_str += "Finalizing plan"
412
- print(print_str)
413
-
414
- planning_steps = parsed.get("planning_steps", planning_steps)
415
-
416
- action = parsed.get("action", {})
417
- action_type = action.get("type")
418
-
419
- if action_type == "call_tool":
420
- tool_name = action.get("tool_name")
421
- tool_task = action.get("tool_task")
422
- if tool_name and tool_task:
423
- print(f"Calling {tool_name} with task: {tool_task}")
424
- if tool_name == "ComputationTool":
425
- result = self.computation_tool.handle_task(tool_task)
426
- elif tool_name == "WebSearchTool":
427
- result = self.web_search_tool.handle_task(tool_task)
428
- elif tool_name == "QuantitiesFinder":
429
- result = self.quantities_finder.handle_task(tool_task)
430
- else:
431
- result = f"Unknown tool: {tool_name}"
432
- tool_results.append(f"{tool_name}: {result}")
433
- else:
434
- print("Missing tool_name or tool_task")
435
-
436
- elif action_type == "draft_plan":
437
- drafted_plan = action.get("drafted_plan")
438
- if drafted_plan:
439
- if "plans" not in memory:
440
- memory["plans"] = {}
441
- memory["plans"]["drafted_plan"] = drafted_plan
442
- result = "Plan drafted and stored in memory"
443
- tool_results.append(result)
444
  else:
445
- result = "Drafted plan not provided"
446
- tool_results.append(result)
447
-
448
- elif action_type == "provide_plan":
449
- final_plan = action.get("final_plan")
450
- if "error" in final_plan:
451
- print(f"\n📋 PLANNER AGENT ERROR: {final_plan}")
452
- return json.dumps(final_plan)
453
- else:
454
- final_plan = final_plan or memory["plans"].get("drafted_plan")
455
- if final_plan:
456
- memory["plans"]["current_plan"] = final_plan
457
- memory["plans"]["plan_timestamp"] = datetime.now().isoformat()
458
- if "drafted_plan" in memory["plans"]:
459
- del memory["plans"]["drafted_plan"]
460
- result = "Planning completed with validated plan"
461
- tool_results.append(result)
462
- log_data = {
463
- "task": task,
464
- "memory_input": relevant_memory,
465
- "tool_results": tool_results,
466
- "final_response": parsed,
467
- "timestamp": datetime.now().isoformat()
468
- }
469
- save_to_json(log_data, f'planner_agent_{datetime.now().isoformat()}.json', subdirectory='PlannerAgent')
470
- print(f"\n📋 PLANNER AGENT COMPLETED: {result}")
471
- return json.dumps(final_plan) if isinstance(final_plan, dict) else final_plan
472
- else:
473
- result = "Cannot finalize: missing plan"
474
- tool_results.append(result)
 
 
475
 
476
  else:
477
- print(f"Unknown action type: {action_type}")
478
  break
479
 
480
- iteration += 1
481
- memory_str = json.dumps({
482
- "user_profile": memory.get("user_profile", {}),
483
- "flags_and_assessments": memory.get("flags_and_assessments", {}),
484
- "plans": memory.get("plans", {})
485
- }, indent=2)
486
-
487
- result = f"Planning stopped after {max_iterations} iterations with {len(tool_results)} actions"
488
- print(f"📋 PLANNER AGENT Stopped (MAX ITERATIONS)")
489
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Agent implementations.
2
+
3
+ Phase 1: every agent's per-turn output is now a Pydantic model from
4
+ ``schemas``. The prompts are split so the static system rules sit in a
5
+ module-level constant (eligible for Gemini's implicit prompt cache) and only
6
+ the dynamic state changes per call.
7
+
8
+ The action-dispatch loops are still *inside* the agent classes — Phase 2 will
9
+ break them into LangGraph subgraphs with parallel tool nodes and the
10
+ ValidationAgent critic loop.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
  import json
16
+ from datetime import datetime
17
+ from typing import Any, Dict, List, Optional
18
+
19
+ from config import get_settings
20
+ from logging_setup import get_logger
21
+ from schemas import (
22
+ CoachDecision,
23
+ MedicalAssessmentDecision,
24
+ MedicalAssessmentResult,
25
+ PlannerDecision,
26
+ )
27
+ from tools import ComputationTool, QuantitiesFinder, WebSearchTool
28
+ from utils import save_to_json, should_debug, update_memory_partition
29
+
30
+ _coach_logger = get_logger("agents.coach")
31
+ _medical_logger = get_logger("agents.medical")
32
+ _planner_logger = get_logger("agents.planner")
33
+
34
+
35
+ # ---------------------------------------------------------------------------
36
+ # Coach
37
+ # ---------------------------------------------------------------------------
38
+ _COACH_SYSTEM_PROMPT = """\
39
+ You are the Coach Agent (central orchestrator) of a nutrition Multi-Agent System.
40
+
41
+ Primary responsibilities:
42
+ - Translate user intent into a concrete workflow of response_steps.
43
+ - Enforce system rules (MedicalAssessment must complete before Planner runs).
44
+ - Decide and perform exactly one action per turn: call_agent, call_tool,
45
+ ask_user, write_memory, or compose_response.
46
+
47
+ Inputs each turn:
48
+ - observation (string built from user query + memory + history)
49
+ - memory partitions: user_profile, medical_history, flags_and_assessments, plans
50
+ - response_steps (list, may be empty on the first turn)
51
+
52
+ Behaviour rules (mandatory):
53
+ 1. If response_steps is empty, generate ordered steps (max 7). Each step
54
+ must include id, actor, prerequisites, and status "pending".
55
+ Typical personal-workflow (when the user asks for a personalised plan):
56
+ 1) Validate required user data (height, weight, age, sex, activity_level,
57
+ allergies, goal). If missing -> ask_user.
58
+ 2) Update memory if the user provided new data [action: write_memory].
59
+ 3) Call MedicalAssessmentAgent with a task to assess the user.
60
+ 4) Wait for assessment to be completed and stored in memory.
61
+ 5) Call PlannerAgent with the relevant task.
62
+ 6) Call ValidationAgent to grade the plan.
63
+ 7) If validation verdict == "revise", re-call PlannerAgent with the
64
+ validation issues prepended to the task; otherwise compose_response.
65
+ 2. When calling any agent, set the called step status to "in_progress" and
66
+ include prerequisites satisfied by your observation.
67
+ 3. Only call PlannerAgent if memory.flags_and_assessments contains an
68
+ "assessment_status" of "assessment_complete". If missing, call
69
+ MedicalAssessmentAgent first.
70
+ 4. After every PlannerAgent run, you MUST call ValidationAgent before
71
+ composing the response. Inspect memory.flags_and_assessments.last_validation:
72
+ * verdict == "pass": proceed to compose_response.
73
+ * verdict == "revise": call PlannerAgent again with task =
74
+ "Revise the plan to address: " + each issue.description joined by "; ".
75
+ Cap revisions at 2; on the third attempt, compose_response with the
76
+ best plan available and append the unresolved issues as warnings.
77
+ * verdict == "reject": compose_response with a clear refusal explaining
78
+ the violation; do NOT show the plan. Append a HITL escalation chip
79
+ (text marker the UI will render).
80
+ 5. When new personal data appears in user input, add steps to: propose memory
81
+ update (write_memory), call MedicalAssessmentAgent if needed, re-plan if
82
+ needed.
83
+ 6. For any write_memory action, provide the full partition contents in
84
+ params.data (not diffs). The Coach is responsible for merging and storing.
85
+
86
+ Output JSON shape (enforced by schema):
87
+ {
88
+ "observation": "...",
89
+ "thought": "...",
90
+ "response_steps": [ ... ],
91
+ "action": "call_agent | call_tool | ask_user | write_memory | compose_response",
92
+ "params": { ... }
93
+ }
94
+
95
+ Required params per action:
96
+ - call_agent: {"agent_name": "...", "task": "..."}
97
+ - call_tool: {"tool_name": "...", "task": "..."}
98
+ - ask_user: {"prompt": "..."}
99
+ - write_memory: {"partition": "...", "data": {...}}
100
+ - compose_response: {"text": "...markdown..."}
101
+
102
+ Composition rules:
103
+ - When composing the response, extract relevant information from memory state
104
+ (calorie target, plan details, dietary restrictions, citations) in markdown.
105
+ - Always include a "trace" line summarising which agents/tools contributed.
106
+ - For high-risk profiles (requires_professional_consultation == true), append
107
+ a bold warning advising professional consultation before implementation.
108
+ """
109
+
110
 
111
  class CoachAgent:
112
  def __init__(self, llm_instance):
113
  self.llm = llm_instance
114
 
115
  def handle_task(self, state: Dict[str, Any]) -> Dict[str, Any]:
116
+ settings = get_settings()
117
+ memory_str = json.dumps(state["memory"], indent=2, default=str)
118
  response_steps = state.get("response_steps", [])
119
+ response_steps_str = (
120
+ json.dumps(response_steps, indent=2, default=str) if response_steps else "None"
121
+ )
122
+
123
+ truncated_history: List[Dict[str, str]] = []
124
  for msg in state["conversation_history"]:
125
  if msg["role"] == "assistant" and len(msg["content"]) > 200:
126
+ truncated_history.append(
127
+ {"role": "assistant", "content": msg["content"][:200] + "... (full response in memory)"}
128
+ )
129
  else:
130
  truncated_history.append(msg)
131
+ history_str = "\n".join(f"{m['role']}: {m['content']}" for m in truncated_history)
132
+
133
+ observation = (
134
+ f"User query: {state['user_question']}\n"
135
+ f"Memory State: {memory_str}\n"
136
+ f"Current Response Steps: {response_steps_str}\n"
137
+ f"Previous Tool Result: {state.get('agent_result', 'None')}\n"
138
+ f"Conversation history: {history_str}"
139
+ )
140
+ prompt = f"{_COACH_SYSTEM_PROMPT}\n\n--- Current State ---\n{observation}"
141
+
142
+ if should_debug("agents", "CoachAgent"):
143
+ _coach_logger.debug("--- Coach Agent Turn %d ---", state["num_turns"] + 1)
144
+ if settings.debug_level == "full":
145
+ _coach_logger.debug("Raw LLM input:\n%s", prompt)
146
+
147
+ decision = self.llm.call_typed(prompt, CoachDecision)
148
+ if decision is None:
149
+ return self._fallback_state(state, "Coach decision could not be parsed.")
150
+
151
+ if should_debug("agents", "CoachAgent"):
152
+ _coach_logger.debug("Coach decision:\n%s", decision.model_dump_json(indent=2))
153
+
154
+ if not settings.debug_mode:
155
+ self._log_user_mode_action(decision)
156
+
157
+ current_action = {"action": decision.action, "params": decision.params}
158
+ new_steps = [s.model_dump() for s in decision.response_steps] or state.get("response_steps", [])
159
+
160
+ save_to_json(
161
+ {
162
+ "prompt": prompt,
163
+ "decision": decision.model_dump(),
164
+ "timestamp": datetime.now().isoformat(),
165
+ },
166
+ f"coach_agent_{datetime.now().isoformat()}.json",
167
+ subdirectory="CoachAgent",
168
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
 
 
 
 
 
 
 
 
 
 
170
  return {
171
  **state,
172
  "current_action": current_action,
173
+ "response_steps": new_steps,
174
  "num_turns": state["num_turns"] + 1,
175
+ "agent_result": None,
176
  }
177
 
178
+ @staticmethod
179
+ def _log_user_mode_action(decision: CoachDecision) -> None:
180
+ params = decision.params or {}
181
+ action = decision.action
182
+ if action == "call_agent":
183
+ msg = f"Calling {params.get('agent_name')} with task '{params.get('task')}'"
184
+ elif action == "call_tool":
185
+ msg = f"Using {params.get('tool_name')} with task '{params.get('task')}'"
186
+ elif action == "ask_user":
187
+ msg = f"Asking user: {params.get('prompt')}"
188
+ elif action == "write_memory":
189
+ msg = f"Writing to memory partition '{params.get('partition')}'"
190
+ elif action == "compose_response":
191
+ msg = "Composing final response"
192
+ else:
193
+ msg = f"Unknown action: {action}"
194
+ _coach_logger.info("\n🏋️‍♂️Coach Agent: %s", msg)
195
+
196
+ @staticmethod
197
+ def _fallback_state(state: Dict[str, Any], message: str) -> Dict[str, Any]:
198
+ _coach_logger.error(message)
199
+ return {
200
+ **state,
201
+ "current_action": {
202
+ "action": "compose_response",
203
+ "params": {"text": f"Sorry — I hit an internal error while planning. ({message})"},
204
+ "_parse_error": True,
205
+ },
206
+ "num_turns": state["num_turns"] + 1,
207
+ "agent_result": None,
208
+ }
209
+
210
+
211
+ # ---------------------------------------------------------------------------
212
+ # Medical Assessment
213
+ # ---------------------------------------------------------------------------
214
+ _MEDICAL_SYSTEM_PROMPT = """\
215
+ You are the Medical Assessment Agent. Produce an evidence-based assessment and
216
+ the clinical flags / calculations the Planner and Validation agents need.
217
+
218
+ Available tools: ComputationTool, WebSearchTool.
219
+
220
+ Mandatory behaviour (do not skip):
221
+ 1. Critical data check: confirm presence of age, sex, height, weight,
222
+ activity_level, allergies, medications. If any critical field is missing,
223
+ set action_type="ask_user" and list the missing names in ``fields``.
224
+ 2. Use ComputationTool for ALL numeric calculations (BMI, BMR, TDEE, calorie
225
+ targets, macro targets). Pass numeric inputs in tool_task.
226
+ 3. Use WebSearchTool to fetch authoritative guidelines (WHO, USDA, ADA,
227
+ EFSA). Capture source URLs with timestamps.
228
+ 4. Produce a compact assessment_plan (3-6 steps). Default sequence:
229
+ a) ComputationTool: BMI, BMR, TDEE, daily_target_calories (single int).
230
+ b) ComputationTool: macro_targets (protein_g, fat_g, carbohydrates_g - all
231
+ single ints, no ranges) optimised for the user's goal.
232
+ c) WebSearchTool: dietary guidelines for the user's conditions.
233
+ d-f) Optional follow-ups for specific risks.
234
+ 5. When complete, set action_type="assessment_complete" and populate
235
+ ``result`` (a MedicalAssessmentResult) with:
236
+ - assessment_summary
237
+ - calculations: { BMI, BMR, TDEE, daily_target_calories,
238
+ macro_targets: { protein_g, fat_g, carbohydrates_g } }
239
+ - flags_to_set (e.g. ["high_ldl", "diabetes_risk"])
240
+ - recommendations (clinical dietary constraints / urgent issues)
241
+ - requires_professional_consultation (True for medically sensitive cases)
242
+ - evidence_sources (list of URLs)
243
+ - trace (one paragraph summarising agent/tool usage)
244
+ 6. If any tool call fails, fall back to best-known values, set
245
+ data_confidence below 1.0, and mark requires_tool_retry=true.
246
+
247
+ Output JSON shape (enforced by schema):
248
+ {
249
+ "medical_reasoning": "...",
250
+ "observation": "...",
251
+ "risk_assessment_priorities": [...],
252
+ "assessment_plan": [...],
253
+ "action_type": "call_tool" | "ask_user" | "assessment_complete",
254
+ "tool_name": "ComputationTool" | "WebSearchTool" | null,
255
+ "tool_task": "..." | null,
256
+ "fields": [...], // only when ask_user
257
+ "result": { ... } // only when assessment_complete
258
+ }
259
+ """
260
+
261
+
262
  class MedicalAssessmentAgent:
263
+ MAX_ITERATIONS = 15
264
+
265
+ def __init__(
266
+ self,
267
+ llm_instance,
268
+ computation_tool: ComputationTool,
269
+ web_search_tool: WebSearchTool,
270
+ ):
271
  self.llm = llm_instance
272
  self.computation_tool = computation_tool
273
  self.web_search_tool = web_search_tool
274
 
275
  def handle_task(self, task: str, memory: Dict[str, Any]) -> str:
276
+ _medical_logger.info("\n👨🏻‍⚕️ MEDICAL ASSESSMENT AGENT STARTED")
277
+ settings = get_settings()
278
 
 
279
  relevant_memory = {
280
  "user_profile": memory.get("user_profile", {}),
281
  "medical_history": memory.get("medical_history", {}),
282
  }
283
+ memory_str = json.dumps(relevant_memory, indent=2, default=str)
284
+ tool_results: List[str] = []
285
+ assessment_plan: List[dict] = []
286
+
287
+ for iteration in range(self.MAX_ITERATIONS):
288
+ tool_results_str = (
289
+ "\n".join(f"Tool Result {i+1}: {r}" for i, r in enumerate(tool_results)) or "None"
290
+ )
291
+ assessment_plan_str = (
292
+ json.dumps(assessment_plan, indent=2, default=str) if assessment_plan else "None"
293
+ )
294
+
295
+ prompt = (
296
+ f"{_MEDICAL_SYSTEM_PROMPT}\n\n--- Task & State ---\n"
297
+ f"Task: {task}\n"
298
+ f"Current Memory: {memory_str}\n"
299
+ f"Current Assessment Plan: {assessment_plan_str}\n"
300
+ f"Previous Tool Results: {tool_results_str}\n"
301
+ )
302
+
303
+ if should_debug("agents", "MedicalAssessmentAgent"):
304
+ _medical_logger.debug("--- Medical Assessment Iteration %d ---", iteration + 1)
305
+ if settings.debug_level == "full":
306
+ _medical_logger.debug("Raw LLM input:\n%s", prompt)
307
+
308
+ decision = self.llm.call_typed(prompt, MedicalAssessmentDecision)
309
+ if decision is None:
310
+ _medical_logger.error("Medical decision parse failed at iteration %d", iteration + 1)
311
+ return "Medical assessment failed: could not parse LLM decision."
312
+
313
+ if should_debug("agents", "MedicalAssessmentAgent"):
314
+ _medical_logger.debug("Medical decision:\n%s", decision.model_dump_json(indent=2))
315
+
316
+ if decision.assessment_plan:
317
+ assessment_plan = [s.model_dump() for s in decision.assessment_plan]
318
+
319
+ if not settings.debug_mode:
320
+ self._log_user_mode_action(decision)
321
+
322
+ if decision.action_type == "call_tool":
323
+ tool_results.append(f"{decision.tool_name}: {self._dispatch_tool(decision)}")
324
+
325
+ elif decision.action_type == "ask_user":
326
+ fields = decision.fields or []
327
+ msg = (
328
+ f"Missing critical fields: {', '.join(fields)}. "
329
+ "Please provide the following information to continue the assessment."
330
+ )
331
+ _medical_logger.info("👨🏻‍⚕️ MEDICAL ASSESSMENT AGENT: User query needed - %s", msg)
332
+ return msg
333
+
334
+ elif decision.action_type == "assessment_complete":
335
+ return self._finalize(task, decision, memory, relevant_memory, tool_results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
  else:
338
+ _medical_logger.error("Unknown action_type: %s", decision.action_type)
339
  break
340
 
341
+ _medical_logger.warning("👨🏻‍⚕️ MEDICAL ASSESSMENT AGENT Stopped (MAX ITERATIONS)")
342
+ return f"Medical assessment stopped after {self.MAX_ITERATIONS} iterations"
343
+
344
+ # ------------------------------------------------------------------
345
+ def _dispatch_tool(self, decision: MedicalAssessmentDecision) -> str:
346
+ tool_name = decision.tool_name
347
+ tool_task = decision.tool_task
348
+ if not tool_task:
349
+ return f"Missing 'tool_task' for {tool_name}"
350
+ if tool_name == "ComputationTool":
351
+ return self.computation_tool.handle_task(tool_task)
352
+ if tool_name == "WebSearchTool":
353
+ return self.web_search_tool.handle_task(tool_task)
354
+ return f"Unknown tool: {tool_name}"
355
+
356
+ @staticmethod
357
+ def _log_user_mode_action(decision: MedicalAssessmentDecision) -> None:
358
+ if decision.action_type == "call_tool":
359
+ _medical_logger.info(
360
+ "👨🏻‍⚕️ Medical Assessment Agent: Using %s for '%s'",
361
+ decision.tool_name,
362
+ decision.tool_task,
363
+ )
364
+ elif decision.action_type == "ask_user":
365
+ _medical_logger.info(
366
+ "👨🏻‍⚕️ Medical Assessment Agent: Asking user for missing fields: %s",
367
+ ", ".join(decision.fields or []),
368
+ )
369
+ elif decision.action_type == "assessment_complete":
370
+ _medical_logger.info("👨🏻‍⚕️ Medical Assessment Agent: Completing assessment")
371
+
372
+ def _finalize(
373
+ self,
374
+ task: str,
375
+ decision: MedicalAssessmentDecision,
376
+ memory: Dict[str, Any],
377
+ relevant_memory: Dict[str, Any],
378
+ tool_results: List[str],
379
+ ) -> str:
380
+ result: Optional[MedicalAssessmentResult] = decision.result
381
+ if result is None:
382
+ _medical_logger.error("assessment_complete decision missing result payload")
383
+ return "Medical assessment failed: completion payload missing."
384
+
385
+ if result.requires_tool_retry:
386
+ msg = "Assessment requires tool retry due to tool failures."
387
+ _medical_logger.warning("👨🏻‍⚕️ MEDICAL ASSESSMENT AGENT: Tool retry needed - %s", msg)
388
+ return msg
389
+
390
+ update_memory_partition(
391
+ memory,
392
+ "flags_and_assessments",
393
+ {
394
+ "assessment_summary": result.assessment_summary,
395
+ "flags": result.flags_to_set,
396
+ "recommendations": result.recommendations,
397
+ "requires_professional_consultation": result.requires_professional_consultation,
398
+ "calculations": result.calculations.model_dump(),
399
+ "evidence_sources": result.evidence_sources,
400
+ "data_confidence": result.data_confidence,
401
+ "trace": result.trace,
402
+ "assessment_status": "assessment_complete",
403
+ "assessment_timestamp": datetime.now().isoformat(),
404
+ },
405
+ )
406
+ save_to_json(
407
+ {
408
+ "task": task,
409
+ "memory_input": relevant_memory,
410
+ "tool_results": tool_results,
411
+ "result": result.model_dump(),
412
+ "timestamp": datetime.now().isoformat(),
413
+ },
414
+ f"medical_assessment_{datetime.now().isoformat()}.json",
415
+ subdirectory="MedicalAssessment",
416
+ )
417
+ _medical_logger.info("👨🏻‍⚕️ MEDICAL ASSESSMENT AGENT COMPLETED: %s", result.assessment_summary)
418
+ return result.assessment_summary
419
+
420
+
421
+ # ---------------------------------------------------------------------------
422
+ # Planner
423
+ # ---------------------------------------------------------------------------
424
+ _PLANNER_SYSTEM_PROMPT = """\
425
+ You are the Planner Agent. Create personalised meal plans constrained by the
426
+ medical assessment.
427
+
428
+ Available tools: WebSearchTool, QuantitiesFinder, ComputationTool.
429
+
430
+ Mandatory behaviour & rules:
431
+ 1. Precondition: do NOT plan unless flags_and_assessments has an
432
+ "assessment_status" of "assessment_complete". If missing, return
433
+ action_type="provide_plan" with final_plan={"error": "..."} explaining the
434
+ blocker and suggesting MedicalAssessmentAgent.
435
+ 2. Batch tool calls: fetch nutrition facts for ALL foods in one WebSearchTool
436
+ call rather than one call per item.
437
+ 3. For each food in the draft, look up per-100g nutrition (calories, protein,
438
+ fat, carbohydrates). If WebSearchTool fails for >2 items, fall back to
439
+ internal knowledge.
440
+ 4. Tolerances: calories +/- 3%, each macro +/- 5% of target.
441
+ 5. Exclude allergens and disliked foods. Propose alternatives if necessary
442
+ for balance.
443
+ 6. Multi-day requests: emit a 1-2 day plan and instruct the user to rotate.
444
+ 7. QuantitiesFinder format: tool_task MUST be a JSON STRING containing
445
+ {"foods": [...], "targets": {...}}. Each food needs name, calories,
446
+ protein, fat, carbohydrates (per 100g) and estimated_g (your best guess).
447
+
448
+ Planning Steps Handling:
449
+ - If Current Planning Steps is empty/None, adopt this fixed 5-step plan:
450
+ 1. Draft a realistic plan; assign a realistic estimated_g per food.
451
+ 2. Batch-gather nutrition facts via WebSearchTool.
452
+ 3. Call QuantitiesFinder with foods + targets to compute precise grams.
453
+ 4. Update the draft with the solver's quantities.
454
+ 5. Provide the final plan via action_type="provide_plan".
455
+ - If steps are provided, you may iterate within a step until targets are met.
456
+
457
+ Output JSON shape (enforced by schema):
458
+ {
459
+ "observation": "...",
460
+ "thought": "...",
461
+ "planning_steps": [...],
462
+ "action_type": "call_tool" | "draft_plan" | "provide_plan",
463
+ "tool_name": "WebSearchTool" | "QuantitiesFinder" | "ComputationTool" | null,
464
+ "tool_task": "..." | null,
465
+ "drafted_plan": { ... } | null,
466
+ "final_plan": { ... } | null
467
+ }
468
+
469
+ Notes:
470
+ - Keep plans realistic and culturally appropriate (regional foods if provided).
471
+ - Include a "trace" line in the final plan summarising agents/tools used.
472
+ - Always echo the full updated planning_steps so they persist across turns.
473
+ """
474
 
 
 
 
 
475
 
476
  class PlannerAgent:
477
+ MAX_ITERATIONS = 15
478
+
479
+ def __init__(
480
+ self,
481
+ llm_instance,
482
+ computation_tool: ComputationTool,
483
+ web_search_tool: WebSearchTool,
484
+ quantities_finder: QuantitiesFinder,
485
+ ):
486
  self.llm = llm_instance
487
  self.computation_tool = computation_tool
488
  self.web_search_tool = web_search_tool
489
  self.quantities_finder = quantities_finder
490
 
491
  def handle_task(self, task: str, memory: Dict[str, Any]) -> str:
492
+ _planner_logger.info("\n📋 PLANNER AGENT STARTED")
493
+ settings = get_settings()
494
 
495
  relevant_memory = {
496
  "user_profile": memory.get("user_profile", {}),
497
  "flags_and_assessments": memory.get("flags_and_assessments", {}),
498
  }
499
+ tool_results: List[str] = []
500
+ planning_steps: List[dict] = []
501
+
502
+ for iteration in range(self.MAX_ITERATIONS):
503
+ memory_str = json.dumps(
504
+ {
505
+ "user_profile": memory.get("user_profile", {}),
506
+ "flags_and_assessments": memory.get("flags_and_assessments", {}),
507
+ "plans": memory.get("plans", {}),
508
+ },
509
+ indent=2,
510
+ default=str,
511
+ )
512
+ tool_results_str = (
513
+ "\n".join(f"Tool Result {i+1}: {r}" for i, r in enumerate(tool_results)) or "None"
514
+ )
515
+ planning_steps_str = (
516
+ json.dumps(planning_steps, indent=2, default=str) if planning_steps else "None"
517
+ )
518
+
519
+ prompt = (
520
+ f"{_PLANNER_SYSTEM_PROMPT}\n\n--- Task & State ---\n"
521
+ f"Task: {task}\n"
522
+ f"Current Memory: {memory_str}\n"
523
+ f"Current Planning Steps: {planning_steps_str}\n"
524
+ f"Previous Tool Results: {tool_results_str}\n"
525
+ )
526
+
527
+ if should_debug("agents", "PlannerAgent"):
528
+ _planner_logger.debug("--- Planner Iteration %d ---", iteration + 1)
529
+ if settings.debug_level == "full":
530
+ _planner_logger.debug("Raw LLM input:\n%s", prompt)
531
+
532
+ decision = self.llm.call_typed(prompt, PlannerDecision)
533
+ if decision is None:
534
+ _planner_logger.error("Planner decision parse failed at iteration %d", iteration + 1)
535
+ return "Planner failed: could not parse LLM decision."
536
+
537
+ if should_debug("agents", "PlannerAgent"):
538
+ _planner_logger.debug("Planner decision:\n%s", decision.model_dump_json(indent=2))
539
+
540
+ if decision.planning_steps:
541
+ planning_steps = [s.model_dump() for s in decision.planning_steps]
542
+
543
+ if not settings.debug_mode:
544
+ self._log_user_mode_action(decision)
545
+
546
+ if decision.action_type == "call_tool":
547
+ tool_results.append(f"{decision.tool_name}: {self._dispatch_tool(decision)}")
548
+
549
+ elif decision.action_type == "draft_plan":
550
+ if decision.drafted_plan:
551
+ memory.setdefault("plans", {})["drafted_plan"] = decision.drafted_plan
552
+ tool_results.append("Plan drafted and stored in memory")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
553
  else:
554
+ tool_results.append("Drafted plan not provided")
555
+
556
+ elif decision.action_type == "provide_plan":
557
+ final = decision.final_plan or memory.get("plans", {}).get("drafted_plan")
558
+
559
+ # Error escape hatch (e.g. precondition not met)
560
+ if isinstance(final, dict) and "error" in final:
561
+ _planner_logger.error("📋 PLANNER AGENT ERROR: %s", final)
562
+ return json.dumps(final)
563
+
564
+ if not final:
565
+ tool_results.append("Cannot finalize: missing plan")
566
+ continue # let the loop try another iteration
567
+
568
+ memory.setdefault("plans", {})
569
+ memory["plans"]["current_plan"] = final
570
+ memory["plans"]["plan_timestamp"] = datetime.now().isoformat()
571
+ memory["plans"].pop("drafted_plan", None)
572
+
573
+ save_to_json(
574
+ {
575
+ "task": task,
576
+ "memory_input": relevant_memory,
577
+ "tool_results": tool_results,
578
+ "final_response": decision.model_dump(),
579
+ "timestamp": datetime.now().isoformat(),
580
+ },
581
+ f"planner_agent_{datetime.now().isoformat()}.json",
582
+ subdirectory="PlannerAgent",
583
+ )
584
+ _planner_logger.info("\n📋 PLANNER AGENT COMPLETED")
585
+ return json.dumps(final) if isinstance(final, dict) else str(final)
586
 
587
  else:
588
+ _planner_logger.error("Unknown action_type: %s", decision.action_type)
589
  break
590
 
591
+ _planner_logger.warning("📋 PLANNER AGENT Stopped (MAX ITERATIONS)")
592
+ return (
593
+ f"Planning stopped after {self.MAX_ITERATIONS} iterations "
594
+ f"with {len(tool_results)} actions"
595
+ )
596
+
597
+ # ------------------------------------------------------------------
598
+ def _dispatch_tool(self, decision: PlannerDecision) -> str:
599
+ tool_name = decision.tool_name
600
+ tool_task = decision.tool_task
601
+ if not tool_name or not tool_task:
602
+ return "Missing tool_name or tool_task"
603
+ if tool_name == "ComputationTool":
604
+ return self.computation_tool.handle_task(tool_task)
605
+ if tool_name == "WebSearchTool":
606
+ return self.web_search_tool.handle_task(tool_task)
607
+ if tool_name == "QuantitiesFinder":
608
+ return self.quantities_finder.handle_task(tool_task)
609
+ return f"Unknown tool: {tool_name}"
610
+
611
+ @staticmethod
612
+ def _log_user_mode_action(decision: PlannerDecision) -> None:
613
+ if decision.action_type == "call_tool":
614
+ _planner_logger.info(
615
+ "📋 Planner Agent: Using %s for '%s'",
616
+ decision.tool_name,
617
+ decision.tool_task,
618
+ )
619
+ elif decision.action_type == "draft_plan":
620
+ _planner_logger.info("📋 Planner Agent: Drafting plan")
621
+ elif decision.action_type == "provide_plan":
622
+ _planner_logger.info("📋 Planner Agent: Finalizing plan")
config.py CHANGED
@@ -1,6 +1,113 @@
1
- # Nutrition MAS Configuration
2
- LOG_DIR = None
3
- PERSISTENCE_DIR = None
4
- DEBUG_MODE = False
5
- DEBUG_LEVEL = 'full' # 'full' or 'output'
6
- DEBUG_SCOPES = {'agents': ['all'], 'tools': ['all']}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Nutrition MAS configuration.
2
+
3
+ This module exposes a Pydantic-Settings ``Settings`` singleton plus a small
4
+ backward-compatibility shim so existing code can still read ``config.DEBUG_MODE``,
5
+ ``config.LOG_DIR``, etc. New code should import :func:`get_settings` directly::
6
+
7
+ from config import get_settings
8
+ settings = get_settings()
9
+ if settings.debug_mode:
10
+ ...
11
+
12
+ Mutation must go through :func:`set_settings` (the legacy ``config.X = y`` write
13
+ pattern would otherwise silently shadow the Pydantic value).
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ from typing import Any, Dict, List, Optional
19
+
20
+ from pydantic import Field
21
+ from pydantic_settings import BaseSettings, SettingsConfigDict
22
+
23
+
24
+ class Settings(BaseSettings):
25
+ """Process-wide configuration.
26
+
27
+ Values are loaded from (in order of precedence): direct ``set_settings``
28
+ calls, environment variables prefixed with ``NUTRITION_MAS_``, the ``.env``
29
+ file in the project root, and the defaults declared here.
30
+ """
31
+
32
+ model_config = SettingsConfigDict(
33
+ env_prefix="NUTRITION_MAS_",
34
+ env_file=".env",
35
+ env_file_encoding="utf-8",
36
+ extra="ignore",
37
+ case_sensitive=False,
38
+ validate_assignment=True,
39
+ )
40
+
41
+ # --- Logging / persistence -------------------------------------------------
42
+ log_dir: Optional[str] = None
43
+ persistence_dir: Optional[str] = None
44
+
45
+ # --- Debug switches --------------------------------------------------------
46
+ debug_mode: bool = False
47
+ debug_level: str = "full" # 'full' or 'output'
48
+ debug_scopes: Dict[str, List[str]] = Field(
49
+ default_factory=lambda: {"agents": ["all"], "tools": ["all"]}
50
+ )
51
+
52
+ # --- LLM / rate limiting ---------------------------------------------------
53
+ enable_rate_limiting: bool = True
54
+ gemini_api_keys: List[str] = Field(default_factory=list)
55
+
56
+
57
+ # Singleton holder. Instantiated lazily so tests can set env vars before first read.
58
+ _settings: Optional[Settings] = None
59
+
60
+
61
+ def get_settings() -> Settings:
62
+ """Return the process-wide ``Settings`` instance, creating it on first call."""
63
+ global _settings
64
+ if _settings is None:
65
+ _settings = Settings()
66
+ return _settings
67
+
68
+
69
+ def reset_settings() -> None:
70
+ """Drop the cached singleton so the next ``get_settings`` call re-reads env.
71
+
72
+ Intended for use in tests.
73
+ """
74
+ global _settings
75
+ _settings = None
76
+
77
+
78
+ def set_settings(**updates: Any) -> Settings:
79
+ """Update fields on the singleton ``Settings``.
80
+
81
+ Accepts both legacy upper-case names (``DEBUG_MODE``) and Pydantic field
82
+ names (``debug_mode``). Returns the updated settings instance.
83
+ """
84
+ s = get_settings()
85
+ for raw_key, value in updates.items():
86
+ attr = _LEGACY_ATTR_MAP.get(raw_key, raw_key.lower())
87
+ if not hasattr(s, attr):
88
+ raise AttributeError(f"Settings has no attribute {attr!r}")
89
+ setattr(s, attr, value)
90
+ return s
91
+
92
+
93
+ # --- Legacy attribute proxy ----------------------------------------------------
94
+ # Existing code does ``import config`` then reads ``config.DEBUG_MODE`` etc.
95
+ # PEP 562 ``__getattr__`` lets us forward those reads to the singleton.
96
+ _LEGACY_ATTR_MAP: Dict[str, str] = {
97
+ "DEBUG_MODE": "debug_mode",
98
+ "DEBUG_LEVEL": "debug_level",
99
+ "DEBUG_SCOPES": "debug_scopes",
100
+ "LOG_DIR": "log_dir",
101
+ "PERSISTENCE_DIR": "persistence_dir",
102
+ "ENABLE_RATE_LIMITING": "enable_rate_limiting",
103
+ }
104
+
105
+
106
+ def __getattr__(name: str) -> Any: # noqa: D401 — module-level dunder
107
+ """PEP 562: forward legacy CONST-style reads to the Settings singleton."""
108
+ if name in _LEGACY_ATTR_MAP:
109
+ return getattr(get_settings(), _LEGACY_ATTR_MAP[name])
110
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
111
+
112
+
113
+ __all__ = ["Settings", "get_settings", "reset_settings", "set_settings"]
logging_setup.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Centralised logging for Nutrition MAS.
2
+
3
+ Agents and tools used to ``print`` directly to stdout. That worked in a notebook
4
+ but coupled the agentic system to the I/O layer. This module provides a single
5
+ ``get_logger`` entrypoint so:
6
+
7
+ * user-mode emoji status lines flow through ``logger.info`` (visible by default),
8
+ * debug-mode raw LLM dumps flow through ``logger.debug`` (hidden unless
9
+ ``settings.debug_mode`` is True),
10
+ * later phases can attach extra handlers (SSE event stream for the API, JSON
11
+ file handler for trace persistence, etc.) without touching agent code.
12
+
13
+ Idempotent: calling :func:`configure_logging` more than once is a no-op unless
14
+ ``force=True``.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import logging
20
+ import sys
21
+
22
+ from config import get_settings
23
+
24
+ _BASE = "nutrition_mas"
25
+ _CONFIGURED = False
26
+
27
+
28
+ def configure_logging(*, force: bool = False) -> None:
29
+ """Wire up the ``nutrition_mas`` logger tree.
30
+
31
+ Reads ``settings.debug_mode`` to choose between INFO (user mode) and DEBUG.
32
+ Safe to call from library code; only the first call attaches a handler.
33
+ """
34
+ global _CONFIGURED
35
+ if _CONFIGURED and not force:
36
+ return
37
+
38
+ settings = get_settings()
39
+ level = logging.DEBUG if settings.debug_mode else logging.INFO
40
+
41
+ handler = logging.StreamHandler(sys.stdout)
42
+ handler.setLevel(level)
43
+ handler.setFormatter(logging.Formatter("%(message)s"))
44
+
45
+ root = logging.getLogger(_BASE)
46
+ root.handlers = [handler]
47
+ root.setLevel(level)
48
+ root.propagate = False
49
+
50
+ _CONFIGURED = True
51
+
52
+
53
+ def get_logger(name: str) -> logging.Logger:
54
+ """Return a sub-logger under the ``nutrition_mas`` namespace.
55
+
56
+ Conventional names: ``agents.coach``, ``agents.medical``, ``tools.computation``,
57
+ ``utils.api_pool``.
58
+ """
59
+ if not _CONFIGURED:
60
+ configure_logging()
61
+ return logging.getLogger(f"{_BASE}.{name}")
62
+
63
+
64
+ def refresh_level() -> None:
65
+ """Re-read ``settings.debug_mode`` and adjust handler levels in place.
66
+
67
+ Call this after toggling debug mode at runtime.
68
+ """
69
+ settings = get_settings()
70
+ level = logging.DEBUG if settings.debug_mode else logging.INFO
71
+ root = logging.getLogger(_BASE)
72
+ root.setLevel(level)
73
+ for handler in root.handlers:
74
+ handler.setLevel(level)
75
+
76
+
77
+ __all__ = ["configure_logging", "get_logger", "refresh_level"]
nutritionmas.py CHANGED
@@ -1,45 +1,52 @@
 
1
  import os
2
- import config
3
- from state import NutritionState, initialize_empty_memory
4
- from utils import create_llm, save_to_json, APIPoolManager
5
- from tools import ComputationTool, WebSearchTool, QuantitiesFinder
 
6
  from agents import CoachAgent, MedicalAssessmentAgent, PlannerAgent
 
 
 
 
 
 
7
  from workflow import setup_workflow as setup_workflow_workflow
8
- from datetime import datetime
9
- import random
10
- import json
11
- from typing import Optional, Dict, Any, List
12
- from IPython.display import display, Markdown
13
 
14
- def debug(level: str = 'full', scopes: Optional[Dict[str, List[str]]] = None):
15
- """
16
- Enable debug mode with specified level and scopes.
17
-
 
 
18
  Args:
19
  level: 'full' (default) to show inputs and outputs, or 'output' to show only outputs.
20
- scopes: Optional dict like {'agents': ['all'], 'tools': ['ComputationTool']}.
21
  If None, defaults to all agents and tools.
22
  """
23
- config.DEBUG_MODE = True
24
- config.DEBUG_LEVEL = level
25
  if scopes is None:
26
- config.DEBUG_SCOPES = {'agents': ['all'], 'tools': ['all']}
27
- else:
28
- config.DEBUG_SCOPES = scopes
29
 
30
- def logging(log_dir=None, persistence_dir=None):
31
- """
32
- Set the directories for logging and persistence.
33
- If log_dir is provided, logging will be enabled to that directory.
34
- If persistence_dir is provided, file-based persistence will be used for checkpoints.
35
- If not provided, logging is disabled, and in-memory persistence is used.
 
36
  """
 
37
  if log_dir is not None:
38
- config.LOG_DIR = log_dir
39
- os.makedirs(config.LOG_DIR, exist_ok=True)
40
  if persistence_dir is not None:
41
- config.PERSISTENCE_DIR = persistence_dir
42
- os.makedirs(config.PERSISTENCE_DIR, exist_ok=True)
 
 
43
 
44
  # Default model configurations (without API keys, as they will be provided by the user)
45
  DEFAULT_MODEL_CONFIGS = {
@@ -71,6 +78,13 @@ DEFAULT_MODEL_CONFIGS = {
71
  "thinking_budget": 600,
72
  "params": {"max_tokens": 5120, "temperature": 0.3}
73
  },
 
 
 
 
 
 
 
74
  "user_simulator": {
75
  "type": "gemini",
76
  "model_name": "gemini-2.5-flash",
@@ -115,25 +129,32 @@ def create_llm_instances(api_keys: list[str], model_overrides: Optional[Dict[str
115
  rate_limits = None
116
 
117
  manager = APIPoolManager(api_keys, rate_limits)
118
- print(f"APIPoolManager initialized with {'rate limiting enabled' if enable_rate_limiting else 'rate limiting disabled'} and {len(api_keys)} API keys.")
119
-
120
- model_configs = {}
 
 
 
 
 
 
121
  for key in DEFAULT_MODEL_CONFIGS:
122
- config = DEFAULT_MODEL_CONFIGS[key].copy()
123
  if model_overrides and key in model_overrides:
124
  override = model_overrides[key]
125
  if "model_name" in override:
126
- config["model_name"] = override["model_name"]
127
  if "params" in override:
128
- config["params"].update(override["params"])
129
- model_configs[key] = config
130
 
131
  LLM_INSTANCES = {
132
  "main": create_llm(model_configs["main"], manager),
133
  "agents_llm": create_llm(model_configs["agents_llm"], manager),
134
  "tools_llm": create_llm(model_configs["tools_llm"], manager),
135
  "planner_agent": create_llm(model_configs["planner_agent"], manager),
136
- "user_simulator": create_llm(model_configs["user_simulator"], manager)
 
137
  }
138
 
139
  def initialize_tools():
@@ -156,10 +177,16 @@ def initialize_agents():
156
  MAIN_LLM = LLM_INSTANCES["main"]
157
  AGENTS_LLM = LLM_INSTANCES["agents_llm"]
158
  PLANNER_LLM = LLM_INSTANCES["planner_agent"]
 
159
  AGENTS = {
160
  "CoachAgent": CoachAgent(MAIN_LLM),
161
- "MedicalAssessmentAgent": MedicalAssessmentAgent(AGENTS_LLM, TOOLS["ComputationTool"], TOOLS["WebSearchTool"]),
162
- "PlannerAgent": PlannerAgent(PLANNER_LLM, TOOLS["ComputationTool"], TOOLS["WebSearchTool"], TOOLS["QuantitiesFinder"])
 
 
 
 
 
163
  }
164
 
165
  def setup_workflow():
 
1
+ import json
2
  import os
3
+ from datetime import datetime
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ from IPython.display import Markdown, display
7
+
8
  from agents import CoachAgent, MedicalAssessmentAgent, PlannerAgent
9
+ from config import set_settings
10
+ from logging_setup import get_logger, refresh_level
11
+ from state import initialize_empty_memory
12
+ from tools import ComputationTool, QuantitiesFinder, WebSearchTool
13
+ from utils import APIPoolManager, create_llm
14
+ from validation import ValidationAgent
15
  from workflow import setup_workflow as setup_workflow_workflow
 
 
 
 
 
16
 
17
+ _logger = get_logger("nutritionmas")
18
+
19
+
20
+ def debug(level: str = "full", scopes: Optional[Dict[str, List[str]]] = None) -> None:
21
+ """Enable debug mode with the given level and scopes.
22
+
23
  Args:
24
  level: 'full' (default) to show inputs and outputs, or 'output' to show only outputs.
25
+ scopes: Optional dict like ``{'agents': ['all'], 'tools': ['ComputationTool']}``.
26
  If None, defaults to all agents and tools.
27
  """
 
 
28
  if scopes is None:
29
+ scopes = {"agents": ["all"], "tools": ["all"]}
30
+ set_settings(debug_mode=True, debug_level=level, debug_scopes=scopes)
31
+ refresh_level()
32
 
33
+
34
+ def logging(log_dir: Optional[str] = None, persistence_dir: Optional[str] = None) -> None: # noqa: A001 - public name kept for backwards compat
35
+ """Set directories for log files and LangGraph checkpoint persistence.
36
+
37
+ If ``log_dir`` is provided, agent/tool I/O is dumped there as JSON.
38
+ If ``persistence_dir`` is provided, LangGraph checkpoints are persisted to disk.
39
+ If neither is set, logging is disabled and persistence is in-memory.
40
  """
41
+ updates: Dict[str, Any] = {}
42
  if log_dir is not None:
43
+ os.makedirs(log_dir, exist_ok=True)
44
+ updates["log_dir"] = log_dir
45
  if persistence_dir is not None:
46
+ os.makedirs(persistence_dir, exist_ok=True)
47
+ updates["persistence_dir"] = persistence_dir
48
+ if updates:
49
+ set_settings(**updates)
50
 
51
  # Default model configurations (without API keys, as they will be provided by the user)
52
  DEFAULT_MODEL_CONFIGS = {
 
78
  "thinking_budget": 600,
79
  "params": {"max_tokens": 5120, "temperature": 0.3}
80
  },
81
+ "validation_agent": {
82
+ "type": "gemini",
83
+ "model_name": "gemini-2.5-flash",
84
+ "structured_output": True,
85
+ "thinking_budget": 300,
86
+ "params": {"max_tokens": 3072, "temperature": 0.2}
87
+ },
88
  "user_simulator": {
89
  "type": "gemini",
90
  "model_name": "gemini-2.5-flash",
 
129
  rate_limits = None
130
 
131
  manager = APIPoolManager(api_keys, rate_limits)
132
+ _logger.info(
133
+ "APIPoolManager initialized with %s and %d API keys.",
134
+ "rate limiting enabled" if enable_rate_limiting else "rate limiting disabled",
135
+ len(api_keys),
136
+ )
137
+
138
+ # Note: previously this loop used a local variable named ``config`` which
139
+ # shadowed the imported ``config`` module — now ``cfg`` to avoid the trap.
140
+ model_configs: Dict[str, Dict[str, Any]] = {}
141
  for key in DEFAULT_MODEL_CONFIGS:
142
+ cfg = DEFAULT_MODEL_CONFIGS[key].copy()
143
  if model_overrides and key in model_overrides:
144
  override = model_overrides[key]
145
  if "model_name" in override:
146
+ cfg["model_name"] = override["model_name"]
147
  if "params" in override:
148
+ cfg["params"] = {**cfg.get("params", {}), **override["params"]}
149
+ model_configs[key] = cfg
150
 
151
  LLM_INSTANCES = {
152
  "main": create_llm(model_configs["main"], manager),
153
  "agents_llm": create_llm(model_configs["agents_llm"], manager),
154
  "tools_llm": create_llm(model_configs["tools_llm"], manager),
155
  "planner_agent": create_llm(model_configs["planner_agent"], manager),
156
+ "validation_agent": create_llm(model_configs["validation_agent"], manager),
157
+ "user_simulator": create_llm(model_configs["user_simulator"], manager),
158
  }
159
 
160
  def initialize_tools():
 
177
  MAIN_LLM = LLM_INSTANCES["main"]
178
  AGENTS_LLM = LLM_INSTANCES["agents_llm"]
179
  PLANNER_LLM = LLM_INSTANCES["planner_agent"]
180
+ VALIDATION_LLM = LLM_INSTANCES["validation_agent"]
181
  AGENTS = {
182
  "CoachAgent": CoachAgent(MAIN_LLM),
183
+ "MedicalAssessmentAgent": MedicalAssessmentAgent(
184
+ AGENTS_LLM, TOOLS["ComputationTool"], TOOLS["WebSearchTool"]
185
+ ),
186
+ "PlannerAgent": PlannerAgent(
187
+ PLANNER_LLM, TOOLS["ComputationTool"], TOOLS["WebSearchTool"], TOOLS["QuantitiesFinder"]
188
+ ),
189
+ "ValidationAgent": ValidationAgent(VALIDATION_LLM),
190
  }
191
 
192
  def setup_workflow():
pyproject.toml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=68", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "nutrition-mas"
7
+ version = "0.2.0"
8
+ description = "Multi-agent system for personalised nutrition planning, built on LangGraph + Gemini."
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ license = { text = "MIT" }
12
+ authors = [{ name = "Moaz Eldegwy", email = "moazeldegwy@gmail.com" }]
13
+
14
+ dependencies = [
15
+ "langgraph>=0.2.50,<0.3",
16
+ "langchain-core>=0.3.20,<0.4",
17
+ "google-genai>=0.3.0",
18
+ "pydantic>=2.9,<3",
19
+ "pydantic-settings>=2.6,<3",
20
+ "pulp>=2.9,<3",
21
+ "ddgs>=6.3,<7",
22
+ "json-repair>=0.30",
23
+ "python-dotenv>=1.0,<2",
24
+ "ipython>=8.0",
25
+ ]
26
+
27
+ [project.optional-dependencies]
28
+ dev = [
29
+ "pytest>=8.0",
30
+ "pytest-asyncio>=0.24",
31
+ "pytest-cov>=5.0",
32
+ "ruff>=0.7",
33
+ ]
34
+
35
+ [tool.setuptools]
36
+ py-modules = ["agents", "config", "nutritionmas", "state", "tools", "utils", "workflow", "logging_setup"]
37
+
38
+ [tool.ruff]
39
+ line-length = 110
40
+ target-version = "py310"
41
+
42
+ [tool.ruff.lint]
43
+ select = ["E", "F", "I", "B", "UP", "SIM"]
44
+ ignore = ["E501"]
45
+
46
+ [tool.pytest.ini_options]
47
+ testpaths = ["tests"]
48
+ addopts = "-ra --strict-markers"
49
+ markers = [
50
+ "integration: tests that hit a real LLM (skipped by default)",
51
+ "slow: long-running tests",
52
+ ]
requirements.txt ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core agent framework
2
+ langgraph>=0.2.50,<0.3
3
+ langchain-core>=0.3.20,<0.4
4
+
5
+ # LLM provider (Gemini)
6
+ google-genai>=0.3.0
7
+
8
+ # Schemas / settings
9
+ pydantic>=2.9,<3
10
+ pydantic-settings>=2.6,<3
11
+
12
+ # Optimization (meal-quantities solver)
13
+ pulp>=2.9,<3
14
+
15
+ # Web search fallback
16
+ ddgs>=6.3,<7
17
+
18
+ # JSON repair fallback (kept until Phase 1 makes it a measured fallback)
19
+ json-repair>=0.30
20
+
21
+ # Env loading
22
+ python-dotenv>=1.0,<2
23
+
24
+ # Markdown rendering for notebook display (kept for backwards compat)
25
+ ipython>=8.0
26
+
27
+ # Tests / dev
28
+ pytest>=8.0
29
+ pytest-asyncio>=0.24
30
+ pytest-cov>=5.0
31
+ ruff>=0.7
schemas.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic models for agent inputs and outputs.
2
+
3
+ This module is the contract between the LLM, the orchestration layer, and the
4
+ test suite. Every agent's decision now passes through one of these models — so:
5
+
6
+ * Gemini's ``response_schema`` (constrained decoding) returns guaranteed-shape
7
+ JSON; we no longer rely on regex / ``json_repair`` for the high-stakes path.
8
+ * Tests can construct decisions directly without hand-crafted JSON strings.
9
+ * Phase 2 can split agent loops into LangGraph nodes that pass typed objects
10
+ between them.
11
+
12
+ Where Gemini's schema support is fussy (e.g. discriminated unions with
13
+ ``$ref``), we keep the outer envelope strict and leave per-action ``params``
14
+ as a free dict — the agent dispatcher validates it at use time.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ from typing import Any, Dict, List, Literal, Optional
20
+
21
+ from pydantic import BaseModel, Field
22
+
23
+
24
+ # ---------------------------------------------------------------------------
25
+ # Shared / leaf types
26
+ # ---------------------------------------------------------------------------
27
+ StepStatus = Literal["pending", "in_progress", "completed", "skipped", "failed"]
28
+
29
+
30
+ class ResponseStep(BaseModel):
31
+ """A single step in the Coach's response plan."""
32
+
33
+ id: int
34
+ actor: str = Field(
35
+ description="Who executes this step. Examples: 'CoachAgent', 'MedicalAssessmentAgent', "
36
+ "'PlannerAgent', 'ValidationAgent', 'user'.",
37
+ )
38
+ description: str
39
+ prerequisites: List[str] = Field(default_factory=list)
40
+ status: StepStatus = "pending"
41
+
42
+
43
+ class MacroTargets(BaseModel):
44
+ """Daily macronutrient targets in grams (single integer values)."""
45
+
46
+ protein_g: int = Field(ge=0)
47
+ fat_g: int = Field(ge=0)
48
+ carbohydrates_g: int = Field(ge=0)
49
+
50
+
51
+ class Calculations(BaseModel):
52
+ """Derived anthropometric + nutritional values from the assessment."""
53
+
54
+ BMI: float = Field(ge=0)
55
+ BMR: float = Field(ge=0)
56
+ TDEE: float = Field(ge=0)
57
+ daily_target_calories: int = Field(ge=0)
58
+ macro_targets: MacroTargets
59
+
60
+
61
+ # ---------------------------------------------------------------------------
62
+ # Coach Agent
63
+ # ---------------------------------------------------------------------------
64
+ CoachActionType = Literal[
65
+ "call_agent",
66
+ "call_tool",
67
+ "ask_user",
68
+ "write_memory",
69
+ "compose_response",
70
+ ]
71
+
72
+
73
+ class CoachDecision(BaseModel):
74
+ """Single turn of the Coach orchestrator.
75
+
76
+ Outer shape is strict; ``params`` is left as a dict because Gemini's
77
+ schema layer struggles with deeply discriminated unions. The dispatcher
78
+ in :mod:`workflow` validates ``params`` against the action type.
79
+ """
80
+
81
+ observation: str
82
+ thought: str
83
+ response_steps: List[ResponseStep] = Field(default_factory=list)
84
+ action: CoachActionType
85
+ params: Dict[str, Any] = Field(
86
+ default_factory=dict,
87
+ description=(
88
+ "Action-specific parameters. Required keys per action: "
89
+ "call_agent={agent_name, task}, call_tool={tool_name, task}, "
90
+ "ask_user={prompt}, write_memory={partition, data}, "
91
+ "compose_response={text}."
92
+ ),
93
+ )
94
+
95
+
96
+ # ---------------------------------------------------------------------------
97
+ # Medical Assessment Agent
98
+ # ---------------------------------------------------------------------------
99
+ MedicalActionType = Literal["call_tool", "ask_user", "assessment_complete"]
100
+
101
+
102
+ class MedicalAssessmentResult(BaseModel):
103
+ """Final payload stored in ``memory.flags_and_assessments``."""
104
+
105
+ assessment_summary: str
106
+ flags_to_set: List[str] = Field(default_factory=list)
107
+ recommendations: List[str] = Field(default_factory=list)
108
+ requires_professional_consultation: bool = False
109
+ calculations: Calculations
110
+ evidence_sources: List[str] = Field(default_factory=list)
111
+ trace: str = ""
112
+ requires_tool_retry: bool = False
113
+ data_confidence: float = Field(default=1.0, ge=0.0, le=1.0)
114
+
115
+
116
+ class MedicalAssessmentDecision(BaseModel):
117
+ """Per-iteration output of the Medical Assessment Agent loop."""
118
+
119
+ medical_reasoning: str
120
+ observation: str
121
+ risk_assessment_priorities: List[str] = Field(default_factory=list)
122
+ assessment_plan: List[ResponseStep] = Field(default_factory=list)
123
+
124
+ action_type: MedicalActionType
125
+ # action-specific fields (kept flat — see CoachDecision rationale)
126
+ tool_name: Optional[str] = None
127
+ tool_task: Optional[str] = None
128
+ fields: List[str] = Field(default_factory=list) # for ask_user
129
+ result: Optional[MedicalAssessmentResult] = None # for assessment_complete
130
+
131
+
132
+ # ---------------------------------------------------------------------------
133
+ # Planner Agent
134
+ # ---------------------------------------------------------------------------
135
+ PlannerActionType = Literal["call_tool", "draft_plan", "provide_plan"]
136
+
137
+
138
+ class FoodItem(BaseModel):
139
+ """A single ingredient on the plan, post-solver."""
140
+
141
+ name: str
142
+ grams: float = Field(ge=0)
143
+ calories: float = Field(ge=0)
144
+ protein_g: float = Field(ge=0)
145
+ fat_g: float = Field(ge=0)
146
+ carbohydrates_g: float = Field(ge=0)
147
+ meal_group: Optional[str] = None
148
+
149
+
150
+ class FinalPlan(BaseModel):
151
+ """The shape stored in ``memory.plans.current_plan``."""
152
+
153
+ days: List[List[FoodItem]] = Field(
154
+ description="One inner list per day. Most plans return a single day.",
155
+ )
156
+ daily_totals: Dict[str, float] = Field(default_factory=dict)
157
+ notes: str = ""
158
+ sources: List[str] = Field(default_factory=list)
159
+ trace: str = ""
160
+
161
+
162
+ class PlannerDecision(BaseModel):
163
+ """Per-iteration output of the Planner Agent loop."""
164
+
165
+ observation: str
166
+ thought: str
167
+ planning_steps: List[ResponseStep] = Field(default_factory=list)
168
+
169
+ action_type: PlannerActionType
170
+ tool_name: Optional[str] = None
171
+ tool_task: Optional[str] = None
172
+ drafted_plan: Optional[Dict[str, Any]] = None # free shape pre-solver
173
+ final_plan: Optional[Dict[str, Any]] = None # free shape until validation lands in Phase 2
174
+
175
+
176
+ # ---------------------------------------------------------------------------
177
+ # Validation Agent (lands in Phase 2; defined here so Phase 1 schemas are
178
+ # the single source of truth)
179
+ # ---------------------------------------------------------------------------
180
+ ValidationVerdict = Literal["pass", "revise", "reject"]
181
+ ValidationSeverity = Literal["low", "medium", "high"]
182
+
183
+
184
+ class ValidationIssue(BaseModel):
185
+ code: str = Field(description="Stable error code, e.g. 'allergy_violation'.")
186
+ description: str
187
+ severity: ValidationSeverity = "medium"
188
+
189
+
190
+ class ValidationDecision(BaseModel):
191
+ verdict: ValidationVerdict
192
+ issues: List[ValidationIssue] = Field(default_factory=list)
193
+ notes: str = ""
194
+ requires_human_review: bool = False
195
+
196
+
197
+ __all__ = [
198
+ "Calculations",
199
+ "CoachActionType",
200
+ "CoachDecision",
201
+ "FinalPlan",
202
+ "FoodItem",
203
+ "MacroTargets",
204
+ "MedicalActionType",
205
+ "MedicalAssessmentDecision",
206
+ "MedicalAssessmentResult",
207
+ "PlannerActionType",
208
+ "PlannerDecision",
209
+ "ResponseStep",
210
+ "StepStatus",
211
+ "ValidationDecision",
212
+ "ValidationIssue",
213
+ "ValidationSeverity",
214
+ "ValidationVerdict",
215
+ ]
tests/__init__.py ADDED
File without changes
tests/conftest.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared pytest fixtures.
2
+
3
+ The mock LLM here is the workhorse for offline tests — it lets us run agents
4
+ end-to-end without paying for Gemini calls. Phase 1 will give us schema-typed
5
+ agent responses; until then, each test passes the raw JSON string the agent
6
+ expects to receive.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import json
12
+ from typing import Any, Dict, List
13
+
14
+ import pytest
15
+
16
+ from config import reset_settings, set_settings
17
+ from utils import APIPoolManager, LLM
18
+
19
+
20
+ class MockLLM(LLM):
21
+ """LLM stub that returns canned responses in order.
22
+
23
+ Tests construct it with a list of either:
24
+
25
+ * a JSON-string (returned as-is for the untyped __call__ path),
26
+ * a dict (JSON-serialised on push; validated against the requested schema
27
+ on call_typed),
28
+ * a Pydantic ``BaseModel`` instance (returned as-is from call_typed; its
29
+ ``.model_dump_json()`` is used for __call__).
30
+
31
+ Each call pops the next item. Out-of-script calls raise so missing
32
+ fixtures are noisy rather than silent.
33
+ """
34
+
35
+ def __init__(self, responses: List[Any]) -> None:
36
+ self._responses: List[Any] = list(responses)
37
+ self.calls: List[str] = []
38
+ self.typed_calls: List[tuple[str, type]] = []
39
+
40
+ def _next(self, prompt: str) -> Any:
41
+ self.calls.append(prompt)
42
+ if not self._responses:
43
+ raise AssertionError(
44
+ f"MockLLM ran out of canned responses (call #{len(self.calls)}). "
45
+ f"Last prompt:\n{prompt[:300]}"
46
+ )
47
+ return self._responses.pop(0)
48
+
49
+ def __call__(self, prompt: str, **_: Any) -> List[str]:
50
+ item = self._next(prompt)
51
+ if hasattr(item, "model_dump_json"):
52
+ return [item.model_dump_json()]
53
+ if isinstance(item, dict):
54
+ return [json.dumps(item)]
55
+ return [str(item)]
56
+
57
+ def call_typed(self, prompt: str, response_model: type, **_: Any):
58
+ from pydantic import BaseModel
59
+
60
+ self.typed_calls.append((prompt, response_model))
61
+ item = self._next(prompt)
62
+ if isinstance(item, BaseModel):
63
+ return item if isinstance(item, response_model) else None
64
+ if isinstance(item, dict):
65
+ try:
66
+ return response_model.model_validate(item)
67
+ except Exception:
68
+ return None
69
+ if isinstance(item, str):
70
+ try:
71
+ return response_model.model_validate_json(item)
72
+ except Exception:
73
+ return None
74
+ return None
75
+
76
+ def format_prompt(self, messages: List[Dict[str, str]]) -> str:
77
+ return "\n".join(f"{m['role']}: {m['content']}" for m in messages)
78
+
79
+
80
+ @pytest.fixture
81
+ def mock_llm_factory():
82
+ """Factory to build a MockLLM from a list of canned responses."""
83
+ return MockLLM
84
+
85
+
86
+ @pytest.fixture(autouse=True)
87
+ def fresh_settings():
88
+ """Reset the Settings singleton before/after each test for isolation."""
89
+ reset_settings()
90
+ set_settings(debug_mode=False, log_dir=None, persistence_dir=None)
91
+ yield
92
+ reset_settings()
93
+
94
+
95
+ @pytest.fixture
96
+ def api_pool_no_limits():
97
+ """An APIPoolManager with rate limiting disabled — for unit tests that
98
+ don't care about throttling."""
99
+ return APIPoolManager(["test-key-1", "test-key-2"], rate_limits=None)
tests/test_api_pool.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for the APIPoolManager rate limiter."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import time
6
+
7
+ import pytest
8
+
9
+ from utils import APIPoolManager
10
+
11
+
12
+ def test_round_robin_no_limits() -> None:
13
+ pool = APIPoolManager(["k1", "k2", "k3"], rate_limits=None)
14
+ seen = [pool.get_next_key("any-model") for _ in range(6)]
15
+ # With no limits we should walk through all keys at least twice.
16
+ assert set(seen) == {"k1", "k2", "k3"}
17
+
18
+
19
+ def test_rpm_spacing_enforced() -> None:
20
+ """With RPM=60 we expect a ~1s spacing between consecutive uses of the
21
+ same key. Two-key pool should let us avoid the wait."""
22
+ pool = APIPoolManager(["k1", "k2"], rate_limits={"m": (60, 1000)})
23
+
24
+ k_a = pool.get_next_key("m")
25
+ pool.record_usage(k_a, "m", time.time())
26
+ k_b = pool.get_next_key("m")
27
+ assert k_a != k_b, "Round-robin should pick the other key when one is hot"
28
+
29
+
30
+ def test_rpd_exhaustion_drops_key() -> None:
31
+ """A key that hits its daily limit must be removed from active pool."""
32
+ pool = APIPoolManager(["k1", "k2"], rate_limits={"m": (60, 2)})
33
+ for _ in range(2):
34
+ k = pool.get_next_key("m")
35
+ pool.record_usage(k, "m")
36
+
37
+ # By now both keys may have hit their RPD=2. Next call should still work
38
+ # if at least one key has capacity, else raise RuntimeError.
39
+ keys_left = list(pool.active_keys)
40
+ if not keys_left:
41
+ with pytest.raises(RuntimeError):
42
+ pool.get_next_key("m")
43
+ else:
44
+ # Drain the remaining one too.
45
+ for _ in range(2):
46
+ try:
47
+ k = pool.get_next_key("m")
48
+ pool.record_usage(k, "m")
49
+ except RuntimeError:
50
+ break
51
+ assert not pool.active_keys, "Both keys should be exhausted now"
tests/test_quantities_finder.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Smoke tests for the PuLP-backed QuantitiesFinder.
2
+
3
+ Pure deterministic tool — no LLM, no network. Should be the fastest test in
4
+ the suite and the one we trust most.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+
11
+ from tools import QuantitiesFinder
12
+
13
+
14
+ def test_basic_two_food_balance() -> None:
15
+ """Two foods, simple targets — solver must find quantities within a few
16
+ percent of the target."""
17
+ qf = QuantitiesFinder()
18
+ payload = {
19
+ "foods": [
20
+ {
21
+ "name": "chicken_breast",
22
+ "calories": 165,
23
+ "protein": 31,
24
+ "fat": 3.6,
25
+ "carbohydrates": 0,
26
+ "estimated_g": 200,
27
+ },
28
+ {
29
+ "name": "rice_cooked",
30
+ "calories": 130,
31
+ "protein": 2.7,
32
+ "fat": 0.3,
33
+ "carbohydrates": 28,
34
+ "estimated_g": 200,
35
+ },
36
+ ],
37
+ "targets": {"calories": 700, "protein": 65, "fat": 8, "carbohydrates": 60},
38
+ }
39
+ result = json.loads(qf.handle_task(json.dumps(payload)))
40
+ assert "quantities" in result and "achieved" in result, f"Bad shape: {result}"
41
+
42
+ achieved = result["achieved"]
43
+ # The solver minimises weighted deviation across all 4 nutrients. With only
44
+ # two foods (chicken and rice) it cannot hit every target tightly — it will
45
+ # nail fat/carbs (constrained by rice) and trade off calories/protein.
46
+ # We assert it lands within 20% of every target, which is the realistic
47
+ # feasibility envelope for a 2-food problem.
48
+ for nut, target in [("calories", 700), ("protein", 65), ("fat", 8), ("carbohydrates", 60)]:
49
+ deviation = abs(achieved[nut] - target) / target
50
+ assert deviation < 0.20, f"{nut} achieved={achieved[nut]} target={target} dev={deviation:.2%}"
51
+
52
+
53
+ def test_invalid_payload_returns_error() -> None:
54
+ qf = QuantitiesFinder()
55
+ bad = {"foods": [{"name": "x"}], "targets": {}} # missing required keys
56
+ result = json.loads(qf.handle_task(json.dumps(bad)))
57
+ assert "error" in result, f"Expected an error key, got {result}"
58
+
59
+
60
+ def test_min_max_bounds_respected() -> None:
61
+ qf = QuantitiesFinder()
62
+ payload = {
63
+ "foods": [
64
+ {
65
+ "name": "egg",
66
+ "calories": 155,
67
+ "protein": 13,
68
+ "fat": 11,
69
+ "carbohydrates": 1.1,
70
+ "estimated_g": 100,
71
+ "min_g": 50,
72
+ "max_g": 120,
73
+ },
74
+ {
75
+ "name": "oats",
76
+ "calories": 389,
77
+ "protein": 17,
78
+ "fat": 7,
79
+ "carbohydrates": 66,
80
+ "estimated_g": 80,
81
+ "min_g": 30,
82
+ "max_g": 150,
83
+ },
84
+ ],
85
+ "targets": {"calories": 500, "protein": 25, "fat": 15, "carbohydrates": 50},
86
+ }
87
+ result = json.loads(qf.handle_task(json.dumps(payload)))
88
+ qty = result["quantities"]
89
+ assert 50 <= qty["egg"] <= 120
90
+ assert 30 <= qty["oats"] <= 150
tests/test_schemas.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Validate the Pydantic schemas that anchor every agent decision in Phase 1."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import pytest
6
+ from pydantic import ValidationError
7
+
8
+ from schemas import (
9
+ Calculations,
10
+ CoachDecision,
11
+ FinalPlan,
12
+ FoodItem,
13
+ MacroTargets,
14
+ MedicalAssessmentDecision,
15
+ MedicalAssessmentResult,
16
+ PlannerDecision,
17
+ ResponseStep,
18
+ ValidationDecision,
19
+ )
20
+
21
+
22
+ # ---- Coach -----------------------------------------------------------------
23
+ def test_coach_decision_call_agent() -> None:
24
+ d = CoachDecision(
25
+ observation="user wants a plan",
26
+ thought="need assessment first",
27
+ response_steps=[
28
+ ResponseStep(id=1, actor="MedicalAssessmentAgent", description="assess"),
29
+ ],
30
+ action="call_agent",
31
+ params={"agent_name": "MedicalAssessmentAgent", "task": "assess user"},
32
+ )
33
+ assert d.action == "call_agent"
34
+ assert d.params["agent_name"] == "MedicalAssessmentAgent"
35
+
36
+
37
+ def test_coach_decision_invalid_action_rejected() -> None:
38
+ with pytest.raises(ValidationError):
39
+ CoachDecision(
40
+ observation="x",
41
+ thought="x",
42
+ response_steps=[],
43
+ action="not_a_real_action", # type: ignore[arg-type]
44
+ params={},
45
+ )
46
+
47
+
48
+ # ---- Medical assessment ----------------------------------------------------
49
+ def test_medical_assessment_complete_round_trip() -> None:
50
+ payload = {
51
+ "medical_reasoning": "BMI within normal range; protein target raised for muscle gain",
52
+ "observation": "all fields present",
53
+ "risk_assessment_priorities": ["maintain micronutrient adequacy"],
54
+ "assessment_plan": [],
55
+ "action_type": "assessment_complete",
56
+ "result": {
57
+ "assessment_summary": "healthy male, hypertrophy goal",
58
+ "flags_to_set": [],
59
+ "recommendations": ["maintain hydration"],
60
+ "requires_professional_consultation": False,
61
+ "calculations": {
62
+ "BMI": 23.4,
63
+ "BMR": 1750,
64
+ "TDEE": 2700,
65
+ "daily_target_calories": 2900,
66
+ "macro_targets": {"protein_g": 180, "fat_g": 70, "carbohydrates_g": 360},
67
+ },
68
+ },
69
+ }
70
+ decision = MedicalAssessmentDecision.model_validate(payload)
71
+ assert decision.action_type == "assessment_complete"
72
+ assert isinstance(decision.result, MedicalAssessmentResult)
73
+ assert decision.result.calculations.macro_targets.protein_g == 180
74
+
75
+
76
+ def test_calculations_negative_values_rejected() -> None:
77
+ with pytest.raises(ValidationError):
78
+ Calculations(
79
+ BMI=-1, # negative not allowed
80
+ BMR=1700,
81
+ TDEE=2500,
82
+ daily_target_calories=2200,
83
+ macro_targets=MacroTargets(protein_g=120, fat_g=60, carbohydrates_g=250),
84
+ )
85
+
86
+
87
+ # ---- Planner ---------------------------------------------------------------
88
+ def test_planner_provide_plan_with_dict_final_plan() -> None:
89
+ decision = PlannerDecision(
90
+ observation="all data ready",
91
+ thought="returning final plan",
92
+ planning_steps=[],
93
+ action_type="provide_plan",
94
+ final_plan={"days": [{"breakfast": "oats"}], "trace": "Coach->Planner"},
95
+ )
96
+ assert decision.action_type == "provide_plan"
97
+ assert decision.final_plan is not None
98
+
99
+
100
+ def test_food_item_strict_grams_non_negative() -> None:
101
+ with pytest.raises(ValidationError):
102
+ FoodItem(
103
+ name="oats",
104
+ grams=-10,
105
+ calories=389,
106
+ protein_g=17,
107
+ fat_g=7,
108
+ carbohydrates_g=66,
109
+ )
110
+
111
+
112
+ def test_final_plan_minimal() -> None:
113
+ plan = FinalPlan(
114
+ days=[
115
+ [
116
+ FoodItem(
117
+ name="oats",
118
+ grams=80,
119
+ calories=311,
120
+ protein_g=14,
121
+ fat_g=5,
122
+ carbohydrates_g=53,
123
+ )
124
+ ]
125
+ ],
126
+ daily_totals={"calories": 311},
127
+ )
128
+ assert plan.days[0][0].name == "oats"
129
+
130
+
131
+ # ---- Validation (Phase 2 schemas, declared in Phase 1) ---------------------
132
+ def test_validation_decision_default_pass() -> None:
133
+ v = ValidationDecision(verdict="pass")
134
+ assert v.verdict == "pass"
135
+ assert v.issues == []
tests/test_settings.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Verify the new Pydantic ``Settings`` and the legacy ``config.X`` proxy."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import config
6
+ from config import get_settings, reset_settings, set_settings
7
+
8
+
9
+ def test_defaults() -> None:
10
+ s = get_settings()
11
+ assert s.debug_mode is False
12
+ assert s.debug_level == "full"
13
+ assert s.enable_rate_limiting is True
14
+ assert s.log_dir is None
15
+ assert s.debug_scopes == {"agents": ["all"], "tools": ["all"]}
16
+
17
+
18
+ def test_set_settings_pydantic_names() -> None:
19
+ set_settings(debug_mode=True, log_dir="/tmp/x")
20
+ s = get_settings()
21
+ assert s.debug_mode is True
22
+ assert s.log_dir == "/tmp/x"
23
+
24
+
25
+ def test_set_settings_legacy_names() -> None:
26
+ set_settings(DEBUG_MODE=True, LOG_DIR="/tmp/y", DEBUG_LEVEL="output")
27
+ s = get_settings()
28
+ assert s.debug_mode is True
29
+ assert s.log_dir == "/tmp/y"
30
+ assert s.debug_level == "output"
31
+
32
+
33
+ def test_pep562_legacy_reads() -> None:
34
+ """Existing code that does ``config.DEBUG_MODE`` still works."""
35
+ set_settings(debug_mode=True, debug_scopes={"agents": ["CoachAgent"], "tools": ["all"]})
36
+ assert config.DEBUG_MODE is True
37
+ assert config.DEBUG_SCOPES == {"agents": ["CoachAgent"], "tools": ["all"]}
38
+
39
+
40
+ def test_unknown_attr_raises() -> None:
41
+ try:
42
+ _ = config.NOT_A_THING # type: ignore[attr-defined]
43
+ except AttributeError as e:
44
+ assert "NOT_A_THING" in str(e)
45
+ else:
46
+ raise AssertionError("Expected AttributeError")
47
+
48
+
49
+ def test_reset_settings_round_trip() -> None:
50
+ set_settings(debug_mode=True)
51
+ assert get_settings().debug_mode is True
52
+ reset_settings()
53
+ assert get_settings().debug_mode is False # back to default
tests/test_smoke.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Top-level smoke tests: every module must import cleanly outside Colab."""
2
+
3
+ from __future__ import annotations
4
+
5
+
6
+ def test_imports_work_outside_colab() -> None:
7
+ """The Phase 0 cleanup removed ``from google.colab import userdata``;
8
+ confirm every module can be imported in a plain Python process."""
9
+ import agents # noqa: F401
10
+ import config # noqa: F401
11
+ import logging_setup # noqa: F401
12
+ import nutritionmas # noqa: F401
13
+ import state # noqa: F401
14
+ import tools # noqa: F401
15
+ import utils # noqa: F401
16
+ import workflow # noqa: F401
17
+
18
+
19
+ def test_only_one_geminillm_class_in_utils() -> None:
20
+ """Phase 0 deleted the duplicate ``GeminiLLM`` definition. Make sure it
21
+ doesn't sneak back."""
22
+ import inspect
23
+
24
+ import utils
25
+
26
+ geminis = [
27
+ cls
28
+ for name, cls in inspect.getmembers(utils, inspect.isclass)
29
+ if name == "GeminiLLM" and cls.__module__ == "utils"
30
+ ]
31
+ assert len(geminis) == 1
32
+
33
+
34
+ def test_initialize_empty_memory_shape() -> None:
35
+ from state import initialize_empty_memory
36
+
37
+ mem = initialize_empty_memory()
38
+ assert set(mem.keys()) == {"user_profile", "medical_history", "flags_and_assessments", "plans"}
39
+ assert all(v == {} for v in mem.values())
40
+
41
+
42
+ def test_default_model_configs_present() -> None:
43
+ """Model topology is a contract the rest of the system depends on.
44
+ Phase 2 adds 'validation_agent' (Gemini Flash; cheap critic loop)."""
45
+ from nutritionmas import DEFAULT_MODEL_CONFIGS
46
+
47
+ expected = {
48
+ "main",
49
+ "agents_llm",
50
+ "tools_llm",
51
+ "planner_agent",
52
+ "validation_agent",
53
+ "user_simulator",
54
+ }
55
+ assert set(DEFAULT_MODEL_CONFIGS.keys()) == expected
56
+
57
+
58
+ def test_create_llm_instances_requires_keys() -> None:
59
+ import pytest
60
+
61
+ from nutritionmas import create_llm_instances
62
+
63
+ with pytest.raises(ValueError, match="At least one API key"):
64
+ create_llm_instances([])
tests/test_typed_agents.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """End-to-end-ish tests of the typed agent path with MockLLM.
2
+
3
+ These don't hit Gemini; they verify that an agent which received a typed
4
+ ``CoachDecision`` / ``MedicalAssessmentDecision`` / ``PlannerDecision`` from
5
+ its LLM produces the expected state mutations.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Any, Dict
11
+
12
+ import pytest
13
+
14
+ from agents import CoachAgent, MedicalAssessmentAgent, PlannerAgent
15
+ from schemas import (
16
+ Calculations,
17
+ CoachDecision,
18
+ MacroTargets,
19
+ MedicalAssessmentDecision,
20
+ MedicalAssessmentResult,
21
+ PlannerDecision,
22
+ )
23
+ from state import initialize_empty_memory
24
+
25
+
26
+ # ---- Coach -----------------------------------------------------------------
27
+ def test_coach_emits_call_agent_action(mock_llm_factory) -> None:
28
+ canned = CoachDecision(
29
+ observation="needs assessment",
30
+ thought="route to medical",
31
+ response_steps=[],
32
+ action="call_agent",
33
+ params={"agent_name": "MedicalAssessmentAgent", "task": "assess"},
34
+ )
35
+ coach = CoachAgent(mock_llm_factory([canned]))
36
+ state: Dict[str, Any] = {
37
+ "memory": initialize_empty_memory(),
38
+ "user_question": "make me a plan",
39
+ "conversation_history": [{"role": "user", "content": "make me a plan"}],
40
+ "current_action": None,
41
+ "agent_result": None,
42
+ "num_turns": 0,
43
+ "max_turns": 10,
44
+ "previous_actions": [],
45
+ "response_steps": [],
46
+ }
47
+ out = coach.handle_task(state)
48
+ assert out["current_action"]["action"] == "call_agent"
49
+ assert out["current_action"]["params"]["agent_name"] == "MedicalAssessmentAgent"
50
+ assert out["num_turns"] == 1
51
+
52
+
53
+ def test_coach_falls_back_when_decision_unparseable(mock_llm_factory) -> None:
54
+ coach = CoachAgent(mock_llm_factory(["{not even close to JSON"]))
55
+ state: Dict[str, Any] = {
56
+ "memory": initialize_empty_memory(),
57
+ "user_question": "anything",
58
+ "conversation_history": [],
59
+ "current_action": None,
60
+ "agent_result": None,
61
+ "num_turns": 0,
62
+ "max_turns": 10,
63
+ "previous_actions": [],
64
+ "response_steps": [],
65
+ }
66
+ out = coach.handle_task(state)
67
+ # Coach injects a compose_response with _parse_error so the workflow can short-circuit
68
+ assert out["current_action"]["action"] == "compose_response"
69
+ assert out["current_action"].get("_parse_error") is True
70
+
71
+
72
+ # ---- Medical ---------------------------------------------------------------
73
+ def test_medical_assessment_complete_writes_memory(mock_llm_factory) -> None:
74
+ """A single assessment_complete decision should land in memory partition."""
75
+ result = MedicalAssessmentResult(
76
+ assessment_summary="healthy adult",
77
+ flags_to_set=["maintenance"],
78
+ recommendations=["balanced diet"],
79
+ requires_professional_consultation=False,
80
+ calculations=Calculations(
81
+ BMI=22.0,
82
+ BMR=1600,
83
+ TDEE=2400,
84
+ daily_target_calories=2400,
85
+ macro_targets=MacroTargets(protein_g=150, fat_g=70, carbohydrates_g=300),
86
+ ),
87
+ evidence_sources=["who.int"],
88
+ trace="Medical agent ran one iteration",
89
+ )
90
+ canned = MedicalAssessmentDecision(
91
+ medical_reasoning="single-shot",
92
+ observation="all data present",
93
+ risk_assessment_priorities=["maintenance"],
94
+ assessment_plan=[],
95
+ action_type="assessment_complete",
96
+ result=result,
97
+ )
98
+ # Need a stub for the tools (won't be called in single-iteration assessment_complete)
99
+ class _StubTool:
100
+ def handle_task(self, _: str) -> str:
101
+ return ""
102
+
103
+ agent = MedicalAssessmentAgent(mock_llm_factory([canned]), _StubTool(), _StubTool())
104
+ memory = initialize_empty_memory()
105
+ memory["user_profile"] = {
106
+ "age": 30,
107
+ "sex": "male",
108
+ "height": 180,
109
+ "weight": 75,
110
+ "activity_level": "moderate",
111
+ "allergies": [],
112
+ "medications": [],
113
+ }
114
+ summary = agent.handle_task("assess this user", memory)
115
+
116
+ assert summary == "healthy adult"
117
+ fa = memory["flags_and_assessments"]
118
+ assert fa["assessment_status"] == "assessment_complete"
119
+ assert fa["calculations"]["macro_targets"]["protein_g"] == 150
120
+
121
+
122
+ def test_medical_ask_user_returns_field_list(mock_llm_factory) -> None:
123
+ canned = MedicalAssessmentDecision(
124
+ medical_reasoning="missing weight + height",
125
+ observation="incomplete",
126
+ risk_assessment_priorities=[],
127
+ assessment_plan=[],
128
+ action_type="ask_user",
129
+ fields=["weight", "height"],
130
+ )
131
+
132
+ class _StubTool:
133
+ def handle_task(self, _: str) -> str:
134
+ return ""
135
+
136
+ agent = MedicalAssessmentAgent(mock_llm_factory([canned]), _StubTool(), _StubTool())
137
+ out = agent.handle_task("assess", initialize_empty_memory())
138
+ assert "weight" in out and "height" in out
139
+
140
+
141
+ # ---- Planner ---------------------------------------------------------------
142
+ def test_planner_provide_plan_stores_to_memory(mock_llm_factory) -> None:
143
+ canned = PlannerDecision(
144
+ observation="ready",
145
+ thought="finalising",
146
+ planning_steps=[],
147
+ action_type="provide_plan",
148
+ final_plan={
149
+ "days": [{"breakfast": "oats", "lunch": "chicken+rice"}],
150
+ "trace": "Planner one-shot",
151
+ },
152
+ )
153
+
154
+ class _StubTool:
155
+ def handle_task(self, _: str) -> str:
156
+ return ""
157
+
158
+ agent = PlannerAgent(mock_llm_factory([canned]), _StubTool(), _StubTool(), _StubTool())
159
+ memory = initialize_empty_memory()
160
+ memory["flags_and_assessments"] = {"assessment_status": "assessment_complete"}
161
+ out = agent.handle_task("make me a one-day plan", memory)
162
+
163
+ assert "trace" in out
164
+ assert memory["plans"]["current_plan"]["days"][0]["breakfast"] == "oats"
165
+ assert "plan_timestamp" in memory["plans"]
166
+
167
+
168
+ def test_planner_error_payload_short_circuits(mock_llm_factory) -> None:
169
+ canned = PlannerDecision(
170
+ observation="missing assessment",
171
+ thought="precondition violated",
172
+ planning_steps=[],
173
+ action_type="provide_plan",
174
+ final_plan={"error": "flags_and_assessments empty; run MedicalAssessmentAgent first"},
175
+ )
176
+
177
+ class _StubTool:
178
+ def handle_task(self, _: str) -> str:
179
+ return ""
180
+
181
+ agent = PlannerAgent(mock_llm_factory([canned]), _StubTool(), _StubTool(), _StubTool())
182
+ out = agent.handle_task("make a plan", initialize_empty_memory())
183
+ assert "error" in out
184
+ assert "MedicalAssessmentAgent" in out
tests/test_validation_agent.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the ValidationAgent critic loop.
2
+
3
+ Most of the value lives in the deterministic checks — they are pure code,
4
+ require no LLM, and can be exercised cheaply across edge cases.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import Any, Dict
10
+
11
+ from schemas import ValidationDecision
12
+ from state import initialize_empty_memory
13
+ from validation import ValidationAgent
14
+
15
+
16
+ # A minimal deterministic-only stub LLM for the cases that should never need
17
+ # the LLM layer (allergy violation -> verdict "reject" short-circuits LLM).
18
+ class _NeverCalledLLM:
19
+ def call_typed(self, *args: Any, **kwargs: Any):
20
+ raise AssertionError("LLM should not be called when deterministic check rejects.")
21
+
22
+
23
+ # ---------------------------------------------------------------------------
24
+ def _build_memory(
25
+ *,
26
+ allergies=None,
27
+ dislikes="",
28
+ target_calories=2000,
29
+ macros=(150, 70, 200),
30
+ ) -> Dict[str, Any]:
31
+ memory = initialize_empty_memory()
32
+ memory["user_profile"] = {
33
+ "name": "Test",
34
+ "country": "Egypt",
35
+ "allergies": allergies or [],
36
+ "food_dislikes": dislikes,
37
+ }
38
+ memory["flags_and_assessments"] = {
39
+ "assessment_status": "assessment_complete",
40
+ "calculations": {
41
+ "BMI": 22,
42
+ "BMR": 1600,
43
+ "TDEE": 2000,
44
+ "daily_target_calories": target_calories,
45
+ "macro_targets": {
46
+ "protein_g": macros[0],
47
+ "fat_g": macros[1],
48
+ "carbohydrates_g": macros[2],
49
+ },
50
+ },
51
+ "flags": [],
52
+ "recommendations": [],
53
+ "requires_professional_consultation": False,
54
+ }
55
+ return memory
56
+
57
+
58
+ def _set_plan(memory: Dict[str, Any], plan: Dict[str, Any]) -> None:
59
+ memory.setdefault("plans", {})["current_plan"] = plan
60
+
61
+
62
+ # ---- deterministic-only paths ----------------------------------------------
63
+ def test_passes_when_plan_within_tolerances(mock_llm_factory) -> None:
64
+ memory = _build_memory(target_calories=2000, macros=(150, 70, 200))
65
+ _set_plan(
66
+ memory,
67
+ {
68
+ "days": [
69
+ {
70
+ "name": "chicken_breast",
71
+ "calories": 1000,
72
+ "protein_g": 100,
73
+ "fat_g": 30,
74
+ "carbohydrates_g": 100,
75
+ },
76
+ {
77
+ "name": "rice",
78
+ "calories": 1000,
79
+ "protein_g": 50,
80
+ "fat_g": 40,
81
+ "carbohydrates_g": 100,
82
+ },
83
+ ],
84
+ },
85
+ )
86
+ # Pre-supply a "no-issues" LLM verdict so the LLM layer is happy.
87
+ llm = mock_llm_factory([ValidationDecision(verdict="pass", issues=[])])
88
+ agent = ValidationAgent(llm)
89
+ out = agent.handle_task("validate plan", memory)
90
+
91
+ decision = ValidationDecision.model_validate_json(out)
92
+ assert decision.verdict == "pass"
93
+ assert decision.issues == []
94
+ assert memory["flags_and_assessments"]["last_validation"]["verdict"] == "pass"
95
+
96
+
97
+ def test_allergy_violation_rejected_without_llm() -> None:
98
+ memory = _build_memory(allergies=["peanut"])
99
+ _set_plan(
100
+ memory,
101
+ {
102
+ "days": [
103
+ {
104
+ "name": "peanut butter sandwich",
105
+ "calories": 400,
106
+ "protein_g": 15,
107
+ "fat_g": 20,
108
+ "carbohydrates_g": 40,
109
+ }
110
+ ]
111
+ },
112
+ )
113
+ agent = ValidationAgent(_NeverCalledLLM())
114
+ out = agent.handle_task("validate", memory)
115
+ decision = ValidationDecision.model_validate_json(out)
116
+
117
+ assert decision.verdict == "reject"
118
+ assert any(i.code == "allergy_violation" for i in decision.issues)
119
+
120
+
121
+ def test_calorie_deviation_triggers_revise(mock_llm_factory) -> None:
122
+ memory = _build_memory(target_calories=2000)
123
+ _set_plan(
124
+ memory,
125
+ {
126
+ "days": [
127
+ {
128
+ "name": "tiny salad",
129
+ "calories": 800, # way under 2000 target -> 60% deviation
130
+ "protein_g": 30,
131
+ "fat_g": 20,
132
+ "carbohydrates_g": 60,
133
+ }
134
+ ]
135
+ },
136
+ )
137
+ llm = mock_llm_factory([ValidationDecision(verdict="pass", issues=[])])
138
+ agent = ValidationAgent(llm)
139
+ out = agent.handle_task("validate", memory)
140
+ decision = ValidationDecision.model_validate_json(out)
141
+
142
+ assert decision.verdict == "revise"
143
+ assert any(i.code == "calorie_deviation" for i in decision.issues)
144
+
145
+
146
+ def test_disliked_food_only_low_severity_still_passes(mock_llm_factory) -> None:
147
+ memory = _build_memory(dislikes="okra", target_calories=2000)
148
+ _set_plan(
149
+ memory,
150
+ {
151
+ "days": [
152
+ {
153
+ "name": "okra stew",
154
+ "calories": 2000,
155
+ "protein_g": 150,
156
+ "fat_g": 70,
157
+ "carbohydrates_g": 200,
158
+ }
159
+ ]
160
+ },
161
+ )
162
+ llm = mock_llm_factory([ValidationDecision(verdict="pass", issues=[])])
163
+ agent = ValidationAgent(llm)
164
+ decision = ValidationDecision.model_validate_json(agent.handle_task("validate", memory))
165
+
166
+ # Low-severity issues alone don't escalate the verdict.
167
+ assert decision.verdict == "pass"
168
+ assert any(i.code == "disliked_food" and i.severity == "low" for i in decision.issues)
169
+
170
+
171
+ def test_missing_plan_rejected() -> None:
172
+ memory = _build_memory()
173
+ # Intentionally no current_plan set
174
+ agent = ValidationAgent(_NeverCalledLLM())
175
+ decision = ValidationDecision.model_validate_json(agent.handle_task("validate", memory))
176
+ assert decision.verdict == "reject"
177
+ assert any(i.code == "missing_plan" for i in decision.issues)
178
+
179
+
180
+ def test_requires_human_review_propagates(mock_llm_factory) -> None:
181
+ memory = _build_memory()
182
+ memory["flags_and_assessments"]["requires_professional_consultation"] = True
183
+ _set_plan(
184
+ memory,
185
+ {
186
+ "days": [
187
+ {
188
+ "name": "balanced meal",
189
+ "calories": 2000,
190
+ "protein_g": 150,
191
+ "fat_g": 70,
192
+ "carbohydrates_g": 200,
193
+ }
194
+ ]
195
+ },
196
+ )
197
+ llm = mock_llm_factory([ValidationDecision(verdict="pass", issues=[])])
198
+ agent = ValidationAgent(llm)
199
+ decision = ValidationDecision.model_validate_json(agent.handle_task("validate", memory))
200
+ # Even on a clean pass, HITL flag must propagate from the assessment.
201
+ assert decision.requires_human_review is True
tools.py CHANGED
@@ -1,29 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import re
2
  import subprocess
3
  import tempfile
4
- import time
5
- from time import sleep
6
- import os
7
  from datetime import datetime
8
- from utils import save_to_json, should_debug
 
 
9
  from ddgs import DDGS
10
- import config
11
- import json
12
- from pulp import *
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
 
 
 
 
14
  class QuantitiesFinder:
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- def __init__(self):
17
  pass
18
 
19
  @staticmethod
20
- def _round(v):
21
  if v is None:
22
  return 0.0
23
  return round(float(v), 2)
24
 
25
  @staticmethod
26
- def _round_structure(obj):
27
  if isinstance(obj, dict):
28
  return {k: QuantitiesFinder._round_structure(v) for k, v in obj.items()}
29
  if isinstance(obj, list):
@@ -33,303 +79,281 @@ class QuantitiesFinder:
33
  return obj
34
 
35
  def handle_task(self, task: str) -> str:
36
- print(f"\n📊 ENHANCED QUANTITIES FINDER (V3) TOOL STARTED")
37
- # --- Define Weights ---
38
- W_NUTRITION = 1.0 # Priority 1: Hitting daily totals
39
- W_ESTIMATE_DEFAULT = 0.1 # Priority 2: Default "soft" estimate penalty
40
 
41
  try:
42
  data = json.loads(task)
43
  foods = data["foods"]
44
  targets = data["targets"]
45
 
46
- # --- 1. VALIDATION ---
47
  required_nutrients = ["calories", "protein", "fat", "carbohydrates"]
48
  for food in foods:
49
- if not all(
50
- key in food
51
- for key in ["name"] + required_nutrients + ["estimated_g"]
52
- ):
53
  raise ValueError(
54
  "Each food must have name, calories, protein, fat, carbohydrates, and estimated_g."
55
  )
56
-
57
  if not all(key in targets for key in required_nutrients):
58
- raise ValueError(
59
- "Targets must include calories, protein, fat, carbohydrates."
60
- )
61
 
62
  prob = LpProblem("Nutrient_Optimization", LpMinimize)
63
 
64
- # --- 2. VARIABLES (Unchanged from V2) ---
65
  g = {}
66
  for food in foods:
67
- food_name = food["name"]
68
- min_bound = food.get("min_g", 0)
69
- max_bound = food.get("max_g")
70
- g[food_name] = LpVariable(
71
- f"g_{food_name}", lowBound=min_bound, upBound=max_bound
72
- )
73
-
74
- # --- 3. NUTRITION DEVIATIONS (Unchanged) ---
75
- nutrients = required_nutrients
76
- totals = {}
77
- for nut in nutrients:
78
- totals[nut] = lpSum(
79
- (g[food["name"]] / 100) * food[nut] for food in foods
80
  )
81
 
82
- d_pos = {nut: LpVariable(f"d_pos_{nut}", lowBound=0) for nut in nutrients}
83
- d_neg = {nut: LpVariable(f"d_neg_{nut}", lowBound=0) for nut in nutrients}
84
-
85
- for nut in nutrients:
 
 
 
86
  prob += totals[nut] - targets[nut] <= d_pos[nut]
87
  prob += targets[nut] - totals[nut] <= d_neg[nut]
88
 
89
- # --- 3.5 MEAL-LEVEL CONSTRAINTS (Unchanged from V2) ---
90
- meal_constraints = data.get("meal_constraints", [])
91
- if meal_constraints:
92
- print("Applying meal-level constraints...")
93
- for constraint in meal_constraints:
94
- group_name = constraint.get("group_name")
95
- if not group_name:
96
- continue
97
- group_foods = [
98
- f for f in foods if f.get("meal_group") == group_name
99
- ]
100
- if not group_foods:
101
- print(f"Warning: No foods found for meal_group '{group_name}'")
102
- continue
103
-
104
- for nut in nutrients:
105
- max_val = constraint.get(f"max_{nut}")
106
- if max_val is not None:
107
- meal_total = lpSum(
108
- (g[f["name"]] / 100) * f[nut] for f in group_foods
109
- )
110
- prob += (
111
- meal_total <= max_val,
112
- f"Meal_{group_name}_max_{nut}",
113
- )
114
- print(f" -> Constraint: {group_name} max {nut} <= {max_val}")
115
-
116
- min_val = constraint.get(f"min_{nut}")
117
- if min_val is not None:
118
- meal_total = lpSum(
119
- (g[f["name"]] / 100) * f[nut] for f in group_foods
120
- )
121
- prob += (
122
- meal_total >= min_val,
123
- f"Meal_{group_name}_min_{nut}",
124
- )
125
- print(f" -> Constraint: {group_name} min {nut} >= {min_val}")
126
-
127
- # --- 4. ESTIMATE DEVIATIONS (ENHANCED) ---
128
- # This section now reads a per-item 'estimate_weight'
129
- dev_est_pos = {
130
- food["name"]: LpVariable(f"dev_est_pos_{food['name']}", lowBound=0)
131
- for food in foods
132
- }
133
- dev_est_neg = {
134
- food["name"]: LpVariable(f"dev_est_neg_{food['name']}", lowBound=0)
135
- for food in foods
136
- }
137
-
138
  for food in foods:
139
- food_name = food["name"]
140
- estimate = food["estimated_g"]
141
- prob += g[food_name] - estimate <= dev_est_pos[food_name]
142
- prob += estimate - g[food_name] <= dev_est_neg[food_name]
143
 
144
- # --- 5. OBJECTIVE FUNCTION (ENHANCED) ---
145
- # Goal 1: (Unchanged)
146
  nutrition_objective = lpSum(
147
- (d_pos[nut] + d_neg[nut]) / max(targets[nut], 1) for nut in nutrients
148
  )
149
-
150
- # Goal 2: (ENHANCED)
151
- # Now uses the per-item 'estimate_weight' if provided,
152
- # otherwise, it falls back to the default.
153
  estimate_objective = lpSum(
154
- (
155
- f.get("estimate_weight", W_ESTIMATE_DEFAULT)
156
- * (dev_est_pos[f["name"]] + dev_est_neg[f["name"]])
157
- )
158
  / max(f["estimated_g"], 1)
159
  for f in foods
160
  if f["estimated_g"] > 0
161
  )
162
-
163
- # Combined objective
164
  prob += (W_NUTRITION * nutrition_objective) + estimate_objective
165
 
166
- # --- 6. SOLVE & RETURN (Unchanged) ---
167
  prob.solve(PULP_CBC_CMD(msg=0))
168
-
169
  if LpStatus[prob.status] != "Optimal":
170
  raise ValueError(
171
  "No optimal solution found (problem may be infeasible). Check your targets and constraints."
172
  )
173
 
174
  quantities = {name: value(g[name]) for name in g}
175
- achieved = {nut: value(totals[nut]) for nut in nutrients}
176
-
177
- result = {"quantities": quantities, "achieved": achieved}
178
- result = QuantitiesFinder._round_structure(result)
179
-
180
- print(f"Solution Status: {LpStatus[prob.status]}")
181
- print(f"Quantities (g): {json.dumps(result['quantities'], indent=2)}")
182
- print(
183
- f"Achieved Nutrition (around): {json.dumps(result['achieved'], indent=2)}"
184
  )
185
- print(
186
- f"Target Nutrition: {json.dumps(QuantitiesFinder._round_structure(targets), indent=2)}"
 
187
  )
188
-
189
- print(f"\n📊 QUANTITIES FINDER COMPLETED")
190
  return json.dumps(result)
191
 
192
- except Exception as e:
193
- error_result = {"error": str(e)}
194
- print(f"QuantitiesFinder Error: {str(e)}")
195
- return json.dumps(error_result)
196
 
 
 
 
197
  class ComputationTool:
198
  def __init__(self, llm_instance):
199
  self.llm = llm_instance
200
 
201
  def handle_task(self, task_description: str) -> str:
202
- print(f"\n🤖 COMPUTATION TOOL STARTED")
203
- instruction = "You are a Python coding assistant. Generate only the Python code required to perform the given task. Do not forget to print the result. Do not add explanations."
 
 
 
 
204
  prompt = f"{instruction}\n\nTask: {task_description}\n\nCode:"
205
 
206
- if should_debug('tools', 'ComputationTool') and config.DEBUG_LEVEL == 'full':
207
- print(f"Computation Tool Prompt:\n{prompt}")
208
  code_response = self.llm(prompt)[0]
209
- if should_debug('tools', 'ComputationTool'):
210
- print(f"Computation Tool Response:\n{code_response}")
211
 
212
- # Try to extract code from markdown blocks first, then use raw response
213
- code_match = re.search(r"```python\n(.*?)\n```", code_response, re.DOTALL)
214
- if not code_match:
215
- code_match = re.search(r"```\n(.*?)\n```", code_response, re.DOTALL)
216
-
217
- if code_match:
218
- code_to_execute = code_match.group(1).strip()
219
- # print(f"Extracted code from markdown blocks")
220
- else:
221
- code_to_execute = code_response.strip()
222
- # print(f"Using raw response as code")
223
 
224
  execution_result = execute_python_code_raw(code_to_execute)
225
 
226
- log_data = {
227
- "instruction": instruction,
228
- "input": task_description,
229
- "output": code_to_execute,
230
- "execution_result": execution_result,
231
- "timestamp": datetime.now().isoformat()
232
- }
233
- save_to_json(log_data, f'computation_tool_{datetime.now().isoformat()}.json', subdirectory='ComputationTool')
234
-
235
- print(f"🤖 COMPUTATION COMPLETED\n{execution_result}")
 
 
236
  return execution_result
237
 
238
 
 
 
 
239
  class WebSearchTool:
240
  def __init__(self, llm_instance):
241
  self.llm = llm_instance
242
 
243
  def handle_task(self, research_task: str) -> str:
244
- print(f"\n🌐 WEB SEARCH TOOL STARTED")
 
245
 
246
  try:
247
  task_data = json.loads(research_task)
248
- if isinstance(task_data, dict) and 'queries' in task_data and isinstance(task_data['queries'], list):
249
- print("JSON query list detected. Converting to single text task.")
250
- research_question = " ".join(task_data['queries'])
 
 
 
 
251
  else:
252
- print("Single question mode detected (non-query JSON). Generating queries.")
253
  research_question = research_task
254
  except (json.JSONDecodeError, TypeError):
255
- print("Single question mode detected (plain text). Generating queries.")
256
  research_question = research_task
257
 
258
- query_instruction = "Formulate concise search queries for DuckDuckGo based on the given question. Output only the queries, one per line."
 
 
 
259
  query_prompt = f"{query_instruction}\n\nQuestion: {research_question}\n\nQueries:"
260
 
261
- if should_debug('tools', 'WebSearchTool') and config.DEBUG_LEVEL == 'full':
262
- print(f"Web Search Query Prompt:\n{query_prompt}")
263
  search_queries_text = self.llm(query_prompt)[0]
264
- if should_debug('tools', 'WebSearchTool'):
265
- print(f"Web Search Query Response:\n{search_queries_text}")
266
 
267
- search_queries = [q.strip() for q in search_queries_text.split('\n') if q.strip()] or [research_question]
268
- if should_debug('tools', 'WebSearchTool'):
269
- print(f"Parsed queries: {search_queries}")
 
 
270
 
271
  all_raw_results = []
272
- for i, query in enumerate(search_queries):
273
  raw_results = search_web_raw(query, num_results=10)
274
- print(f"Search results:\n{raw_results[:200]}...")
275
  all_raw_results.append(f"Results for '{query}':\n{raw_results}")
276
  sleep(1)
277
 
278
  raw_search_output = "\n\n".join(all_raw_results)
279
-
280
- synthesis_instruction = f"""Synthesize a concise answer to:
281
- Question: {research_question}
282
- Based on:
283
- ---
284
- {raw_search_output}
285
- ---
286
- """
287
-
288
- if should_debug('tools', 'WebSearchTool') and config.DEBUG_LEVEL == 'full':
289
- print(f"Web Search Synthesis Instruction:\n{synthesis_instruction}")
290
  synthesized_answer = self.llm(synthesis_instruction)[0]
291
- if should_debug('tools', 'WebSearchTool'):
292
- print(f"Web Search Synthesis Response:\n{synthesized_answer}")
293
 
294
  timestamp = datetime.now().isoformat()
295
- save_to_json({
296
- "instruction": query_instruction,
297
- "input": research_question,
298
- "output": search_queries_text,
299
- "timestamp": timestamp
300
- }, f'web_search_tool_queries_{timestamp}.json', subdirectory='WebSearchTool')
301
-
302
- save_to_json({
303
- "instruction": synthesis_instruction,
304
- "input": raw_search_output,
305
- "output": synthesized_answer,
306
- "timestamp": timestamp
307
- }, f'web_search_tool_synthesis_{timestamp}.json', subdirectory='WebSearchTool')
308
-
309
- print(f"\n🌐 WEB SEARCH TOOL Result:\n{synthesized_answer}\n")
 
 
 
 
 
 
310
  return synthesized_answer
311
 
 
 
 
 
312
  def execute_python_code_raw(code_string: str) -> str:
313
- if should_debug('tools', 'ComputationTool') and config.DEBUG_LEVEL == 'full':
314
- print(f"🐍 Executing Code (raw):\n{code_string}")
 
 
 
315
  try:
316
  with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as tmp_script:
317
  tmp_script.write(code_string)
318
  script_path = tmp_script.name
319
- process = subprocess.run(["python", script_path], capture_output=True, text=True, timeout=30)
320
- os.remove(script_path)
 
 
 
 
321
  if process.returncode == 0:
322
  return f"Output:\n{process.stdout if process.stdout else 'Code executed successfully.'}"
323
- else:
324
- return f"Error:\n{process.stderr}"
325
- except Exception as e:
326
  return f"Execution Exception: {str(e)}"
327
  finally:
328
- if os.path.exists(script_path):
329
  os.remove(script_path)
330
 
 
331
  def search_web_raw(query: str, num_results: int = 3) -> str:
332
- print(f"🌐 Searching Web (raw) for: {query}")
333
  max_retries = 3
334
  for attempt in range(max_retries):
335
  try:
@@ -337,12 +361,13 @@ def search_web_raw(query: str, num_results: int = 3) -> str:
337
  results = list(ddgs.text(query, max_results=num_results, timelimit="m"))
338
  if not results:
339
  return "No search results found."
340
- return "\n".join([f"Title: {r.get('title')}\nURL: {r.get('href')}\nSnippet: {r.get('body')}" for r in results])
341
- except Exception as e:
 
 
 
342
  if attempt < max_retries - 1:
343
  sleep(1)
344
  continue
345
  return f"Search Exception after {max_retries} attempts: {str(e)}"
346
-
347
-
348
-
 
1
+ """Tools layer.
2
+
3
+ Phase 1 cleanup notes:
4
+
5
+ * Replaced ``print`` with namespaced loggers so user-mode emoji output is
6
+ filterable and the API/UI in Phase 7 can subscribe to it as events.
7
+ * Reads ``settings.debug_mode`` via :func:`config.get_settings` instead of the
8
+ legacy module-level globals.
9
+
10
+ The :class:`ComputationTool` still shells out to ``subprocess.run(['python', ...])``
11
+ - **this is a known security issue**, fixed in Phase 4 by either deterministic
12
+ formula functions or a ``RestrictedPython`` sandbox.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import json
18
+ import os
19
  import re
20
  import subprocess
21
  import tempfile
 
 
 
22
  from datetime import datetime
23
+ from time import sleep
24
+ from typing import Any
25
+
26
  from ddgs import DDGS
27
+ from pulp import (
28
+ LpMinimize,
29
+ LpProblem,
30
+ LpStatus,
31
+ LpVariable,
32
+ PULP_CBC_CMD,
33
+ lpSum,
34
+ value,
35
+ )
36
+
37
+ from config import get_settings
38
+ from logging_setup import get_logger
39
+ from utils import save_to_json, should_debug
40
+
41
+ _qf_logger = get_logger("tools.quantities_finder")
42
+ _comp_logger = get_logger("tools.computation")
43
+ _web_logger = get_logger("tools.web_search")
44
 
45
+
46
+ # ---------------------------------------------------------------------------
47
+ # QuantitiesFinder (PuLP LP solver)
48
+ # ---------------------------------------------------------------------------
49
  class QuantitiesFinder:
50
+ """Linear-program solver that turns an LLM-drafted plan into precise grams.
51
+
52
+ The schema is:
53
+
54
+ {
55
+ "foods": [{name, calories, protein, fat, carbohydrates,
56
+ estimated_g, [min_g, max_g, meal_group, estimate_weight]}, ...],
57
+ "targets": {calories, protein, fat, carbohydrates},
58
+ "meal_constraints": [{group_name, [max_<nut>], [min_<nut>]}, ...] # optional
59
+ }
60
+ """
61
 
62
+ def __init__(self) -> None:
63
  pass
64
 
65
  @staticmethod
66
+ def _round(v: Any) -> float:
67
  if v is None:
68
  return 0.0
69
  return round(float(v), 2)
70
 
71
  @staticmethod
72
+ def _round_structure(obj: Any) -> Any:
73
  if isinstance(obj, dict):
74
  return {k: QuantitiesFinder._round_structure(v) for k, v in obj.items()}
75
  if isinstance(obj, list):
 
79
  return obj
80
 
81
  def handle_task(self, task: str) -> str:
82
+ _qf_logger.info("\n📊 ENHANCED QUANTITIES FINDER (V3) TOOL STARTED")
83
+ # Priority 1: hit daily totals; Priority 2: stay close to per-item estimates.
84
+ W_NUTRITION = 1.0
85
+ W_ESTIMATE_DEFAULT = 0.1
86
 
87
  try:
88
  data = json.loads(task)
89
  foods = data["foods"]
90
  targets = data["targets"]
91
 
92
+ # 1. Validation
93
  required_nutrients = ["calories", "protein", "fat", "carbohydrates"]
94
  for food in foods:
95
+ if not all(key in food for key in ["name"] + required_nutrients + ["estimated_g"]):
 
 
 
96
  raise ValueError(
97
  "Each food must have name, calories, protein, fat, carbohydrates, and estimated_g."
98
  )
 
99
  if not all(key in targets for key in required_nutrients):
100
+ raise ValueError("Targets must include calories, protein, fat, carbohydrates.")
 
 
101
 
102
  prob = LpProblem("Nutrient_Optimization", LpMinimize)
103
 
104
+ # 2. Variables
105
  g = {}
106
  for food in foods:
107
+ g[food["name"]] = LpVariable(
108
+ f"g_{food['name']}",
109
+ lowBound=food.get("min_g", 0),
110
+ upBound=food.get("max_g"),
 
 
 
 
 
 
 
 
 
111
  )
112
 
113
+ # 3. Nutrition deviations
114
+ totals = {
115
+ nut: lpSum((g[f["name"]] / 100) * f[nut] for f in foods) for nut in required_nutrients
116
+ }
117
+ d_pos = {nut: LpVariable(f"d_pos_{nut}", lowBound=0) for nut in required_nutrients}
118
+ d_neg = {nut: LpVariable(f"d_neg_{nut}", lowBound=0) for nut in required_nutrients}
119
+ for nut in required_nutrients:
120
  prob += totals[nut] - targets[nut] <= d_pos[nut]
121
  prob += targets[nut] - totals[nut] <= d_neg[nut]
122
 
123
+ # 3.5 Optional meal-level constraints
124
+ for constraint in data.get("meal_constraints", []) or []:
125
+ group_name = constraint.get("group_name")
126
+ if not group_name:
127
+ continue
128
+ group_foods = [f for f in foods if f.get("meal_group") == group_name]
129
+ if not group_foods:
130
+ _qf_logger.warning("No foods found for meal_group '%s'", group_name)
131
+ continue
132
+ for nut in required_nutrients:
133
+ meal_total = lpSum((g[f["name"]] / 100) * f[nut] for f in group_foods)
134
+ if (max_val := constraint.get(f"max_{nut}")) is not None:
135
+ prob += (meal_total <= max_val, f"Meal_{group_name}_max_{nut}")
136
+ _qf_logger.debug("Constraint: %s max %s <= %s", group_name, nut, max_val)
137
+ if (min_val := constraint.get(f"min_{nut}")) is not None:
138
+ prob += (meal_total >= min_val, f"Meal_{group_name}_min_{nut}")
139
+ _qf_logger.debug("Constraint: %s min %s >= %s", group_name, nut, min_val)
140
+
141
+ # 4. Estimate deviations (per-item soft anchor)
142
+ dev_est_pos = {f["name"]: LpVariable(f"dev_est_pos_{f['name']}", lowBound=0) for f in foods}
143
+ dev_est_neg = {f["name"]: LpVariable(f"dev_est_neg_{f['name']}", lowBound=0) for f in foods}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  for food in foods:
145
+ name = food["name"]
146
+ est = food["estimated_g"]
147
+ prob += g[name] - est <= dev_est_pos[name]
148
+ prob += est - g[name] <= dev_est_neg[name]
149
 
150
+ # 5. Objective
 
151
  nutrition_objective = lpSum(
152
+ (d_pos[nut] + d_neg[nut]) / max(targets[nut], 1) for nut in required_nutrients
153
  )
 
 
 
 
154
  estimate_objective = lpSum(
155
+ f.get("estimate_weight", W_ESTIMATE_DEFAULT)
156
+ * (dev_est_pos[f["name"]] + dev_est_neg[f["name"]])
 
 
157
  / max(f["estimated_g"], 1)
158
  for f in foods
159
  if f["estimated_g"] > 0
160
  )
 
 
161
  prob += (W_NUTRITION * nutrition_objective) + estimate_objective
162
 
163
+ # 6. Solve
164
  prob.solve(PULP_CBC_CMD(msg=0))
 
165
  if LpStatus[prob.status] != "Optimal":
166
  raise ValueError(
167
  "No optimal solution found (problem may be infeasible). Check your targets and constraints."
168
  )
169
 
170
  quantities = {name: value(g[name]) for name in g}
171
+ achieved = {nut: value(totals[nut]) for nut in required_nutrients}
172
+ result = QuantitiesFinder._round_structure({"quantities": quantities, "achieved": achieved})
173
+
174
+ _qf_logger.info("Solution Status: %s", LpStatus[prob.status])
175
+ _qf_logger.info("Quantities (g): %s", json.dumps(result["quantities"], indent=2))
176
+ _qf_logger.info(
177
+ "Achieved Nutrition (around): %s",
178
+ json.dumps(result["achieved"], indent=2),
 
179
  )
180
+ _qf_logger.info(
181
+ "Target Nutrition: %s",
182
+ json.dumps(QuantitiesFinder._round_structure(targets), indent=2),
183
  )
184
+ _qf_logger.info("\n📊 QUANTITIES FINDER COMPLETED")
 
185
  return json.dumps(result)
186
 
187
+ except Exception as e: # noqa: BLE001
188
+ _qf_logger.error("QuantitiesFinder Error: %s", str(e))
189
+ return json.dumps({"error": str(e)})
190
+
191
 
192
+ # ---------------------------------------------------------------------------
193
+ # ComputationTool (LLM-generated Python; ⚠ replace in Phase 4)
194
+ # ---------------------------------------------------------------------------
195
  class ComputationTool:
196
  def __init__(self, llm_instance):
197
  self.llm = llm_instance
198
 
199
  def handle_task(self, task_description: str) -> str:
200
+ _comp_logger.info("\n🤖 COMPUTATION TOOL STARTED")
201
+ settings = get_settings()
202
+ instruction = (
203
+ "You are a Python coding assistant. Generate only the Python code required "
204
+ "to perform the given task. Do not forget to print the result. Do not add explanations."
205
+ )
206
  prompt = f"{instruction}\n\nTask: {task_description}\n\nCode:"
207
 
208
+ if should_debug("tools", "ComputationTool") and settings.debug_level == "full":
209
+ _comp_logger.debug("Computation Tool Prompt:\n%s", prompt)
210
  code_response = self.llm(prompt)[0]
211
+ if should_debug("tools", "ComputationTool"):
212
+ _comp_logger.debug("Computation Tool Response:\n%s", code_response)
213
 
214
+ match = re.search(r"```python\n(.*?)\n```", code_response, re.DOTALL)
215
+ if not match:
216
+ match = re.search(r"```\n(.*?)\n```", code_response, re.DOTALL)
217
+ code_to_execute = match.group(1).strip() if match else code_response.strip()
 
 
 
 
 
 
 
218
 
219
  execution_result = execute_python_code_raw(code_to_execute)
220
 
221
+ save_to_json(
222
+ {
223
+ "instruction": instruction,
224
+ "input": task_description,
225
+ "output": code_to_execute,
226
+ "execution_result": execution_result,
227
+ "timestamp": datetime.now().isoformat(),
228
+ },
229
+ f"computation_tool_{datetime.now().isoformat()}.json",
230
+ subdirectory="ComputationTool",
231
+ )
232
+ _comp_logger.info("🤖 COMPUTATION COMPLETED\n%s", execution_result)
233
  return execution_result
234
 
235
 
236
+ # ---------------------------------------------------------------------------
237
+ # WebSearchTool (DuckDuckGo + LLM synthesis)
238
+ # ---------------------------------------------------------------------------
239
  class WebSearchTool:
240
  def __init__(self, llm_instance):
241
  self.llm = llm_instance
242
 
243
  def handle_task(self, research_task: str) -> str:
244
+ _web_logger.info("\n🌐 WEB SEARCH TOOL STARTED")
245
+ settings = get_settings()
246
 
247
  try:
248
  task_data = json.loads(research_task)
249
+ if (
250
+ isinstance(task_data, dict)
251
+ and "queries" in task_data
252
+ and isinstance(task_data["queries"], list)
253
+ ):
254
+ _web_logger.info("JSON query list detected. Converting to single text task.")
255
+ research_question = " ".join(task_data["queries"])
256
  else:
257
+ _web_logger.info("Single question mode (non-query JSON). Generating queries.")
258
  research_question = research_task
259
  except (json.JSONDecodeError, TypeError):
260
+ _web_logger.info("Single question mode (plain text). Generating queries.")
261
  research_question = research_task
262
 
263
+ query_instruction = (
264
+ "Formulate concise search queries for DuckDuckGo based on the given question. "
265
+ "Output only the queries, one per line."
266
+ )
267
  query_prompt = f"{query_instruction}\n\nQuestion: {research_question}\n\nQueries:"
268
 
269
+ if should_debug("tools", "WebSearchTool") and settings.debug_level == "full":
270
+ _web_logger.debug("Web Search Query Prompt:\n%s", query_prompt)
271
  search_queries_text = self.llm(query_prompt)[0]
272
+ if should_debug("tools", "WebSearchTool"):
273
+ _web_logger.debug("Web Search Query Response:\n%s", search_queries_text)
274
 
275
+ search_queries = [q.strip() for q in search_queries_text.split("\n") if q.strip()] or [
276
+ research_question
277
+ ]
278
+ if should_debug("tools", "WebSearchTool"):
279
+ _web_logger.debug("Parsed queries: %s", search_queries)
280
 
281
  all_raw_results = []
282
+ for query in search_queries:
283
  raw_results = search_web_raw(query, num_results=10)
284
+ _web_logger.info("Search results: %s...", raw_results[:200])
285
  all_raw_results.append(f"Results for '{query}':\n{raw_results}")
286
  sleep(1)
287
 
288
  raw_search_output = "\n\n".join(all_raw_results)
289
+ synthesis_instruction = (
290
+ f"Synthesize a concise answer to:\n"
291
+ f"Question: {research_question}\n"
292
+ f"Based on:\n---\n{raw_search_output}\n---\n"
293
+ )
294
+
295
+ if should_debug("tools", "WebSearchTool") and settings.debug_level == "full":
296
+ _web_logger.debug("Web Search Synthesis Instruction:\n%s", synthesis_instruction)
 
 
 
297
  synthesized_answer = self.llm(synthesis_instruction)[0]
298
+ if should_debug("tools", "WebSearchTool"):
299
+ _web_logger.debug("Web Search Synthesis Response:\n%s", synthesized_answer)
300
 
301
  timestamp = datetime.now().isoformat()
302
+ save_to_json(
303
+ {
304
+ "instruction": query_instruction,
305
+ "input": research_question,
306
+ "output": search_queries_text,
307
+ "timestamp": timestamp,
308
+ },
309
+ f"web_search_tool_queries_{timestamp}.json",
310
+ subdirectory="WebSearchTool",
311
+ )
312
+ save_to_json(
313
+ {
314
+ "instruction": synthesis_instruction,
315
+ "input": raw_search_output,
316
+ "output": synthesized_answer,
317
+ "timestamp": timestamp,
318
+ },
319
+ f"web_search_tool_synthesis_{timestamp}.json",
320
+ subdirectory="WebSearchTool",
321
+ )
322
+ _web_logger.info("\n🌐 WEB SEARCH TOOL Result:\n%s\n", synthesized_answer)
323
  return synthesized_answer
324
 
325
+
326
+ # ---------------------------------------------------------------------------
327
+ # Helpers
328
+ # ---------------------------------------------------------------------------
329
  def execute_python_code_raw(code_string: str) -> str:
330
+ """⚠ Phase 4 will replace this with a sandbox or deterministic functions."""
331
+ settings = get_settings()
332
+ if should_debug("tools", "ComputationTool") and settings.debug_level == "full":
333
+ _comp_logger.debug("🐍 Executing Code (raw):\n%s", code_string)
334
+ script_path = ""
335
  try:
336
  with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as tmp_script:
337
  tmp_script.write(code_string)
338
  script_path = tmp_script.name
339
+ process = subprocess.run(
340
+ ["python", script_path],
341
+ capture_output=True,
342
+ text=True,
343
+ timeout=30,
344
+ )
345
  if process.returncode == 0:
346
  return f"Output:\n{process.stdout if process.stdout else 'Code executed successfully.'}"
347
+ return f"Error:\n{process.stderr}"
348
+ except Exception as e: # noqa: BLE001
 
349
  return f"Execution Exception: {str(e)}"
350
  finally:
351
+ if script_path and os.path.exists(script_path):
352
  os.remove(script_path)
353
 
354
+
355
  def search_web_raw(query: str, num_results: int = 3) -> str:
356
+ _web_logger.info("🌐 Searching Web (raw) for: %s", query)
357
  max_retries = 3
358
  for attempt in range(max_retries):
359
  try:
 
361
  results = list(ddgs.text(query, max_results=num_results, timelimit="m"))
362
  if not results:
363
  return "No search results found."
364
+ return "\n".join(
365
+ f"Title: {r.get('title')}\nURL: {r.get('href')}\nSnippet: {r.get('body')}"
366
+ for r in results
367
+ )
368
+ except Exception as e: # noqa: BLE001
369
  if attempt < max_retries - 1:
370
  sleep(1)
371
  continue
372
  return f"Search Exception after {max_retries} attempts: {str(e)}"
373
+ return "Search Exception: exhausted retries"
 
 
utils.py CHANGED
@@ -1,139 +1,134 @@
1
- import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import json
3
- import re
4
- import config
5
- from typing import TypedDict, List, Optional, Dict, Any, Tuple
6
  import pickle
7
- from langgraph.checkpoint.base import BaseCheckpointSaver
8
- from google import genai
9
- from google.genai import types
10
- from datetime import datetime, date
11
  import time
12
- from google.colab import userdata
13
- from json_repair import repair_json
14
  from collections import deque
 
 
15
  from threading import Lock
 
16
 
17
- # LANGSMITH SETUP FOR DEBUGGING
18
- # os.environ["LANGCHAIN_TRACING_V2"] = "true"
19
- # os.environ["LANGCHAIN_API_KEY"] = userdata.get("LANGCHAIN_API_KEY")
20
- # os.environ["LANGCHAIN_PROJECT"] = "Nutrition-MAS-v1"
 
21
 
22
- def should_debug(scope: str, name: str) -> bool:
23
- if not config.DEBUG_MODE:
24
- return False
25
- if scope not in config.DEBUG_SCOPES:
26
- return False
27
- scopes_list = config.DEBUG_SCOPES[scope]
28
- return 'all' in scopes_list or name in scopes_list
29
 
30
- def save_to_json(data: Dict[str, Any], filename: str, subdirectory: str = None):
31
- if config.LOG_DIR is None:
32
- # print("Logging is disabled. Skipping save_to_json.")
33
- return
34
- if subdirectory:
35
- log_dir = os.path.join(config.LOG_DIR, subdirectory)
36
- else:
37
- log_dir = config.LOG_DIR
38
- os.makedirs(log_dir, exist_ok=True)
39
- filepath = os.path.join(log_dir, filename)
40
- with open(filepath, 'w') as f:
41
- json.dump(data, f, indent=2)
42
 
43
- class LLM:
44
- def __call__(self, prompt: str, **kwargs) -> list[str]:
45
- pass
46
 
47
- def format_prompt(self, messages: List[Dict[str, str]]) -> str:
48
- pass
49
 
50
- class GeminiLLM(LLM):
51
- def __init__(self, model_name: str, structured_output: bool = False, thinking_budget: int = 300, manager=None, **kwargs):
52
- self.model_name = model_name
53
- self.structured_output = structured_output
54
- self.thinking_budget = thinking_budget
55
- self.kwargs = kwargs
56
- self.manager = manager
57
 
58
- def __call__(self, prompt: str, **kwargs) -> list[str]:
59
- if self.manager is None:
60
- raise ValueError("APIPoolManager must be provided for rate limiting.")
61
 
62
- merged_kwargs = {**self.kwargs, **kwargs}
 
 
 
63
 
64
- # Get next available API key
65
- api_key = self.manager.get_next_key(self.model_name)
 
 
 
 
 
 
 
66
 
67
- try:
68
- client = genai.Client(api_key=api_key)
69
 
70
- contents = [
71
- types.Content(
72
- role="user",
73
- parts=[types.Part.from_text(text=prompt)],
74
- )
75
- ]
76
-
77
- if self.structured_output:
78
- generate_content_config = types.GenerateContentConfig(
79
- thinking_config=types.ThinkingConfig(
80
- thinking_budget=self.thinking_budget,
81
- ),
82
- response_mime_type="application/json",
83
- max_output_tokens=merged_kwargs.get("max_tokens", 5120),
84
- temperature=merged_kwargs.get("temperature", 0.3),
85
- )
86
- else:
87
- generate_content_config = types.GenerateContentConfig(
88
- thinking_config=types.ThinkingConfig(
89
- thinking_budget=self.thinking_budget,
90
- ),
91
- response_mime_type="text/plain",
92
- max_output_tokens=merged_kwargs.get("max_tokens", 5120),
93
- temperature=merged_kwargs.get("temperature", 0.3),
94
- )
95
 
96
- response_text = ""
97
- start_time = time.time()
98
- for chunk in client.models.generate_content_stream(
99
- model=self.model_name,
100
- contents=contents,
101
- config=generate_content_config,
102
- ):
103
- if chunk.text:
104
- response_text += chunk.text
105
-
106
- # Record usage only on successful completion
107
- completion_time = time.time()
108
- if self.manager.rate_limits is not None:
109
- self.manager.record_usage(api_key, self.model_name, completion_time)
110
 
111
- if config.DEBUG_MODE:
112
- print(f"LLM call completed for {self.model_name} using key {api_key[-4:]} in {completion_time - start_time:.2f}s")
 
113
 
114
- return [response_text.strip()]
115
 
116
- except Exception as e:
117
- # Do not record usage on error to avoid inflating limits for failed calls
118
- # print(f"LLM call failed for {self.model_name} using key {api_key[-4:]}: {str(e)}")
119
- return [f"Error: LLM call failed - {str(e)}"]
 
 
 
 
 
 
120
 
121
- def format_prompt(self, messages: List[Dict[str, str]]) -> str:
122
- prompt = ""
123
- for msg in messages:
124
- if msg["role"] == "system":
125
- prompt += f"System: {msg['content']}\n"
126
- elif msg["role"] == "user":
127
- prompt += f"User: {msg['content']}\n"
128
- elif msg["role"] == "assistant":
129
- prompt += f"Assistant: {msg['content']}\n"
130
- prompt += "Assistant:"
131
- return prompt
132
 
133
- # In utils.py, update the GeminiLLM class as follows:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  class GeminiLLM(LLM):
136
- def __init__(self, model_name: str, structured_output: bool = False, thinking_budget: int = 300, manager=None, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  self.model_name = model_name
138
  self.structured_output = structured_output
139
  self.thinking_budget = thinking_budget
@@ -141,79 +136,167 @@ class GeminiLLM(LLM):
141
  self.manager = manager
142
  self.is_gemma = "gemma" in model_name.lower()
143
  if self.is_gemma:
 
144
  self.structured_output = False
145
  self.thinking_budget = None
146
- # No self.client or self.api_key; created dynamically
147
 
148
- def __call__(self, prompt: str, **kwargs) -> list[str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  if self.manager is None:
150
  raise ValueError("APIPoolManager must be provided for rate limiting.")
151
 
152
  merged_kwargs = {**self.kwargs, **kwargs}
153
-
154
- # Get next available API key
155
  api_key = self.manager.get_next_key(self.model_name)
156
 
157
  try:
158
  client = genai.Client(api_key=api_key)
 
 
159
 
160
- contents = [
161
- types.Content(
162
- role="user",
163
- parts=[types.Part.from_text(text=prompt)],
164
- )
165
- ]
166
-
167
- if self.is_gemma:
168
- generate_content_config = types.GenerateContentConfig(
169
- response_mime_type="text/plain",
170
- max_output_tokens=merged_kwargs.get("max_tokens", 5120),
171
- temperature=merged_kwargs.get("temperature", 0.3),
172
  )
 
 
173
  else:
174
- if self.structured_output:
175
- generate_content_config = types.GenerateContentConfig(
176
- thinking_config=types.ThinkingConfig(
177
- thinking_budget=self.thinking_budget,
178
- ),
179
- response_mime_type="application/json",
180
- max_output_tokens=merged_kwargs.get("max_tokens", 5120),
181
- temperature=merged_kwargs.get("temperature", 0.3),
182
- )
183
- else:
184
- generate_content_config = types.GenerateContentConfig(
185
- thinking_config=types.ThinkingConfig(
186
- thinking_budget=self.thinking_budget,
187
- ),
188
- response_mime_type="text/plain",
189
- max_output_tokens=merged_kwargs.get("max_tokens", 5120),
190
- temperature=merged_kwargs.get("temperature", 0.3),
191
- )
192
-
193
- response_text = ""
194
- start_time = time.time()
195
- for chunk in client.models.generate_content_stream(
196
- model=self.model_name,
197
- contents=contents,
198
- config=generate_content_config,
199
- ):
200
- if chunk.text:
201
- response_text += chunk.text
202
-
203
- # Record usage only on successful completion
204
  completion_time = time.time()
205
  if self.manager.rate_limits is not None:
206
  self.manager.record_usage(api_key, self.model_name, completion_time)
207
 
208
- if config.DEBUG_MODE:
209
- print(f"LLM call completed for {self.model_name} using key {api_key[-4:]} in {completion_time - start_time:.2f}s")
210
-
211
- return [response_text.strip()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
- except Exception as e:
214
- # Do not record usage on error to avoid inflating limits for failed calls
215
- # print(f"LLM call failed for {self.model_name} using key {api_key[-4:]}: {str(e)}")
216
- return [f"Error: LLM call failed - {str(e)}"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
  def format_prompt(self, messages: List[Dict[str, str]]) -> str:
219
  prompt = ""
@@ -228,12 +311,19 @@ class GeminiLLM(LLM):
228
  return prompt
229
 
230
 
 
231
  class APIPoolManager:
232
- def __init__(self, api_keys: List[str], rate_limits: Optional[Dict[str, Tuple[int, int]]] = None):
233
- """
234
- rate_limits: { model_name: (RPM, RPD) }
235
- usage: { api_key: { model: { "timestamps": deque(maxlen=rpm), "daily_requests": int, "last_day": date } } }
236
- """
 
 
 
 
 
 
237
  self.api_keys = list(api_keys)
238
  self.active_keys = list(api_keys)
239
  self.rate_limits = rate_limits
@@ -244,16 +334,15 @@ class APIPoolManager:
244
  if rate_limits is not None:
245
  for key in api_keys:
246
  self.usage[key] = {}
247
- for model, (rpm, rpd) in rate_limits.items():
248
  self.usage[key][model] = {
249
  "timestamps": deque(maxlen=max(1, rpm)),
250
  "daily_requests": 0,
251
- "last_day": date.today()
252
  }
253
- else:
254
- self.usage = {}
255
 
256
- def _refresh_daily(self, key: str, model: str):
 
257
  usage = self.usage[key][model]
258
  today = date.today()
259
  if usage["last_day"] < today:
@@ -266,25 +355,18 @@ class APIPoolManager:
266
  self._refresh_daily(key, model)
267
  _, rpd = self.rate_limits[model]
268
  if self.usage[key][model]["daily_requests"] >= rpd:
269
- # drop this api key
270
  if key in self.active_keys:
271
  self.active_keys.remove(key)
272
  return False
273
  return True
274
 
275
  def _key_wait_info(self, key: str, model: str) -> Tuple[float, float]:
276
- """
277
- Return tuple (wait_slot_seconds, wait_spacing_seconds)
278
- - wait_slot_seconds: time until an RPM slot frees because deque is full (0 if slot available)
279
- - wait_spacing_seconds: time until spacing interval satisfied relative to last timestamp (0 if spacing ok)
280
- """
281
  if self.rate_limits is None:
282
  return 0.0, 0.0
283
  rpm, _ = self.rate_limits[model]
284
  usage = self.usage[key][model]
285
  now = time.time()
286
 
287
- # Clean old timestamps > 60s
288
  timestamps = usage["timestamps"]
289
  while timestamps and now - timestamps[0] > 60:
290
  timestamps.popleft()
@@ -295,7 +377,7 @@ class APIPoolManager:
295
  wait_slot = max(0.0, 60.0 - (now - oldest))
296
 
297
  wait_spacing = 0.0
298
- if len(timestamps) > 0:
299
  time_since_last = now - timestamps[-1]
300
  min_interval = 60.0 / rpm if rpm > 0 else 0.0
301
  wait_spacing = max(0.0, min_interval - time_since_last)
@@ -303,9 +385,6 @@ class APIPoolManager:
303
  return wait_slot, wait_spacing
304
 
305
  def can_use_now(self, key: str, model: str) -> bool:
306
- """
307
- True if key is active, RPD ok, and both slot and spacing waits are zero.
308
- """
309
  if key not in self.active_keys:
310
  return False
311
  if not self._key_is_rpd_ok(key, model):
@@ -313,30 +392,22 @@ class APIPoolManager:
313
  wait_slot, wait_spacing = self._key_wait_info(key, model)
314
  return wait_slot <= 0.0 and wait_spacing <= 0.0
315
 
 
316
  def get_next_key(self, model: str, max_sleep_once: bool = True) -> str:
317
- """
318
- Choose an API key that can be used immediately for the given model.
319
- If none available now, compute minimum sleep needed across all keys, sleep once,
320
- then re-evaluate. Loop until a key is found or no keys left.
321
- """
322
  with self.lock:
323
  if not self.active_keys:
324
  raise RuntimeError("No available API keys left due to rate limits.")
325
 
326
- # Quick pass: try to find an immediately-available key starting from current_index
327
  n = len(self.active_keys)
328
  for i in range(n):
329
  idx = (self.current_index + i) % n
330
  key = self.active_keys[idx]
331
  if self.can_use_now(key, model):
332
- # advance pointer fairly to next key for next call
333
  self.current_index = (idx + 1) % max(1, len(self.active_keys))
334
  return key
335
 
336
- # If we reach here: no key is available *right now*
337
- # compute minimal wait across active keys
338
- min_wait = None
339
- for key in list(self.active_keys): # list() to be safe if removal happens
340
  if not self._key_is_rpd_ok(key, model):
341
  continue
342
  wait_slot, wait_spacing = self._key_wait_info(key, model)
@@ -345,139 +416,145 @@ class APIPoolManager:
345
  min_wait = wait
346
 
347
  if min_wait is None:
348
- # No keys left after RPD filtering
349
  raise RuntimeError("No available API keys left (RPD exhausted).")
350
 
351
  if min_wait and min_wait > 0:
352
- if max_sleep_once:
353
- time.sleep(min_wait)
354
- else:
355
- time.sleep(min_wait)
356
  return self.get_next_key(model, max_sleep_once=True)
357
 
358
- def record_usage(self, key: str, model: str, timestamp: Optional[float] = None):
359
- """
360
- Call this after you receive the response to record actual usage/time.
361
- timestamp default is now (time of completion).
362
- """
363
  if self.rate_limits is None:
364
  return
365
  t = timestamp or time.time()
366
  with self.lock:
367
  if key not in self.active_keys:
368
- # safety - if key was removed in-between, ignore or re-add depending on policy
369
  return
370
  self._refresh_daily(key, model)
371
  self.usage[key][model]["timestamps"].append(t)
372
  self.usage[key][model]["daily_requests"] += 1
373
- # Remove if daily limit reached
374
  _, rpd = self.rate_limits[model]
375
  if self.usage[key][model]["daily_requests"] >= rpd:
376
  if key in self.active_keys:
377
  self.active_keys.remove(key)
378
 
379
 
380
- def create_llm(config: dict, manager) -> LLM:
 
 
381
  if config["type"] == "gemini":
382
- structured_output = config.get("structured_output", False)
383
- thinking_budget = config.get("thinking_budget", 300)
384
- llm = GeminiLLM(
385
  model_name=config["model_name"],
386
- structured_output=structured_output,
387
- thinking_budget=thinking_budget,
388
  manager=manager,
389
- **config.get("params", {})
390
  )
391
- return llm
392
- else:
393
- raise ValueError(f"Unknown LLM type: {config['type']}")
394
 
 
395
  def extract_and_parse_json(text: str) -> Dict[str, Any]:
396
- """Enhanced JSON extraction and parsing with multiple fallback strategies"""
 
 
 
 
 
397
  try:
398
  return json.loads(text.strip())
399
- except:
400
  pass
401
 
402
- json_match = re.search(r'```json\s*(.*?)\s*```', text, re.DOTALL)
403
- if json_match:
404
  try:
405
- return json.loads(json_match.group(1))
406
- except:
407
  pass
408
 
409
- json_match = re.search(r'\{.*\}', text, re.DOTALL)
410
- if json_match:
411
  try:
412
- repaired_json = repair_json(json_match.group(0))
413
- return json.loads(repaired_json)
414
- except:
415
  pass
416
 
417
  try:
418
- repaired_json = repair_json(text)
419
- return json.loads(repaired_json)
420
  except Exception as e:
421
- print(f"All JSON parsing strategies failed: {str(e)}")
422
  return {
423
  "thought": f"JSON parsing failed: {str(e)}",
424
  "action": "compose_response",
425
- "params": {
426
- "text": f"I encountered an error processing your request. Original response: {text[:200]}..."
427
- },
428
  "_parse_error": True,
429
- "_original_text": text
430
  }
431
 
432
- def set_nested(d: Dict[str, Any], key: str, value: Any):
433
- keys = key.split('.')
 
 
434
  for k in keys[:-1]:
435
  d = d.setdefault(k, {})
436
  d[keys[-1]] = value
437
 
438
- def get_memory_summary(memory: Dict[str, Any], partitions: List[str] = None) -> str:
439
- """Get a formatted summary of specific memory partitions"""
 
440
  if partitions is None:
441
  partitions = ["user_profile", "medical_history", "flags_and_assessments", "plans"]
442
-
443
- summary = {}
444
  for partition in partitions:
445
- if partition in memory and memory[partition]:
446
- summary[partition] = memory[partition]
447
- else:
448
- summary[partition] = "empty"
449
 
450
- return json.dumps(summary, indent=2)
451
 
452
  def update_memory_partition(memory: Dict[str, Any], partition: str, data: Any) -> None:
453
- """Safely update a memory partition with new data"""
454
  if partition not in memory:
455
  memory[partition] = {}
456
-
457
  if isinstance(data, dict) and isinstance(memory[partition], dict):
458
  memory[partition].update(data)
459
  else:
460
  memory[partition] = data
461
- if config.DEBUG_MODE:
462
- print(f"Updated memory partition '{partition}' with new data")
463
 
 
464
  class FileCheckpointSaver(BaseCheckpointSaver):
465
- def __init__(self, directory: str):
 
 
466
  self.directory = directory
467
  os.makedirs(directory, exist_ok=True)
468
 
469
  def put(self, config: Dict[str, Any], checkpoint: Dict[str, Any]) -> None:
470
- """Save checkpoint to file"""
471
  thread_id = config.get("configurable", {}).get("thread_id", "default")
472
  filepath = os.path.join(self.directory, f"checkpoint_{thread_id}.pkl")
473
- with open(filepath, 'wb') as f:
474
  pickle.dump(checkpoint, f)
475
 
476
  def get(self, config: Dict[str, Any]) -> Optional[Dict[str, Any]]:
477
- """Load checkpoint from file"""
478
  thread_id = config.get("configurable", {}).get("thread_id", "default")
479
  filepath = os.path.join(self.directory, f"checkpoint_{thread_id}.pkl")
480
  if os.path.exists(filepath):
481
- with open(filepath, 'rb') as f:
482
  return pickle.load(f)
483
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utilities: LLM wrapper, API-key pool with rate limiting, JSON helpers, and a
2
+ LangGraph file checkpointer.
3
+
4
+ Phase 0 cleanup notes:
5
+
6
+ * Removed the duplicate ``GeminiLLM`` definition (the second class silently
7
+ shadowed the first; both remained import-visible).
8
+ * Dropped ``from google.colab import userdata`` so the module imports cleanly
9
+ outside Colab. API keys come in via ``create_llm_instances`` or env.
10
+ * Replaced ``print(...)`` calls with module loggers under ``nutrition_mas.*``.
11
+ * Routed all reads of ``config.X`` through :func:`config.get_settings`.
12
+
13
+ Larger refactors (Pydantic-typed agent IO, native Gemini ``response_schema``,
14
+ async ``acall``) land in Phase 1.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
  import json
20
+ import os
 
 
21
  import pickle
22
+ import re
 
 
 
23
  import time
 
 
24
  from collections import deque
25
+ from dataclasses import dataclass, field
26
+ from datetime import date, datetime
27
  from threading import Lock
28
+ from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar
29
 
30
+ from google import genai
31
+ from google.genai import types
32
+ from json_repair import repair_json
33
+ from langgraph.checkpoint.base import BaseCheckpointSaver
34
+ from pydantic import BaseModel, ValidationError
35
 
36
+ from config import get_settings
37
+ from logging_setup import get_logger
 
 
 
 
 
38
 
39
+ _logger = get_logger("utils")
40
+ _llm_logger = get_logger("llm.gemini")
41
+ _pool_logger = get_logger("utils.api_pool")
 
 
 
 
 
 
 
 
 
42
 
43
+ T = TypeVar("T", bound=BaseModel)
 
 
44
 
 
 
45
 
46
+ # --- Phase 1 fallback metrics --------------------------------------------------
47
+ @dataclass
48
+ class ParseMetrics:
49
+ """Counts native-vs-fallback parses across the process.
 
 
 
50
 
51
+ Phase 1's goal is to drive ``fallback_parses`` to zero. Phase 2 will surface
52
+ these via the eval harness.
53
+ """
54
 
55
+ native_parses: int = 0 # response.parsed worked first try
56
+ fallback_parses: int = 0 # had to invoke extract_and_parse_json
57
+ schema_failures: int = 0 # output failed Pydantic validation altogether
58
+ by_model: Dict[str, Dict[str, int]] = field(default_factory=dict)
59
 
60
+ def record(self, model: str, kind: str) -> None:
61
+ if kind == "native":
62
+ self.native_parses += 1
63
+ elif kind == "fallback":
64
+ self.fallback_parses += 1
65
+ elif kind == "failure":
66
+ self.schema_failures += 1
67
+ slot = self.by_model.setdefault(model, {"native": 0, "fallback": 0, "failure": 0})
68
+ slot[kind] = slot.get(kind, 0) + 1
69
 
 
 
70
 
71
+ _parse_metrics = ParseMetrics()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ def get_parse_metrics() -> ParseMetrics:
75
+ """Return the global parse-metrics singleton (read-only-ish)."""
76
+ return _parse_metrics
77
 
 
78
 
79
+ # --- Debug-scope helper --------------------------------------------------------
80
+ def should_debug(scope: str, name: str) -> bool:
81
+ """Return True when this scope/name is enabled in ``settings.debug_scopes``."""
82
+ settings = get_settings()
83
+ if not settings.debug_mode:
84
+ return False
85
+ if scope not in settings.debug_scopes:
86
+ return False
87
+ scopes_list = settings.debug_scopes[scope]
88
+ return "all" in scopes_list or name in scopes_list
89
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ # --- Filesystem logging --------------------------------------------------------
92
+ def save_to_json(data: Dict[str, Any], filename: str, subdirectory: Optional[str] = None) -> None:
93
+ """Persist a structured payload to ``settings.log_dir`` if logging is on."""
94
+ settings = get_settings()
95
+ if settings.log_dir is None:
96
+ return
97
+ log_dir = os.path.join(settings.log_dir, subdirectory) if subdirectory else settings.log_dir
98
+ os.makedirs(log_dir, exist_ok=True)
99
+ # Filenames may contain ``:`` from ISO timestamps which is invalid on Windows.
100
+ safe_name = filename.replace(":", "-")
101
+ filepath = os.path.join(log_dir, safe_name)
102
+ with open(filepath, "w", encoding="utf-8") as f:
103
+ json.dump(data, f, indent=2, default=str)
104
+
105
+
106
+ # --- LLM abstractions ----------------------------------------------------------
107
+ class LLM:
108
+ """Minimal LLM contract: callable returning a list with one string."""
109
+
110
+ def __call__(self, prompt: str, **kwargs: Any) -> list[str]: # pragma: no cover - interface
111
+ raise NotImplementedError
112
+
113
+ def format_prompt(self, messages: List[Dict[str, str]]) -> str: # pragma: no cover - interface
114
+ raise NotImplementedError
115
+
116
 
117
  class GeminiLLM(LLM):
118
+ """Synchronous Gemini wrapper with API-key pooling.
119
+
120
+ Phase 1 will add an ``acall`` async path and replace the JSON-in-text
121
+ contract with native ``response_schema`` Pydantic models.
122
+ """
123
+
124
+ def __init__(
125
+ self,
126
+ model_name: str,
127
+ structured_output: bool = False,
128
+ thinking_budget: int = 300,
129
+ manager: Optional["APIPoolManager"] = None,
130
+ **kwargs: Any,
131
+ ) -> None:
132
  self.model_name = model_name
133
  self.structured_output = structured_output
134
  self.thinking_budget = thinking_budget
 
136
  self.manager = manager
137
  self.is_gemma = "gemma" in model_name.lower()
138
  if self.is_gemma:
139
+ # Gemma family doesn't support thinking_config or JSON response schema.
140
  self.structured_output = False
141
  self.thinking_budget = None
 
142
 
143
+ def __call__(self, prompt: str, **kwargs: Any) -> list[str]:
144
+ """Untyped streaming call. Returns ``[response_text]``.
145
+
146
+ Backwards-compat path used by code that still parses JSON-from-text.
147
+ Prefer :meth:`call_typed` when a Pydantic schema is available.
148
+ """
149
+ text, _ = self._invoke(prompt, response_schema=None, **kwargs)
150
+ return [text]
151
+
152
+ def call_typed(
153
+ self,
154
+ prompt: str,
155
+ response_model: Type[T],
156
+ **kwargs: Any,
157
+ ) -> Optional[T]:
158
+ """Call Gemini with constrained-decoded JSON matching ``response_model``.
159
+
160
+ Returns a parsed instance of ``response_model``, or ``None`` if every
161
+ parse strategy failed (in which case the parse-metrics ``schema_failures``
162
+ counter is incremented so the eval harness can spot it).
163
+ """
164
+ text, parsed = self._invoke(prompt, response_schema=response_model, **kwargs)
165
+
166
+ # Strategy 1: SDK already parsed it for us via response_schema.
167
+ if isinstance(parsed, response_model):
168
+ _parse_metrics.record(self.model_name, "native")
169
+ return parsed
170
+
171
+ # Strategy 2: SDK gave us a dict; try to validate it.
172
+ if isinstance(parsed, dict):
173
+ try:
174
+ instance = response_model.model_validate(parsed)
175
+ _parse_metrics.record(self.model_name, "native")
176
+ return instance
177
+ except ValidationError as e:
178
+ _llm_logger.debug("response.parsed dict failed Pydantic validation: %s", e)
179
+
180
+ # Strategy 3: regex / json_repair fallback on the raw text.
181
+ try:
182
+ data = extract_and_parse_json(text)
183
+ instance = response_model.model_validate(data)
184
+ _parse_metrics.record(self.model_name, "fallback")
185
+ _llm_logger.warning(
186
+ "Used JSON-repair fallback for %s on model %s — fix the prompt or schema",
187
+ response_model.__name__,
188
+ self.model_name,
189
+ )
190
+ return instance
191
+ except (ValidationError, Exception) as e: # noqa: BLE001
192
+ _parse_metrics.record(self.model_name, "failure")
193
+ _llm_logger.error(
194
+ "Failed to parse %s from %s response: %s",
195
+ response_model.__name__,
196
+ self.model_name,
197
+ str(e),
198
+ )
199
+ return None
200
+
201
+ def _invoke(
202
+ self,
203
+ prompt: str,
204
+ response_schema: Optional[Type[BaseModel]] = None,
205
+ **kwargs: Any,
206
+ ) -> Tuple[str, Any]:
207
+ """Single Gemini round-trip. Returns ``(text, response.parsed)``.
208
+
209
+ ``parsed`` is whatever the SDK populated on ``response.parsed`` —
210
+ usually a Pydantic instance when ``response_schema`` is supplied, ``None``
211
+ otherwise.
212
+ """
213
  if self.manager is None:
214
  raise ValueError("APIPoolManager must be provided for rate limiting.")
215
 
216
  merged_kwargs = {**self.kwargs, **kwargs}
 
 
217
  api_key = self.manager.get_next_key(self.model_name)
218
 
219
  try:
220
  client = genai.Client(api_key=api_key)
221
+ contents = [types.Content(role="user", parts=[types.Part.from_text(text=prompt)])]
222
+ generate_content_config = self._build_config(merged_kwargs, response_schema=response_schema)
223
 
224
+ start_time = time.time()
225
+ # Non-streaming when we want response.parsed (the streaming API
226
+ # doesn't populate it). Streaming for free-text plain calls.
227
+ if response_schema is not None:
228
+ response = client.models.generate_content(
229
+ model=self.model_name,
230
+ contents=contents,
231
+ config=generate_content_config,
 
 
 
 
232
  )
233
+ response_text = response.text or ""
234
+ parsed = getattr(response, "parsed", None)
235
  else:
236
+ response_text = ""
237
+ parsed = None
238
+ for chunk in client.models.generate_content_stream(
239
+ model=self.model_name,
240
+ contents=contents,
241
+ config=generate_content_config,
242
+ ):
243
+ if chunk.text:
244
+ response_text += chunk.text
245
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  completion_time = time.time()
247
  if self.manager.rate_limits is not None:
248
  self.manager.record_usage(api_key, self.model_name, completion_time)
249
 
250
+ _llm_logger.debug(
251
+ "LLM call completed for %s using key …%s in %.2fs (schema=%s)",
252
+ self.model_name,
253
+ api_key[-4:],
254
+ completion_time - start_time,
255
+ response_schema.__name__ if response_schema else "none",
256
+ )
257
+ return response_text.strip(), parsed
258
+
259
+ except Exception as e: # noqa: BLE001 — narrow this in Phase 4 (per-error retries)
260
+ _llm_logger.warning(
261
+ "LLM call failed for %s using key …%s: %s",
262
+ self.model_name,
263
+ api_key[-4:],
264
+ str(e),
265
+ )
266
+ return f"Error: LLM call failed - {str(e)}", None
267
+
268
+ def _build_config(
269
+ self,
270
+ merged_kwargs: Dict[str, Any],
271
+ response_schema: Optional[Type[BaseModel]] = None,
272
+ ) -> types.GenerateContentConfig:
273
+ max_tokens = merged_kwargs.get("max_tokens", 5120)
274
+ temperature = merged_kwargs.get("temperature", 0.3)
275
 
276
+ if self.is_gemma:
277
+ # Gemma can't do thinking_config or response_schema.
278
+ return types.GenerateContentConfig(
279
+ response_mime_type="text/plain",
280
+ max_output_tokens=max_tokens,
281
+ temperature=temperature,
282
+ )
283
+
284
+ thinking_cfg = types.ThinkingConfig(thinking_budget=self.thinking_budget)
285
+ if response_schema is not None:
286
+ return types.GenerateContentConfig(
287
+ thinking_config=thinking_cfg,
288
+ response_mime_type="application/json",
289
+ response_schema=response_schema,
290
+ max_output_tokens=max_tokens,
291
+ temperature=temperature,
292
+ )
293
+ mime = "application/json" if self.structured_output else "text/plain"
294
+ return types.GenerateContentConfig(
295
+ thinking_config=thinking_cfg,
296
+ response_mime_type=mime,
297
+ max_output_tokens=max_tokens,
298
+ temperature=temperature,
299
+ )
300
 
301
  def format_prompt(self, messages: List[Dict[str, str]]) -> str:
302
  prompt = ""
 
311
  return prompt
312
 
313
 
314
+ # --- API key pool with optional rate limiting ----------------------------------
315
  class APIPoolManager:
316
+ """Round-robin Gemini API keys with per-key RPM/RPD enforcement.
317
+
318
+ ``rate_limits`` is ``{model_name: (rpm, rpd)}``. When ``None``, the pool
319
+ just rotates keys without any throttling.
320
+ """
321
+
322
+ def __init__(
323
+ self,
324
+ api_keys: List[str],
325
+ rate_limits: Optional[Dict[str, Tuple[int, int]]] = None,
326
+ ) -> None:
327
  self.api_keys = list(api_keys)
328
  self.active_keys = list(api_keys)
329
  self.rate_limits = rate_limits
 
334
  if rate_limits is not None:
335
  for key in api_keys:
336
  self.usage[key] = {}
337
+ for model, (rpm, _rpd) in rate_limits.items():
338
  self.usage[key][model] = {
339
  "timestamps": deque(maxlen=max(1, rpm)),
340
  "daily_requests": 0,
341
+ "last_day": date.today(),
342
  }
 
 
343
 
344
+ # --- internal helpers ------------------------------------------------------
345
+ def _refresh_daily(self, key: str, model: str) -> None:
346
  usage = self.usage[key][model]
347
  today = date.today()
348
  if usage["last_day"] < today:
 
355
  self._refresh_daily(key, model)
356
  _, rpd = self.rate_limits[model]
357
  if self.usage[key][model]["daily_requests"] >= rpd:
 
358
  if key in self.active_keys:
359
  self.active_keys.remove(key)
360
  return False
361
  return True
362
 
363
  def _key_wait_info(self, key: str, model: str) -> Tuple[float, float]:
 
 
 
 
 
364
  if self.rate_limits is None:
365
  return 0.0, 0.0
366
  rpm, _ = self.rate_limits[model]
367
  usage = self.usage[key][model]
368
  now = time.time()
369
 
 
370
  timestamps = usage["timestamps"]
371
  while timestamps and now - timestamps[0] > 60:
372
  timestamps.popleft()
 
377
  wait_slot = max(0.0, 60.0 - (now - oldest))
378
 
379
  wait_spacing = 0.0
380
+ if timestamps:
381
  time_since_last = now - timestamps[-1]
382
  min_interval = 60.0 / rpm if rpm > 0 else 0.0
383
  wait_spacing = max(0.0, min_interval - time_since_last)
 
385
  return wait_slot, wait_spacing
386
 
387
  def can_use_now(self, key: str, model: str) -> bool:
 
 
 
388
  if key not in self.active_keys:
389
  return False
390
  if not self._key_is_rpd_ok(key, model):
 
392
  wait_slot, wait_spacing = self._key_wait_info(key, model)
393
  return wait_slot <= 0.0 and wait_spacing <= 0.0
394
 
395
+ # --- public API ------------------------------------------------------------
396
  def get_next_key(self, model: str, max_sleep_once: bool = True) -> str:
 
 
 
 
 
397
  with self.lock:
398
  if not self.active_keys:
399
  raise RuntimeError("No available API keys left due to rate limits.")
400
 
 
401
  n = len(self.active_keys)
402
  for i in range(n):
403
  idx = (self.current_index + i) % n
404
  key = self.active_keys[idx]
405
  if self.can_use_now(key, model):
 
406
  self.current_index = (idx + 1) % max(1, len(self.active_keys))
407
  return key
408
 
409
+ min_wait: Optional[float] = None
410
+ for key in list(self.active_keys):
 
 
411
  if not self._key_is_rpd_ok(key, model):
412
  continue
413
  wait_slot, wait_spacing = self._key_wait_info(key, model)
 
416
  min_wait = wait
417
 
418
  if min_wait is None:
 
419
  raise RuntimeError("No available API keys left (RPD exhausted).")
420
 
421
  if min_wait and min_wait > 0:
422
+ _pool_logger.debug("Waiting %.2fs for next API slot", min_wait)
423
+ time.sleep(min_wait)
 
 
424
  return self.get_next_key(model, max_sleep_once=True)
425
 
426
+ def record_usage(self, key: str, model: str, timestamp: Optional[float] = None) -> None:
 
 
 
 
427
  if self.rate_limits is None:
428
  return
429
  t = timestamp or time.time()
430
  with self.lock:
431
  if key not in self.active_keys:
 
432
  return
433
  self._refresh_daily(key, model)
434
  self.usage[key][model]["timestamps"].append(t)
435
  self.usage[key][model]["daily_requests"] += 1
 
436
  _, rpd = self.rate_limits[model]
437
  if self.usage[key][model]["daily_requests"] >= rpd:
438
  if key in self.active_keys:
439
  self.active_keys.remove(key)
440
 
441
 
442
+ # --- Factory -------------------------------------------------------------------
443
+ def create_llm(config: dict, manager: APIPoolManager) -> LLM:
444
+ """Instantiate an LLM from a config dict."""
445
  if config["type"] == "gemini":
446
+ return GeminiLLM(
 
 
447
  model_name=config["model_name"],
448
+ structured_output=config.get("structured_output", False),
449
+ thinking_budget=config.get("thinking_budget", 300),
450
  manager=manager,
451
+ **config.get("params", {}),
452
  )
453
+ raise ValueError(f"Unknown LLM type: {config['type']}")
454
+
 
455
 
456
+ # --- JSON helpers --------------------------------------------------------------
457
  def extract_and_parse_json(text: str) -> Dict[str, Any]:
458
+ """Best-effort JSON extraction with a chain of fallbacks.
459
+
460
+ Phase 1 makes this a measured *fallback* path only — agents will use
461
+ Gemini's native ``response_schema`` for guaranteed structure. Until then,
462
+ this remains the primary parser.
463
+ """
464
  try:
465
  return json.loads(text.strip())
466
+ except Exception:
467
  pass
468
 
469
+ fenced = re.search(r"```json\s*(.*?)\s*```", text, re.DOTALL)
470
+ if fenced:
471
  try:
472
+ return json.loads(fenced.group(1))
473
+ except Exception:
474
  pass
475
 
476
+ braces = re.search(r"\{.*\}", text, re.DOTALL)
477
+ if braces:
478
  try:
479
+ return json.loads(repair_json(braces.group(0)))
480
+ except Exception:
 
481
  pass
482
 
483
  try:
484
+ return json.loads(repair_json(text))
 
485
  except Exception as e:
486
+ _logger.warning("All JSON parsing strategies failed: %s", str(e))
487
  return {
488
  "thought": f"JSON parsing failed: {str(e)}",
489
  "action": "compose_response",
490
+ "params": {"text": f"I encountered an error processing your request. Original response: {text[:200]}..."},
 
 
491
  "_parse_error": True,
492
+ "_original_text": text,
493
  }
494
 
495
+
496
+ def set_nested(d: Dict[str, Any], key: str, value: Any) -> None:
497
+ """Assign ``value`` at a dotted-path key inside a nested dict."""
498
+ keys = key.split(".")
499
  for k in keys[:-1]:
500
  d = d.setdefault(k, {})
501
  d[keys[-1]] = value
502
 
503
+
504
+ def get_memory_summary(memory: Dict[str, Any], partitions: Optional[List[str]] = None) -> str:
505
+ """Format selected memory partitions as JSON for prompt embedding."""
506
  if partitions is None:
507
  partitions = ["user_profile", "medical_history", "flags_and_assessments", "plans"]
508
+ summary: Dict[str, Any] = {}
 
509
  for partition in partitions:
510
+ summary[partition] = memory[partition] if partition in memory and memory[partition] else "empty"
511
+ return json.dumps(summary, indent=2, default=str)
 
 
512
 
 
513
 
514
  def update_memory_partition(memory: Dict[str, Any], partition: str, data: Any) -> None:
515
+ """Merge ``data`` into ``memory[partition]`` (or assign when types disagree)."""
516
  if partition not in memory:
517
  memory[partition] = {}
 
518
  if isinstance(data, dict) and isinstance(memory[partition], dict):
519
  memory[partition].update(data)
520
  else:
521
  memory[partition] = data
522
+ _logger.debug("Updated memory partition %r with new data", partition)
523
+
524
 
525
+ # --- Checkpointer --------------------------------------------------------------
526
  class FileCheckpointSaver(BaseCheckpointSaver):
527
+ """Pickle LangGraph checkpoints to ``directory/checkpoint_<thread_id>.pkl``."""
528
+
529
+ def __init__(self, directory: str) -> None:
530
  self.directory = directory
531
  os.makedirs(directory, exist_ok=True)
532
 
533
  def put(self, config: Dict[str, Any], checkpoint: Dict[str, Any]) -> None:
 
534
  thread_id = config.get("configurable", {}).get("thread_id", "default")
535
  filepath = os.path.join(self.directory, f"checkpoint_{thread_id}.pkl")
536
+ with open(filepath, "wb") as f:
537
  pickle.dump(checkpoint, f)
538
 
539
  def get(self, config: Dict[str, Any]) -> Optional[Dict[str, Any]]:
 
540
  thread_id = config.get("configurable", {}).get("thread_id", "default")
541
  filepath = os.path.join(self.directory, f"checkpoint_{thread_id}.pkl")
542
  if os.path.exists(filepath):
543
+ with open(filepath, "rb") as f:
544
  return pickle.load(f)
545
  return None
546
+
547
+
548
+ __all__ = [
549
+ "APIPoolManager",
550
+ "FileCheckpointSaver",
551
+ "GeminiLLM",
552
+ "LLM",
553
+ "create_llm",
554
+ "extract_and_parse_json",
555
+ "get_memory_summary",
556
+ "save_to_json",
557
+ "set_nested",
558
+ "should_debug",
559
+ "update_memory_partition",
560
+ ]
validation.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ValidationAgent — the critic in the generator-critic loop.
2
+
3
+ Why this exists
4
+ ----------------
5
+ The original README promised a ``ValidationAgent`` but it was never
6
+ implemented; the system shipped plans straight from the Planner to the user.
7
+ Modern multi-agent literature (Anthropic's research-system writeup, every
8
+ LangGraph reflection-pattern tutorial) is unanimous that a separate critic
9
+ node materially raises output quality on tasks with hard constraints.
10
+
11
+ Design
12
+ ------
13
+ We combine two layers:
14
+
15
+ 1. **Deterministic checks** (no LLM, no cost, instant):
16
+ * allergy violations,
17
+ * calorie deviation > 3 % of daily target,
18
+ * each macro deviation > 5 % of its target,
19
+ * disliked foods present (advisory),
20
+ * professional-consultation flag set without disclaimer.
21
+
22
+ 2. **LLM-graded checks** (one Gemini round-trip, structured output):
23
+ * medical-flag respect (e.g., diabetes user should avoid high-GL meals),
24
+ * citation presence for clinical recommendations,
25
+ * cultural appropriateness against user's country/cuisine preference.
26
+
27
+ Verdict semantics
28
+ -----------------
29
+ * ``pass`` — Coach proceeds to ``compose_response``.
30
+ * ``revise`` — Issues are bundled into the next Planner task; Coach loops back
31
+ to ``call_agent('PlannerAgent', task=...)``. Capped at 2
32
+ revisions (enforced by Coach prompt) to avoid infinite loops.
33
+ * ``reject`` — Hard stop with ``severity='high'``. Coach must compose a
34
+ warning + HITL escalation chip (Phase 4 wires up the chip).
35
+ """
36
+
37
+ from __future__ import annotations
38
+
39
+ import json
40
+ from datetime import datetime
41
+ from typing import Any, Dict, List, Optional, Tuple
42
+
43
+ from logging_setup import get_logger
44
+ from schemas import ValidationDecision, ValidationIssue
45
+ from utils import save_to_json
46
+
47
+ _logger = get_logger("agents.validation")
48
+
49
+
50
+ # Tolerances are class-level so tests/configs can override.
51
+ CALORIE_TOLERANCE = 0.03 # +/- 3 %
52
+ MACRO_TOLERANCE = 0.05 # +/- 5 %
53
+
54
+
55
+ _VALIDATION_SYSTEM_PROMPT = """\
56
+ You are the Validation Agent. You receive a meal plan and the medical
57
+ assessment context. Your job is to grade the plan, NOT redesign it.
58
+
59
+ Mandatory checks (in addition to the deterministic ones already supplied):
60
+ 1. Medical-flag respect: for each flag in flags_and_assessments.flags
61
+ (e.g., "diabetes_risk", "high_ldl"), confirm the plan does not contain
62
+ foods that contraindicate the flag. Cite which food fails which flag.
63
+ 2. Evidence: clinical recommendations in flags_and_assessments.recommendations
64
+ must be reflected in the plan or notes. Mention any unaddressed item.
65
+ 3. Cultural appropriateness: if user_profile.country is set, confirm at
66
+ least 60 % of foods are commonly available / culturally familiar there.
67
+ Otherwise emit a low-severity issue suggesting substitutions.
68
+
69
+ Output JSON shape (enforced by schema):
70
+ {
71
+ "verdict": "pass" | "revise" | "reject",
72
+ "issues": [
73
+ {"code": "...", "description": "...",
74
+ "severity": "low" | "medium" | "high"}
75
+ ],
76
+ "notes": "...",
77
+ "requires_human_review": false
78
+ }
79
+
80
+ Rules:
81
+ - Mark requires_human_review=true if any issue has severity="high" OR if
82
+ flags_and_assessments.requires_professional_consultation is true.
83
+ - Use verdict="reject" only for hard safety violations (allergy made it
84
+ through, food explicitly contraindicated by medication).
85
+ - Use verdict="revise" for fixable problems (over-budget calories, missing
86
+ guideline citation, monotonous menu).
87
+ - Use verdict="pass" only when issues is empty OR all issues are severity="low".
88
+ """
89
+
90
+
91
+ class ValidationAgent:
92
+ """Generator-critic gate for the Planner's output."""
93
+
94
+ def __init__(self, llm_instance):
95
+ self.llm = llm_instance
96
+
97
+ # ------------------------------------------------------------------
98
+ # Public API
99
+ # ------------------------------------------------------------------
100
+ def handle_task(self, task: str, memory: Dict[str, Any]) -> str:
101
+ """Validate the current plan in ``memory.plans.current_plan``.
102
+
103
+ Returns a JSON string of ``ValidationDecision.model_dump()`` so the
104
+ Coach can read structured fields back out (``verdict``, ``issues``).
105
+ """
106
+ _logger.info("\n🛡️ VALIDATION AGENT STARTED")
107
+
108
+ plan = memory.get("plans", {}).get("current_plan")
109
+ if plan is None:
110
+ _logger.warning("No current_plan in memory; nothing to validate.")
111
+ verdict = ValidationDecision(
112
+ verdict="reject",
113
+ issues=[
114
+ ValidationIssue(
115
+ code="missing_plan",
116
+ description="No current_plan in memory; Planner did not finalise.",
117
+ severity="high",
118
+ )
119
+ ],
120
+ notes="Validator received no plan. Re-run PlannerAgent.",
121
+ requires_human_review=False,
122
+ )
123
+ return self._save_and_return(task, memory, verdict)
124
+
125
+ # 1. Deterministic checks
126
+ det_issues = self._deterministic_checks(plan, memory)
127
+
128
+ # 2. LLM-graded checks (only if deterministic ones don't already reject)
129
+ llm_decision: Optional[ValidationDecision] = None
130
+ hard_block = any(i.severity == "high" for i in det_issues)
131
+ if not hard_block:
132
+ llm_decision = self._llm_review(plan, memory, det_issues)
133
+
134
+ # 3. Merge
135
+ all_issues = list(det_issues)
136
+ notes_parts: List[str] = []
137
+ requires_hr = False
138
+ if llm_decision is not None:
139
+ all_issues.extend(llm_decision.issues)
140
+ if llm_decision.notes:
141
+ notes_parts.append(llm_decision.notes)
142
+ requires_hr |= llm_decision.requires_human_review
143
+
144
+ # Force human review when the medical assessment said so.
145
+ if memory.get("flags_and_assessments", {}).get("requires_professional_consultation"):
146
+ requires_hr = True
147
+
148
+ verdict = self._compute_verdict(all_issues)
149
+ decision = ValidationDecision(
150
+ verdict=verdict,
151
+ issues=all_issues,
152
+ notes=" | ".join(notes_parts) if notes_parts else "",
153
+ requires_human_review=requires_hr,
154
+ )
155
+ _logger.info("🛡️ Validation verdict: %s (%d issue(s))", verdict, len(all_issues))
156
+ return self._save_and_return(task, memory, decision)
157
+
158
+ # ------------------------------------------------------------------
159
+ # Deterministic layer
160
+ # ------------------------------------------------------------------
161
+ @staticmethod
162
+ def _deterministic_checks(plan: Dict[str, Any], memory: Dict[str, Any]) -> List[ValidationIssue]:
163
+ issues: List[ValidationIssue] = []
164
+
165
+ user_profile = memory.get("user_profile", {}) or {}
166
+ allergies = {a.strip().lower() for a in user_profile.get("allergies", []) or [] if a}
167
+ dislikes_raw = user_profile.get("food_dislikes", "") or ""
168
+ dislikes = {d.strip().lower() for d in dislikes_raw.split(",") if d.strip()}
169
+
170
+ flags = memory.get("flags_and_assessments", {}) or {}
171
+ calc = flags.get("calculations", {}) or {}
172
+ target_calories = calc.get("daily_target_calories")
173
+ macro_targets = calc.get("macro_targets") or {}
174
+
175
+ # Walk plan, accumulating foods and totals.
176
+ foods, totals = ValidationAgent._extract_foods_and_totals(plan)
177
+
178
+ # 1. Allergy violations (severity high — never let these through)
179
+ for food in foods:
180
+ name = (food.get("name") or "").lower()
181
+ for allergen in allergies:
182
+ if allergen and allergen in name:
183
+ issues.append(
184
+ ValidationIssue(
185
+ code="allergy_violation",
186
+ description=f"Food '{name}' matches allergen '{allergen}'.",
187
+ severity="high",
188
+ )
189
+ )
190
+
191
+ # 2. Disliked foods (advisory)
192
+ for food in foods:
193
+ name = (food.get("name") or "").lower()
194
+ for d in dislikes:
195
+ if d and d in name:
196
+ issues.append(
197
+ ValidationIssue(
198
+ code="disliked_food",
199
+ description=f"Food '{name}' matches user dislike '{d}'.",
200
+ severity="low",
201
+ )
202
+ )
203
+
204
+ # 3. Calorie tolerance
205
+ if target_calories and totals.get("calories"):
206
+ dev = abs(totals["calories"] - target_calories) / target_calories
207
+ if dev > CALORIE_TOLERANCE:
208
+ issues.append(
209
+ ValidationIssue(
210
+ code="calorie_deviation",
211
+ description=(
212
+ f"Plan total {totals['calories']:.0f} kcal vs target "
213
+ f"{target_calories} kcal ({dev*100:.1f}% deviation)."
214
+ ),
215
+ severity="medium",
216
+ )
217
+ )
218
+
219
+ # 4. Macro tolerances
220
+ macro_map = {"protein_g": "protein", "fat_g": "fat", "carbohydrates_g": "carbohydrates"}
221
+ for tgt_key, plan_key in macro_map.items():
222
+ target = macro_targets.get(tgt_key)
223
+ actual = totals.get(plan_key)
224
+ if target and actual:
225
+ dev = abs(actual - target) / target
226
+ if dev > MACRO_TOLERANCE:
227
+ issues.append(
228
+ ValidationIssue(
229
+ code=f"{plan_key}_deviation",
230
+ description=(
231
+ f"{plan_key} total {actual:.0f}g vs target {target}g "
232
+ f"({dev*100:.1f}% deviation)."
233
+ ),
234
+ severity="medium",
235
+ )
236
+ )
237
+
238
+ return issues
239
+
240
+ @staticmethod
241
+ def _extract_foods_and_totals(
242
+ plan: Dict[str, Any],
243
+ ) -> Tuple[List[Dict[str, Any]], Dict[str, float]]:
244
+ """Best-effort: support both 'days' shape and a flat dict-of-foods.
245
+
246
+ We tolerate the LLM's free-form ``drafted_plan`` shape too, since the
247
+ Planner's final_plan isn't yet strictly typed.
248
+ """
249
+ foods: List[Dict[str, Any]] = []
250
+ totals: Dict[str, float] = {"calories": 0.0, "protein": 0.0, "fat": 0.0, "carbohydrates": 0.0}
251
+
252
+ def _walk(node: Any) -> None:
253
+ if isinstance(node, list):
254
+ for item in node:
255
+ _walk(item)
256
+ elif isinstance(node, dict):
257
+ if "name" in node and any(k in node for k in ("calories", "calories_g", "kcal")):
258
+ foods.append(node)
259
+ totals["calories"] += float(node.get("calories", node.get("kcal", 0)) or 0)
260
+ totals["protein"] += float(node.get("protein_g", node.get("protein", 0)) or 0)
261
+ totals["fat"] += float(node.get("fat_g", node.get("fat", 0)) or 0)
262
+ totals["carbohydrates"] += float(
263
+ node.get("carbohydrates_g", node.get("carbohydrates", 0)) or 0
264
+ )
265
+ else:
266
+ for v in node.values():
267
+ _walk(v)
268
+
269
+ _walk(plan)
270
+
271
+ # Plans may also surface daily_totals directly — prefer those when present.
272
+ if isinstance(plan, dict) and "daily_totals" in plan:
273
+ dt = plan["daily_totals"]
274
+ for k in ("calories", "protein", "fat", "carbohydrates"):
275
+ if k in dt:
276
+ totals[k] = float(dt[k])
277
+ return foods, totals
278
+
279
+ # ------------------------------------------------------------------
280
+ # LLM layer
281
+ # ------------------------------------------------------------------
282
+ def _llm_review(
283
+ self,
284
+ plan: Dict[str, Any],
285
+ memory: Dict[str, Any],
286
+ deterministic_issues: List[ValidationIssue],
287
+ ) -> Optional[ValidationDecision]:
288
+ det_summary = "\n".join(f"- [{i.severity}] {i.code}: {i.description}" for i in deterministic_issues) or "None"
289
+ prompt = (
290
+ f"{_VALIDATION_SYSTEM_PROMPT}\n\n--- Plan ---\n{json.dumps(plan, indent=2, default=str)}\n\n"
291
+ f"--- User profile ---\n{json.dumps(memory.get('user_profile', {}), indent=2, default=str)}\n\n"
292
+ f"--- Medical assessment ---\n"
293
+ f"{json.dumps(memory.get('flags_and_assessments', {}), indent=2, default=str)}\n\n"
294
+ f"--- Deterministic findings already raised ---\n{det_summary}\n\n"
295
+ "Add only NEW issues. Do not repeat the deterministic ones."
296
+ )
297
+ decision = self.llm.call_typed(prompt, ValidationDecision)
298
+ if decision is None:
299
+ _logger.warning("Validator LLM call returned no parseable decision; skipping LLM layer.")
300
+ return decision
301
+
302
+ # ------------------------------------------------------------------
303
+ @staticmethod
304
+ def _compute_verdict(issues: List[ValidationIssue]) -> str:
305
+ if any(i.severity == "high" for i in issues):
306
+ return "reject"
307
+ if any(i.severity == "medium" for i in issues):
308
+ return "revise"
309
+ return "pass"
310
+
311
+ @staticmethod
312
+ def _save_and_return(task: str, memory: Dict[str, Any], decision: ValidationDecision) -> str:
313
+ # Persist to memory so the Coach can inspect the verdict next turn.
314
+ memory.setdefault("flags_and_assessments", {})
315
+ memory["flags_and_assessments"]["last_validation"] = decision.model_dump()
316
+ memory["flags_and_assessments"]["last_validation_at"] = datetime.now().isoformat()
317
+
318
+ save_to_json(
319
+ {
320
+ "task": task,
321
+ "decision": decision.model_dump(),
322
+ "timestamp": datetime.now().isoformat(),
323
+ },
324
+ f"validation_agent_{datetime.now().isoformat()}.json",
325
+ subdirectory="ValidationAgent",
326
+ )
327
+ return decision.model_dump_json()
workflow.py CHANGED
@@ -1,121 +1,135 @@
1
- from langgraph.graph import StateGraph, END
 
 
 
 
 
 
 
 
 
 
 
2
  from langgraph.checkpoint.memory import MemorySaver
 
 
 
 
3
  from state import NutritionState
4
- from utils import extract_and_parse_json, set_nested, FileCheckpointSaver
5
- from datetime import datetime
6
- import json
7
- import config
8
 
9
  def should_continue(state: NutritionState) -> str:
10
- if state["current_action"] and state["current_action"]["action"] in ["compose_response", "ask_user"]:
 
 
11
  return "end"
12
  if state["num_turns"] >= state["max_turns"]:
13
  return "end"
14
  return "execute_action"
15
 
 
16
  def coach_node(state: NutritionState, coach_agent) -> NutritionState:
17
  return coach_agent.handle_task(state)
18
 
19
- def execute_action_node(state: NutritionState, agents, tools) -> NutritionState:
20
- action = state["current_action"]
21
- if not action or not action.get("action"):
 
 
 
 
22
  return state
23
- if config.DEBUG_MODE:
24
- print(f"Executing Action: {action['action']}")
25
-
26
- # Add more specific high-level print for user mode
27
- if not config.DEBUG_MODE:
28
- if action["action"] == "call_agent":
29
- agent_name = action["params"]["agent_name"]
30
- task = action["params"]["task"]
31
- elif action["action"] == "call_tool":
32
- tool_name = action["params"]["tool_name"]
33
- task = action["params"]["task"]
34
- elif action["action"] == "ask_user":
35
- print(f"❓Asking user: {action['params']['prompt']}")
36
- elif action["action"] == "write_memory":
37
- print(f"Writing to memory partition: {action['params']['partition']}")
38
-
39
- # Handle JSON parsing errors
40
  if action.get("_parse_error"):
41
  error_message = "I encountered an error processing the request. Let me try a different approach."
42
  state["conversation_history"].append({"role": "assistant", "content": error_message})
43
  return {**state, "agent_result": error_message}
44
 
45
- # Initialize previous_actions if not present
46
- if 'previous_actions' not in state:
47
- state['previous_actions'] = []
48
 
49
  try:
50
- if action["action"] == "call_agent":
51
- agent_name = action["params"]["agent_name"]
52
- task = action["params"]["task"]
53
  agent_result = agents[agent_name].handle_task(task, state["memory"])
54
- # Set success message instead of full result
55
- success_message = f"{agent_name} task completed and stored in the memory successfully" if agent_result else f"{agent_name} task failed"
56
- action_description = f"Called agent {agent_name} with task: {task}"
57
- state['previous_actions'].append(action_description)
 
 
58
  return {**state, "agent_result": success_message}
59
 
60
- elif action["action"] == "call_tool":
61
- tool_name = action["params"]["tool_name"]
62
- task = action["params"]["task"]
63
- tool_result = tools[tool_name].handle_task(task) if tool_name in tools else f"Unknown tool: {tool_name}"
64
- action_description = f"Called tool {tool_name} with task: {task}"
65
- state['previous_actions'].append(action_description)
 
66
  return {**state, "agent_result": tool_result}
67
 
68
- elif action["action"] == "write_memory":
69
- partition = action["params"]["partition"]
70
- data = action["params"]["data"]
71
  updated_data = {**data, "last_updated": datetime.now().isoformat()}
72
  set_nested(state["memory"], partition, updated_data)
73
- action_description = f"Wrote to memory partition: {partition}"
74
- state['previous_actions'].append(action_description)
75
  return {**state, "agent_result": "Memory updated successfully"}
76
 
77
- elif action["action"] == "compose_response":
78
- response_text = action["params"].get("text") or action["params"].get("response")
79
  if not response_text:
80
  raise ValueError("Missing 'text' or 'response' in params for compose_response")
81
  state["conversation_history"].append({"role": "assistant", "content": response_text})
82
- action_description = "Composed response to user"
83
- state['previous_actions'].append(action_description)
84
  return {**state, "agent_result": response_text}
85
 
86
- elif action["action"] == "ask_user":
87
- prompt_text = action["params"]["prompt"]
88
  state["conversation_history"].append({"role": "assistant", "content": prompt_text})
89
- action_description = f"Asked user: {prompt_text}"
90
- state['previous_actions'].append(action_description)
91
  return {**state, "agent_result": "User prompted for input"}
92
 
93
- else:
94
- action_description = f"Executed {action['action']} with params: {action.get('params', {})}"
95
- state['previous_actions'].append(action_description)
96
- return {**state, "agent_result": f"Unknown action: {action['action']}"}
97
 
98
- except Exception as e:
99
- error_result = f"Error executing {action['action']}: {str(e)}"
100
- action_description = f"Attempted {action['action']} with params: {action.get('params', {})}"
101
- state['previous_actions'].append(action_description)
102
- return {**state, "agent_result": error_result}
103
 
104
- def setup_workflow(coach_agent, agents, tools, persistence_dir=None):
 
105
  workflow = StateGraph(NutritionState)
106
  workflow.add_node("coach", lambda state: coach_node(state, coach_agent))
107
  workflow.add_node("execute_action", lambda state: execute_action_node(state, agents, tools))
108
  workflow.set_entry_point("coach")
109
  workflow.add_edge("coach", "execute_action")
110
- workflow.add_conditional_edges("execute_action", should_continue, {"execute_action": "coach", "end": END})
111
-
 
 
 
 
112
  if persistence_dir:
113
  checkpointer = FileCheckpointSaver(persistence_dir)
114
- print(f"MAS workflow compiled with file-based persistence at {persistence_dir}.")
115
  else:
116
  checkpointer = MemorySaver()
117
- print("MAS workflow compiled with in-memory persistence.")
118
-
119
- app = workflow.compile(checkpointer=checkpointer)
120
- return app
121
 
 
 
1
+ """LangGraph wiring for the Coach <-> action loop.
2
+
3
+ Phase 1 keeps the same two-node graph (``coach`` -> ``execute_action`` -> loop)
4
+ so the public contract is unchanged. Phase 2 will explode this into subgraphs
5
+ with parallel branches and a Validator critic loop.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from datetime import datetime
11
+ from typing import Any, Dict
12
+
13
  from langgraph.checkpoint.memory import MemorySaver
14
+ from langgraph.graph import END, StateGraph
15
+
16
+ from config import get_settings
17
+ from logging_setup import get_logger
18
  from state import NutritionState
19
+ from utils import FileCheckpointSaver, set_nested
20
+
21
+ _logger = get_logger("workflow")
22
+
23
 
24
  def should_continue(state: NutritionState) -> str:
25
+ """Edge predicate: stop on terminal action or when we hit max_turns."""
26
+ current = state.get("current_action") or {}
27
+ if current.get("action") in {"compose_response", "ask_user"}:
28
  return "end"
29
  if state["num_turns"] >= state["max_turns"]:
30
  return "end"
31
  return "execute_action"
32
 
33
+
34
  def coach_node(state: NutritionState, coach_agent) -> NutritionState:
35
  return coach_agent.handle_task(state)
36
 
37
+
38
+ def execute_action_node(state: NutritionState, agents: Dict[str, Any], tools: Dict[str, Any]) -> NutritionState:
39
+ action = state.get("current_action") or {}
40
+ action_name = action.get("action")
41
+ params = action.get("params", {}) or {}
42
+
43
+ if not action_name:
44
  return state
45
+
46
+ settings = get_settings()
47
+ if settings.debug_mode:
48
+ _logger.debug("Executing Action: %s", action_name)
49
+ else:
50
+ if action_name == "ask_user":
51
+ _logger.info("❓ Asking user: %s", params.get("prompt"))
52
+ elif action_name == "write_memory":
53
+ _logger.info("Writing to memory partition: %s", params.get("partition"))
54
+
 
 
 
 
 
 
 
55
  if action.get("_parse_error"):
56
  error_message = "I encountered an error processing the request. Let me try a different approach."
57
  state["conversation_history"].append({"role": "assistant", "content": error_message})
58
  return {**state, "agent_result": error_message}
59
 
60
+ if "previous_actions" not in state:
61
+ state["previous_actions"] = []
 
62
 
63
  try:
64
+ if action_name == "call_agent":
65
+ agent_name = params["agent_name"]
66
+ task = params["task"]
67
  agent_result = agents[agent_name].handle_task(task, state["memory"])
68
+ success_message = (
69
+ f"{agent_name} task completed and stored in the memory successfully"
70
+ if agent_result
71
+ else f"{agent_name} task failed"
72
+ )
73
+ state["previous_actions"].append(f"Called agent {agent_name} with task: {task}")
74
  return {**state, "agent_result": success_message}
75
 
76
+ if action_name == "call_tool":
77
+ tool_name = params["tool_name"]
78
+ task = params["task"]
79
+ tool_result = (
80
+ tools[tool_name].handle_task(task) if tool_name in tools else f"Unknown tool: {tool_name}"
81
+ )
82
+ state["previous_actions"].append(f"Called tool {tool_name} with task: {task}")
83
  return {**state, "agent_result": tool_result}
84
 
85
+ if action_name == "write_memory":
86
+ partition = params["partition"]
87
+ data = params["data"]
88
  updated_data = {**data, "last_updated": datetime.now().isoformat()}
89
  set_nested(state["memory"], partition, updated_data)
90
+ state["previous_actions"].append(f"Wrote to memory partition: {partition}")
 
91
  return {**state, "agent_result": "Memory updated successfully"}
92
 
93
+ if action_name == "compose_response":
94
+ response_text = params.get("text") or params.get("response")
95
  if not response_text:
96
  raise ValueError("Missing 'text' or 'response' in params for compose_response")
97
  state["conversation_history"].append({"role": "assistant", "content": response_text})
98
+ state["previous_actions"].append("Composed response to user")
 
99
  return {**state, "agent_result": response_text}
100
 
101
+ if action_name == "ask_user":
102
+ prompt_text = params["prompt"]
103
  state["conversation_history"].append({"role": "assistant", "content": prompt_text})
104
+ state["previous_actions"].append(f"Asked user: {prompt_text}")
 
105
  return {**state, "agent_result": "User prompted for input"}
106
 
107
+ state["previous_actions"].append(f"Executed {action_name} with params: {params}")
108
+ return {**state, "agent_result": f"Unknown action: {action_name}"}
 
 
109
 
110
+ except Exception as e: # noqa: BLE001
111
+ _logger.exception("Error executing %s", action_name)
112
+ state["previous_actions"].append(f"Attempted {action_name} with params: {params}")
113
+ return {**state, "agent_result": f"Error executing {action_name}: {str(e)}"}
 
114
 
115
+
116
+ def setup_workflow(coach_agent, agents: Dict[str, Any], tools: Dict[str, Any], persistence_dir: str | None = None):
117
  workflow = StateGraph(NutritionState)
118
  workflow.add_node("coach", lambda state: coach_node(state, coach_agent))
119
  workflow.add_node("execute_action", lambda state: execute_action_node(state, agents, tools))
120
  workflow.set_entry_point("coach")
121
  workflow.add_edge("coach", "execute_action")
122
+ workflow.add_conditional_edges(
123
+ "execute_action",
124
+ should_continue,
125
+ {"execute_action": "coach", "end": END},
126
+ )
127
+
128
  if persistence_dir:
129
  checkpointer = FileCheckpointSaver(persistence_dir)
130
+ _logger.info("MAS workflow compiled with file-based persistence at %s.", persistence_dir)
131
  else:
132
  checkpointer = MemorySaver()
133
+ _logger.info("MAS workflow compiled with in-memory persistence.")
 
 
 
134
 
135
+ return workflow.compile(checkpointer=checkpointer)