Spaces:
Sleeping
Sleeping
Merge Phase 2 into Phase 3 base
Browse files- .env.example +23 -0
- .gitignore +60 -0
- Makefile +26 -0
- agents.py +569 -436
- config.py +113 -6
- logging_setup.py +77 -0
- nutritionmas.py +66 -39
- pyproject.toml +52 -0
- requirements.txt +31 -0
- schemas.py +215 -0
- tests/__init__.py +0 -0
- tests/conftest.py +99 -0
- tests/test_api_pool.py +51 -0
- tests/test_quantities_finder.py +90 -0
- tests/test_schemas.py +135 -0
- tests/test_settings.py +53 -0
- tests/test_smoke.py +64 -0
- tests/test_typed_agents.py +184 -0
- tests/test_validation_agent.py +201 -0
- tools.py +236 -211
- utils.py +335 -258
- validation.py +327 -0
- workflow.py +86 -72
.env.example
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Comma-separated list of Gemini API keys. The system rotates through them
|
| 2 |
+
# and respects per-key RPM/RPD limits when NUTRITION_MAS_ENABLE_RATE_LIMITING=true.
|
| 3 |
+
NUTRITION_MAS_GEMINI_API_KEYS=key_one,key_two
|
| 4 |
+
|
| 5 |
+
# Optional infra paths.
|
| 6 |
+
# When set, every agent/tool I/O is dumped to LOG_DIR/<subdir>/<timestamp>.json
|
| 7 |
+
NUTRITION_MAS_LOG_DIR=
|
| 8 |
+
# When set, LangGraph checkpoints are persisted to disk (instead of memory).
|
| 9 |
+
NUTRITION_MAS_PERSISTENCE_DIR=
|
| 10 |
+
|
| 11 |
+
# Debug switches (default: off)
|
| 12 |
+
NUTRITION_MAS_DEBUG_MODE=false
|
| 13 |
+
NUTRITION_MAS_DEBUG_LEVEL=full # 'full' | 'output'
|
| 14 |
+
# JSON-encoded dict, see config.Settings for shape. Defaults to all/all.
|
| 15 |
+
# NUTRITION_MAS_DEBUG_SCOPES={"agents": ["CoachAgent"], "tools": ["all"]}
|
| 16 |
+
|
| 17 |
+
# Rate limiting (default: on)
|
| 18 |
+
NUTRITION_MAS_ENABLE_RATE_LIMITING=true
|
| 19 |
+
|
| 20 |
+
# Optional LangSmith tracing (Phase 6 will wire this up properly)
|
| 21 |
+
# LANGCHAIN_TRACING_V2=true
|
| 22 |
+
# LANGCHAIN_API_KEY=
|
| 23 |
+
# LANGCHAIN_PROJECT=Nutrition-MAS
|
.gitignore
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
wheels/
|
| 19 |
+
*.egg-info/
|
| 20 |
+
.installed.cfg
|
| 21 |
+
*.egg
|
| 22 |
+
MANIFEST
|
| 23 |
+
|
| 24 |
+
# Virtual environments
|
| 25 |
+
.venv/
|
| 26 |
+
venv/
|
| 27 |
+
env/
|
| 28 |
+
ENV/
|
| 29 |
+
.env
|
| 30 |
+
|
| 31 |
+
# IDE
|
| 32 |
+
.vscode/
|
| 33 |
+
.idea/
|
| 34 |
+
.claude/
|
| 35 |
+
*.swp
|
| 36 |
+
*.swo
|
| 37 |
+
*~
|
| 38 |
+
.DS_Store
|
| 39 |
+
|
| 40 |
+
# Project
|
| 41 |
+
logs/
|
| 42 |
+
checkpoints/
|
| 43 |
+
data/cache/
|
| 44 |
+
*.sqlite
|
| 45 |
+
*.sqlite3
|
| 46 |
+
*.db
|
| 47 |
+
.cache/
|
| 48 |
+
|
| 49 |
+
# Pytest / coverage
|
| 50 |
+
.pytest_cache/
|
| 51 |
+
.coverage
|
| 52 |
+
htmlcov/
|
| 53 |
+
.mypy_cache/
|
| 54 |
+
.ruff_cache/
|
| 55 |
+
|
| 56 |
+
# Notebook
|
| 57 |
+
.ipynb_checkpoints/
|
| 58 |
+
|
| 59 |
+
# LangSmith / tracing
|
| 60 |
+
.langsmith/
|
Makefile
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.PHONY: install dev test lint format clean run
|
| 2 |
+
|
| 3 |
+
install:
|
| 4 |
+
pip install -r requirements.txt
|
| 5 |
+
|
| 6 |
+
dev:
|
| 7 |
+
pip install -e ".[dev]"
|
| 8 |
+
|
| 9 |
+
test:
|
| 10 |
+
pytest -ra -q
|
| 11 |
+
|
| 12 |
+
test-cov:
|
| 13 |
+
pytest -ra --cov=. --cov-report=term-missing --cov-report=html
|
| 14 |
+
|
| 15 |
+
lint:
|
| 16 |
+
ruff check .
|
| 17 |
+
|
| 18 |
+
format:
|
| 19 |
+
ruff format .
|
| 20 |
+
|
| 21 |
+
clean:
|
| 22 |
+
rm -rf .pytest_cache .ruff_cache .mypy_cache htmlcov .coverage
|
| 23 |
+
find . -type d -name __pycache__ -exec rm -rf {} +
|
| 24 |
+
|
| 25 |
+
run:
|
| 26 |
+
python -c "import nutritionmas; print('Module imports OK')"
|
agents.py
CHANGED
|
@@ -1,489 +1,622 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
import json
|
| 6 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
class CoachAgent:
|
| 9 |
def __init__(self, llm_instance):
|
| 10 |
self.llm = llm_instance
|
| 11 |
|
| 12 |
def handle_task(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
| 13 |
-
|
|
|
|
| 14 |
response_steps = state.get("response_steps", [])
|
| 15 |
-
response_steps_str =
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
| 17 |
for msg in state["conversation_history"]:
|
| 18 |
if msg["role"] == "assistant" and len(msg["content"]) > 200:
|
| 19 |
-
|
| 20 |
-
|
|
|
|
| 21 |
else:
|
| 22 |
truncated_history.append(msg)
|
| 23 |
-
history_str = "\n".join(
|
| 24 |
-
|
| 25 |
-
observation =
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
Action outputs: respond with a JSON object:
|
| 63 |
-
{{
|
| 64 |
-
"observation": "...",
|
| 65 |
-
"thought": "...",
|
| 66 |
-
"response_steps": [ ... ],
|
| 67 |
-
"action": "call_agent | call_tool | ask_user | write_memory | compose_response",
|
| 68 |
-
"params": {{ ... }}
|
| 69 |
-
}}
|
| 70 |
-
|
| 71 |
-
Examples:
|
| 72 |
-
- call_agent params: {{"agent_name":"MedicalAssessmentAgent", "task":"task description"}}
|
| 73 |
-
- compose_response params:{{"text":"Complete response in markdown"}}
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
Rules:
|
| 77 |
-
- When composing the response, extract and include relevant information from the memory state (e.g., calorie target, plan details, dietary restrictions) in markdown format for readability.
|
| 78 |
-
- Always include a "trace" field in composed responses summarizing which agents/tools were called for and which sources were used.
|
| 79 |
-
- For high-risk profiles (e.g., requires_professional_consultation: true); in such cases append a bold warning at the end of the diet plan response advising professional consultation before implementation.
|
| 80 |
-
"""
|
| 81 |
-
|
| 82 |
-
if should_debug('agents', 'CoachAgent'):
|
| 83 |
-
print(f"\n--- Coach Agent Turn {state['num_turns'] + 1} ---")
|
| 84 |
-
if should_debug('agents', 'CoachAgent') and config.DEBUG_LEVEL == 'full':
|
| 85 |
-
print(f"Raw LLM input:\n{prompt}")
|
| 86 |
-
response = self.llm(prompt)[0]
|
| 87 |
-
if should_debug('agents', 'CoachAgent'):
|
| 88 |
-
print(f"Coach Raw Response:\n{response}")
|
| 89 |
-
|
| 90 |
-
parsed = extract_and_parse_json(response)
|
| 91 |
-
|
| 92 |
-
# Add high-level print for user mode
|
| 93 |
-
if not config.DEBUG_MODE:
|
| 94 |
-
action = parsed.get("action")
|
| 95 |
-
params = parsed.get("params", {})
|
| 96 |
-
print_str = "\n🏋️♂️Coach Agent: "
|
| 97 |
-
if action == "call_agent":
|
| 98 |
-
print_str += f"Calling {params.get('agent_name')} with task '{params.get('task')}'"
|
| 99 |
-
elif action == "call_tool":
|
| 100 |
-
print_str += f"Using {params.get('tool_name')} with task '{params.get('task')}'"
|
| 101 |
-
elif action == "ask_user":
|
| 102 |
-
print_str += f"Asking user: {params.get('prompt')}"
|
| 103 |
-
elif action == "write_memory":
|
| 104 |
-
print_str += f"Writing to memory partition '{params.get('partition')}'"
|
| 105 |
-
elif action == "compose_response":
|
| 106 |
-
print_str += "Composing final response"
|
| 107 |
-
print(print_str)
|
| 108 |
-
|
| 109 |
-
current_action = {
|
| 110 |
-
"action": parsed.get("action"),
|
| 111 |
-
"params": parsed.get("params", {})
|
| 112 |
-
}
|
| 113 |
|
| 114 |
-
response_steps = parsed.get("response_steps", state.get("response_steps", []))
|
| 115 |
-
|
| 116 |
-
log_data = {
|
| 117 |
-
"prompt": prompt,
|
| 118 |
-
"output":response,
|
| 119 |
-
"parsed": parsed,
|
| 120 |
-
"timestamp": datetime.now().isoformat()
|
| 121 |
-
}
|
| 122 |
-
save_to_json(log_data, f'coach_agent_{datetime.now().isoformat()}.json', subdirectory='CoachAgent')
|
| 123 |
-
|
| 124 |
return {
|
| 125 |
**state,
|
| 126 |
"current_action": current_action,
|
| 127 |
-
"response_steps":
|
| 128 |
"num_turns": state["num_turns"] + 1,
|
| 129 |
-
"agent_result": None
|
| 130 |
}
|
| 131 |
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
class MedicalAssessmentAgent:
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
self.llm = llm_instance
|
| 136 |
self.computation_tool = computation_tool
|
| 137 |
self.web_search_tool = web_search_tool
|
| 138 |
|
| 139 |
def handle_task(self, task: str, memory: Dict[str, Any]) -> str:
|
| 140 |
-
|
|
|
|
| 141 |
|
| 142 |
-
# Build relevant memory context
|
| 143 |
relevant_memory = {
|
| 144 |
"user_profile": memory.get("user_profile", {}),
|
| 145 |
"medical_history": memory.get("medical_history", {}),
|
| 146 |
}
|
| 147 |
-
memory_str = json.dumps(relevant_memory, indent=2)
|
| 148 |
-
tool_results = []
|
| 149 |
-
assessment_plan = []
|
| 150 |
-
|
| 151 |
-
iteration
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
print(f"Medical Assessment Raw Response:\n{response}")
|
| 201 |
-
|
| 202 |
-
parsed = extract_and_parse_json(response)
|
| 203 |
-
|
| 204 |
-
# Add high-level print for user mode
|
| 205 |
-
if not config.DEBUG_MODE:
|
| 206 |
-
action_type = parsed.get("action", {}).get("type")
|
| 207 |
-
if action_type == "call_tool":
|
| 208 |
-
tool_name = parsed["action"].get("tool_name")
|
| 209 |
-
tool_task = parsed["action"].get("tool_task")
|
| 210 |
-
print(f"👨🏻⚕️ Medical Assessment Agent: Using {tool_name} for '{tool_task}'")
|
| 211 |
-
elif action_type == "ask_user":
|
| 212 |
-
fields = parsed["action"].get("fields", [])
|
| 213 |
-
print(f"👨🏻⚕️ Medical Assessment Agent: Asking user for missing fields: {', '.join(fields)}")
|
| 214 |
-
elif action_type == "assessment_complete":
|
| 215 |
-
print("👨🏻⚕️ Medical Assessment Agent: Completing assessment")
|
| 216 |
-
|
| 217 |
-
if "assessment_plan" in parsed:
|
| 218 |
-
assessment_plan = parsed["assessment_plan"]
|
| 219 |
-
|
| 220 |
-
action = parsed.get("action", {})
|
| 221 |
-
action_type = action.get("type")
|
| 222 |
-
|
| 223 |
-
if action_type == "call_tool":
|
| 224 |
-
tool_name = action.get("tool_name")
|
| 225 |
-
tool_task = action.get("tool_task")
|
| 226 |
-
|
| 227 |
-
if tool_name == "ComputationTool":
|
| 228 |
-
if tool_task:
|
| 229 |
-
result = self.computation_tool.handle_task(tool_task)
|
| 230 |
-
else:
|
| 231 |
-
result = "Missing 'tool_task' for ComputationTool"
|
| 232 |
-
elif tool_name == "WebSearchTool":
|
| 233 |
-
if tool_task:
|
| 234 |
-
result = self.web_search_tool.handle_task(tool_task)
|
| 235 |
-
else:
|
| 236 |
-
result = "Missing 'tool_task' for WebSearchTool"
|
| 237 |
-
else:
|
| 238 |
-
result = f"Unknown tool: {tool_name}"
|
| 239 |
-
|
| 240 |
-
tool_results.append(f"{tool_name}: {result}")
|
| 241 |
-
|
| 242 |
-
elif action_type == "ask_user":
|
| 243 |
-
fields = action.get("fields", [])
|
| 244 |
-
result = f"Missing critical fields: {', '.join(fields)}. Please provide the following information to continue the assessment."
|
| 245 |
-
print(f"👨🏻⚕️ MEDICAL ASSESSMENT AGENT: User query needed - {result}")
|
| 246 |
-
return result
|
| 247 |
-
|
| 248 |
-
elif action_type == "assessment_complete":
|
| 249 |
-
assessment_summary = action.get("assessment_summary")
|
| 250 |
-
flags_to_set = action.get("flags_to_set", [])
|
| 251 |
-
recommendations = action.get("recommendations", [])
|
| 252 |
-
requires_professional_consultation = action.get("requires_professional_consultation", False)
|
| 253 |
-
calculations = action.get("calculations", {}) # Now a dict as per new prompt
|
| 254 |
-
evidence_sources = action.get("evidence_sources", [])
|
| 255 |
-
trace = action.get("trace", "")
|
| 256 |
-
|
| 257 |
-
if action.get("requires_tool_retry", False):
|
| 258 |
-
result = "Assessment requires tool retry due to failures. Please re-run with fixed tools."
|
| 259 |
-
print(f"👨🏻⚕️ MEDICAL ASSESSMENT AGENT: Tool retry needed - {result}")
|
| 260 |
-
return result # Return early without updating memory
|
| 261 |
-
|
| 262 |
-
# Update memory using update_memory_partition
|
| 263 |
-
update_memory_partition(memory, "flags_and_assessments", {
|
| 264 |
-
"assessment_summary": assessment_summary,
|
| 265 |
-
"flags": flags_to_set,
|
| 266 |
-
"recommendations": recommendations,
|
| 267 |
-
"requires_professional_consultation": requires_professional_consultation,
|
| 268 |
-
"calculations": calculations,
|
| 269 |
-
"evidence_sources": evidence_sources,
|
| 270 |
-
"trace": trace,
|
| 271 |
-
"assessment_timestamp": datetime.now().isoformat() # Retained timestamp
|
| 272 |
-
})
|
| 273 |
-
|
| 274 |
-
# Log the assessment (updated to include new fields)
|
| 275 |
-
log_data = {
|
| 276 |
-
"task": task,
|
| 277 |
-
"memory_input": relevant_memory,
|
| 278 |
-
"tool_results": tool_results,
|
| 279 |
-
"assessment_summary": assessment_summary,
|
| 280 |
-
"flags_set": flags_to_set,
|
| 281 |
-
"recommendations": recommendations,
|
| 282 |
-
"requires_professional_consultation": requires_professional_consultation,
|
| 283 |
-
"evidence_sources": evidence_sources,
|
| 284 |
-
"trace": trace,
|
| 285 |
-
"timestamp": datetime.now().isoformat()
|
| 286 |
-
}
|
| 287 |
-
save_to_json(log_data, f'medical_assessment_{datetime.now().isoformat()}.json', subdirectory='MedicalAssessment')
|
| 288 |
-
|
| 289 |
-
result = assessment_summary
|
| 290 |
-
print(f"👨🏻⚕️ MEDICAL ASSESSMENT AGENT COMPLETED: {result}")
|
| 291 |
-
return result
|
| 292 |
|
| 293 |
else:
|
| 294 |
-
|
| 295 |
break
|
| 296 |
|
| 297 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
|
| 299 |
-
# Fallback if max iterations reached
|
| 300 |
-
result = f"Medical assessment stopped after {max_iterations} iterations"
|
| 301 |
-
print(f"👨🏻⚕️ MEDICAL ASSESSMENT AGENT Stopped (MAX ITERATIONS)")
|
| 302 |
-
return result
|
| 303 |
|
| 304 |
class PlannerAgent:
|
| 305 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
self.llm = llm_instance
|
| 307 |
self.computation_tool = computation_tool
|
| 308 |
self.web_search_tool = web_search_tool
|
| 309 |
self.quantities_finder = quantities_finder
|
| 310 |
|
| 311 |
def handle_task(self, task: str, memory: Dict[str, Any]) -> str:
|
| 312 |
-
|
|
|
|
| 313 |
|
| 314 |
relevant_memory = {
|
| 315 |
"user_profile": memory.get("user_profile", {}),
|
| 316 |
"flags_and_assessments": memory.get("flags_and_assessments", {}),
|
| 317 |
}
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
- If Current Planning Steps is provided... You may remain in a step for multiple iterations if necessary to meet all targets, as outlined in the Iterative Correction Loop rule.
|
| 374 |
-
|
| 375 |
-
Return JSON:
|
| 376 |
-
- observation, thought
|
| 377 |
-
- planning_steps (full list of response_step objects)
|
| 378 |
-
- action: one of {{
|
| 379 |
-
"type":"call_tool","tool_name":...,"tool_task":...,
|
| 380 |
-
"type":"draft_plan","drafted_plan":{{...}},
|
| 381 |
-
"type":"provide_plan","final_plan":{{...}}
|
| 382 |
-
}}
|
| 383 |
-
|
| 384 |
-
Notes:
|
| 385 |
-
- Keep each plan realistic and culturally appropriate (regional foods if provided).
|
| 386 |
-
- Trace: at the end of the plan, summarize which agents/tools were called.
|
| 387 |
-
- Always include the full updated planning_steps in your response JSON to persist across iterations.
|
| 388 |
-
"""
|
| 389 |
-
|
| 390 |
-
if should_debug('agents', 'PlannerAgent'):
|
| 391 |
-
print(f"\n--- Planner Agent Iteration {iteration + 1} ---")
|
| 392 |
-
if should_debug('agents', 'PlannerAgent') and config.DEBUG_LEVEL == 'full':
|
| 393 |
-
print(f"Raw LLM input:\n{prompt}")
|
| 394 |
-
response = self.llm(prompt)[0]
|
| 395 |
-
if should_debug('agents', 'PlannerAgent'):
|
| 396 |
-
print(f"Planner Raw Response:\n{response}")
|
| 397 |
-
|
| 398 |
-
parsed = extract_and_parse_json(response)
|
| 399 |
-
|
| 400 |
-
# Add high-level print for user mode
|
| 401 |
-
if not config.DEBUG_MODE:
|
| 402 |
-
action_type = parsed.get("action", {}).get("type")
|
| 403 |
-
print_str = "📋 Planner Agent: "
|
| 404 |
-
if action_type == "call_tool":
|
| 405 |
-
tool_name = parsed["action"].get("tool_name")
|
| 406 |
-
tool_task = parsed["action"].get("tool_task")
|
| 407 |
-
print_str += f"Using {tool_name} for '{tool_task}'"
|
| 408 |
-
elif action_type == "draft_plan":
|
| 409 |
-
print_str += "Drafting plan"
|
| 410 |
-
elif action_type == "provide_plan":
|
| 411 |
-
print_str += "Finalizing plan"
|
| 412 |
-
print(print_str)
|
| 413 |
-
|
| 414 |
-
planning_steps = parsed.get("planning_steps", planning_steps)
|
| 415 |
-
|
| 416 |
-
action = parsed.get("action", {})
|
| 417 |
-
action_type = action.get("type")
|
| 418 |
-
|
| 419 |
-
if action_type == "call_tool":
|
| 420 |
-
tool_name = action.get("tool_name")
|
| 421 |
-
tool_task = action.get("tool_task")
|
| 422 |
-
if tool_name and tool_task:
|
| 423 |
-
print(f"Calling {tool_name} with task: {tool_task}")
|
| 424 |
-
if tool_name == "ComputationTool":
|
| 425 |
-
result = self.computation_tool.handle_task(tool_task)
|
| 426 |
-
elif tool_name == "WebSearchTool":
|
| 427 |
-
result = self.web_search_tool.handle_task(tool_task)
|
| 428 |
-
elif tool_name == "QuantitiesFinder":
|
| 429 |
-
result = self.quantities_finder.handle_task(tool_task)
|
| 430 |
-
else:
|
| 431 |
-
result = f"Unknown tool: {tool_name}"
|
| 432 |
-
tool_results.append(f"{tool_name}: {result}")
|
| 433 |
-
else:
|
| 434 |
-
print("Missing tool_name or tool_task")
|
| 435 |
-
|
| 436 |
-
elif action_type == "draft_plan":
|
| 437 |
-
drafted_plan = action.get("drafted_plan")
|
| 438 |
-
if drafted_plan:
|
| 439 |
-
if "plans" not in memory:
|
| 440 |
-
memory["plans"] = {}
|
| 441 |
-
memory["plans"]["drafted_plan"] = drafted_plan
|
| 442 |
-
result = "Plan drafted and stored in memory"
|
| 443 |
-
tool_results.append(result)
|
| 444 |
else:
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
|
|
|
|
|
|
| 475 |
|
| 476 |
else:
|
| 477 |
-
|
| 478 |
break
|
| 479 |
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Agent implementations.
|
| 2 |
+
|
| 3 |
+
Phase 1: every agent's per-turn output is now a Pydantic model from
|
| 4 |
+
``schemas``. The prompts are split so the static system rules sit in a
|
| 5 |
+
module-level constant (eligible for Gemini's implicit prompt cache) and only
|
| 6 |
+
the dynamic state changes per call.
|
| 7 |
+
|
| 8 |
+
The action-dispatch loops are still *inside* the agent classes — Phase 2 will
|
| 9 |
+
break them into LangGraph subgraphs with parallel tool nodes and the
|
| 10 |
+
ValidationAgent critic loop.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
import json
|
| 16 |
+
from datetime import datetime
|
| 17 |
+
from typing import Any, Dict, List, Optional
|
| 18 |
+
|
| 19 |
+
from config import get_settings
|
| 20 |
+
from logging_setup import get_logger
|
| 21 |
+
from schemas import (
|
| 22 |
+
CoachDecision,
|
| 23 |
+
MedicalAssessmentDecision,
|
| 24 |
+
MedicalAssessmentResult,
|
| 25 |
+
PlannerDecision,
|
| 26 |
+
)
|
| 27 |
+
from tools import ComputationTool, QuantitiesFinder, WebSearchTool
|
| 28 |
+
from utils import save_to_json, should_debug, update_memory_partition
|
| 29 |
+
|
| 30 |
+
_coach_logger = get_logger("agents.coach")
|
| 31 |
+
_medical_logger = get_logger("agents.medical")
|
| 32 |
+
_planner_logger = get_logger("agents.planner")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ---------------------------------------------------------------------------
|
| 36 |
+
# Coach
|
| 37 |
+
# ---------------------------------------------------------------------------
|
| 38 |
+
_COACH_SYSTEM_PROMPT = """\
|
| 39 |
+
You are the Coach Agent (central orchestrator) of a nutrition Multi-Agent System.
|
| 40 |
+
|
| 41 |
+
Primary responsibilities:
|
| 42 |
+
- Translate user intent into a concrete workflow of response_steps.
|
| 43 |
+
- Enforce system rules (MedicalAssessment must complete before Planner runs).
|
| 44 |
+
- Decide and perform exactly one action per turn: call_agent, call_tool,
|
| 45 |
+
ask_user, write_memory, or compose_response.
|
| 46 |
+
|
| 47 |
+
Inputs each turn:
|
| 48 |
+
- observation (string built from user query + memory + history)
|
| 49 |
+
- memory partitions: user_profile, medical_history, flags_and_assessments, plans
|
| 50 |
+
- response_steps (list, may be empty on the first turn)
|
| 51 |
+
|
| 52 |
+
Behaviour rules (mandatory):
|
| 53 |
+
1. If response_steps is empty, generate ordered steps (max 7). Each step
|
| 54 |
+
must include id, actor, prerequisites, and status "pending".
|
| 55 |
+
Typical personal-workflow (when the user asks for a personalised plan):
|
| 56 |
+
1) Validate required user data (height, weight, age, sex, activity_level,
|
| 57 |
+
allergies, goal). If missing -> ask_user.
|
| 58 |
+
2) Update memory if the user provided new data [action: write_memory].
|
| 59 |
+
3) Call MedicalAssessmentAgent with a task to assess the user.
|
| 60 |
+
4) Wait for assessment to be completed and stored in memory.
|
| 61 |
+
5) Call PlannerAgent with the relevant task.
|
| 62 |
+
6) Call ValidationAgent to grade the plan.
|
| 63 |
+
7) If validation verdict == "revise", re-call PlannerAgent with the
|
| 64 |
+
validation issues prepended to the task; otherwise compose_response.
|
| 65 |
+
2. When calling any agent, set the called step status to "in_progress" and
|
| 66 |
+
include prerequisites satisfied by your observation.
|
| 67 |
+
3. Only call PlannerAgent if memory.flags_and_assessments contains an
|
| 68 |
+
"assessment_status" of "assessment_complete". If missing, call
|
| 69 |
+
MedicalAssessmentAgent first.
|
| 70 |
+
4. After every PlannerAgent run, you MUST call ValidationAgent before
|
| 71 |
+
composing the response. Inspect memory.flags_and_assessments.last_validation:
|
| 72 |
+
* verdict == "pass": proceed to compose_response.
|
| 73 |
+
* verdict == "revise": call PlannerAgent again with task =
|
| 74 |
+
"Revise the plan to address: " + each issue.description joined by "; ".
|
| 75 |
+
Cap revisions at 2; on the third attempt, compose_response with the
|
| 76 |
+
best plan available and append the unresolved issues as warnings.
|
| 77 |
+
* verdict == "reject": compose_response with a clear refusal explaining
|
| 78 |
+
the violation; do NOT show the plan. Append a HITL escalation chip
|
| 79 |
+
(text marker the UI will render).
|
| 80 |
+
5. When new personal data appears in user input, add steps to: propose memory
|
| 81 |
+
update (write_memory), call MedicalAssessmentAgent if needed, re-plan if
|
| 82 |
+
needed.
|
| 83 |
+
6. For any write_memory action, provide the full partition contents in
|
| 84 |
+
params.data (not diffs). The Coach is responsible for merging and storing.
|
| 85 |
+
|
| 86 |
+
Output JSON shape (enforced by schema):
|
| 87 |
+
{
|
| 88 |
+
"observation": "...",
|
| 89 |
+
"thought": "...",
|
| 90 |
+
"response_steps": [ ... ],
|
| 91 |
+
"action": "call_agent | call_tool | ask_user | write_memory | compose_response",
|
| 92 |
+
"params": { ... }
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
Required params per action:
|
| 96 |
+
- call_agent: {"agent_name": "...", "task": "..."}
|
| 97 |
+
- call_tool: {"tool_name": "...", "task": "..."}
|
| 98 |
+
- ask_user: {"prompt": "..."}
|
| 99 |
+
- write_memory: {"partition": "...", "data": {...}}
|
| 100 |
+
- compose_response: {"text": "...markdown..."}
|
| 101 |
+
|
| 102 |
+
Composition rules:
|
| 103 |
+
- When composing the response, extract relevant information from memory state
|
| 104 |
+
(calorie target, plan details, dietary restrictions, citations) in markdown.
|
| 105 |
+
- Always include a "trace" line summarising which agents/tools contributed.
|
| 106 |
+
- For high-risk profiles (requires_professional_consultation == true), append
|
| 107 |
+
a bold warning advising professional consultation before implementation.
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
|
| 111 |
class CoachAgent:
|
| 112 |
def __init__(self, llm_instance):
|
| 113 |
self.llm = llm_instance
|
| 114 |
|
| 115 |
def handle_task(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
| 116 |
+
settings = get_settings()
|
| 117 |
+
memory_str = json.dumps(state["memory"], indent=2, default=str)
|
| 118 |
response_steps = state.get("response_steps", [])
|
| 119 |
+
response_steps_str = (
|
| 120 |
+
json.dumps(response_steps, indent=2, default=str) if response_steps else "None"
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
truncated_history: List[Dict[str, str]] = []
|
| 124 |
for msg in state["conversation_history"]:
|
| 125 |
if msg["role"] == "assistant" and len(msg["content"]) > 200:
|
| 126 |
+
truncated_history.append(
|
| 127 |
+
{"role": "assistant", "content": msg["content"][:200] + "... (full response in memory)"}
|
| 128 |
+
)
|
| 129 |
else:
|
| 130 |
truncated_history.append(msg)
|
| 131 |
+
history_str = "\n".join(f"{m['role']}: {m['content']}" for m in truncated_history)
|
| 132 |
+
|
| 133 |
+
observation = (
|
| 134 |
+
f"User query: {state['user_question']}\n"
|
| 135 |
+
f"Memory State: {memory_str}\n"
|
| 136 |
+
f"Current Response Steps: {response_steps_str}\n"
|
| 137 |
+
f"Previous Tool Result: {state.get('agent_result', 'None')}\n"
|
| 138 |
+
f"Conversation history: {history_str}"
|
| 139 |
+
)
|
| 140 |
+
prompt = f"{_COACH_SYSTEM_PROMPT}\n\n--- Current State ---\n{observation}"
|
| 141 |
+
|
| 142 |
+
if should_debug("agents", "CoachAgent"):
|
| 143 |
+
_coach_logger.debug("--- Coach Agent Turn %d ---", state["num_turns"] + 1)
|
| 144 |
+
if settings.debug_level == "full":
|
| 145 |
+
_coach_logger.debug("Raw LLM input:\n%s", prompt)
|
| 146 |
+
|
| 147 |
+
decision = self.llm.call_typed(prompt, CoachDecision)
|
| 148 |
+
if decision is None:
|
| 149 |
+
return self._fallback_state(state, "Coach decision could not be parsed.")
|
| 150 |
+
|
| 151 |
+
if should_debug("agents", "CoachAgent"):
|
| 152 |
+
_coach_logger.debug("Coach decision:\n%s", decision.model_dump_json(indent=2))
|
| 153 |
+
|
| 154 |
+
if not settings.debug_mode:
|
| 155 |
+
self._log_user_mode_action(decision)
|
| 156 |
+
|
| 157 |
+
current_action = {"action": decision.action, "params": decision.params}
|
| 158 |
+
new_steps = [s.model_dump() for s in decision.response_steps] or state.get("response_steps", [])
|
| 159 |
+
|
| 160 |
+
save_to_json(
|
| 161 |
+
{
|
| 162 |
+
"prompt": prompt,
|
| 163 |
+
"decision": decision.model_dump(),
|
| 164 |
+
"timestamp": datetime.now().isoformat(),
|
| 165 |
+
},
|
| 166 |
+
f"coach_agent_{datetime.now().isoformat()}.json",
|
| 167 |
+
subdirectory="CoachAgent",
|
| 168 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
return {
|
| 171 |
**state,
|
| 172 |
"current_action": current_action,
|
| 173 |
+
"response_steps": new_steps,
|
| 174 |
"num_turns": state["num_turns"] + 1,
|
| 175 |
+
"agent_result": None,
|
| 176 |
}
|
| 177 |
|
| 178 |
+
@staticmethod
|
| 179 |
+
def _log_user_mode_action(decision: CoachDecision) -> None:
|
| 180 |
+
params = decision.params or {}
|
| 181 |
+
action = decision.action
|
| 182 |
+
if action == "call_agent":
|
| 183 |
+
msg = f"Calling {params.get('agent_name')} with task '{params.get('task')}'"
|
| 184 |
+
elif action == "call_tool":
|
| 185 |
+
msg = f"Using {params.get('tool_name')} with task '{params.get('task')}'"
|
| 186 |
+
elif action == "ask_user":
|
| 187 |
+
msg = f"Asking user: {params.get('prompt')}"
|
| 188 |
+
elif action == "write_memory":
|
| 189 |
+
msg = f"Writing to memory partition '{params.get('partition')}'"
|
| 190 |
+
elif action == "compose_response":
|
| 191 |
+
msg = "Composing final response"
|
| 192 |
+
else:
|
| 193 |
+
msg = f"Unknown action: {action}"
|
| 194 |
+
_coach_logger.info("\n🏋️♂️Coach Agent: %s", msg)
|
| 195 |
+
|
| 196 |
+
@staticmethod
|
| 197 |
+
def _fallback_state(state: Dict[str, Any], message: str) -> Dict[str, Any]:
|
| 198 |
+
_coach_logger.error(message)
|
| 199 |
+
return {
|
| 200 |
+
**state,
|
| 201 |
+
"current_action": {
|
| 202 |
+
"action": "compose_response",
|
| 203 |
+
"params": {"text": f"Sorry — I hit an internal error while planning. ({message})"},
|
| 204 |
+
"_parse_error": True,
|
| 205 |
+
},
|
| 206 |
+
"num_turns": state["num_turns"] + 1,
|
| 207 |
+
"agent_result": None,
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# ---------------------------------------------------------------------------
|
| 212 |
+
# Medical Assessment
|
| 213 |
+
# ---------------------------------------------------------------------------
|
| 214 |
+
_MEDICAL_SYSTEM_PROMPT = """\
|
| 215 |
+
You are the Medical Assessment Agent. Produce an evidence-based assessment and
|
| 216 |
+
the clinical flags / calculations the Planner and Validation agents need.
|
| 217 |
+
|
| 218 |
+
Available tools: ComputationTool, WebSearchTool.
|
| 219 |
+
|
| 220 |
+
Mandatory behaviour (do not skip):
|
| 221 |
+
1. Critical data check: confirm presence of age, sex, height, weight,
|
| 222 |
+
activity_level, allergies, medications. If any critical field is missing,
|
| 223 |
+
set action_type="ask_user" and list the missing names in ``fields``.
|
| 224 |
+
2. Use ComputationTool for ALL numeric calculations (BMI, BMR, TDEE, calorie
|
| 225 |
+
targets, macro targets). Pass numeric inputs in tool_task.
|
| 226 |
+
3. Use WebSearchTool to fetch authoritative guidelines (WHO, USDA, ADA,
|
| 227 |
+
EFSA). Capture source URLs with timestamps.
|
| 228 |
+
4. Produce a compact assessment_plan (3-6 steps). Default sequence:
|
| 229 |
+
a) ComputationTool: BMI, BMR, TDEE, daily_target_calories (single int).
|
| 230 |
+
b) ComputationTool: macro_targets (protein_g, fat_g, carbohydrates_g - all
|
| 231 |
+
single ints, no ranges) optimised for the user's goal.
|
| 232 |
+
c) WebSearchTool: dietary guidelines for the user's conditions.
|
| 233 |
+
d-f) Optional follow-ups for specific risks.
|
| 234 |
+
5. When complete, set action_type="assessment_complete" and populate
|
| 235 |
+
``result`` (a MedicalAssessmentResult) with:
|
| 236 |
+
- assessment_summary
|
| 237 |
+
- calculations: { BMI, BMR, TDEE, daily_target_calories,
|
| 238 |
+
macro_targets: { protein_g, fat_g, carbohydrates_g } }
|
| 239 |
+
- flags_to_set (e.g. ["high_ldl", "diabetes_risk"])
|
| 240 |
+
- recommendations (clinical dietary constraints / urgent issues)
|
| 241 |
+
- requires_professional_consultation (True for medically sensitive cases)
|
| 242 |
+
- evidence_sources (list of URLs)
|
| 243 |
+
- trace (one paragraph summarising agent/tool usage)
|
| 244 |
+
6. If any tool call fails, fall back to best-known values, set
|
| 245 |
+
data_confidence below 1.0, and mark requires_tool_retry=true.
|
| 246 |
+
|
| 247 |
+
Output JSON shape (enforced by schema):
|
| 248 |
+
{
|
| 249 |
+
"medical_reasoning": "...",
|
| 250 |
+
"observation": "...",
|
| 251 |
+
"risk_assessment_priorities": [...],
|
| 252 |
+
"assessment_plan": [...],
|
| 253 |
+
"action_type": "call_tool" | "ask_user" | "assessment_complete",
|
| 254 |
+
"tool_name": "ComputationTool" | "WebSearchTool" | null,
|
| 255 |
+
"tool_task": "..." | null,
|
| 256 |
+
"fields": [...], // only when ask_user
|
| 257 |
+
"result": { ... } // only when assessment_complete
|
| 258 |
+
}
|
| 259 |
+
"""
|
| 260 |
+
|
| 261 |
+
|
| 262 |
class MedicalAssessmentAgent:
|
| 263 |
+
MAX_ITERATIONS = 15
|
| 264 |
+
|
| 265 |
+
def __init__(
|
| 266 |
+
self,
|
| 267 |
+
llm_instance,
|
| 268 |
+
computation_tool: ComputationTool,
|
| 269 |
+
web_search_tool: WebSearchTool,
|
| 270 |
+
):
|
| 271 |
self.llm = llm_instance
|
| 272 |
self.computation_tool = computation_tool
|
| 273 |
self.web_search_tool = web_search_tool
|
| 274 |
|
| 275 |
def handle_task(self, task: str, memory: Dict[str, Any]) -> str:
|
| 276 |
+
_medical_logger.info("\n👨🏻⚕️ MEDICAL ASSESSMENT AGENT STARTED")
|
| 277 |
+
settings = get_settings()
|
| 278 |
|
|
|
|
| 279 |
relevant_memory = {
|
| 280 |
"user_profile": memory.get("user_profile", {}),
|
| 281 |
"medical_history": memory.get("medical_history", {}),
|
| 282 |
}
|
| 283 |
+
memory_str = json.dumps(relevant_memory, indent=2, default=str)
|
| 284 |
+
tool_results: List[str] = []
|
| 285 |
+
assessment_plan: List[dict] = []
|
| 286 |
+
|
| 287 |
+
for iteration in range(self.MAX_ITERATIONS):
|
| 288 |
+
tool_results_str = (
|
| 289 |
+
"\n".join(f"Tool Result {i+1}: {r}" for i, r in enumerate(tool_results)) or "None"
|
| 290 |
+
)
|
| 291 |
+
assessment_plan_str = (
|
| 292 |
+
json.dumps(assessment_plan, indent=2, default=str) if assessment_plan else "None"
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
prompt = (
|
| 296 |
+
f"{_MEDICAL_SYSTEM_PROMPT}\n\n--- Task & State ---\n"
|
| 297 |
+
f"Task: {task}\n"
|
| 298 |
+
f"Current Memory: {memory_str}\n"
|
| 299 |
+
f"Current Assessment Plan: {assessment_plan_str}\n"
|
| 300 |
+
f"Previous Tool Results: {tool_results_str}\n"
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
if should_debug("agents", "MedicalAssessmentAgent"):
|
| 304 |
+
_medical_logger.debug("--- Medical Assessment Iteration %d ---", iteration + 1)
|
| 305 |
+
if settings.debug_level == "full":
|
| 306 |
+
_medical_logger.debug("Raw LLM input:\n%s", prompt)
|
| 307 |
+
|
| 308 |
+
decision = self.llm.call_typed(prompt, MedicalAssessmentDecision)
|
| 309 |
+
if decision is None:
|
| 310 |
+
_medical_logger.error("Medical decision parse failed at iteration %d", iteration + 1)
|
| 311 |
+
return "Medical assessment failed: could not parse LLM decision."
|
| 312 |
+
|
| 313 |
+
if should_debug("agents", "MedicalAssessmentAgent"):
|
| 314 |
+
_medical_logger.debug("Medical decision:\n%s", decision.model_dump_json(indent=2))
|
| 315 |
+
|
| 316 |
+
if decision.assessment_plan:
|
| 317 |
+
assessment_plan = [s.model_dump() for s in decision.assessment_plan]
|
| 318 |
+
|
| 319 |
+
if not settings.debug_mode:
|
| 320 |
+
self._log_user_mode_action(decision)
|
| 321 |
+
|
| 322 |
+
if decision.action_type == "call_tool":
|
| 323 |
+
tool_results.append(f"{decision.tool_name}: {self._dispatch_tool(decision)}")
|
| 324 |
+
|
| 325 |
+
elif decision.action_type == "ask_user":
|
| 326 |
+
fields = decision.fields or []
|
| 327 |
+
msg = (
|
| 328 |
+
f"Missing critical fields: {', '.join(fields)}. "
|
| 329 |
+
"Please provide the following information to continue the assessment."
|
| 330 |
+
)
|
| 331 |
+
_medical_logger.info("👨🏻⚕️ MEDICAL ASSESSMENT AGENT: User query needed - %s", msg)
|
| 332 |
+
return msg
|
| 333 |
+
|
| 334 |
+
elif decision.action_type == "assessment_complete":
|
| 335 |
+
return self._finalize(task, decision, memory, relevant_memory, tool_results)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
|
| 337 |
else:
|
| 338 |
+
_medical_logger.error("Unknown action_type: %s", decision.action_type)
|
| 339 |
break
|
| 340 |
|
| 341 |
+
_medical_logger.warning("👨🏻⚕️ MEDICAL ASSESSMENT AGENT Stopped (MAX ITERATIONS)")
|
| 342 |
+
return f"Medical assessment stopped after {self.MAX_ITERATIONS} iterations"
|
| 343 |
+
|
| 344 |
+
# ------------------------------------------------------------------
|
| 345 |
+
def _dispatch_tool(self, decision: MedicalAssessmentDecision) -> str:
|
| 346 |
+
tool_name = decision.tool_name
|
| 347 |
+
tool_task = decision.tool_task
|
| 348 |
+
if not tool_task:
|
| 349 |
+
return f"Missing 'tool_task' for {tool_name}"
|
| 350 |
+
if tool_name == "ComputationTool":
|
| 351 |
+
return self.computation_tool.handle_task(tool_task)
|
| 352 |
+
if tool_name == "WebSearchTool":
|
| 353 |
+
return self.web_search_tool.handle_task(tool_task)
|
| 354 |
+
return f"Unknown tool: {tool_name}"
|
| 355 |
+
|
| 356 |
+
@staticmethod
|
| 357 |
+
def _log_user_mode_action(decision: MedicalAssessmentDecision) -> None:
|
| 358 |
+
if decision.action_type == "call_tool":
|
| 359 |
+
_medical_logger.info(
|
| 360 |
+
"👨🏻⚕️ Medical Assessment Agent: Using %s for '%s'",
|
| 361 |
+
decision.tool_name,
|
| 362 |
+
decision.tool_task,
|
| 363 |
+
)
|
| 364 |
+
elif decision.action_type == "ask_user":
|
| 365 |
+
_medical_logger.info(
|
| 366 |
+
"👨🏻⚕️ Medical Assessment Agent: Asking user for missing fields: %s",
|
| 367 |
+
", ".join(decision.fields or []),
|
| 368 |
+
)
|
| 369 |
+
elif decision.action_type == "assessment_complete":
|
| 370 |
+
_medical_logger.info("👨🏻⚕️ Medical Assessment Agent: Completing assessment")
|
| 371 |
+
|
| 372 |
+
def _finalize(
|
| 373 |
+
self,
|
| 374 |
+
task: str,
|
| 375 |
+
decision: MedicalAssessmentDecision,
|
| 376 |
+
memory: Dict[str, Any],
|
| 377 |
+
relevant_memory: Dict[str, Any],
|
| 378 |
+
tool_results: List[str],
|
| 379 |
+
) -> str:
|
| 380 |
+
result: Optional[MedicalAssessmentResult] = decision.result
|
| 381 |
+
if result is None:
|
| 382 |
+
_medical_logger.error("assessment_complete decision missing result payload")
|
| 383 |
+
return "Medical assessment failed: completion payload missing."
|
| 384 |
+
|
| 385 |
+
if result.requires_tool_retry:
|
| 386 |
+
msg = "Assessment requires tool retry due to tool failures."
|
| 387 |
+
_medical_logger.warning("👨🏻⚕️ MEDICAL ASSESSMENT AGENT: Tool retry needed - %s", msg)
|
| 388 |
+
return msg
|
| 389 |
+
|
| 390 |
+
update_memory_partition(
|
| 391 |
+
memory,
|
| 392 |
+
"flags_and_assessments",
|
| 393 |
+
{
|
| 394 |
+
"assessment_summary": result.assessment_summary,
|
| 395 |
+
"flags": result.flags_to_set,
|
| 396 |
+
"recommendations": result.recommendations,
|
| 397 |
+
"requires_professional_consultation": result.requires_professional_consultation,
|
| 398 |
+
"calculations": result.calculations.model_dump(),
|
| 399 |
+
"evidence_sources": result.evidence_sources,
|
| 400 |
+
"data_confidence": result.data_confidence,
|
| 401 |
+
"trace": result.trace,
|
| 402 |
+
"assessment_status": "assessment_complete",
|
| 403 |
+
"assessment_timestamp": datetime.now().isoformat(),
|
| 404 |
+
},
|
| 405 |
+
)
|
| 406 |
+
save_to_json(
|
| 407 |
+
{
|
| 408 |
+
"task": task,
|
| 409 |
+
"memory_input": relevant_memory,
|
| 410 |
+
"tool_results": tool_results,
|
| 411 |
+
"result": result.model_dump(),
|
| 412 |
+
"timestamp": datetime.now().isoformat(),
|
| 413 |
+
},
|
| 414 |
+
f"medical_assessment_{datetime.now().isoformat()}.json",
|
| 415 |
+
subdirectory="MedicalAssessment",
|
| 416 |
+
)
|
| 417 |
+
_medical_logger.info("👨🏻⚕️ MEDICAL ASSESSMENT AGENT COMPLETED: %s", result.assessment_summary)
|
| 418 |
+
return result.assessment_summary
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
# ---------------------------------------------------------------------------
|
| 422 |
+
# Planner
|
| 423 |
+
# ---------------------------------------------------------------------------
|
| 424 |
+
_PLANNER_SYSTEM_PROMPT = """\
|
| 425 |
+
You are the Planner Agent. Create personalised meal plans constrained by the
|
| 426 |
+
medical assessment.
|
| 427 |
+
|
| 428 |
+
Available tools: WebSearchTool, QuantitiesFinder, ComputationTool.
|
| 429 |
+
|
| 430 |
+
Mandatory behaviour & rules:
|
| 431 |
+
1. Precondition: do NOT plan unless flags_and_assessments has an
|
| 432 |
+
"assessment_status" of "assessment_complete". If missing, return
|
| 433 |
+
action_type="provide_plan" with final_plan={"error": "..."} explaining the
|
| 434 |
+
blocker and suggesting MedicalAssessmentAgent.
|
| 435 |
+
2. Batch tool calls: fetch nutrition facts for ALL foods in one WebSearchTool
|
| 436 |
+
call rather than one call per item.
|
| 437 |
+
3. For each food in the draft, look up per-100g nutrition (calories, protein,
|
| 438 |
+
fat, carbohydrates). If WebSearchTool fails for >2 items, fall back to
|
| 439 |
+
internal knowledge.
|
| 440 |
+
4. Tolerances: calories +/- 3%, each macro +/- 5% of target.
|
| 441 |
+
5. Exclude allergens and disliked foods. Propose alternatives if necessary
|
| 442 |
+
for balance.
|
| 443 |
+
6. Multi-day requests: emit a 1-2 day plan and instruct the user to rotate.
|
| 444 |
+
7. QuantitiesFinder format: tool_task MUST be a JSON STRING containing
|
| 445 |
+
{"foods": [...], "targets": {...}}. Each food needs name, calories,
|
| 446 |
+
protein, fat, carbohydrates (per 100g) and estimated_g (your best guess).
|
| 447 |
+
|
| 448 |
+
Planning Steps Handling:
|
| 449 |
+
- If Current Planning Steps is empty/None, adopt this fixed 5-step plan:
|
| 450 |
+
1. Draft a realistic plan; assign a realistic estimated_g per food.
|
| 451 |
+
2. Batch-gather nutrition facts via WebSearchTool.
|
| 452 |
+
3. Call QuantitiesFinder with foods + targets to compute precise grams.
|
| 453 |
+
4. Update the draft with the solver's quantities.
|
| 454 |
+
5. Provide the final plan via action_type="provide_plan".
|
| 455 |
+
- If steps are provided, you may iterate within a step until targets are met.
|
| 456 |
+
|
| 457 |
+
Output JSON shape (enforced by schema):
|
| 458 |
+
{
|
| 459 |
+
"observation": "...",
|
| 460 |
+
"thought": "...",
|
| 461 |
+
"planning_steps": [...],
|
| 462 |
+
"action_type": "call_tool" | "draft_plan" | "provide_plan",
|
| 463 |
+
"tool_name": "WebSearchTool" | "QuantitiesFinder" | "ComputationTool" | null,
|
| 464 |
+
"tool_task": "..." | null,
|
| 465 |
+
"drafted_plan": { ... } | null,
|
| 466 |
+
"final_plan": { ... } | null
|
| 467 |
+
}
|
| 468 |
+
|
| 469 |
+
Notes:
|
| 470 |
+
- Keep plans realistic and culturally appropriate (regional foods if provided).
|
| 471 |
+
- Include a "trace" line in the final plan summarising agents/tools used.
|
| 472 |
+
- Always echo the full updated planning_steps so they persist across turns.
|
| 473 |
+
"""
|
| 474 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
|
| 476 |
class PlannerAgent:
|
| 477 |
+
MAX_ITERATIONS = 15
|
| 478 |
+
|
| 479 |
+
def __init__(
|
| 480 |
+
self,
|
| 481 |
+
llm_instance,
|
| 482 |
+
computation_tool: ComputationTool,
|
| 483 |
+
web_search_tool: WebSearchTool,
|
| 484 |
+
quantities_finder: QuantitiesFinder,
|
| 485 |
+
):
|
| 486 |
self.llm = llm_instance
|
| 487 |
self.computation_tool = computation_tool
|
| 488 |
self.web_search_tool = web_search_tool
|
| 489 |
self.quantities_finder = quantities_finder
|
| 490 |
|
| 491 |
def handle_task(self, task: str, memory: Dict[str, Any]) -> str:
|
| 492 |
+
_planner_logger.info("\n📋 PLANNER AGENT STARTED")
|
| 493 |
+
settings = get_settings()
|
| 494 |
|
| 495 |
relevant_memory = {
|
| 496 |
"user_profile": memory.get("user_profile", {}),
|
| 497 |
"flags_and_assessments": memory.get("flags_and_assessments", {}),
|
| 498 |
}
|
| 499 |
+
tool_results: List[str] = []
|
| 500 |
+
planning_steps: List[dict] = []
|
| 501 |
+
|
| 502 |
+
for iteration in range(self.MAX_ITERATIONS):
|
| 503 |
+
memory_str = json.dumps(
|
| 504 |
+
{
|
| 505 |
+
"user_profile": memory.get("user_profile", {}),
|
| 506 |
+
"flags_and_assessments": memory.get("flags_and_assessments", {}),
|
| 507 |
+
"plans": memory.get("plans", {}),
|
| 508 |
+
},
|
| 509 |
+
indent=2,
|
| 510 |
+
default=str,
|
| 511 |
+
)
|
| 512 |
+
tool_results_str = (
|
| 513 |
+
"\n".join(f"Tool Result {i+1}: {r}" for i, r in enumerate(tool_results)) or "None"
|
| 514 |
+
)
|
| 515 |
+
planning_steps_str = (
|
| 516 |
+
json.dumps(planning_steps, indent=2, default=str) if planning_steps else "None"
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
prompt = (
|
| 520 |
+
f"{_PLANNER_SYSTEM_PROMPT}\n\n--- Task & State ---\n"
|
| 521 |
+
f"Task: {task}\n"
|
| 522 |
+
f"Current Memory: {memory_str}\n"
|
| 523 |
+
f"Current Planning Steps: {planning_steps_str}\n"
|
| 524 |
+
f"Previous Tool Results: {tool_results_str}\n"
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
if should_debug("agents", "PlannerAgent"):
|
| 528 |
+
_planner_logger.debug("--- Planner Iteration %d ---", iteration + 1)
|
| 529 |
+
if settings.debug_level == "full":
|
| 530 |
+
_planner_logger.debug("Raw LLM input:\n%s", prompt)
|
| 531 |
+
|
| 532 |
+
decision = self.llm.call_typed(prompt, PlannerDecision)
|
| 533 |
+
if decision is None:
|
| 534 |
+
_planner_logger.error("Planner decision parse failed at iteration %d", iteration + 1)
|
| 535 |
+
return "Planner failed: could not parse LLM decision."
|
| 536 |
+
|
| 537 |
+
if should_debug("agents", "PlannerAgent"):
|
| 538 |
+
_planner_logger.debug("Planner decision:\n%s", decision.model_dump_json(indent=2))
|
| 539 |
+
|
| 540 |
+
if decision.planning_steps:
|
| 541 |
+
planning_steps = [s.model_dump() for s in decision.planning_steps]
|
| 542 |
+
|
| 543 |
+
if not settings.debug_mode:
|
| 544 |
+
self._log_user_mode_action(decision)
|
| 545 |
+
|
| 546 |
+
if decision.action_type == "call_tool":
|
| 547 |
+
tool_results.append(f"{decision.tool_name}: {self._dispatch_tool(decision)}")
|
| 548 |
+
|
| 549 |
+
elif decision.action_type == "draft_plan":
|
| 550 |
+
if decision.drafted_plan:
|
| 551 |
+
memory.setdefault("plans", {})["drafted_plan"] = decision.drafted_plan
|
| 552 |
+
tool_results.append("Plan drafted and stored in memory")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
else:
|
| 554 |
+
tool_results.append("Drafted plan not provided")
|
| 555 |
+
|
| 556 |
+
elif decision.action_type == "provide_plan":
|
| 557 |
+
final = decision.final_plan or memory.get("plans", {}).get("drafted_plan")
|
| 558 |
+
|
| 559 |
+
# Error escape hatch (e.g. precondition not met)
|
| 560 |
+
if isinstance(final, dict) and "error" in final:
|
| 561 |
+
_planner_logger.error("📋 PLANNER AGENT ERROR: %s", final)
|
| 562 |
+
return json.dumps(final)
|
| 563 |
+
|
| 564 |
+
if not final:
|
| 565 |
+
tool_results.append("Cannot finalize: missing plan")
|
| 566 |
+
continue # let the loop try another iteration
|
| 567 |
+
|
| 568 |
+
memory.setdefault("plans", {})
|
| 569 |
+
memory["plans"]["current_plan"] = final
|
| 570 |
+
memory["plans"]["plan_timestamp"] = datetime.now().isoformat()
|
| 571 |
+
memory["plans"].pop("drafted_plan", None)
|
| 572 |
+
|
| 573 |
+
save_to_json(
|
| 574 |
+
{
|
| 575 |
+
"task": task,
|
| 576 |
+
"memory_input": relevant_memory,
|
| 577 |
+
"tool_results": tool_results,
|
| 578 |
+
"final_response": decision.model_dump(),
|
| 579 |
+
"timestamp": datetime.now().isoformat(),
|
| 580 |
+
},
|
| 581 |
+
f"planner_agent_{datetime.now().isoformat()}.json",
|
| 582 |
+
subdirectory="PlannerAgent",
|
| 583 |
+
)
|
| 584 |
+
_planner_logger.info("\n📋 PLANNER AGENT COMPLETED")
|
| 585 |
+
return json.dumps(final) if isinstance(final, dict) else str(final)
|
| 586 |
|
| 587 |
else:
|
| 588 |
+
_planner_logger.error("Unknown action_type: %s", decision.action_type)
|
| 589 |
break
|
| 590 |
|
| 591 |
+
_planner_logger.warning("📋 PLANNER AGENT Stopped (MAX ITERATIONS)")
|
| 592 |
+
return (
|
| 593 |
+
f"Planning stopped after {self.MAX_ITERATIONS} iterations "
|
| 594 |
+
f"with {len(tool_results)} actions"
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
# ------------------------------------------------------------------
|
| 598 |
+
def _dispatch_tool(self, decision: PlannerDecision) -> str:
|
| 599 |
+
tool_name = decision.tool_name
|
| 600 |
+
tool_task = decision.tool_task
|
| 601 |
+
if not tool_name or not tool_task:
|
| 602 |
+
return "Missing tool_name or tool_task"
|
| 603 |
+
if tool_name == "ComputationTool":
|
| 604 |
+
return self.computation_tool.handle_task(tool_task)
|
| 605 |
+
if tool_name == "WebSearchTool":
|
| 606 |
+
return self.web_search_tool.handle_task(tool_task)
|
| 607 |
+
if tool_name == "QuantitiesFinder":
|
| 608 |
+
return self.quantities_finder.handle_task(tool_task)
|
| 609 |
+
return f"Unknown tool: {tool_name}"
|
| 610 |
+
|
| 611 |
+
@staticmethod
|
| 612 |
+
def _log_user_mode_action(decision: PlannerDecision) -> None:
|
| 613 |
+
if decision.action_type == "call_tool":
|
| 614 |
+
_planner_logger.info(
|
| 615 |
+
"📋 Planner Agent: Using %s for '%s'",
|
| 616 |
+
decision.tool_name,
|
| 617 |
+
decision.tool_task,
|
| 618 |
+
)
|
| 619 |
+
elif decision.action_type == "draft_plan":
|
| 620 |
+
_planner_logger.info("📋 Planner Agent: Drafting plan")
|
| 621 |
+
elif decision.action_type == "provide_plan":
|
| 622 |
+
_planner_logger.info("📋 Planner Agent: Finalizing plan")
|
config.py
CHANGED
|
@@ -1,6 +1,113 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Nutrition MAS configuration.
|
| 2 |
+
|
| 3 |
+
This module exposes a Pydantic-Settings ``Settings`` singleton plus a small
|
| 4 |
+
backward-compatibility shim so existing code can still read ``config.DEBUG_MODE``,
|
| 5 |
+
``config.LOG_DIR``, etc. New code should import :func:`get_settings` directly::
|
| 6 |
+
|
| 7 |
+
from config import get_settings
|
| 8 |
+
settings = get_settings()
|
| 9 |
+
if settings.debug_mode:
|
| 10 |
+
...
|
| 11 |
+
|
| 12 |
+
Mutation must go through :func:`set_settings` (the legacy ``config.X = y`` write
|
| 13 |
+
pattern would otherwise silently shadow the Pydantic value).
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
from typing import Any, Dict, List, Optional
|
| 19 |
+
|
| 20 |
+
from pydantic import Field
|
| 21 |
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Settings(BaseSettings):
|
| 25 |
+
"""Process-wide configuration.
|
| 26 |
+
|
| 27 |
+
Values are loaded from (in order of precedence): direct ``set_settings``
|
| 28 |
+
calls, environment variables prefixed with ``NUTRITION_MAS_``, the ``.env``
|
| 29 |
+
file in the project root, and the defaults declared here.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
model_config = SettingsConfigDict(
|
| 33 |
+
env_prefix="NUTRITION_MAS_",
|
| 34 |
+
env_file=".env",
|
| 35 |
+
env_file_encoding="utf-8",
|
| 36 |
+
extra="ignore",
|
| 37 |
+
case_sensitive=False,
|
| 38 |
+
validate_assignment=True,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# --- Logging / persistence -------------------------------------------------
|
| 42 |
+
log_dir: Optional[str] = None
|
| 43 |
+
persistence_dir: Optional[str] = None
|
| 44 |
+
|
| 45 |
+
# --- Debug switches --------------------------------------------------------
|
| 46 |
+
debug_mode: bool = False
|
| 47 |
+
debug_level: str = "full" # 'full' or 'output'
|
| 48 |
+
debug_scopes: Dict[str, List[str]] = Field(
|
| 49 |
+
default_factory=lambda: {"agents": ["all"], "tools": ["all"]}
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# --- LLM / rate limiting ---------------------------------------------------
|
| 53 |
+
enable_rate_limiting: bool = True
|
| 54 |
+
gemini_api_keys: List[str] = Field(default_factory=list)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# Singleton holder. Instantiated lazily so tests can set env vars before first read.
|
| 58 |
+
_settings: Optional[Settings] = None
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_settings() -> Settings:
|
| 62 |
+
"""Return the process-wide ``Settings`` instance, creating it on first call."""
|
| 63 |
+
global _settings
|
| 64 |
+
if _settings is None:
|
| 65 |
+
_settings = Settings()
|
| 66 |
+
return _settings
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def reset_settings() -> None:
|
| 70 |
+
"""Drop the cached singleton so the next ``get_settings`` call re-reads env.
|
| 71 |
+
|
| 72 |
+
Intended for use in tests.
|
| 73 |
+
"""
|
| 74 |
+
global _settings
|
| 75 |
+
_settings = None
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def set_settings(**updates: Any) -> Settings:
|
| 79 |
+
"""Update fields on the singleton ``Settings``.
|
| 80 |
+
|
| 81 |
+
Accepts both legacy upper-case names (``DEBUG_MODE``) and Pydantic field
|
| 82 |
+
names (``debug_mode``). Returns the updated settings instance.
|
| 83 |
+
"""
|
| 84 |
+
s = get_settings()
|
| 85 |
+
for raw_key, value in updates.items():
|
| 86 |
+
attr = _LEGACY_ATTR_MAP.get(raw_key, raw_key.lower())
|
| 87 |
+
if not hasattr(s, attr):
|
| 88 |
+
raise AttributeError(f"Settings has no attribute {attr!r}")
|
| 89 |
+
setattr(s, attr, value)
|
| 90 |
+
return s
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# --- Legacy attribute proxy ----------------------------------------------------
|
| 94 |
+
# Existing code does ``import config`` then reads ``config.DEBUG_MODE`` etc.
|
| 95 |
+
# PEP 562 ``__getattr__`` lets us forward those reads to the singleton.
|
| 96 |
+
_LEGACY_ATTR_MAP: Dict[str, str] = {
|
| 97 |
+
"DEBUG_MODE": "debug_mode",
|
| 98 |
+
"DEBUG_LEVEL": "debug_level",
|
| 99 |
+
"DEBUG_SCOPES": "debug_scopes",
|
| 100 |
+
"LOG_DIR": "log_dir",
|
| 101 |
+
"PERSISTENCE_DIR": "persistence_dir",
|
| 102 |
+
"ENABLE_RATE_LIMITING": "enable_rate_limiting",
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def __getattr__(name: str) -> Any: # noqa: D401 — module-level dunder
|
| 107 |
+
"""PEP 562: forward legacy CONST-style reads to the Settings singleton."""
|
| 108 |
+
if name in _LEGACY_ATTR_MAP:
|
| 109 |
+
return getattr(get_settings(), _LEGACY_ATTR_MAP[name])
|
| 110 |
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
__all__ = ["Settings", "get_settings", "reset_settings", "set_settings"]
|
logging_setup.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Centralised logging for Nutrition MAS.
|
| 2 |
+
|
| 3 |
+
Agents and tools used to ``print`` directly to stdout. That worked in a notebook
|
| 4 |
+
but coupled the agentic system to the I/O layer. This module provides a single
|
| 5 |
+
``get_logger`` entrypoint so:
|
| 6 |
+
|
| 7 |
+
* user-mode emoji status lines flow through ``logger.info`` (visible by default),
|
| 8 |
+
* debug-mode raw LLM dumps flow through ``logger.debug`` (hidden unless
|
| 9 |
+
``settings.debug_mode`` is True),
|
| 10 |
+
* later phases can attach extra handlers (SSE event stream for the API, JSON
|
| 11 |
+
file handler for trace persistence, etc.) without touching agent code.
|
| 12 |
+
|
| 13 |
+
Idempotent: calling :func:`configure_logging` more than once is a no-op unless
|
| 14 |
+
``force=True``.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import logging
|
| 20 |
+
import sys
|
| 21 |
+
|
| 22 |
+
from config import get_settings
|
| 23 |
+
|
| 24 |
+
_BASE = "nutrition_mas"
|
| 25 |
+
_CONFIGURED = False
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def configure_logging(*, force: bool = False) -> None:
|
| 29 |
+
"""Wire up the ``nutrition_mas`` logger tree.
|
| 30 |
+
|
| 31 |
+
Reads ``settings.debug_mode`` to choose between INFO (user mode) and DEBUG.
|
| 32 |
+
Safe to call from library code; only the first call attaches a handler.
|
| 33 |
+
"""
|
| 34 |
+
global _CONFIGURED
|
| 35 |
+
if _CONFIGURED and not force:
|
| 36 |
+
return
|
| 37 |
+
|
| 38 |
+
settings = get_settings()
|
| 39 |
+
level = logging.DEBUG if settings.debug_mode else logging.INFO
|
| 40 |
+
|
| 41 |
+
handler = logging.StreamHandler(sys.stdout)
|
| 42 |
+
handler.setLevel(level)
|
| 43 |
+
handler.setFormatter(logging.Formatter("%(message)s"))
|
| 44 |
+
|
| 45 |
+
root = logging.getLogger(_BASE)
|
| 46 |
+
root.handlers = [handler]
|
| 47 |
+
root.setLevel(level)
|
| 48 |
+
root.propagate = False
|
| 49 |
+
|
| 50 |
+
_CONFIGURED = True
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_logger(name: str) -> logging.Logger:
|
| 54 |
+
"""Return a sub-logger under the ``nutrition_mas`` namespace.
|
| 55 |
+
|
| 56 |
+
Conventional names: ``agents.coach``, ``agents.medical``, ``tools.computation``,
|
| 57 |
+
``utils.api_pool``.
|
| 58 |
+
"""
|
| 59 |
+
if not _CONFIGURED:
|
| 60 |
+
configure_logging()
|
| 61 |
+
return logging.getLogger(f"{_BASE}.{name}")
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def refresh_level() -> None:
|
| 65 |
+
"""Re-read ``settings.debug_mode`` and adjust handler levels in place.
|
| 66 |
+
|
| 67 |
+
Call this after toggling debug mode at runtime.
|
| 68 |
+
"""
|
| 69 |
+
settings = get_settings()
|
| 70 |
+
level = logging.DEBUG if settings.debug_mode else logging.INFO
|
| 71 |
+
root = logging.getLogger(_BASE)
|
| 72 |
+
root.setLevel(level)
|
| 73 |
+
for handler in root.handlers:
|
| 74 |
+
handler.setLevel(level)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
__all__ = ["configure_logging", "get_logger", "refresh_level"]
|
nutritionmas.py
CHANGED
|
@@ -1,45 +1,52 @@
|
|
|
|
|
| 1 |
import os
|
| 2 |
-
import
|
| 3 |
-
from
|
| 4 |
-
|
| 5 |
-
from
|
|
|
|
| 6 |
from agents import CoachAgent, MedicalAssessmentAgent, PlannerAgent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
from workflow import setup_workflow as setup_workflow_workflow
|
| 8 |
-
from datetime import datetime
|
| 9 |
-
import random
|
| 10 |
-
import json
|
| 11 |
-
from typing import Optional, Dict, Any, List
|
| 12 |
-
from IPython.display import display, Markdown
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
| 18 |
Args:
|
| 19 |
level: 'full' (default) to show inputs and outputs, or 'output' to show only outputs.
|
| 20 |
-
scopes: Optional dict like {'agents': ['all'], 'tools': ['ComputationTool']}.
|
| 21 |
If None, defaults to all agents and tools.
|
| 22 |
"""
|
| 23 |
-
config.DEBUG_MODE = True
|
| 24 |
-
config.DEBUG_LEVEL = level
|
| 25 |
if scopes is None:
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
Set
|
| 33 |
-
|
| 34 |
-
If
|
| 35 |
-
If
|
|
|
|
| 36 |
"""
|
|
|
|
| 37 |
if log_dir is not None:
|
| 38 |
-
|
| 39 |
-
|
| 40 |
if persistence_dir is not None:
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
| 43 |
|
| 44 |
# Default model configurations (without API keys, as they will be provided by the user)
|
| 45 |
DEFAULT_MODEL_CONFIGS = {
|
|
@@ -71,6 +78,13 @@ DEFAULT_MODEL_CONFIGS = {
|
|
| 71 |
"thinking_budget": 600,
|
| 72 |
"params": {"max_tokens": 5120, "temperature": 0.3}
|
| 73 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
"user_simulator": {
|
| 75 |
"type": "gemini",
|
| 76 |
"model_name": "gemini-2.5-flash",
|
|
@@ -115,25 +129,32 @@ def create_llm_instances(api_keys: list[str], model_overrides: Optional[Dict[str
|
|
| 115 |
rate_limits = None
|
| 116 |
|
| 117 |
manager = APIPoolManager(api_keys, rate_limits)
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
for key in DEFAULT_MODEL_CONFIGS:
|
| 122 |
-
|
| 123 |
if model_overrides and key in model_overrides:
|
| 124 |
override = model_overrides[key]
|
| 125 |
if "model_name" in override:
|
| 126 |
-
|
| 127 |
if "params" in override:
|
| 128 |
-
|
| 129 |
-
model_configs[key] =
|
| 130 |
|
| 131 |
LLM_INSTANCES = {
|
| 132 |
"main": create_llm(model_configs["main"], manager),
|
| 133 |
"agents_llm": create_llm(model_configs["agents_llm"], manager),
|
| 134 |
"tools_llm": create_llm(model_configs["tools_llm"], manager),
|
| 135 |
"planner_agent": create_llm(model_configs["planner_agent"], manager),
|
| 136 |
-
"
|
|
|
|
| 137 |
}
|
| 138 |
|
| 139 |
def initialize_tools():
|
|
@@ -156,10 +177,16 @@ def initialize_agents():
|
|
| 156 |
MAIN_LLM = LLM_INSTANCES["main"]
|
| 157 |
AGENTS_LLM = LLM_INSTANCES["agents_llm"]
|
| 158 |
PLANNER_LLM = LLM_INSTANCES["planner_agent"]
|
|
|
|
| 159 |
AGENTS = {
|
| 160 |
"CoachAgent": CoachAgent(MAIN_LLM),
|
| 161 |
-
"MedicalAssessmentAgent": MedicalAssessmentAgent(
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
}
|
| 164 |
|
| 165 |
def setup_workflow():
|
|
|
|
| 1 |
+
import json
|
| 2 |
import os
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
from typing import Any, Dict, List, Optional
|
| 5 |
+
|
| 6 |
+
from IPython.display import Markdown, display
|
| 7 |
+
|
| 8 |
from agents import CoachAgent, MedicalAssessmentAgent, PlannerAgent
|
| 9 |
+
from config import set_settings
|
| 10 |
+
from logging_setup import get_logger, refresh_level
|
| 11 |
+
from state import initialize_empty_memory
|
| 12 |
+
from tools import ComputationTool, QuantitiesFinder, WebSearchTool
|
| 13 |
+
from utils import APIPoolManager, create_llm
|
| 14 |
+
from validation import ValidationAgent
|
| 15 |
from workflow import setup_workflow as setup_workflow_workflow
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
+
_logger = get_logger("nutritionmas")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def debug(level: str = "full", scopes: Optional[Dict[str, List[str]]] = None) -> None:
|
| 21 |
+
"""Enable debug mode with the given level and scopes.
|
| 22 |
+
|
| 23 |
Args:
|
| 24 |
level: 'full' (default) to show inputs and outputs, or 'output' to show only outputs.
|
| 25 |
+
scopes: Optional dict like ``{'agents': ['all'], 'tools': ['ComputationTool']}``.
|
| 26 |
If None, defaults to all agents and tools.
|
| 27 |
"""
|
|
|
|
|
|
|
| 28 |
if scopes is None:
|
| 29 |
+
scopes = {"agents": ["all"], "tools": ["all"]}
|
| 30 |
+
set_settings(debug_mode=True, debug_level=level, debug_scopes=scopes)
|
| 31 |
+
refresh_level()
|
| 32 |
|
| 33 |
+
|
| 34 |
+
def logging(log_dir: Optional[str] = None, persistence_dir: Optional[str] = None) -> None: # noqa: A001 - public name kept for backwards compat
|
| 35 |
+
"""Set directories for log files and LangGraph checkpoint persistence.
|
| 36 |
+
|
| 37 |
+
If ``log_dir`` is provided, agent/tool I/O is dumped there as JSON.
|
| 38 |
+
If ``persistence_dir`` is provided, LangGraph checkpoints are persisted to disk.
|
| 39 |
+
If neither is set, logging is disabled and persistence is in-memory.
|
| 40 |
"""
|
| 41 |
+
updates: Dict[str, Any] = {}
|
| 42 |
if log_dir is not None:
|
| 43 |
+
os.makedirs(log_dir, exist_ok=True)
|
| 44 |
+
updates["log_dir"] = log_dir
|
| 45 |
if persistence_dir is not None:
|
| 46 |
+
os.makedirs(persistence_dir, exist_ok=True)
|
| 47 |
+
updates["persistence_dir"] = persistence_dir
|
| 48 |
+
if updates:
|
| 49 |
+
set_settings(**updates)
|
| 50 |
|
| 51 |
# Default model configurations (without API keys, as they will be provided by the user)
|
| 52 |
DEFAULT_MODEL_CONFIGS = {
|
|
|
|
| 78 |
"thinking_budget": 600,
|
| 79 |
"params": {"max_tokens": 5120, "temperature": 0.3}
|
| 80 |
},
|
| 81 |
+
"validation_agent": {
|
| 82 |
+
"type": "gemini",
|
| 83 |
+
"model_name": "gemini-2.5-flash",
|
| 84 |
+
"structured_output": True,
|
| 85 |
+
"thinking_budget": 300,
|
| 86 |
+
"params": {"max_tokens": 3072, "temperature": 0.2}
|
| 87 |
+
},
|
| 88 |
"user_simulator": {
|
| 89 |
"type": "gemini",
|
| 90 |
"model_name": "gemini-2.5-flash",
|
|
|
|
| 129 |
rate_limits = None
|
| 130 |
|
| 131 |
manager = APIPoolManager(api_keys, rate_limits)
|
| 132 |
+
_logger.info(
|
| 133 |
+
"APIPoolManager initialized with %s and %d API keys.",
|
| 134 |
+
"rate limiting enabled" if enable_rate_limiting else "rate limiting disabled",
|
| 135 |
+
len(api_keys),
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Note: previously this loop used a local variable named ``config`` which
|
| 139 |
+
# shadowed the imported ``config`` module — now ``cfg`` to avoid the trap.
|
| 140 |
+
model_configs: Dict[str, Dict[str, Any]] = {}
|
| 141 |
for key in DEFAULT_MODEL_CONFIGS:
|
| 142 |
+
cfg = DEFAULT_MODEL_CONFIGS[key].copy()
|
| 143 |
if model_overrides and key in model_overrides:
|
| 144 |
override = model_overrides[key]
|
| 145 |
if "model_name" in override:
|
| 146 |
+
cfg["model_name"] = override["model_name"]
|
| 147 |
if "params" in override:
|
| 148 |
+
cfg["params"] = {**cfg.get("params", {}), **override["params"]}
|
| 149 |
+
model_configs[key] = cfg
|
| 150 |
|
| 151 |
LLM_INSTANCES = {
|
| 152 |
"main": create_llm(model_configs["main"], manager),
|
| 153 |
"agents_llm": create_llm(model_configs["agents_llm"], manager),
|
| 154 |
"tools_llm": create_llm(model_configs["tools_llm"], manager),
|
| 155 |
"planner_agent": create_llm(model_configs["planner_agent"], manager),
|
| 156 |
+
"validation_agent": create_llm(model_configs["validation_agent"], manager),
|
| 157 |
+
"user_simulator": create_llm(model_configs["user_simulator"], manager),
|
| 158 |
}
|
| 159 |
|
| 160 |
def initialize_tools():
|
|
|
|
| 177 |
MAIN_LLM = LLM_INSTANCES["main"]
|
| 178 |
AGENTS_LLM = LLM_INSTANCES["agents_llm"]
|
| 179 |
PLANNER_LLM = LLM_INSTANCES["planner_agent"]
|
| 180 |
+
VALIDATION_LLM = LLM_INSTANCES["validation_agent"]
|
| 181 |
AGENTS = {
|
| 182 |
"CoachAgent": CoachAgent(MAIN_LLM),
|
| 183 |
+
"MedicalAssessmentAgent": MedicalAssessmentAgent(
|
| 184 |
+
AGENTS_LLM, TOOLS["ComputationTool"], TOOLS["WebSearchTool"]
|
| 185 |
+
),
|
| 186 |
+
"PlannerAgent": PlannerAgent(
|
| 187 |
+
PLANNER_LLM, TOOLS["ComputationTool"], TOOLS["WebSearchTool"], TOOLS["QuantitiesFinder"]
|
| 188 |
+
),
|
| 189 |
+
"ValidationAgent": ValidationAgent(VALIDATION_LLM),
|
| 190 |
}
|
| 191 |
|
| 192 |
def setup_workflow():
|
pyproject.toml
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=68", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "nutrition-mas"
|
| 7 |
+
version = "0.2.0"
|
| 8 |
+
description = "Multi-agent system for personalised nutrition planning, built on LangGraph + Gemini."
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.10"
|
| 11 |
+
license = { text = "MIT" }
|
| 12 |
+
authors = [{ name = "Moaz Eldegwy", email = "moazeldegwy@gmail.com" }]
|
| 13 |
+
|
| 14 |
+
dependencies = [
|
| 15 |
+
"langgraph>=0.2.50,<0.3",
|
| 16 |
+
"langchain-core>=0.3.20,<0.4",
|
| 17 |
+
"google-genai>=0.3.0",
|
| 18 |
+
"pydantic>=2.9,<3",
|
| 19 |
+
"pydantic-settings>=2.6,<3",
|
| 20 |
+
"pulp>=2.9,<3",
|
| 21 |
+
"ddgs>=6.3,<7",
|
| 22 |
+
"json-repair>=0.30",
|
| 23 |
+
"python-dotenv>=1.0,<2",
|
| 24 |
+
"ipython>=8.0",
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
[project.optional-dependencies]
|
| 28 |
+
dev = [
|
| 29 |
+
"pytest>=8.0",
|
| 30 |
+
"pytest-asyncio>=0.24",
|
| 31 |
+
"pytest-cov>=5.0",
|
| 32 |
+
"ruff>=0.7",
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
[tool.setuptools]
|
| 36 |
+
py-modules = ["agents", "config", "nutritionmas", "state", "tools", "utils", "workflow", "logging_setup"]
|
| 37 |
+
|
| 38 |
+
[tool.ruff]
|
| 39 |
+
line-length = 110
|
| 40 |
+
target-version = "py310"
|
| 41 |
+
|
| 42 |
+
[tool.ruff.lint]
|
| 43 |
+
select = ["E", "F", "I", "B", "UP", "SIM"]
|
| 44 |
+
ignore = ["E501"]
|
| 45 |
+
|
| 46 |
+
[tool.pytest.ini_options]
|
| 47 |
+
testpaths = ["tests"]
|
| 48 |
+
addopts = "-ra --strict-markers"
|
| 49 |
+
markers = [
|
| 50 |
+
"integration: tests that hit a real LLM (skipped by default)",
|
| 51 |
+
"slow: long-running tests",
|
| 52 |
+
]
|
requirements.txt
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core agent framework
|
| 2 |
+
langgraph>=0.2.50,<0.3
|
| 3 |
+
langchain-core>=0.3.20,<0.4
|
| 4 |
+
|
| 5 |
+
# LLM provider (Gemini)
|
| 6 |
+
google-genai>=0.3.0
|
| 7 |
+
|
| 8 |
+
# Schemas / settings
|
| 9 |
+
pydantic>=2.9,<3
|
| 10 |
+
pydantic-settings>=2.6,<3
|
| 11 |
+
|
| 12 |
+
# Optimization (meal-quantities solver)
|
| 13 |
+
pulp>=2.9,<3
|
| 14 |
+
|
| 15 |
+
# Web search fallback
|
| 16 |
+
ddgs>=6.3,<7
|
| 17 |
+
|
| 18 |
+
# JSON repair fallback (kept until Phase 1 makes it a measured fallback)
|
| 19 |
+
json-repair>=0.30
|
| 20 |
+
|
| 21 |
+
# Env loading
|
| 22 |
+
python-dotenv>=1.0,<2
|
| 23 |
+
|
| 24 |
+
# Markdown rendering for notebook display (kept for backwards compat)
|
| 25 |
+
ipython>=8.0
|
| 26 |
+
|
| 27 |
+
# Tests / dev
|
| 28 |
+
pytest>=8.0
|
| 29 |
+
pytest-asyncio>=0.24
|
| 30 |
+
pytest-cov>=5.0
|
| 31 |
+
ruff>=0.7
|
schemas.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pydantic models for agent inputs and outputs.
|
| 2 |
+
|
| 3 |
+
This module is the contract between the LLM, the orchestration layer, and the
|
| 4 |
+
test suite. Every agent's decision now passes through one of these models — so:
|
| 5 |
+
|
| 6 |
+
* Gemini's ``response_schema`` (constrained decoding) returns guaranteed-shape
|
| 7 |
+
JSON; we no longer rely on regex / ``json_repair`` for the high-stakes path.
|
| 8 |
+
* Tests can construct decisions directly without hand-crafted JSON strings.
|
| 9 |
+
* Phase 2 can split agent loops into LangGraph nodes that pass typed objects
|
| 10 |
+
between them.
|
| 11 |
+
|
| 12 |
+
Where Gemini's schema support is fussy (e.g. discriminated unions with
|
| 13 |
+
``$ref``), we keep the outer envelope strict and leave per-action ``params``
|
| 14 |
+
as a free dict — the agent dispatcher validates it at use time.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 20 |
+
|
| 21 |
+
from pydantic import BaseModel, Field
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# ---------------------------------------------------------------------------
|
| 25 |
+
# Shared / leaf types
|
| 26 |
+
# ---------------------------------------------------------------------------
|
| 27 |
+
StepStatus = Literal["pending", "in_progress", "completed", "skipped", "failed"]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class ResponseStep(BaseModel):
|
| 31 |
+
"""A single step in the Coach's response plan."""
|
| 32 |
+
|
| 33 |
+
id: int
|
| 34 |
+
actor: str = Field(
|
| 35 |
+
description="Who executes this step. Examples: 'CoachAgent', 'MedicalAssessmentAgent', "
|
| 36 |
+
"'PlannerAgent', 'ValidationAgent', 'user'.",
|
| 37 |
+
)
|
| 38 |
+
description: str
|
| 39 |
+
prerequisites: List[str] = Field(default_factory=list)
|
| 40 |
+
status: StepStatus = "pending"
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class MacroTargets(BaseModel):
|
| 44 |
+
"""Daily macronutrient targets in grams (single integer values)."""
|
| 45 |
+
|
| 46 |
+
protein_g: int = Field(ge=0)
|
| 47 |
+
fat_g: int = Field(ge=0)
|
| 48 |
+
carbohydrates_g: int = Field(ge=0)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class Calculations(BaseModel):
|
| 52 |
+
"""Derived anthropometric + nutritional values from the assessment."""
|
| 53 |
+
|
| 54 |
+
BMI: float = Field(ge=0)
|
| 55 |
+
BMR: float = Field(ge=0)
|
| 56 |
+
TDEE: float = Field(ge=0)
|
| 57 |
+
daily_target_calories: int = Field(ge=0)
|
| 58 |
+
macro_targets: MacroTargets
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# ---------------------------------------------------------------------------
|
| 62 |
+
# Coach Agent
|
| 63 |
+
# ---------------------------------------------------------------------------
|
| 64 |
+
CoachActionType = Literal[
|
| 65 |
+
"call_agent",
|
| 66 |
+
"call_tool",
|
| 67 |
+
"ask_user",
|
| 68 |
+
"write_memory",
|
| 69 |
+
"compose_response",
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class CoachDecision(BaseModel):
|
| 74 |
+
"""Single turn of the Coach orchestrator.
|
| 75 |
+
|
| 76 |
+
Outer shape is strict; ``params`` is left as a dict because Gemini's
|
| 77 |
+
schema layer struggles with deeply discriminated unions. The dispatcher
|
| 78 |
+
in :mod:`workflow` validates ``params`` against the action type.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
observation: str
|
| 82 |
+
thought: str
|
| 83 |
+
response_steps: List[ResponseStep] = Field(default_factory=list)
|
| 84 |
+
action: CoachActionType
|
| 85 |
+
params: Dict[str, Any] = Field(
|
| 86 |
+
default_factory=dict,
|
| 87 |
+
description=(
|
| 88 |
+
"Action-specific parameters. Required keys per action: "
|
| 89 |
+
"call_agent={agent_name, task}, call_tool={tool_name, task}, "
|
| 90 |
+
"ask_user={prompt}, write_memory={partition, data}, "
|
| 91 |
+
"compose_response={text}."
|
| 92 |
+
),
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# ---------------------------------------------------------------------------
|
| 97 |
+
# Medical Assessment Agent
|
| 98 |
+
# ---------------------------------------------------------------------------
|
| 99 |
+
MedicalActionType = Literal["call_tool", "ask_user", "assessment_complete"]
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class MedicalAssessmentResult(BaseModel):
|
| 103 |
+
"""Final payload stored in ``memory.flags_and_assessments``."""
|
| 104 |
+
|
| 105 |
+
assessment_summary: str
|
| 106 |
+
flags_to_set: List[str] = Field(default_factory=list)
|
| 107 |
+
recommendations: List[str] = Field(default_factory=list)
|
| 108 |
+
requires_professional_consultation: bool = False
|
| 109 |
+
calculations: Calculations
|
| 110 |
+
evidence_sources: List[str] = Field(default_factory=list)
|
| 111 |
+
trace: str = ""
|
| 112 |
+
requires_tool_retry: bool = False
|
| 113 |
+
data_confidence: float = Field(default=1.0, ge=0.0, le=1.0)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class MedicalAssessmentDecision(BaseModel):
|
| 117 |
+
"""Per-iteration output of the Medical Assessment Agent loop."""
|
| 118 |
+
|
| 119 |
+
medical_reasoning: str
|
| 120 |
+
observation: str
|
| 121 |
+
risk_assessment_priorities: List[str] = Field(default_factory=list)
|
| 122 |
+
assessment_plan: List[ResponseStep] = Field(default_factory=list)
|
| 123 |
+
|
| 124 |
+
action_type: MedicalActionType
|
| 125 |
+
# action-specific fields (kept flat — see CoachDecision rationale)
|
| 126 |
+
tool_name: Optional[str] = None
|
| 127 |
+
tool_task: Optional[str] = None
|
| 128 |
+
fields: List[str] = Field(default_factory=list) # for ask_user
|
| 129 |
+
result: Optional[MedicalAssessmentResult] = None # for assessment_complete
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
# ---------------------------------------------------------------------------
|
| 133 |
+
# Planner Agent
|
| 134 |
+
# ---------------------------------------------------------------------------
|
| 135 |
+
PlannerActionType = Literal["call_tool", "draft_plan", "provide_plan"]
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class FoodItem(BaseModel):
|
| 139 |
+
"""A single ingredient on the plan, post-solver."""
|
| 140 |
+
|
| 141 |
+
name: str
|
| 142 |
+
grams: float = Field(ge=0)
|
| 143 |
+
calories: float = Field(ge=0)
|
| 144 |
+
protein_g: float = Field(ge=0)
|
| 145 |
+
fat_g: float = Field(ge=0)
|
| 146 |
+
carbohydrates_g: float = Field(ge=0)
|
| 147 |
+
meal_group: Optional[str] = None
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class FinalPlan(BaseModel):
|
| 151 |
+
"""The shape stored in ``memory.plans.current_plan``."""
|
| 152 |
+
|
| 153 |
+
days: List[List[FoodItem]] = Field(
|
| 154 |
+
description="One inner list per day. Most plans return a single day.",
|
| 155 |
+
)
|
| 156 |
+
daily_totals: Dict[str, float] = Field(default_factory=dict)
|
| 157 |
+
notes: str = ""
|
| 158 |
+
sources: List[str] = Field(default_factory=list)
|
| 159 |
+
trace: str = ""
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class PlannerDecision(BaseModel):
|
| 163 |
+
"""Per-iteration output of the Planner Agent loop."""
|
| 164 |
+
|
| 165 |
+
observation: str
|
| 166 |
+
thought: str
|
| 167 |
+
planning_steps: List[ResponseStep] = Field(default_factory=list)
|
| 168 |
+
|
| 169 |
+
action_type: PlannerActionType
|
| 170 |
+
tool_name: Optional[str] = None
|
| 171 |
+
tool_task: Optional[str] = None
|
| 172 |
+
drafted_plan: Optional[Dict[str, Any]] = None # free shape pre-solver
|
| 173 |
+
final_plan: Optional[Dict[str, Any]] = None # free shape until validation lands in Phase 2
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# ---------------------------------------------------------------------------
|
| 177 |
+
# Validation Agent (lands in Phase 2; defined here so Phase 1 schemas are
|
| 178 |
+
# the single source of truth)
|
| 179 |
+
# ---------------------------------------------------------------------------
|
| 180 |
+
ValidationVerdict = Literal["pass", "revise", "reject"]
|
| 181 |
+
ValidationSeverity = Literal["low", "medium", "high"]
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class ValidationIssue(BaseModel):
|
| 185 |
+
code: str = Field(description="Stable error code, e.g. 'allergy_violation'.")
|
| 186 |
+
description: str
|
| 187 |
+
severity: ValidationSeverity = "medium"
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class ValidationDecision(BaseModel):
|
| 191 |
+
verdict: ValidationVerdict
|
| 192 |
+
issues: List[ValidationIssue] = Field(default_factory=list)
|
| 193 |
+
notes: str = ""
|
| 194 |
+
requires_human_review: bool = False
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
__all__ = [
|
| 198 |
+
"Calculations",
|
| 199 |
+
"CoachActionType",
|
| 200 |
+
"CoachDecision",
|
| 201 |
+
"FinalPlan",
|
| 202 |
+
"FoodItem",
|
| 203 |
+
"MacroTargets",
|
| 204 |
+
"MedicalActionType",
|
| 205 |
+
"MedicalAssessmentDecision",
|
| 206 |
+
"MedicalAssessmentResult",
|
| 207 |
+
"PlannerActionType",
|
| 208 |
+
"PlannerDecision",
|
| 209 |
+
"ResponseStep",
|
| 210 |
+
"StepStatus",
|
| 211 |
+
"ValidationDecision",
|
| 212 |
+
"ValidationIssue",
|
| 213 |
+
"ValidationSeverity",
|
| 214 |
+
"ValidationVerdict",
|
| 215 |
+
]
|
tests/__init__.py
ADDED
|
File without changes
|
tests/conftest.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared pytest fixtures.
|
| 2 |
+
|
| 3 |
+
The mock LLM here is the workhorse for offline tests — it lets us run agents
|
| 4 |
+
end-to-end without paying for Gemini calls. Phase 1 will give us schema-typed
|
| 5 |
+
agent responses; until then, each test passes the raw JSON string the agent
|
| 6 |
+
expects to receive.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import json
|
| 12 |
+
from typing import Any, Dict, List
|
| 13 |
+
|
| 14 |
+
import pytest
|
| 15 |
+
|
| 16 |
+
from config import reset_settings, set_settings
|
| 17 |
+
from utils import APIPoolManager, LLM
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class MockLLM(LLM):
|
| 21 |
+
"""LLM stub that returns canned responses in order.
|
| 22 |
+
|
| 23 |
+
Tests construct it with a list of either:
|
| 24 |
+
|
| 25 |
+
* a JSON-string (returned as-is for the untyped __call__ path),
|
| 26 |
+
* a dict (JSON-serialised on push; validated against the requested schema
|
| 27 |
+
on call_typed),
|
| 28 |
+
* a Pydantic ``BaseModel`` instance (returned as-is from call_typed; its
|
| 29 |
+
``.model_dump_json()`` is used for __call__).
|
| 30 |
+
|
| 31 |
+
Each call pops the next item. Out-of-script calls raise so missing
|
| 32 |
+
fixtures are noisy rather than silent.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, responses: List[Any]) -> None:
|
| 36 |
+
self._responses: List[Any] = list(responses)
|
| 37 |
+
self.calls: List[str] = []
|
| 38 |
+
self.typed_calls: List[tuple[str, type]] = []
|
| 39 |
+
|
| 40 |
+
def _next(self, prompt: str) -> Any:
|
| 41 |
+
self.calls.append(prompt)
|
| 42 |
+
if not self._responses:
|
| 43 |
+
raise AssertionError(
|
| 44 |
+
f"MockLLM ran out of canned responses (call #{len(self.calls)}). "
|
| 45 |
+
f"Last prompt:\n{prompt[:300]}"
|
| 46 |
+
)
|
| 47 |
+
return self._responses.pop(0)
|
| 48 |
+
|
| 49 |
+
def __call__(self, prompt: str, **_: Any) -> List[str]:
|
| 50 |
+
item = self._next(prompt)
|
| 51 |
+
if hasattr(item, "model_dump_json"):
|
| 52 |
+
return [item.model_dump_json()]
|
| 53 |
+
if isinstance(item, dict):
|
| 54 |
+
return [json.dumps(item)]
|
| 55 |
+
return [str(item)]
|
| 56 |
+
|
| 57 |
+
def call_typed(self, prompt: str, response_model: type, **_: Any):
|
| 58 |
+
from pydantic import BaseModel
|
| 59 |
+
|
| 60 |
+
self.typed_calls.append((prompt, response_model))
|
| 61 |
+
item = self._next(prompt)
|
| 62 |
+
if isinstance(item, BaseModel):
|
| 63 |
+
return item if isinstance(item, response_model) else None
|
| 64 |
+
if isinstance(item, dict):
|
| 65 |
+
try:
|
| 66 |
+
return response_model.model_validate(item)
|
| 67 |
+
except Exception:
|
| 68 |
+
return None
|
| 69 |
+
if isinstance(item, str):
|
| 70 |
+
try:
|
| 71 |
+
return response_model.model_validate_json(item)
|
| 72 |
+
except Exception:
|
| 73 |
+
return None
|
| 74 |
+
return None
|
| 75 |
+
|
| 76 |
+
def format_prompt(self, messages: List[Dict[str, str]]) -> str:
|
| 77 |
+
return "\n".join(f"{m['role']}: {m['content']}" for m in messages)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@pytest.fixture
|
| 81 |
+
def mock_llm_factory():
|
| 82 |
+
"""Factory to build a MockLLM from a list of canned responses."""
|
| 83 |
+
return MockLLM
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
@pytest.fixture(autouse=True)
|
| 87 |
+
def fresh_settings():
|
| 88 |
+
"""Reset the Settings singleton before/after each test for isolation."""
|
| 89 |
+
reset_settings()
|
| 90 |
+
set_settings(debug_mode=False, log_dir=None, persistence_dir=None)
|
| 91 |
+
yield
|
| 92 |
+
reset_settings()
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@pytest.fixture
|
| 96 |
+
def api_pool_no_limits():
|
| 97 |
+
"""An APIPoolManager with rate limiting disabled — for unit tests that
|
| 98 |
+
don't care about throttling."""
|
| 99 |
+
return APIPoolManager(["test-key-1", "test-key-2"], rate_limits=None)
|
tests/test_api_pool.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for the APIPoolManager rate limiter."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
|
| 9 |
+
from utils import APIPoolManager
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def test_round_robin_no_limits() -> None:
|
| 13 |
+
pool = APIPoolManager(["k1", "k2", "k3"], rate_limits=None)
|
| 14 |
+
seen = [pool.get_next_key("any-model") for _ in range(6)]
|
| 15 |
+
# With no limits we should walk through all keys at least twice.
|
| 16 |
+
assert set(seen) == {"k1", "k2", "k3"}
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def test_rpm_spacing_enforced() -> None:
|
| 20 |
+
"""With RPM=60 we expect a ~1s spacing between consecutive uses of the
|
| 21 |
+
same key. Two-key pool should let us avoid the wait."""
|
| 22 |
+
pool = APIPoolManager(["k1", "k2"], rate_limits={"m": (60, 1000)})
|
| 23 |
+
|
| 24 |
+
k_a = pool.get_next_key("m")
|
| 25 |
+
pool.record_usage(k_a, "m", time.time())
|
| 26 |
+
k_b = pool.get_next_key("m")
|
| 27 |
+
assert k_a != k_b, "Round-robin should pick the other key when one is hot"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def test_rpd_exhaustion_drops_key() -> None:
|
| 31 |
+
"""A key that hits its daily limit must be removed from active pool."""
|
| 32 |
+
pool = APIPoolManager(["k1", "k2"], rate_limits={"m": (60, 2)})
|
| 33 |
+
for _ in range(2):
|
| 34 |
+
k = pool.get_next_key("m")
|
| 35 |
+
pool.record_usage(k, "m")
|
| 36 |
+
|
| 37 |
+
# By now both keys may have hit their RPD=2. Next call should still work
|
| 38 |
+
# if at least one key has capacity, else raise RuntimeError.
|
| 39 |
+
keys_left = list(pool.active_keys)
|
| 40 |
+
if not keys_left:
|
| 41 |
+
with pytest.raises(RuntimeError):
|
| 42 |
+
pool.get_next_key("m")
|
| 43 |
+
else:
|
| 44 |
+
# Drain the remaining one too.
|
| 45 |
+
for _ in range(2):
|
| 46 |
+
try:
|
| 47 |
+
k = pool.get_next_key("m")
|
| 48 |
+
pool.record_usage(k, "m")
|
| 49 |
+
except RuntimeError:
|
| 50 |
+
break
|
| 51 |
+
assert not pool.active_keys, "Both keys should be exhausted now"
|
tests/test_quantities_finder.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Smoke tests for the PuLP-backed QuantitiesFinder.
|
| 2 |
+
|
| 3 |
+
Pure deterministic tool — no LLM, no network. Should be the fastest test in
|
| 4 |
+
the suite and the one we trust most.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
|
| 11 |
+
from tools import QuantitiesFinder
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def test_basic_two_food_balance() -> None:
|
| 15 |
+
"""Two foods, simple targets — solver must find quantities within a few
|
| 16 |
+
percent of the target."""
|
| 17 |
+
qf = QuantitiesFinder()
|
| 18 |
+
payload = {
|
| 19 |
+
"foods": [
|
| 20 |
+
{
|
| 21 |
+
"name": "chicken_breast",
|
| 22 |
+
"calories": 165,
|
| 23 |
+
"protein": 31,
|
| 24 |
+
"fat": 3.6,
|
| 25 |
+
"carbohydrates": 0,
|
| 26 |
+
"estimated_g": 200,
|
| 27 |
+
},
|
| 28 |
+
{
|
| 29 |
+
"name": "rice_cooked",
|
| 30 |
+
"calories": 130,
|
| 31 |
+
"protein": 2.7,
|
| 32 |
+
"fat": 0.3,
|
| 33 |
+
"carbohydrates": 28,
|
| 34 |
+
"estimated_g": 200,
|
| 35 |
+
},
|
| 36 |
+
],
|
| 37 |
+
"targets": {"calories": 700, "protein": 65, "fat": 8, "carbohydrates": 60},
|
| 38 |
+
}
|
| 39 |
+
result = json.loads(qf.handle_task(json.dumps(payload)))
|
| 40 |
+
assert "quantities" in result and "achieved" in result, f"Bad shape: {result}"
|
| 41 |
+
|
| 42 |
+
achieved = result["achieved"]
|
| 43 |
+
# The solver minimises weighted deviation across all 4 nutrients. With only
|
| 44 |
+
# two foods (chicken and rice) it cannot hit every target tightly — it will
|
| 45 |
+
# nail fat/carbs (constrained by rice) and trade off calories/protein.
|
| 46 |
+
# We assert it lands within 20% of every target, which is the realistic
|
| 47 |
+
# feasibility envelope for a 2-food problem.
|
| 48 |
+
for nut, target in [("calories", 700), ("protein", 65), ("fat", 8), ("carbohydrates", 60)]:
|
| 49 |
+
deviation = abs(achieved[nut] - target) / target
|
| 50 |
+
assert deviation < 0.20, f"{nut} achieved={achieved[nut]} target={target} dev={deviation:.2%}"
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def test_invalid_payload_returns_error() -> None:
|
| 54 |
+
qf = QuantitiesFinder()
|
| 55 |
+
bad = {"foods": [{"name": "x"}], "targets": {}} # missing required keys
|
| 56 |
+
result = json.loads(qf.handle_task(json.dumps(bad)))
|
| 57 |
+
assert "error" in result, f"Expected an error key, got {result}"
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def test_min_max_bounds_respected() -> None:
|
| 61 |
+
qf = QuantitiesFinder()
|
| 62 |
+
payload = {
|
| 63 |
+
"foods": [
|
| 64 |
+
{
|
| 65 |
+
"name": "egg",
|
| 66 |
+
"calories": 155,
|
| 67 |
+
"protein": 13,
|
| 68 |
+
"fat": 11,
|
| 69 |
+
"carbohydrates": 1.1,
|
| 70 |
+
"estimated_g": 100,
|
| 71 |
+
"min_g": 50,
|
| 72 |
+
"max_g": 120,
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"name": "oats",
|
| 76 |
+
"calories": 389,
|
| 77 |
+
"protein": 17,
|
| 78 |
+
"fat": 7,
|
| 79 |
+
"carbohydrates": 66,
|
| 80 |
+
"estimated_g": 80,
|
| 81 |
+
"min_g": 30,
|
| 82 |
+
"max_g": 150,
|
| 83 |
+
},
|
| 84 |
+
],
|
| 85 |
+
"targets": {"calories": 500, "protein": 25, "fat": 15, "carbohydrates": 50},
|
| 86 |
+
}
|
| 87 |
+
result = json.loads(qf.handle_task(json.dumps(payload)))
|
| 88 |
+
qty = result["quantities"]
|
| 89 |
+
assert 50 <= qty["egg"] <= 120
|
| 90 |
+
assert 30 <= qty["oats"] <= 150
|
tests/test_schemas.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Validate the Pydantic schemas that anchor every agent decision in Phase 1."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
from pydantic import ValidationError
|
| 7 |
+
|
| 8 |
+
from schemas import (
|
| 9 |
+
Calculations,
|
| 10 |
+
CoachDecision,
|
| 11 |
+
FinalPlan,
|
| 12 |
+
FoodItem,
|
| 13 |
+
MacroTargets,
|
| 14 |
+
MedicalAssessmentDecision,
|
| 15 |
+
MedicalAssessmentResult,
|
| 16 |
+
PlannerDecision,
|
| 17 |
+
ResponseStep,
|
| 18 |
+
ValidationDecision,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# ---- Coach -----------------------------------------------------------------
|
| 23 |
+
def test_coach_decision_call_agent() -> None:
|
| 24 |
+
d = CoachDecision(
|
| 25 |
+
observation="user wants a plan",
|
| 26 |
+
thought="need assessment first",
|
| 27 |
+
response_steps=[
|
| 28 |
+
ResponseStep(id=1, actor="MedicalAssessmentAgent", description="assess"),
|
| 29 |
+
],
|
| 30 |
+
action="call_agent",
|
| 31 |
+
params={"agent_name": "MedicalAssessmentAgent", "task": "assess user"},
|
| 32 |
+
)
|
| 33 |
+
assert d.action == "call_agent"
|
| 34 |
+
assert d.params["agent_name"] == "MedicalAssessmentAgent"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def test_coach_decision_invalid_action_rejected() -> None:
|
| 38 |
+
with pytest.raises(ValidationError):
|
| 39 |
+
CoachDecision(
|
| 40 |
+
observation="x",
|
| 41 |
+
thought="x",
|
| 42 |
+
response_steps=[],
|
| 43 |
+
action="not_a_real_action", # type: ignore[arg-type]
|
| 44 |
+
params={},
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ---- Medical assessment ----------------------------------------------------
|
| 49 |
+
def test_medical_assessment_complete_round_trip() -> None:
|
| 50 |
+
payload = {
|
| 51 |
+
"medical_reasoning": "BMI within normal range; protein target raised for muscle gain",
|
| 52 |
+
"observation": "all fields present",
|
| 53 |
+
"risk_assessment_priorities": ["maintain micronutrient adequacy"],
|
| 54 |
+
"assessment_plan": [],
|
| 55 |
+
"action_type": "assessment_complete",
|
| 56 |
+
"result": {
|
| 57 |
+
"assessment_summary": "healthy male, hypertrophy goal",
|
| 58 |
+
"flags_to_set": [],
|
| 59 |
+
"recommendations": ["maintain hydration"],
|
| 60 |
+
"requires_professional_consultation": False,
|
| 61 |
+
"calculations": {
|
| 62 |
+
"BMI": 23.4,
|
| 63 |
+
"BMR": 1750,
|
| 64 |
+
"TDEE": 2700,
|
| 65 |
+
"daily_target_calories": 2900,
|
| 66 |
+
"macro_targets": {"protein_g": 180, "fat_g": 70, "carbohydrates_g": 360},
|
| 67 |
+
},
|
| 68 |
+
},
|
| 69 |
+
}
|
| 70 |
+
decision = MedicalAssessmentDecision.model_validate(payload)
|
| 71 |
+
assert decision.action_type == "assessment_complete"
|
| 72 |
+
assert isinstance(decision.result, MedicalAssessmentResult)
|
| 73 |
+
assert decision.result.calculations.macro_targets.protein_g == 180
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def test_calculations_negative_values_rejected() -> None:
|
| 77 |
+
with pytest.raises(ValidationError):
|
| 78 |
+
Calculations(
|
| 79 |
+
BMI=-1, # negative not allowed
|
| 80 |
+
BMR=1700,
|
| 81 |
+
TDEE=2500,
|
| 82 |
+
daily_target_calories=2200,
|
| 83 |
+
macro_targets=MacroTargets(protein_g=120, fat_g=60, carbohydrates_g=250),
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# ---- Planner ---------------------------------------------------------------
|
| 88 |
+
def test_planner_provide_plan_with_dict_final_plan() -> None:
|
| 89 |
+
decision = PlannerDecision(
|
| 90 |
+
observation="all data ready",
|
| 91 |
+
thought="returning final plan",
|
| 92 |
+
planning_steps=[],
|
| 93 |
+
action_type="provide_plan",
|
| 94 |
+
final_plan={"days": [{"breakfast": "oats"}], "trace": "Coach->Planner"},
|
| 95 |
+
)
|
| 96 |
+
assert decision.action_type == "provide_plan"
|
| 97 |
+
assert decision.final_plan is not None
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def test_food_item_strict_grams_non_negative() -> None:
|
| 101 |
+
with pytest.raises(ValidationError):
|
| 102 |
+
FoodItem(
|
| 103 |
+
name="oats",
|
| 104 |
+
grams=-10,
|
| 105 |
+
calories=389,
|
| 106 |
+
protein_g=17,
|
| 107 |
+
fat_g=7,
|
| 108 |
+
carbohydrates_g=66,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def test_final_plan_minimal() -> None:
|
| 113 |
+
plan = FinalPlan(
|
| 114 |
+
days=[
|
| 115 |
+
[
|
| 116 |
+
FoodItem(
|
| 117 |
+
name="oats",
|
| 118 |
+
grams=80,
|
| 119 |
+
calories=311,
|
| 120 |
+
protein_g=14,
|
| 121 |
+
fat_g=5,
|
| 122 |
+
carbohydrates_g=53,
|
| 123 |
+
)
|
| 124 |
+
]
|
| 125 |
+
],
|
| 126 |
+
daily_totals={"calories": 311},
|
| 127 |
+
)
|
| 128 |
+
assert plan.days[0][0].name == "oats"
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# ---- Validation (Phase 2 schemas, declared in Phase 1) ---------------------
|
| 132 |
+
def test_validation_decision_default_pass() -> None:
|
| 133 |
+
v = ValidationDecision(verdict="pass")
|
| 134 |
+
assert v.verdict == "pass"
|
| 135 |
+
assert v.issues == []
|
tests/test_settings.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Verify the new Pydantic ``Settings`` and the legacy ``config.X`` proxy."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import config
|
| 6 |
+
from config import get_settings, reset_settings, set_settings
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def test_defaults() -> None:
|
| 10 |
+
s = get_settings()
|
| 11 |
+
assert s.debug_mode is False
|
| 12 |
+
assert s.debug_level == "full"
|
| 13 |
+
assert s.enable_rate_limiting is True
|
| 14 |
+
assert s.log_dir is None
|
| 15 |
+
assert s.debug_scopes == {"agents": ["all"], "tools": ["all"]}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def test_set_settings_pydantic_names() -> None:
|
| 19 |
+
set_settings(debug_mode=True, log_dir="/tmp/x")
|
| 20 |
+
s = get_settings()
|
| 21 |
+
assert s.debug_mode is True
|
| 22 |
+
assert s.log_dir == "/tmp/x"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def test_set_settings_legacy_names() -> None:
|
| 26 |
+
set_settings(DEBUG_MODE=True, LOG_DIR="/tmp/y", DEBUG_LEVEL="output")
|
| 27 |
+
s = get_settings()
|
| 28 |
+
assert s.debug_mode is True
|
| 29 |
+
assert s.log_dir == "/tmp/y"
|
| 30 |
+
assert s.debug_level == "output"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def test_pep562_legacy_reads() -> None:
|
| 34 |
+
"""Existing code that does ``config.DEBUG_MODE`` still works."""
|
| 35 |
+
set_settings(debug_mode=True, debug_scopes={"agents": ["CoachAgent"], "tools": ["all"]})
|
| 36 |
+
assert config.DEBUG_MODE is True
|
| 37 |
+
assert config.DEBUG_SCOPES == {"agents": ["CoachAgent"], "tools": ["all"]}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def test_unknown_attr_raises() -> None:
|
| 41 |
+
try:
|
| 42 |
+
_ = config.NOT_A_THING # type: ignore[attr-defined]
|
| 43 |
+
except AttributeError as e:
|
| 44 |
+
assert "NOT_A_THING" in str(e)
|
| 45 |
+
else:
|
| 46 |
+
raise AssertionError("Expected AttributeError")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def test_reset_settings_round_trip() -> None:
|
| 50 |
+
set_settings(debug_mode=True)
|
| 51 |
+
assert get_settings().debug_mode is True
|
| 52 |
+
reset_settings()
|
| 53 |
+
assert get_settings().debug_mode is False # back to default
|
tests/test_smoke.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Top-level smoke tests: every module must import cleanly outside Colab."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def test_imports_work_outside_colab() -> None:
|
| 7 |
+
"""The Phase 0 cleanup removed ``from google.colab import userdata``;
|
| 8 |
+
confirm every module can be imported in a plain Python process."""
|
| 9 |
+
import agents # noqa: F401
|
| 10 |
+
import config # noqa: F401
|
| 11 |
+
import logging_setup # noqa: F401
|
| 12 |
+
import nutritionmas # noqa: F401
|
| 13 |
+
import state # noqa: F401
|
| 14 |
+
import tools # noqa: F401
|
| 15 |
+
import utils # noqa: F401
|
| 16 |
+
import workflow # noqa: F401
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def test_only_one_geminillm_class_in_utils() -> None:
|
| 20 |
+
"""Phase 0 deleted the duplicate ``GeminiLLM`` definition. Make sure it
|
| 21 |
+
doesn't sneak back."""
|
| 22 |
+
import inspect
|
| 23 |
+
|
| 24 |
+
import utils
|
| 25 |
+
|
| 26 |
+
geminis = [
|
| 27 |
+
cls
|
| 28 |
+
for name, cls in inspect.getmembers(utils, inspect.isclass)
|
| 29 |
+
if name == "GeminiLLM" and cls.__module__ == "utils"
|
| 30 |
+
]
|
| 31 |
+
assert len(geminis) == 1
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def test_initialize_empty_memory_shape() -> None:
|
| 35 |
+
from state import initialize_empty_memory
|
| 36 |
+
|
| 37 |
+
mem = initialize_empty_memory()
|
| 38 |
+
assert set(mem.keys()) == {"user_profile", "medical_history", "flags_and_assessments", "plans"}
|
| 39 |
+
assert all(v == {} for v in mem.values())
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def test_default_model_configs_present() -> None:
|
| 43 |
+
"""Model topology is a contract the rest of the system depends on.
|
| 44 |
+
Phase 2 adds 'validation_agent' (Gemini Flash; cheap critic loop)."""
|
| 45 |
+
from nutritionmas import DEFAULT_MODEL_CONFIGS
|
| 46 |
+
|
| 47 |
+
expected = {
|
| 48 |
+
"main",
|
| 49 |
+
"agents_llm",
|
| 50 |
+
"tools_llm",
|
| 51 |
+
"planner_agent",
|
| 52 |
+
"validation_agent",
|
| 53 |
+
"user_simulator",
|
| 54 |
+
}
|
| 55 |
+
assert set(DEFAULT_MODEL_CONFIGS.keys()) == expected
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def test_create_llm_instances_requires_keys() -> None:
|
| 59 |
+
import pytest
|
| 60 |
+
|
| 61 |
+
from nutritionmas import create_llm_instances
|
| 62 |
+
|
| 63 |
+
with pytest.raises(ValueError, match="At least one API key"):
|
| 64 |
+
create_llm_instances([])
|
tests/test_typed_agents.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""End-to-end-ish tests of the typed agent path with MockLLM.
|
| 2 |
+
|
| 3 |
+
These don't hit Gemini; they verify that an agent which received a typed
|
| 4 |
+
``CoachDecision`` / ``MedicalAssessmentDecision`` / ``PlannerDecision`` from
|
| 5 |
+
its LLM produces the expected state mutations.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from typing import Any, Dict
|
| 11 |
+
|
| 12 |
+
import pytest
|
| 13 |
+
|
| 14 |
+
from agents import CoachAgent, MedicalAssessmentAgent, PlannerAgent
|
| 15 |
+
from schemas import (
|
| 16 |
+
Calculations,
|
| 17 |
+
CoachDecision,
|
| 18 |
+
MacroTargets,
|
| 19 |
+
MedicalAssessmentDecision,
|
| 20 |
+
MedicalAssessmentResult,
|
| 21 |
+
PlannerDecision,
|
| 22 |
+
)
|
| 23 |
+
from state import initialize_empty_memory
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ---- Coach -----------------------------------------------------------------
|
| 27 |
+
def test_coach_emits_call_agent_action(mock_llm_factory) -> None:
|
| 28 |
+
canned = CoachDecision(
|
| 29 |
+
observation="needs assessment",
|
| 30 |
+
thought="route to medical",
|
| 31 |
+
response_steps=[],
|
| 32 |
+
action="call_agent",
|
| 33 |
+
params={"agent_name": "MedicalAssessmentAgent", "task": "assess"},
|
| 34 |
+
)
|
| 35 |
+
coach = CoachAgent(mock_llm_factory([canned]))
|
| 36 |
+
state: Dict[str, Any] = {
|
| 37 |
+
"memory": initialize_empty_memory(),
|
| 38 |
+
"user_question": "make me a plan",
|
| 39 |
+
"conversation_history": [{"role": "user", "content": "make me a plan"}],
|
| 40 |
+
"current_action": None,
|
| 41 |
+
"agent_result": None,
|
| 42 |
+
"num_turns": 0,
|
| 43 |
+
"max_turns": 10,
|
| 44 |
+
"previous_actions": [],
|
| 45 |
+
"response_steps": [],
|
| 46 |
+
}
|
| 47 |
+
out = coach.handle_task(state)
|
| 48 |
+
assert out["current_action"]["action"] == "call_agent"
|
| 49 |
+
assert out["current_action"]["params"]["agent_name"] == "MedicalAssessmentAgent"
|
| 50 |
+
assert out["num_turns"] == 1
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def test_coach_falls_back_when_decision_unparseable(mock_llm_factory) -> None:
|
| 54 |
+
coach = CoachAgent(mock_llm_factory(["{not even close to JSON"]))
|
| 55 |
+
state: Dict[str, Any] = {
|
| 56 |
+
"memory": initialize_empty_memory(),
|
| 57 |
+
"user_question": "anything",
|
| 58 |
+
"conversation_history": [],
|
| 59 |
+
"current_action": None,
|
| 60 |
+
"agent_result": None,
|
| 61 |
+
"num_turns": 0,
|
| 62 |
+
"max_turns": 10,
|
| 63 |
+
"previous_actions": [],
|
| 64 |
+
"response_steps": [],
|
| 65 |
+
}
|
| 66 |
+
out = coach.handle_task(state)
|
| 67 |
+
# Coach injects a compose_response with _parse_error so the workflow can short-circuit
|
| 68 |
+
assert out["current_action"]["action"] == "compose_response"
|
| 69 |
+
assert out["current_action"].get("_parse_error") is True
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# ---- Medical ---------------------------------------------------------------
|
| 73 |
+
def test_medical_assessment_complete_writes_memory(mock_llm_factory) -> None:
|
| 74 |
+
"""A single assessment_complete decision should land in memory partition."""
|
| 75 |
+
result = MedicalAssessmentResult(
|
| 76 |
+
assessment_summary="healthy adult",
|
| 77 |
+
flags_to_set=["maintenance"],
|
| 78 |
+
recommendations=["balanced diet"],
|
| 79 |
+
requires_professional_consultation=False,
|
| 80 |
+
calculations=Calculations(
|
| 81 |
+
BMI=22.0,
|
| 82 |
+
BMR=1600,
|
| 83 |
+
TDEE=2400,
|
| 84 |
+
daily_target_calories=2400,
|
| 85 |
+
macro_targets=MacroTargets(protein_g=150, fat_g=70, carbohydrates_g=300),
|
| 86 |
+
),
|
| 87 |
+
evidence_sources=["who.int"],
|
| 88 |
+
trace="Medical agent ran one iteration",
|
| 89 |
+
)
|
| 90 |
+
canned = MedicalAssessmentDecision(
|
| 91 |
+
medical_reasoning="single-shot",
|
| 92 |
+
observation="all data present",
|
| 93 |
+
risk_assessment_priorities=["maintenance"],
|
| 94 |
+
assessment_plan=[],
|
| 95 |
+
action_type="assessment_complete",
|
| 96 |
+
result=result,
|
| 97 |
+
)
|
| 98 |
+
# Need a stub for the tools (won't be called in single-iteration assessment_complete)
|
| 99 |
+
class _StubTool:
|
| 100 |
+
def handle_task(self, _: str) -> str:
|
| 101 |
+
return ""
|
| 102 |
+
|
| 103 |
+
agent = MedicalAssessmentAgent(mock_llm_factory([canned]), _StubTool(), _StubTool())
|
| 104 |
+
memory = initialize_empty_memory()
|
| 105 |
+
memory["user_profile"] = {
|
| 106 |
+
"age": 30,
|
| 107 |
+
"sex": "male",
|
| 108 |
+
"height": 180,
|
| 109 |
+
"weight": 75,
|
| 110 |
+
"activity_level": "moderate",
|
| 111 |
+
"allergies": [],
|
| 112 |
+
"medications": [],
|
| 113 |
+
}
|
| 114 |
+
summary = agent.handle_task("assess this user", memory)
|
| 115 |
+
|
| 116 |
+
assert summary == "healthy adult"
|
| 117 |
+
fa = memory["flags_and_assessments"]
|
| 118 |
+
assert fa["assessment_status"] == "assessment_complete"
|
| 119 |
+
assert fa["calculations"]["macro_targets"]["protein_g"] == 150
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def test_medical_ask_user_returns_field_list(mock_llm_factory) -> None:
|
| 123 |
+
canned = MedicalAssessmentDecision(
|
| 124 |
+
medical_reasoning="missing weight + height",
|
| 125 |
+
observation="incomplete",
|
| 126 |
+
risk_assessment_priorities=[],
|
| 127 |
+
assessment_plan=[],
|
| 128 |
+
action_type="ask_user",
|
| 129 |
+
fields=["weight", "height"],
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
class _StubTool:
|
| 133 |
+
def handle_task(self, _: str) -> str:
|
| 134 |
+
return ""
|
| 135 |
+
|
| 136 |
+
agent = MedicalAssessmentAgent(mock_llm_factory([canned]), _StubTool(), _StubTool())
|
| 137 |
+
out = agent.handle_task("assess", initialize_empty_memory())
|
| 138 |
+
assert "weight" in out and "height" in out
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# ---- Planner ---------------------------------------------------------------
|
| 142 |
+
def test_planner_provide_plan_stores_to_memory(mock_llm_factory) -> None:
|
| 143 |
+
canned = PlannerDecision(
|
| 144 |
+
observation="ready",
|
| 145 |
+
thought="finalising",
|
| 146 |
+
planning_steps=[],
|
| 147 |
+
action_type="provide_plan",
|
| 148 |
+
final_plan={
|
| 149 |
+
"days": [{"breakfast": "oats", "lunch": "chicken+rice"}],
|
| 150 |
+
"trace": "Planner one-shot",
|
| 151 |
+
},
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
class _StubTool:
|
| 155 |
+
def handle_task(self, _: str) -> str:
|
| 156 |
+
return ""
|
| 157 |
+
|
| 158 |
+
agent = PlannerAgent(mock_llm_factory([canned]), _StubTool(), _StubTool(), _StubTool())
|
| 159 |
+
memory = initialize_empty_memory()
|
| 160 |
+
memory["flags_and_assessments"] = {"assessment_status": "assessment_complete"}
|
| 161 |
+
out = agent.handle_task("make me a one-day plan", memory)
|
| 162 |
+
|
| 163 |
+
assert "trace" in out
|
| 164 |
+
assert memory["plans"]["current_plan"]["days"][0]["breakfast"] == "oats"
|
| 165 |
+
assert "plan_timestamp" in memory["plans"]
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def test_planner_error_payload_short_circuits(mock_llm_factory) -> None:
|
| 169 |
+
canned = PlannerDecision(
|
| 170 |
+
observation="missing assessment",
|
| 171 |
+
thought="precondition violated",
|
| 172 |
+
planning_steps=[],
|
| 173 |
+
action_type="provide_plan",
|
| 174 |
+
final_plan={"error": "flags_and_assessments empty; run MedicalAssessmentAgent first"},
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
class _StubTool:
|
| 178 |
+
def handle_task(self, _: str) -> str:
|
| 179 |
+
return ""
|
| 180 |
+
|
| 181 |
+
agent = PlannerAgent(mock_llm_factory([canned]), _StubTool(), _StubTool(), _StubTool())
|
| 182 |
+
out = agent.handle_task("make a plan", initialize_empty_memory())
|
| 183 |
+
assert "error" in out
|
| 184 |
+
assert "MedicalAssessmentAgent" in out
|
tests/test_validation_agent.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for the ValidationAgent critic loop.
|
| 2 |
+
|
| 3 |
+
Most of the value lives in the deterministic checks — they are pure code,
|
| 4 |
+
require no LLM, and can be exercised cheaply across edge cases.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from typing import Any, Dict
|
| 10 |
+
|
| 11 |
+
from schemas import ValidationDecision
|
| 12 |
+
from state import initialize_empty_memory
|
| 13 |
+
from validation import ValidationAgent
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# A minimal deterministic-only stub LLM for the cases that should never need
|
| 17 |
+
# the LLM layer (allergy violation -> verdict "reject" short-circuits LLM).
|
| 18 |
+
class _NeverCalledLLM:
|
| 19 |
+
def call_typed(self, *args: Any, **kwargs: Any):
|
| 20 |
+
raise AssertionError("LLM should not be called when deterministic check rejects.")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# ---------------------------------------------------------------------------
|
| 24 |
+
def _build_memory(
|
| 25 |
+
*,
|
| 26 |
+
allergies=None,
|
| 27 |
+
dislikes="",
|
| 28 |
+
target_calories=2000,
|
| 29 |
+
macros=(150, 70, 200),
|
| 30 |
+
) -> Dict[str, Any]:
|
| 31 |
+
memory = initialize_empty_memory()
|
| 32 |
+
memory["user_profile"] = {
|
| 33 |
+
"name": "Test",
|
| 34 |
+
"country": "Egypt",
|
| 35 |
+
"allergies": allergies or [],
|
| 36 |
+
"food_dislikes": dislikes,
|
| 37 |
+
}
|
| 38 |
+
memory["flags_and_assessments"] = {
|
| 39 |
+
"assessment_status": "assessment_complete",
|
| 40 |
+
"calculations": {
|
| 41 |
+
"BMI": 22,
|
| 42 |
+
"BMR": 1600,
|
| 43 |
+
"TDEE": 2000,
|
| 44 |
+
"daily_target_calories": target_calories,
|
| 45 |
+
"macro_targets": {
|
| 46 |
+
"protein_g": macros[0],
|
| 47 |
+
"fat_g": macros[1],
|
| 48 |
+
"carbohydrates_g": macros[2],
|
| 49 |
+
},
|
| 50 |
+
},
|
| 51 |
+
"flags": [],
|
| 52 |
+
"recommendations": [],
|
| 53 |
+
"requires_professional_consultation": False,
|
| 54 |
+
}
|
| 55 |
+
return memory
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _set_plan(memory: Dict[str, Any], plan: Dict[str, Any]) -> None:
|
| 59 |
+
memory.setdefault("plans", {})["current_plan"] = plan
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# ---- deterministic-only paths ----------------------------------------------
|
| 63 |
+
def test_passes_when_plan_within_tolerances(mock_llm_factory) -> None:
|
| 64 |
+
memory = _build_memory(target_calories=2000, macros=(150, 70, 200))
|
| 65 |
+
_set_plan(
|
| 66 |
+
memory,
|
| 67 |
+
{
|
| 68 |
+
"days": [
|
| 69 |
+
{
|
| 70 |
+
"name": "chicken_breast",
|
| 71 |
+
"calories": 1000,
|
| 72 |
+
"protein_g": 100,
|
| 73 |
+
"fat_g": 30,
|
| 74 |
+
"carbohydrates_g": 100,
|
| 75 |
+
},
|
| 76 |
+
{
|
| 77 |
+
"name": "rice",
|
| 78 |
+
"calories": 1000,
|
| 79 |
+
"protein_g": 50,
|
| 80 |
+
"fat_g": 40,
|
| 81 |
+
"carbohydrates_g": 100,
|
| 82 |
+
},
|
| 83 |
+
],
|
| 84 |
+
},
|
| 85 |
+
)
|
| 86 |
+
# Pre-supply a "no-issues" LLM verdict so the LLM layer is happy.
|
| 87 |
+
llm = mock_llm_factory([ValidationDecision(verdict="pass", issues=[])])
|
| 88 |
+
agent = ValidationAgent(llm)
|
| 89 |
+
out = agent.handle_task("validate plan", memory)
|
| 90 |
+
|
| 91 |
+
decision = ValidationDecision.model_validate_json(out)
|
| 92 |
+
assert decision.verdict == "pass"
|
| 93 |
+
assert decision.issues == []
|
| 94 |
+
assert memory["flags_and_assessments"]["last_validation"]["verdict"] == "pass"
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def test_allergy_violation_rejected_without_llm() -> None:
|
| 98 |
+
memory = _build_memory(allergies=["peanut"])
|
| 99 |
+
_set_plan(
|
| 100 |
+
memory,
|
| 101 |
+
{
|
| 102 |
+
"days": [
|
| 103 |
+
{
|
| 104 |
+
"name": "peanut butter sandwich",
|
| 105 |
+
"calories": 400,
|
| 106 |
+
"protein_g": 15,
|
| 107 |
+
"fat_g": 20,
|
| 108 |
+
"carbohydrates_g": 40,
|
| 109 |
+
}
|
| 110 |
+
]
|
| 111 |
+
},
|
| 112 |
+
)
|
| 113 |
+
agent = ValidationAgent(_NeverCalledLLM())
|
| 114 |
+
out = agent.handle_task("validate", memory)
|
| 115 |
+
decision = ValidationDecision.model_validate_json(out)
|
| 116 |
+
|
| 117 |
+
assert decision.verdict == "reject"
|
| 118 |
+
assert any(i.code == "allergy_violation" for i in decision.issues)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def test_calorie_deviation_triggers_revise(mock_llm_factory) -> None:
|
| 122 |
+
memory = _build_memory(target_calories=2000)
|
| 123 |
+
_set_plan(
|
| 124 |
+
memory,
|
| 125 |
+
{
|
| 126 |
+
"days": [
|
| 127 |
+
{
|
| 128 |
+
"name": "tiny salad",
|
| 129 |
+
"calories": 800, # way under 2000 target -> 60% deviation
|
| 130 |
+
"protein_g": 30,
|
| 131 |
+
"fat_g": 20,
|
| 132 |
+
"carbohydrates_g": 60,
|
| 133 |
+
}
|
| 134 |
+
]
|
| 135 |
+
},
|
| 136 |
+
)
|
| 137 |
+
llm = mock_llm_factory([ValidationDecision(verdict="pass", issues=[])])
|
| 138 |
+
agent = ValidationAgent(llm)
|
| 139 |
+
out = agent.handle_task("validate", memory)
|
| 140 |
+
decision = ValidationDecision.model_validate_json(out)
|
| 141 |
+
|
| 142 |
+
assert decision.verdict == "revise"
|
| 143 |
+
assert any(i.code == "calorie_deviation" for i in decision.issues)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def test_disliked_food_only_low_severity_still_passes(mock_llm_factory) -> None:
|
| 147 |
+
memory = _build_memory(dislikes="okra", target_calories=2000)
|
| 148 |
+
_set_plan(
|
| 149 |
+
memory,
|
| 150 |
+
{
|
| 151 |
+
"days": [
|
| 152 |
+
{
|
| 153 |
+
"name": "okra stew",
|
| 154 |
+
"calories": 2000,
|
| 155 |
+
"protein_g": 150,
|
| 156 |
+
"fat_g": 70,
|
| 157 |
+
"carbohydrates_g": 200,
|
| 158 |
+
}
|
| 159 |
+
]
|
| 160 |
+
},
|
| 161 |
+
)
|
| 162 |
+
llm = mock_llm_factory([ValidationDecision(verdict="pass", issues=[])])
|
| 163 |
+
agent = ValidationAgent(llm)
|
| 164 |
+
decision = ValidationDecision.model_validate_json(agent.handle_task("validate", memory))
|
| 165 |
+
|
| 166 |
+
# Low-severity issues alone don't escalate the verdict.
|
| 167 |
+
assert decision.verdict == "pass"
|
| 168 |
+
assert any(i.code == "disliked_food" and i.severity == "low" for i in decision.issues)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def test_missing_plan_rejected() -> None:
|
| 172 |
+
memory = _build_memory()
|
| 173 |
+
# Intentionally no current_plan set
|
| 174 |
+
agent = ValidationAgent(_NeverCalledLLM())
|
| 175 |
+
decision = ValidationDecision.model_validate_json(agent.handle_task("validate", memory))
|
| 176 |
+
assert decision.verdict == "reject"
|
| 177 |
+
assert any(i.code == "missing_plan" for i in decision.issues)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def test_requires_human_review_propagates(mock_llm_factory) -> None:
|
| 181 |
+
memory = _build_memory()
|
| 182 |
+
memory["flags_and_assessments"]["requires_professional_consultation"] = True
|
| 183 |
+
_set_plan(
|
| 184 |
+
memory,
|
| 185 |
+
{
|
| 186 |
+
"days": [
|
| 187 |
+
{
|
| 188 |
+
"name": "balanced meal",
|
| 189 |
+
"calories": 2000,
|
| 190 |
+
"protein_g": 150,
|
| 191 |
+
"fat_g": 70,
|
| 192 |
+
"carbohydrates_g": 200,
|
| 193 |
+
}
|
| 194 |
+
]
|
| 195 |
+
},
|
| 196 |
+
)
|
| 197 |
+
llm = mock_llm_factory([ValidationDecision(verdict="pass", issues=[])])
|
| 198 |
+
agent = ValidationAgent(llm)
|
| 199 |
+
decision = ValidationDecision.model_validate_json(agent.handle_task("validate", memory))
|
| 200 |
+
# Even on a clean pass, HITL flag must propagate from the assessment.
|
| 201 |
+
assert decision.requires_human_review is True
|
tools.py
CHANGED
|
@@ -1,29 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import re
|
| 2 |
import subprocess
|
| 3 |
import tempfile
|
| 4 |
-
import time
|
| 5 |
-
from time import sleep
|
| 6 |
-
import os
|
| 7 |
from datetime import datetime
|
| 8 |
-
from
|
|
|
|
|
|
|
| 9 |
from ddgs import DDGS
|
| 10 |
-
import
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
class QuantitiesFinder:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
def __init__(self):
|
| 17 |
pass
|
| 18 |
|
| 19 |
@staticmethod
|
| 20 |
-
def _round(v):
|
| 21 |
if v is None:
|
| 22 |
return 0.0
|
| 23 |
return round(float(v), 2)
|
| 24 |
|
| 25 |
@staticmethod
|
| 26 |
-
def _round_structure(obj):
|
| 27 |
if isinstance(obj, dict):
|
| 28 |
return {k: QuantitiesFinder._round_structure(v) for k, v in obj.items()}
|
| 29 |
if isinstance(obj, list):
|
|
@@ -33,303 +79,281 @@ class QuantitiesFinder:
|
|
| 33 |
return obj
|
| 34 |
|
| 35 |
def handle_task(self, task: str) -> str:
|
| 36 |
-
|
| 37 |
-
#
|
| 38 |
-
W_NUTRITION = 1.0
|
| 39 |
-
W_ESTIMATE_DEFAULT = 0.1
|
| 40 |
|
| 41 |
try:
|
| 42 |
data = json.loads(task)
|
| 43 |
foods = data["foods"]
|
| 44 |
targets = data["targets"]
|
| 45 |
|
| 46 |
-
#
|
| 47 |
required_nutrients = ["calories", "protein", "fat", "carbohydrates"]
|
| 48 |
for food in foods:
|
| 49 |
-
if not all(
|
| 50 |
-
key in food
|
| 51 |
-
for key in ["name"] + required_nutrients + ["estimated_g"]
|
| 52 |
-
):
|
| 53 |
raise ValueError(
|
| 54 |
"Each food must have name, calories, protein, fat, carbohydrates, and estimated_g."
|
| 55 |
)
|
| 56 |
-
|
| 57 |
if not all(key in targets for key in required_nutrients):
|
| 58 |
-
raise ValueError(
|
| 59 |
-
"Targets must include calories, protein, fat, carbohydrates."
|
| 60 |
-
)
|
| 61 |
|
| 62 |
prob = LpProblem("Nutrient_Optimization", LpMinimize)
|
| 63 |
|
| 64 |
-
#
|
| 65 |
g = {}
|
| 66 |
for food in foods:
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
f"g_{food_name}", lowBound=min_bound, upBound=max_bound
|
| 72 |
-
)
|
| 73 |
-
|
| 74 |
-
# --- 3. NUTRITION DEVIATIONS (Unchanged) ---
|
| 75 |
-
nutrients = required_nutrients
|
| 76 |
-
totals = {}
|
| 77 |
-
for nut in nutrients:
|
| 78 |
-
totals[nut] = lpSum(
|
| 79 |
-
(g[food["name"]] / 100) * food[nut] for food in foods
|
| 80 |
)
|
| 81 |
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
| 86 |
prob += totals[nut] - targets[nut] <= d_pos[nut]
|
| 87 |
prob += targets[nut] - totals[nut] <= d_neg[nut]
|
| 88 |
|
| 89 |
-
#
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
]
|
| 100 |
-
if not
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
prob += (
|
| 111 |
-
meal_total <= max_val,
|
| 112 |
-
f"Meal_{group_name}_max_{nut}",
|
| 113 |
-
)
|
| 114 |
-
print(f" -> Constraint: {group_name} max {nut} <= {max_val}")
|
| 115 |
-
|
| 116 |
-
min_val = constraint.get(f"min_{nut}")
|
| 117 |
-
if min_val is not None:
|
| 118 |
-
meal_total = lpSum(
|
| 119 |
-
(g[f["name"]] / 100) * f[nut] for f in group_foods
|
| 120 |
-
)
|
| 121 |
-
prob += (
|
| 122 |
-
meal_total >= min_val,
|
| 123 |
-
f"Meal_{group_name}_min_{nut}",
|
| 124 |
-
)
|
| 125 |
-
print(f" -> Constraint: {group_name} min {nut} >= {min_val}")
|
| 126 |
-
|
| 127 |
-
# --- 4. ESTIMATE DEVIATIONS (ENHANCED) ---
|
| 128 |
-
# This section now reads a per-item 'estimate_weight'
|
| 129 |
-
dev_est_pos = {
|
| 130 |
-
food["name"]: LpVariable(f"dev_est_pos_{food['name']}", lowBound=0)
|
| 131 |
-
for food in foods
|
| 132 |
-
}
|
| 133 |
-
dev_est_neg = {
|
| 134 |
-
food["name"]: LpVariable(f"dev_est_neg_{food['name']}", lowBound=0)
|
| 135 |
-
for food in foods
|
| 136 |
-
}
|
| 137 |
-
|
| 138 |
for food in foods:
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
prob += g[
|
| 142 |
-
prob +=
|
| 143 |
|
| 144 |
-
#
|
| 145 |
-
# Goal 1: (Unchanged)
|
| 146 |
nutrition_objective = lpSum(
|
| 147 |
-
(d_pos[nut] + d_neg[nut]) / max(targets[nut], 1) for nut in
|
| 148 |
)
|
| 149 |
-
|
| 150 |
-
# Goal 2: (ENHANCED)
|
| 151 |
-
# Now uses the per-item 'estimate_weight' if provided,
|
| 152 |
-
# otherwise, it falls back to the default.
|
| 153 |
estimate_objective = lpSum(
|
| 154 |
-
(
|
| 155 |
-
|
| 156 |
-
* (dev_est_pos[f["name"]] + dev_est_neg[f["name"]])
|
| 157 |
-
)
|
| 158 |
/ max(f["estimated_g"], 1)
|
| 159 |
for f in foods
|
| 160 |
if f["estimated_g"] > 0
|
| 161 |
)
|
| 162 |
-
|
| 163 |
-
# Combined objective
|
| 164 |
prob += (W_NUTRITION * nutrition_objective) + estimate_objective
|
| 165 |
|
| 166 |
-
#
|
| 167 |
prob.solve(PULP_CBC_CMD(msg=0))
|
| 168 |
-
|
| 169 |
if LpStatus[prob.status] != "Optimal":
|
| 170 |
raise ValueError(
|
| 171 |
"No optimal solution found (problem may be infeasible). Check your targets and constraints."
|
| 172 |
)
|
| 173 |
|
| 174 |
quantities = {name: value(g[name]) for name in g}
|
| 175 |
-
achieved = {nut: value(totals[nut]) for nut in
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
f"Achieved Nutrition (around): {json.dumps(result['achieved'], indent=2)}"
|
| 184 |
)
|
| 185 |
-
|
| 186 |
-
|
|
|
|
| 187 |
)
|
| 188 |
-
|
| 189 |
-
print(f"\n📊 QUANTITIES FINDER COMPLETED")
|
| 190 |
return json.dumps(result)
|
| 191 |
|
| 192 |
-
except Exception as e:
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
|
|
|
|
|
|
|
|
|
|
| 197 |
class ComputationTool:
|
| 198 |
def __init__(self, llm_instance):
|
| 199 |
self.llm = llm_instance
|
| 200 |
|
| 201 |
def handle_task(self, task_description: str) -> str:
|
| 202 |
-
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
prompt = f"{instruction}\n\nTask: {task_description}\n\nCode:"
|
| 205 |
|
| 206 |
-
if should_debug(
|
| 207 |
-
|
| 208 |
code_response = self.llm(prompt)[0]
|
| 209 |
-
if should_debug(
|
| 210 |
-
|
| 211 |
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
if code_match:
|
| 218 |
-
code_to_execute = code_match.group(1).strip()
|
| 219 |
-
# print(f"Extracted code from markdown blocks")
|
| 220 |
-
else:
|
| 221 |
-
code_to_execute = code_response.strip()
|
| 222 |
-
# print(f"Using raw response as code")
|
| 223 |
|
| 224 |
execution_result = execute_python_code_raw(code_to_execute)
|
| 225 |
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
|
|
|
|
|
|
| 236 |
return execution_result
|
| 237 |
|
| 238 |
|
|
|
|
|
|
|
|
|
|
| 239 |
class WebSearchTool:
|
| 240 |
def __init__(self, llm_instance):
|
| 241 |
self.llm = llm_instance
|
| 242 |
|
| 243 |
def handle_task(self, research_task: str) -> str:
|
| 244 |
-
|
|
|
|
| 245 |
|
| 246 |
try:
|
| 247 |
task_data = json.loads(research_task)
|
| 248 |
-
if
|
| 249 |
-
|
| 250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
else:
|
| 252 |
-
|
| 253 |
research_question = research_task
|
| 254 |
except (json.JSONDecodeError, TypeError):
|
| 255 |
-
|
| 256 |
research_question = research_task
|
| 257 |
|
| 258 |
-
query_instruction =
|
|
|
|
|
|
|
|
|
|
| 259 |
query_prompt = f"{query_instruction}\n\nQuestion: {research_question}\n\nQueries:"
|
| 260 |
|
| 261 |
-
if should_debug(
|
| 262 |
-
|
| 263 |
search_queries_text = self.llm(query_prompt)[0]
|
| 264 |
-
if should_debug(
|
| 265 |
-
|
| 266 |
|
| 267 |
-
search_queries = [q.strip() for q in search_queries_text.split(
|
| 268 |
-
|
| 269 |
-
|
|
|
|
|
|
|
| 270 |
|
| 271 |
all_raw_results = []
|
| 272 |
-
for
|
| 273 |
raw_results = search_web_raw(query, num_results=10)
|
| 274 |
-
|
| 275 |
all_raw_results.append(f"Results for '{query}':\n{raw_results}")
|
| 276 |
sleep(1)
|
| 277 |
|
| 278 |
raw_search_output = "\n\n".join(all_raw_results)
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
if should_debug('tools', 'WebSearchTool') and config.DEBUG_LEVEL == 'full':
|
| 289 |
-
print(f"Web Search Synthesis Instruction:\n{synthesis_instruction}")
|
| 290 |
synthesized_answer = self.llm(synthesis_instruction)[0]
|
| 291 |
-
if should_debug(
|
| 292 |
-
|
| 293 |
|
| 294 |
timestamp = datetime.now().isoformat()
|
| 295 |
-
save_to_json(
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
"
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
return synthesized_answer
|
| 311 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
def execute_python_code_raw(code_string: str) -> str:
|
| 313 |
-
|
| 314 |
-
|
|
|
|
|
|
|
|
|
|
| 315 |
try:
|
| 316 |
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as tmp_script:
|
| 317 |
tmp_script.write(code_string)
|
| 318 |
script_path = tmp_script.name
|
| 319 |
-
process = subprocess.run(
|
| 320 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
if process.returncode == 0:
|
| 322 |
return f"Output:\n{process.stdout if process.stdout else 'Code executed successfully.'}"
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
except Exception as e:
|
| 326 |
return f"Execution Exception: {str(e)}"
|
| 327 |
finally:
|
| 328 |
-
if os.path.exists(script_path):
|
| 329 |
os.remove(script_path)
|
| 330 |
|
|
|
|
| 331 |
def search_web_raw(query: str, num_results: int = 3) -> str:
|
| 332 |
-
|
| 333 |
max_retries = 3
|
| 334 |
for attempt in range(max_retries):
|
| 335 |
try:
|
|
@@ -337,12 +361,13 @@ def search_web_raw(query: str, num_results: int = 3) -> str:
|
|
| 337 |
results = list(ddgs.text(query, max_results=num_results, timelimit="m"))
|
| 338 |
if not results:
|
| 339 |
return "No search results found."
|
| 340 |
-
return "\n".join(
|
| 341 |
-
|
|
|
|
|
|
|
|
|
|
| 342 |
if attempt < max_retries - 1:
|
| 343 |
sleep(1)
|
| 344 |
continue
|
| 345 |
return f"Search Exception after {max_retries} attempts: {str(e)}"
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
|
|
|
| 1 |
+
"""Tools layer.
|
| 2 |
+
|
| 3 |
+
Phase 1 cleanup notes:
|
| 4 |
+
|
| 5 |
+
* Replaced ``print`` with namespaced loggers so user-mode emoji output is
|
| 6 |
+
filterable and the API/UI in Phase 7 can subscribe to it as events.
|
| 7 |
+
* Reads ``settings.debug_mode`` via :func:`config.get_settings` instead of the
|
| 8 |
+
legacy module-level globals.
|
| 9 |
+
|
| 10 |
+
The :class:`ComputationTool` still shells out to ``subprocess.run(['python', ...])``
|
| 11 |
+
- **this is a known security issue**, fixed in Phase 4 by either deterministic
|
| 12 |
+
formula functions or a ``RestrictedPython`` sandbox.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
import re
|
| 20 |
import subprocess
|
| 21 |
import tempfile
|
|
|
|
|
|
|
|
|
|
| 22 |
from datetime import datetime
|
| 23 |
+
from time import sleep
|
| 24 |
+
from typing import Any
|
| 25 |
+
|
| 26 |
from ddgs import DDGS
|
| 27 |
+
from pulp import (
|
| 28 |
+
LpMinimize,
|
| 29 |
+
LpProblem,
|
| 30 |
+
LpStatus,
|
| 31 |
+
LpVariable,
|
| 32 |
+
PULP_CBC_CMD,
|
| 33 |
+
lpSum,
|
| 34 |
+
value,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
from config import get_settings
|
| 38 |
+
from logging_setup import get_logger
|
| 39 |
+
from utils import save_to_json, should_debug
|
| 40 |
+
|
| 41 |
+
_qf_logger = get_logger("tools.quantities_finder")
|
| 42 |
+
_comp_logger = get_logger("tools.computation")
|
| 43 |
+
_web_logger = get_logger("tools.web_search")
|
| 44 |
|
| 45 |
+
|
| 46 |
+
# ---------------------------------------------------------------------------
|
| 47 |
+
# QuantitiesFinder (PuLP LP solver)
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
class QuantitiesFinder:
|
| 50 |
+
"""Linear-program solver that turns an LLM-drafted plan into precise grams.
|
| 51 |
+
|
| 52 |
+
The schema is:
|
| 53 |
+
|
| 54 |
+
{
|
| 55 |
+
"foods": [{name, calories, protein, fat, carbohydrates,
|
| 56 |
+
estimated_g, [min_g, max_g, meal_group, estimate_weight]}, ...],
|
| 57 |
+
"targets": {calories, protein, fat, carbohydrates},
|
| 58 |
+
"meal_constraints": [{group_name, [max_<nut>], [min_<nut>]}, ...] # optional
|
| 59 |
+
}
|
| 60 |
+
"""
|
| 61 |
|
| 62 |
+
def __init__(self) -> None:
|
| 63 |
pass
|
| 64 |
|
| 65 |
@staticmethod
|
| 66 |
+
def _round(v: Any) -> float:
|
| 67 |
if v is None:
|
| 68 |
return 0.0
|
| 69 |
return round(float(v), 2)
|
| 70 |
|
| 71 |
@staticmethod
|
| 72 |
+
def _round_structure(obj: Any) -> Any:
|
| 73 |
if isinstance(obj, dict):
|
| 74 |
return {k: QuantitiesFinder._round_structure(v) for k, v in obj.items()}
|
| 75 |
if isinstance(obj, list):
|
|
|
|
| 79 |
return obj
|
| 80 |
|
| 81 |
def handle_task(self, task: str) -> str:
|
| 82 |
+
_qf_logger.info("\n📊 ENHANCED QUANTITIES FINDER (V3) TOOL STARTED")
|
| 83 |
+
# Priority 1: hit daily totals; Priority 2: stay close to per-item estimates.
|
| 84 |
+
W_NUTRITION = 1.0
|
| 85 |
+
W_ESTIMATE_DEFAULT = 0.1
|
| 86 |
|
| 87 |
try:
|
| 88 |
data = json.loads(task)
|
| 89 |
foods = data["foods"]
|
| 90 |
targets = data["targets"]
|
| 91 |
|
| 92 |
+
# 1. Validation
|
| 93 |
required_nutrients = ["calories", "protein", "fat", "carbohydrates"]
|
| 94 |
for food in foods:
|
| 95 |
+
if not all(key in food for key in ["name"] + required_nutrients + ["estimated_g"]):
|
|
|
|
|
|
|
|
|
|
| 96 |
raise ValueError(
|
| 97 |
"Each food must have name, calories, protein, fat, carbohydrates, and estimated_g."
|
| 98 |
)
|
|
|
|
| 99 |
if not all(key in targets for key in required_nutrients):
|
| 100 |
+
raise ValueError("Targets must include calories, protein, fat, carbohydrates.")
|
|
|
|
|
|
|
| 101 |
|
| 102 |
prob = LpProblem("Nutrient_Optimization", LpMinimize)
|
| 103 |
|
| 104 |
+
# 2. Variables
|
| 105 |
g = {}
|
| 106 |
for food in foods:
|
| 107 |
+
g[food["name"]] = LpVariable(
|
| 108 |
+
f"g_{food['name']}",
|
| 109 |
+
lowBound=food.get("min_g", 0),
|
| 110 |
+
upBound=food.get("max_g"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
)
|
| 112 |
|
| 113 |
+
# 3. Nutrition deviations
|
| 114 |
+
totals = {
|
| 115 |
+
nut: lpSum((g[f["name"]] / 100) * f[nut] for f in foods) for nut in required_nutrients
|
| 116 |
+
}
|
| 117 |
+
d_pos = {nut: LpVariable(f"d_pos_{nut}", lowBound=0) for nut in required_nutrients}
|
| 118 |
+
d_neg = {nut: LpVariable(f"d_neg_{nut}", lowBound=0) for nut in required_nutrients}
|
| 119 |
+
for nut in required_nutrients:
|
| 120 |
prob += totals[nut] - targets[nut] <= d_pos[nut]
|
| 121 |
prob += targets[nut] - totals[nut] <= d_neg[nut]
|
| 122 |
|
| 123 |
+
# 3.5 Optional meal-level constraints
|
| 124 |
+
for constraint in data.get("meal_constraints", []) or []:
|
| 125 |
+
group_name = constraint.get("group_name")
|
| 126 |
+
if not group_name:
|
| 127 |
+
continue
|
| 128 |
+
group_foods = [f for f in foods if f.get("meal_group") == group_name]
|
| 129 |
+
if not group_foods:
|
| 130 |
+
_qf_logger.warning("No foods found for meal_group '%s'", group_name)
|
| 131 |
+
continue
|
| 132 |
+
for nut in required_nutrients:
|
| 133 |
+
meal_total = lpSum((g[f["name"]] / 100) * f[nut] for f in group_foods)
|
| 134 |
+
if (max_val := constraint.get(f"max_{nut}")) is not None:
|
| 135 |
+
prob += (meal_total <= max_val, f"Meal_{group_name}_max_{nut}")
|
| 136 |
+
_qf_logger.debug("Constraint: %s max %s <= %s", group_name, nut, max_val)
|
| 137 |
+
if (min_val := constraint.get(f"min_{nut}")) is not None:
|
| 138 |
+
prob += (meal_total >= min_val, f"Meal_{group_name}_min_{nut}")
|
| 139 |
+
_qf_logger.debug("Constraint: %s min %s >= %s", group_name, nut, min_val)
|
| 140 |
+
|
| 141 |
+
# 4. Estimate deviations (per-item soft anchor)
|
| 142 |
+
dev_est_pos = {f["name"]: LpVariable(f"dev_est_pos_{f['name']}", lowBound=0) for f in foods}
|
| 143 |
+
dev_est_neg = {f["name"]: LpVariable(f"dev_est_neg_{f['name']}", lowBound=0) for f in foods}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
for food in foods:
|
| 145 |
+
name = food["name"]
|
| 146 |
+
est = food["estimated_g"]
|
| 147 |
+
prob += g[name] - est <= dev_est_pos[name]
|
| 148 |
+
prob += est - g[name] <= dev_est_neg[name]
|
| 149 |
|
| 150 |
+
# 5. Objective
|
|
|
|
| 151 |
nutrition_objective = lpSum(
|
| 152 |
+
(d_pos[nut] + d_neg[nut]) / max(targets[nut], 1) for nut in required_nutrients
|
| 153 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
estimate_objective = lpSum(
|
| 155 |
+
f.get("estimate_weight", W_ESTIMATE_DEFAULT)
|
| 156 |
+
* (dev_est_pos[f["name"]] + dev_est_neg[f["name"]])
|
|
|
|
|
|
|
| 157 |
/ max(f["estimated_g"], 1)
|
| 158 |
for f in foods
|
| 159 |
if f["estimated_g"] > 0
|
| 160 |
)
|
|
|
|
|
|
|
| 161 |
prob += (W_NUTRITION * nutrition_objective) + estimate_objective
|
| 162 |
|
| 163 |
+
# 6. Solve
|
| 164 |
prob.solve(PULP_CBC_CMD(msg=0))
|
|
|
|
| 165 |
if LpStatus[prob.status] != "Optimal":
|
| 166 |
raise ValueError(
|
| 167 |
"No optimal solution found (problem may be infeasible). Check your targets and constraints."
|
| 168 |
)
|
| 169 |
|
| 170 |
quantities = {name: value(g[name]) for name in g}
|
| 171 |
+
achieved = {nut: value(totals[nut]) for nut in required_nutrients}
|
| 172 |
+
result = QuantitiesFinder._round_structure({"quantities": quantities, "achieved": achieved})
|
| 173 |
+
|
| 174 |
+
_qf_logger.info("Solution Status: %s", LpStatus[prob.status])
|
| 175 |
+
_qf_logger.info("Quantities (g): %s", json.dumps(result["quantities"], indent=2))
|
| 176 |
+
_qf_logger.info(
|
| 177 |
+
"Achieved Nutrition (around): %s",
|
| 178 |
+
json.dumps(result["achieved"], indent=2),
|
|
|
|
| 179 |
)
|
| 180 |
+
_qf_logger.info(
|
| 181 |
+
"Target Nutrition: %s",
|
| 182 |
+
json.dumps(QuantitiesFinder._round_structure(targets), indent=2),
|
| 183 |
)
|
| 184 |
+
_qf_logger.info("\n📊 QUANTITIES FINDER COMPLETED")
|
|
|
|
| 185 |
return json.dumps(result)
|
| 186 |
|
| 187 |
+
except Exception as e: # noqa: BLE001
|
| 188 |
+
_qf_logger.error("QuantitiesFinder Error: %s", str(e))
|
| 189 |
+
return json.dumps({"error": str(e)})
|
| 190 |
+
|
| 191 |
|
| 192 |
+
# ---------------------------------------------------------------------------
|
| 193 |
+
# ComputationTool (LLM-generated Python; ⚠ replace in Phase 4)
|
| 194 |
+
# ---------------------------------------------------------------------------
|
| 195 |
class ComputationTool:
|
| 196 |
def __init__(self, llm_instance):
|
| 197 |
self.llm = llm_instance
|
| 198 |
|
| 199 |
def handle_task(self, task_description: str) -> str:
|
| 200 |
+
_comp_logger.info("\n🤖 COMPUTATION TOOL STARTED")
|
| 201 |
+
settings = get_settings()
|
| 202 |
+
instruction = (
|
| 203 |
+
"You are a Python coding assistant. Generate only the Python code required "
|
| 204 |
+
"to perform the given task. Do not forget to print the result. Do not add explanations."
|
| 205 |
+
)
|
| 206 |
prompt = f"{instruction}\n\nTask: {task_description}\n\nCode:"
|
| 207 |
|
| 208 |
+
if should_debug("tools", "ComputationTool") and settings.debug_level == "full":
|
| 209 |
+
_comp_logger.debug("Computation Tool Prompt:\n%s", prompt)
|
| 210 |
code_response = self.llm(prompt)[0]
|
| 211 |
+
if should_debug("tools", "ComputationTool"):
|
| 212 |
+
_comp_logger.debug("Computation Tool Response:\n%s", code_response)
|
| 213 |
|
| 214 |
+
match = re.search(r"```python\n(.*?)\n```", code_response, re.DOTALL)
|
| 215 |
+
if not match:
|
| 216 |
+
match = re.search(r"```\n(.*?)\n```", code_response, re.DOTALL)
|
| 217 |
+
code_to_execute = match.group(1).strip() if match else code_response.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
|
| 219 |
execution_result = execute_python_code_raw(code_to_execute)
|
| 220 |
|
| 221 |
+
save_to_json(
|
| 222 |
+
{
|
| 223 |
+
"instruction": instruction,
|
| 224 |
+
"input": task_description,
|
| 225 |
+
"output": code_to_execute,
|
| 226 |
+
"execution_result": execution_result,
|
| 227 |
+
"timestamp": datetime.now().isoformat(),
|
| 228 |
+
},
|
| 229 |
+
f"computation_tool_{datetime.now().isoformat()}.json",
|
| 230 |
+
subdirectory="ComputationTool",
|
| 231 |
+
)
|
| 232 |
+
_comp_logger.info("🤖 COMPUTATION COMPLETED\n%s", execution_result)
|
| 233 |
return execution_result
|
| 234 |
|
| 235 |
|
| 236 |
+
# ---------------------------------------------------------------------------
|
| 237 |
+
# WebSearchTool (DuckDuckGo + LLM synthesis)
|
| 238 |
+
# ---------------------------------------------------------------------------
|
| 239 |
class WebSearchTool:
|
| 240 |
def __init__(self, llm_instance):
|
| 241 |
self.llm = llm_instance
|
| 242 |
|
| 243 |
def handle_task(self, research_task: str) -> str:
|
| 244 |
+
_web_logger.info("\n🌐 WEB SEARCH TOOL STARTED")
|
| 245 |
+
settings = get_settings()
|
| 246 |
|
| 247 |
try:
|
| 248 |
task_data = json.loads(research_task)
|
| 249 |
+
if (
|
| 250 |
+
isinstance(task_data, dict)
|
| 251 |
+
and "queries" in task_data
|
| 252 |
+
and isinstance(task_data["queries"], list)
|
| 253 |
+
):
|
| 254 |
+
_web_logger.info("JSON query list detected. Converting to single text task.")
|
| 255 |
+
research_question = " ".join(task_data["queries"])
|
| 256 |
else:
|
| 257 |
+
_web_logger.info("Single question mode (non-query JSON). Generating queries.")
|
| 258 |
research_question = research_task
|
| 259 |
except (json.JSONDecodeError, TypeError):
|
| 260 |
+
_web_logger.info("Single question mode (plain text). Generating queries.")
|
| 261 |
research_question = research_task
|
| 262 |
|
| 263 |
+
query_instruction = (
|
| 264 |
+
"Formulate concise search queries for DuckDuckGo based on the given question. "
|
| 265 |
+
"Output only the queries, one per line."
|
| 266 |
+
)
|
| 267 |
query_prompt = f"{query_instruction}\n\nQuestion: {research_question}\n\nQueries:"
|
| 268 |
|
| 269 |
+
if should_debug("tools", "WebSearchTool") and settings.debug_level == "full":
|
| 270 |
+
_web_logger.debug("Web Search Query Prompt:\n%s", query_prompt)
|
| 271 |
search_queries_text = self.llm(query_prompt)[0]
|
| 272 |
+
if should_debug("tools", "WebSearchTool"):
|
| 273 |
+
_web_logger.debug("Web Search Query Response:\n%s", search_queries_text)
|
| 274 |
|
| 275 |
+
search_queries = [q.strip() for q in search_queries_text.split("\n") if q.strip()] or [
|
| 276 |
+
research_question
|
| 277 |
+
]
|
| 278 |
+
if should_debug("tools", "WebSearchTool"):
|
| 279 |
+
_web_logger.debug("Parsed queries: %s", search_queries)
|
| 280 |
|
| 281 |
all_raw_results = []
|
| 282 |
+
for query in search_queries:
|
| 283 |
raw_results = search_web_raw(query, num_results=10)
|
| 284 |
+
_web_logger.info("Search results: %s...", raw_results[:200])
|
| 285 |
all_raw_results.append(f"Results for '{query}':\n{raw_results}")
|
| 286 |
sleep(1)
|
| 287 |
|
| 288 |
raw_search_output = "\n\n".join(all_raw_results)
|
| 289 |
+
synthesis_instruction = (
|
| 290 |
+
f"Synthesize a concise answer to:\n"
|
| 291 |
+
f"Question: {research_question}\n"
|
| 292 |
+
f"Based on:\n---\n{raw_search_output}\n---\n"
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
if should_debug("tools", "WebSearchTool") and settings.debug_level == "full":
|
| 296 |
+
_web_logger.debug("Web Search Synthesis Instruction:\n%s", synthesis_instruction)
|
|
|
|
|
|
|
|
|
|
| 297 |
synthesized_answer = self.llm(synthesis_instruction)[0]
|
| 298 |
+
if should_debug("tools", "WebSearchTool"):
|
| 299 |
+
_web_logger.debug("Web Search Synthesis Response:\n%s", synthesized_answer)
|
| 300 |
|
| 301 |
timestamp = datetime.now().isoformat()
|
| 302 |
+
save_to_json(
|
| 303 |
+
{
|
| 304 |
+
"instruction": query_instruction,
|
| 305 |
+
"input": research_question,
|
| 306 |
+
"output": search_queries_text,
|
| 307 |
+
"timestamp": timestamp,
|
| 308 |
+
},
|
| 309 |
+
f"web_search_tool_queries_{timestamp}.json",
|
| 310 |
+
subdirectory="WebSearchTool",
|
| 311 |
+
)
|
| 312 |
+
save_to_json(
|
| 313 |
+
{
|
| 314 |
+
"instruction": synthesis_instruction,
|
| 315 |
+
"input": raw_search_output,
|
| 316 |
+
"output": synthesized_answer,
|
| 317 |
+
"timestamp": timestamp,
|
| 318 |
+
},
|
| 319 |
+
f"web_search_tool_synthesis_{timestamp}.json",
|
| 320 |
+
subdirectory="WebSearchTool",
|
| 321 |
+
)
|
| 322 |
+
_web_logger.info("\n🌐 WEB SEARCH TOOL Result:\n%s\n", synthesized_answer)
|
| 323 |
return synthesized_answer
|
| 324 |
|
| 325 |
+
|
| 326 |
+
# ---------------------------------------------------------------------------
|
| 327 |
+
# Helpers
|
| 328 |
+
# ---------------------------------------------------------------------------
|
| 329 |
def execute_python_code_raw(code_string: str) -> str:
|
| 330 |
+
"""⚠ Phase 4 will replace this with a sandbox or deterministic functions."""
|
| 331 |
+
settings = get_settings()
|
| 332 |
+
if should_debug("tools", "ComputationTool") and settings.debug_level == "full":
|
| 333 |
+
_comp_logger.debug("🐍 Executing Code (raw):\n%s", code_string)
|
| 334 |
+
script_path = ""
|
| 335 |
try:
|
| 336 |
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as tmp_script:
|
| 337 |
tmp_script.write(code_string)
|
| 338 |
script_path = tmp_script.name
|
| 339 |
+
process = subprocess.run(
|
| 340 |
+
["python", script_path],
|
| 341 |
+
capture_output=True,
|
| 342 |
+
text=True,
|
| 343 |
+
timeout=30,
|
| 344 |
+
)
|
| 345 |
if process.returncode == 0:
|
| 346 |
return f"Output:\n{process.stdout if process.stdout else 'Code executed successfully.'}"
|
| 347 |
+
return f"Error:\n{process.stderr}"
|
| 348 |
+
except Exception as e: # noqa: BLE001
|
|
|
|
| 349 |
return f"Execution Exception: {str(e)}"
|
| 350 |
finally:
|
| 351 |
+
if script_path and os.path.exists(script_path):
|
| 352 |
os.remove(script_path)
|
| 353 |
|
| 354 |
+
|
| 355 |
def search_web_raw(query: str, num_results: int = 3) -> str:
|
| 356 |
+
_web_logger.info("🌐 Searching Web (raw) for: %s", query)
|
| 357 |
max_retries = 3
|
| 358 |
for attempt in range(max_retries):
|
| 359 |
try:
|
|
|
|
| 361 |
results = list(ddgs.text(query, max_results=num_results, timelimit="m"))
|
| 362 |
if not results:
|
| 363 |
return "No search results found."
|
| 364 |
+
return "\n".join(
|
| 365 |
+
f"Title: {r.get('title')}\nURL: {r.get('href')}\nSnippet: {r.get('body')}"
|
| 366 |
+
for r in results
|
| 367 |
+
)
|
| 368 |
+
except Exception as e: # noqa: BLE001
|
| 369 |
if attempt < max_retries - 1:
|
| 370 |
sleep(1)
|
| 371 |
continue
|
| 372 |
return f"Search Exception after {max_retries} attempts: {str(e)}"
|
| 373 |
+
return "Search Exception: exhausted retries"
|
|
|
|
|
|
utils.py
CHANGED
|
@@ -1,139 +1,134 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import json
|
| 3 |
-
import
|
| 4 |
-
import config
|
| 5 |
-
from typing import TypedDict, List, Optional, Dict, Any, Tuple
|
| 6 |
import pickle
|
| 7 |
-
|
| 8 |
-
from google import genai
|
| 9 |
-
from google.genai import types
|
| 10 |
-
from datetime import datetime, date
|
| 11 |
import time
|
| 12 |
-
from google.colab import userdata
|
| 13 |
-
from json_repair import repair_json
|
| 14 |
from collections import deque
|
|
|
|
|
|
|
| 15 |
from threading import Lock
|
|
|
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
|
|
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
return False
|
| 25 |
-
if scope not in config.DEBUG_SCOPES:
|
| 26 |
-
return False
|
| 27 |
-
scopes_list = config.DEBUG_SCOPES[scope]
|
| 28 |
-
return 'all' in scopes_list or name in scopes_list
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
return
|
| 34 |
-
if subdirectory:
|
| 35 |
-
log_dir = os.path.join(config.LOG_DIR, subdirectory)
|
| 36 |
-
else:
|
| 37 |
-
log_dir = config.LOG_DIR
|
| 38 |
-
os.makedirs(log_dir, exist_ok=True)
|
| 39 |
-
filepath = os.path.join(log_dir, filename)
|
| 40 |
-
with open(filepath, 'w') as f:
|
| 41 |
-
json.dump(data, f, indent=2)
|
| 42 |
|
| 43 |
-
|
| 44 |
-
def __call__(self, prompt: str, **kwargs) -> list[str]:
|
| 45 |
-
pass
|
| 46 |
|
| 47 |
-
def format_prompt(self, messages: List[Dict[str, str]]) -> str:
|
| 48 |
-
pass
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
self.thinking_budget = thinking_budget
|
| 55 |
-
self.kwargs = kwargs
|
| 56 |
-
self.manager = manager
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
-
try:
|
| 68 |
-
client = genai.Client(api_key=api_key)
|
| 69 |
|
| 70 |
-
|
| 71 |
-
types.Content(
|
| 72 |
-
role="user",
|
| 73 |
-
parts=[types.Part.from_text(text=prompt)],
|
| 74 |
-
)
|
| 75 |
-
]
|
| 76 |
-
|
| 77 |
-
if self.structured_output:
|
| 78 |
-
generate_content_config = types.GenerateContentConfig(
|
| 79 |
-
thinking_config=types.ThinkingConfig(
|
| 80 |
-
thinking_budget=self.thinking_budget,
|
| 81 |
-
),
|
| 82 |
-
response_mime_type="application/json",
|
| 83 |
-
max_output_tokens=merged_kwargs.get("max_tokens", 5120),
|
| 84 |
-
temperature=merged_kwargs.get("temperature", 0.3),
|
| 85 |
-
)
|
| 86 |
-
else:
|
| 87 |
-
generate_content_config = types.GenerateContentConfig(
|
| 88 |
-
thinking_config=types.ThinkingConfig(
|
| 89 |
-
thinking_budget=self.thinking_budget,
|
| 90 |
-
),
|
| 91 |
-
response_mime_type="text/plain",
|
| 92 |
-
max_output_tokens=merged_kwargs.get("max_tokens", 5120),
|
| 93 |
-
temperature=merged_kwargs.get("temperature", 0.3),
|
| 94 |
-
)
|
| 95 |
|
| 96 |
-
response_text = ""
|
| 97 |
-
start_time = time.time()
|
| 98 |
-
for chunk in client.models.generate_content_stream(
|
| 99 |
-
model=self.model_name,
|
| 100 |
-
contents=contents,
|
| 101 |
-
config=generate_content_config,
|
| 102 |
-
):
|
| 103 |
-
if chunk.text:
|
| 104 |
-
response_text += chunk.text
|
| 105 |
-
|
| 106 |
-
# Record usage only on successful completion
|
| 107 |
-
completion_time = time.time()
|
| 108 |
-
if self.manager.rate_limits is not None:
|
| 109 |
-
self.manager.record_usage(api_key, self.model_name, completion_time)
|
| 110 |
|
| 111 |
-
|
| 112 |
-
|
|
|
|
| 113 |
|
| 114 |
-
return [response_text.strip()]
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
-
def format_prompt(self, messages: List[Dict[str, str]]) -> str:
|
| 122 |
-
prompt = ""
|
| 123 |
-
for msg in messages:
|
| 124 |
-
if msg["role"] == "system":
|
| 125 |
-
prompt += f"System: {msg['content']}\n"
|
| 126 |
-
elif msg["role"] == "user":
|
| 127 |
-
prompt += f"User: {msg['content']}\n"
|
| 128 |
-
elif msg["role"] == "assistant":
|
| 129 |
-
prompt += f"Assistant: {msg['content']}\n"
|
| 130 |
-
prompt += "Assistant:"
|
| 131 |
-
return prompt
|
| 132 |
|
| 133 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
class GeminiLLM(LLM):
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
self.model_name = model_name
|
| 138 |
self.structured_output = structured_output
|
| 139 |
self.thinking_budget = thinking_budget
|
|
@@ -141,79 +136,167 @@ class GeminiLLM(LLM):
|
|
| 141 |
self.manager = manager
|
| 142 |
self.is_gemma = "gemma" in model_name.lower()
|
| 143 |
if self.is_gemma:
|
|
|
|
| 144 |
self.structured_output = False
|
| 145 |
self.thinking_budget = None
|
| 146 |
-
# No self.client or self.api_key; created dynamically
|
| 147 |
|
| 148 |
-
def __call__(self, prompt: str, **kwargs) -> list[str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
if self.manager is None:
|
| 150 |
raise ValueError("APIPoolManager must be provided for rate limiting.")
|
| 151 |
|
| 152 |
merged_kwargs = {**self.kwargs, **kwargs}
|
| 153 |
-
|
| 154 |
-
# Get next available API key
|
| 155 |
api_key = self.manager.get_next_key(self.model_name)
|
| 156 |
|
| 157 |
try:
|
| 158 |
client = genai.Client(api_key=api_key)
|
|
|
|
|
|
|
| 159 |
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
generate_content_config = types.GenerateContentConfig(
|
| 169 |
-
response_mime_type="text/plain",
|
| 170 |
-
max_output_tokens=merged_kwargs.get("max_tokens", 5120),
|
| 171 |
-
temperature=merged_kwargs.get("temperature", 0.3),
|
| 172 |
)
|
|
|
|
|
|
|
| 173 |
else:
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
generate_content_config = types.GenerateContentConfig(
|
| 185 |
-
thinking_config=types.ThinkingConfig(
|
| 186 |
-
thinking_budget=self.thinking_budget,
|
| 187 |
-
),
|
| 188 |
-
response_mime_type="text/plain",
|
| 189 |
-
max_output_tokens=merged_kwargs.get("max_tokens", 5120),
|
| 190 |
-
temperature=merged_kwargs.get("temperature", 0.3),
|
| 191 |
-
)
|
| 192 |
-
|
| 193 |
-
response_text = ""
|
| 194 |
-
start_time = time.time()
|
| 195 |
-
for chunk in client.models.generate_content_stream(
|
| 196 |
-
model=self.model_name,
|
| 197 |
-
contents=contents,
|
| 198 |
-
config=generate_content_config,
|
| 199 |
-
):
|
| 200 |
-
if chunk.text:
|
| 201 |
-
response_text += chunk.text
|
| 202 |
-
|
| 203 |
-
# Record usage only on successful completion
|
| 204 |
completion_time = time.time()
|
| 205 |
if self.manager.rate_limits is not None:
|
| 206 |
self.manager.record_usage(api_key, self.model_name, completion_time)
|
| 207 |
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
-
|
| 214 |
-
#
|
| 215 |
-
|
| 216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
def format_prompt(self, messages: List[Dict[str, str]]) -> str:
|
| 219 |
prompt = ""
|
|
@@ -228,12 +311,19 @@ class GeminiLLM(LLM):
|
|
| 228 |
return prompt
|
| 229 |
|
| 230 |
|
|
|
|
| 231 |
class APIPoolManager:
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
self.api_keys = list(api_keys)
|
| 238 |
self.active_keys = list(api_keys)
|
| 239 |
self.rate_limits = rate_limits
|
|
@@ -244,16 +334,15 @@ class APIPoolManager:
|
|
| 244 |
if rate_limits is not None:
|
| 245 |
for key in api_keys:
|
| 246 |
self.usage[key] = {}
|
| 247 |
-
for model, (rpm,
|
| 248 |
self.usage[key][model] = {
|
| 249 |
"timestamps": deque(maxlen=max(1, rpm)),
|
| 250 |
"daily_requests": 0,
|
| 251 |
-
"last_day": date.today()
|
| 252 |
}
|
| 253 |
-
else:
|
| 254 |
-
self.usage = {}
|
| 255 |
|
| 256 |
-
|
|
|
|
| 257 |
usage = self.usage[key][model]
|
| 258 |
today = date.today()
|
| 259 |
if usage["last_day"] < today:
|
|
@@ -266,25 +355,18 @@ class APIPoolManager:
|
|
| 266 |
self._refresh_daily(key, model)
|
| 267 |
_, rpd = self.rate_limits[model]
|
| 268 |
if self.usage[key][model]["daily_requests"] >= rpd:
|
| 269 |
-
# drop this api key
|
| 270 |
if key in self.active_keys:
|
| 271 |
self.active_keys.remove(key)
|
| 272 |
return False
|
| 273 |
return True
|
| 274 |
|
| 275 |
def _key_wait_info(self, key: str, model: str) -> Tuple[float, float]:
|
| 276 |
-
"""
|
| 277 |
-
Return tuple (wait_slot_seconds, wait_spacing_seconds)
|
| 278 |
-
- wait_slot_seconds: time until an RPM slot frees because deque is full (0 if slot available)
|
| 279 |
-
- wait_spacing_seconds: time until spacing interval satisfied relative to last timestamp (0 if spacing ok)
|
| 280 |
-
"""
|
| 281 |
if self.rate_limits is None:
|
| 282 |
return 0.0, 0.0
|
| 283 |
rpm, _ = self.rate_limits[model]
|
| 284 |
usage = self.usage[key][model]
|
| 285 |
now = time.time()
|
| 286 |
|
| 287 |
-
# Clean old timestamps > 60s
|
| 288 |
timestamps = usage["timestamps"]
|
| 289 |
while timestamps and now - timestamps[0] > 60:
|
| 290 |
timestamps.popleft()
|
|
@@ -295,7 +377,7 @@ class APIPoolManager:
|
|
| 295 |
wait_slot = max(0.0, 60.0 - (now - oldest))
|
| 296 |
|
| 297 |
wait_spacing = 0.0
|
| 298 |
-
if
|
| 299 |
time_since_last = now - timestamps[-1]
|
| 300 |
min_interval = 60.0 / rpm if rpm > 0 else 0.0
|
| 301 |
wait_spacing = max(0.0, min_interval - time_since_last)
|
|
@@ -303,9 +385,6 @@ class APIPoolManager:
|
|
| 303 |
return wait_slot, wait_spacing
|
| 304 |
|
| 305 |
def can_use_now(self, key: str, model: str) -> bool:
|
| 306 |
-
"""
|
| 307 |
-
True if key is active, RPD ok, and both slot and spacing waits are zero.
|
| 308 |
-
"""
|
| 309 |
if key not in self.active_keys:
|
| 310 |
return False
|
| 311 |
if not self._key_is_rpd_ok(key, model):
|
|
@@ -313,30 +392,22 @@ class APIPoolManager:
|
|
| 313 |
wait_slot, wait_spacing = self._key_wait_info(key, model)
|
| 314 |
return wait_slot <= 0.0 and wait_spacing <= 0.0
|
| 315 |
|
|
|
|
| 316 |
def get_next_key(self, model: str, max_sleep_once: bool = True) -> str:
|
| 317 |
-
"""
|
| 318 |
-
Choose an API key that can be used immediately for the given model.
|
| 319 |
-
If none available now, compute minimum sleep needed across all keys, sleep once,
|
| 320 |
-
then re-evaluate. Loop until a key is found or no keys left.
|
| 321 |
-
"""
|
| 322 |
with self.lock:
|
| 323 |
if not self.active_keys:
|
| 324 |
raise RuntimeError("No available API keys left due to rate limits.")
|
| 325 |
|
| 326 |
-
# Quick pass: try to find an immediately-available key starting from current_index
|
| 327 |
n = len(self.active_keys)
|
| 328 |
for i in range(n):
|
| 329 |
idx = (self.current_index + i) % n
|
| 330 |
key = self.active_keys[idx]
|
| 331 |
if self.can_use_now(key, model):
|
| 332 |
-
# advance pointer fairly to next key for next call
|
| 333 |
self.current_index = (idx + 1) % max(1, len(self.active_keys))
|
| 334 |
return key
|
| 335 |
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
min_wait = None
|
| 339 |
-
for key in list(self.active_keys): # list() to be safe if removal happens
|
| 340 |
if not self._key_is_rpd_ok(key, model):
|
| 341 |
continue
|
| 342 |
wait_slot, wait_spacing = self._key_wait_info(key, model)
|
|
@@ -345,139 +416,145 @@ class APIPoolManager:
|
|
| 345 |
min_wait = wait
|
| 346 |
|
| 347 |
if min_wait is None:
|
| 348 |
-
# No keys left after RPD filtering
|
| 349 |
raise RuntimeError("No available API keys left (RPD exhausted).")
|
| 350 |
|
| 351 |
if min_wait and min_wait > 0:
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
else:
|
| 355 |
-
time.sleep(min_wait)
|
| 356 |
return self.get_next_key(model, max_sleep_once=True)
|
| 357 |
|
| 358 |
-
def record_usage(self, key: str, model: str, timestamp: Optional[float] = None):
|
| 359 |
-
"""
|
| 360 |
-
Call this after you receive the response to record actual usage/time.
|
| 361 |
-
timestamp default is now (time of completion).
|
| 362 |
-
"""
|
| 363 |
if self.rate_limits is None:
|
| 364 |
return
|
| 365 |
t = timestamp or time.time()
|
| 366 |
with self.lock:
|
| 367 |
if key not in self.active_keys:
|
| 368 |
-
# safety - if key was removed in-between, ignore or re-add depending on policy
|
| 369 |
return
|
| 370 |
self._refresh_daily(key, model)
|
| 371 |
self.usage[key][model]["timestamps"].append(t)
|
| 372 |
self.usage[key][model]["daily_requests"] += 1
|
| 373 |
-
# Remove if daily limit reached
|
| 374 |
_, rpd = self.rate_limits[model]
|
| 375 |
if self.usage[key][model]["daily_requests"] >= rpd:
|
| 376 |
if key in self.active_keys:
|
| 377 |
self.active_keys.remove(key)
|
| 378 |
|
| 379 |
|
| 380 |
-
|
|
|
|
|
|
|
| 381 |
if config["type"] == "gemini":
|
| 382 |
-
|
| 383 |
-
thinking_budget = config.get("thinking_budget", 300)
|
| 384 |
-
llm = GeminiLLM(
|
| 385 |
model_name=config["model_name"],
|
| 386 |
-
structured_output=structured_output,
|
| 387 |
-
thinking_budget=thinking_budget,
|
| 388 |
manager=manager,
|
| 389 |
-
**config.get("params", {})
|
| 390 |
)
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
raise ValueError(f"Unknown LLM type: {config['type']}")
|
| 394 |
|
|
|
|
| 395 |
def extract_and_parse_json(text: str) -> Dict[str, Any]:
|
| 396 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
try:
|
| 398 |
return json.loads(text.strip())
|
| 399 |
-
except:
|
| 400 |
pass
|
| 401 |
|
| 402 |
-
|
| 403 |
-
if
|
| 404 |
try:
|
| 405 |
-
return json.loads(
|
| 406 |
-
except:
|
| 407 |
pass
|
| 408 |
|
| 409 |
-
|
| 410 |
-
if
|
| 411 |
try:
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
except:
|
| 415 |
pass
|
| 416 |
|
| 417 |
try:
|
| 418 |
-
|
| 419 |
-
return json.loads(repaired_json)
|
| 420 |
except Exception as e:
|
| 421 |
-
|
| 422 |
return {
|
| 423 |
"thought": f"JSON parsing failed: {str(e)}",
|
| 424 |
"action": "compose_response",
|
| 425 |
-
"params": {
|
| 426 |
-
"text": f"I encountered an error processing your request. Original response: {text[:200]}..."
|
| 427 |
-
},
|
| 428 |
"_parse_error": True,
|
| 429 |
-
"_original_text": text
|
| 430 |
}
|
| 431 |
|
| 432 |
-
|
| 433 |
-
|
|
|
|
|
|
|
| 434 |
for k in keys[:-1]:
|
| 435 |
d = d.setdefault(k, {})
|
| 436 |
d[keys[-1]] = value
|
| 437 |
|
| 438 |
-
|
| 439 |
-
|
|
|
|
| 440 |
if partitions is None:
|
| 441 |
partitions = ["user_profile", "medical_history", "flags_and_assessments", "plans"]
|
| 442 |
-
|
| 443 |
-
summary = {}
|
| 444 |
for partition in partitions:
|
| 445 |
-
if partition in memory and memory[partition]
|
| 446 |
-
|
| 447 |
-
else:
|
| 448 |
-
summary[partition] = "empty"
|
| 449 |
|
| 450 |
-
return json.dumps(summary, indent=2)
|
| 451 |
|
| 452 |
def update_memory_partition(memory: Dict[str, Any], partition: str, data: Any) -> None:
|
| 453 |
-
"""
|
| 454 |
if partition not in memory:
|
| 455 |
memory[partition] = {}
|
| 456 |
-
|
| 457 |
if isinstance(data, dict) and isinstance(memory[partition], dict):
|
| 458 |
memory[partition].update(data)
|
| 459 |
else:
|
| 460 |
memory[partition] = data
|
| 461 |
-
|
| 462 |
-
|
| 463 |
|
|
|
|
| 464 |
class FileCheckpointSaver(BaseCheckpointSaver):
|
| 465 |
-
|
|
|
|
|
|
|
| 466 |
self.directory = directory
|
| 467 |
os.makedirs(directory, exist_ok=True)
|
| 468 |
|
| 469 |
def put(self, config: Dict[str, Any], checkpoint: Dict[str, Any]) -> None:
|
| 470 |
-
"""Save checkpoint to file"""
|
| 471 |
thread_id = config.get("configurable", {}).get("thread_id", "default")
|
| 472 |
filepath = os.path.join(self.directory, f"checkpoint_{thread_id}.pkl")
|
| 473 |
-
with open(filepath,
|
| 474 |
pickle.dump(checkpoint, f)
|
| 475 |
|
| 476 |
def get(self, config: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
| 477 |
-
"""Load checkpoint from file"""
|
| 478 |
thread_id = config.get("configurable", {}).get("thread_id", "default")
|
| 479 |
filepath = os.path.join(self.directory, f"checkpoint_{thread_id}.pkl")
|
| 480 |
if os.path.exists(filepath):
|
| 481 |
-
with open(filepath,
|
| 482 |
return pickle.load(f)
|
| 483 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utilities: LLM wrapper, API-key pool with rate limiting, JSON helpers, and a
|
| 2 |
+
LangGraph file checkpointer.
|
| 3 |
+
|
| 4 |
+
Phase 0 cleanup notes:
|
| 5 |
+
|
| 6 |
+
* Removed the duplicate ``GeminiLLM`` definition (the second class silently
|
| 7 |
+
shadowed the first; both remained import-visible).
|
| 8 |
+
* Dropped ``from google.colab import userdata`` so the module imports cleanly
|
| 9 |
+
outside Colab. API keys come in via ``create_llm_instances`` or env.
|
| 10 |
+
* Replaced ``print(...)`` calls with module loggers under ``nutrition_mas.*``.
|
| 11 |
+
* Routed all reads of ``config.X`` through :func:`config.get_settings`.
|
| 12 |
+
|
| 13 |
+
Larger refactors (Pydantic-typed agent IO, native Gemini ``response_schema``,
|
| 14 |
+
async ``acall``) land in Phase 1.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
import json
|
| 20 |
+
import os
|
|
|
|
|
|
|
| 21 |
import pickle
|
| 22 |
+
import re
|
|
|
|
|
|
|
|
|
|
| 23 |
import time
|
|
|
|
|
|
|
| 24 |
from collections import deque
|
| 25 |
+
from dataclasses import dataclass, field
|
| 26 |
+
from datetime import date, datetime
|
| 27 |
from threading import Lock
|
| 28 |
+
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar
|
| 29 |
|
| 30 |
+
from google import genai
|
| 31 |
+
from google.genai import types
|
| 32 |
+
from json_repair import repair_json
|
| 33 |
+
from langgraph.checkpoint.base import BaseCheckpointSaver
|
| 34 |
+
from pydantic import BaseModel, ValidationError
|
| 35 |
|
| 36 |
+
from config import get_settings
|
| 37 |
+
from logging_setup import get_logger
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
+
_logger = get_logger("utils")
|
| 40 |
+
_llm_logger = get_logger("llm.gemini")
|
| 41 |
+
_pool_logger = get_logger("utils.api_pool")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
+
T = TypeVar("T", bound=BaseModel)
|
|
|
|
|
|
|
| 44 |
|
|
|
|
|
|
|
| 45 |
|
| 46 |
+
# --- Phase 1 fallback metrics --------------------------------------------------
|
| 47 |
+
@dataclass
|
| 48 |
+
class ParseMetrics:
|
| 49 |
+
"""Counts native-vs-fallback parses across the process.
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
+
Phase 1's goal is to drive ``fallback_parses`` to zero. Phase 2 will surface
|
| 52 |
+
these via the eval harness.
|
| 53 |
+
"""
|
| 54 |
|
| 55 |
+
native_parses: int = 0 # response.parsed worked first try
|
| 56 |
+
fallback_parses: int = 0 # had to invoke extract_and_parse_json
|
| 57 |
+
schema_failures: int = 0 # output failed Pydantic validation altogether
|
| 58 |
+
by_model: Dict[str, Dict[str, int]] = field(default_factory=dict)
|
| 59 |
|
| 60 |
+
def record(self, model: str, kind: str) -> None:
|
| 61 |
+
if kind == "native":
|
| 62 |
+
self.native_parses += 1
|
| 63 |
+
elif kind == "fallback":
|
| 64 |
+
self.fallback_parses += 1
|
| 65 |
+
elif kind == "failure":
|
| 66 |
+
self.schema_failures += 1
|
| 67 |
+
slot = self.by_model.setdefault(model, {"native": 0, "fallback": 0, "failure": 0})
|
| 68 |
+
slot[kind] = slot.get(kind, 0) + 1
|
| 69 |
|
|
|
|
|
|
|
| 70 |
|
| 71 |
+
_parse_metrics = ParseMetrics()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
+
def get_parse_metrics() -> ParseMetrics:
|
| 75 |
+
"""Return the global parse-metrics singleton (read-only-ish)."""
|
| 76 |
+
return _parse_metrics
|
| 77 |
|
|
|
|
| 78 |
|
| 79 |
+
# --- Debug-scope helper --------------------------------------------------------
|
| 80 |
+
def should_debug(scope: str, name: str) -> bool:
|
| 81 |
+
"""Return True when this scope/name is enabled in ``settings.debug_scopes``."""
|
| 82 |
+
settings = get_settings()
|
| 83 |
+
if not settings.debug_mode:
|
| 84 |
+
return False
|
| 85 |
+
if scope not in settings.debug_scopes:
|
| 86 |
+
return False
|
| 87 |
+
scopes_list = settings.debug_scopes[scope]
|
| 88 |
+
return "all" in scopes_list or name in scopes_list
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
+
# --- Filesystem logging --------------------------------------------------------
|
| 92 |
+
def save_to_json(data: Dict[str, Any], filename: str, subdirectory: Optional[str] = None) -> None:
|
| 93 |
+
"""Persist a structured payload to ``settings.log_dir`` if logging is on."""
|
| 94 |
+
settings = get_settings()
|
| 95 |
+
if settings.log_dir is None:
|
| 96 |
+
return
|
| 97 |
+
log_dir = os.path.join(settings.log_dir, subdirectory) if subdirectory else settings.log_dir
|
| 98 |
+
os.makedirs(log_dir, exist_ok=True)
|
| 99 |
+
# Filenames may contain ``:`` from ISO timestamps which is invalid on Windows.
|
| 100 |
+
safe_name = filename.replace(":", "-")
|
| 101 |
+
filepath = os.path.join(log_dir, safe_name)
|
| 102 |
+
with open(filepath, "w", encoding="utf-8") as f:
|
| 103 |
+
json.dump(data, f, indent=2, default=str)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# --- LLM abstractions ----------------------------------------------------------
|
| 107 |
+
class LLM:
|
| 108 |
+
"""Minimal LLM contract: callable returning a list with one string."""
|
| 109 |
+
|
| 110 |
+
def __call__(self, prompt: str, **kwargs: Any) -> list[str]: # pragma: no cover - interface
|
| 111 |
+
raise NotImplementedError
|
| 112 |
+
|
| 113 |
+
def format_prompt(self, messages: List[Dict[str, str]]) -> str: # pragma: no cover - interface
|
| 114 |
+
raise NotImplementedError
|
| 115 |
+
|
| 116 |
|
| 117 |
class GeminiLLM(LLM):
|
| 118 |
+
"""Synchronous Gemini wrapper with API-key pooling.
|
| 119 |
+
|
| 120 |
+
Phase 1 will add an ``acall`` async path and replace the JSON-in-text
|
| 121 |
+
contract with native ``response_schema`` Pydantic models.
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
model_name: str,
|
| 127 |
+
structured_output: bool = False,
|
| 128 |
+
thinking_budget: int = 300,
|
| 129 |
+
manager: Optional["APIPoolManager"] = None,
|
| 130 |
+
**kwargs: Any,
|
| 131 |
+
) -> None:
|
| 132 |
self.model_name = model_name
|
| 133 |
self.structured_output = structured_output
|
| 134 |
self.thinking_budget = thinking_budget
|
|
|
|
| 136 |
self.manager = manager
|
| 137 |
self.is_gemma = "gemma" in model_name.lower()
|
| 138 |
if self.is_gemma:
|
| 139 |
+
# Gemma family doesn't support thinking_config or JSON response schema.
|
| 140 |
self.structured_output = False
|
| 141 |
self.thinking_budget = None
|
|
|
|
| 142 |
|
| 143 |
+
def __call__(self, prompt: str, **kwargs: Any) -> list[str]:
|
| 144 |
+
"""Untyped streaming call. Returns ``[response_text]``.
|
| 145 |
+
|
| 146 |
+
Backwards-compat path used by code that still parses JSON-from-text.
|
| 147 |
+
Prefer :meth:`call_typed` when a Pydantic schema is available.
|
| 148 |
+
"""
|
| 149 |
+
text, _ = self._invoke(prompt, response_schema=None, **kwargs)
|
| 150 |
+
return [text]
|
| 151 |
+
|
| 152 |
+
def call_typed(
|
| 153 |
+
self,
|
| 154 |
+
prompt: str,
|
| 155 |
+
response_model: Type[T],
|
| 156 |
+
**kwargs: Any,
|
| 157 |
+
) -> Optional[T]:
|
| 158 |
+
"""Call Gemini with constrained-decoded JSON matching ``response_model``.
|
| 159 |
+
|
| 160 |
+
Returns a parsed instance of ``response_model``, or ``None`` if every
|
| 161 |
+
parse strategy failed (in which case the parse-metrics ``schema_failures``
|
| 162 |
+
counter is incremented so the eval harness can spot it).
|
| 163 |
+
"""
|
| 164 |
+
text, parsed = self._invoke(prompt, response_schema=response_model, **kwargs)
|
| 165 |
+
|
| 166 |
+
# Strategy 1: SDK already parsed it for us via response_schema.
|
| 167 |
+
if isinstance(parsed, response_model):
|
| 168 |
+
_parse_metrics.record(self.model_name, "native")
|
| 169 |
+
return parsed
|
| 170 |
+
|
| 171 |
+
# Strategy 2: SDK gave us a dict; try to validate it.
|
| 172 |
+
if isinstance(parsed, dict):
|
| 173 |
+
try:
|
| 174 |
+
instance = response_model.model_validate(parsed)
|
| 175 |
+
_parse_metrics.record(self.model_name, "native")
|
| 176 |
+
return instance
|
| 177 |
+
except ValidationError as e:
|
| 178 |
+
_llm_logger.debug("response.parsed dict failed Pydantic validation: %s", e)
|
| 179 |
+
|
| 180 |
+
# Strategy 3: regex / json_repair fallback on the raw text.
|
| 181 |
+
try:
|
| 182 |
+
data = extract_and_parse_json(text)
|
| 183 |
+
instance = response_model.model_validate(data)
|
| 184 |
+
_parse_metrics.record(self.model_name, "fallback")
|
| 185 |
+
_llm_logger.warning(
|
| 186 |
+
"Used JSON-repair fallback for %s on model %s — fix the prompt or schema",
|
| 187 |
+
response_model.__name__,
|
| 188 |
+
self.model_name,
|
| 189 |
+
)
|
| 190 |
+
return instance
|
| 191 |
+
except (ValidationError, Exception) as e: # noqa: BLE001
|
| 192 |
+
_parse_metrics.record(self.model_name, "failure")
|
| 193 |
+
_llm_logger.error(
|
| 194 |
+
"Failed to parse %s from %s response: %s",
|
| 195 |
+
response_model.__name__,
|
| 196 |
+
self.model_name,
|
| 197 |
+
str(e),
|
| 198 |
+
)
|
| 199 |
+
return None
|
| 200 |
+
|
| 201 |
+
def _invoke(
|
| 202 |
+
self,
|
| 203 |
+
prompt: str,
|
| 204 |
+
response_schema: Optional[Type[BaseModel]] = None,
|
| 205 |
+
**kwargs: Any,
|
| 206 |
+
) -> Tuple[str, Any]:
|
| 207 |
+
"""Single Gemini round-trip. Returns ``(text, response.parsed)``.
|
| 208 |
+
|
| 209 |
+
``parsed`` is whatever the SDK populated on ``response.parsed`` —
|
| 210 |
+
usually a Pydantic instance when ``response_schema`` is supplied, ``None``
|
| 211 |
+
otherwise.
|
| 212 |
+
"""
|
| 213 |
if self.manager is None:
|
| 214 |
raise ValueError("APIPoolManager must be provided for rate limiting.")
|
| 215 |
|
| 216 |
merged_kwargs = {**self.kwargs, **kwargs}
|
|
|
|
|
|
|
| 217 |
api_key = self.manager.get_next_key(self.model_name)
|
| 218 |
|
| 219 |
try:
|
| 220 |
client = genai.Client(api_key=api_key)
|
| 221 |
+
contents = [types.Content(role="user", parts=[types.Part.from_text(text=prompt)])]
|
| 222 |
+
generate_content_config = self._build_config(merged_kwargs, response_schema=response_schema)
|
| 223 |
|
| 224 |
+
start_time = time.time()
|
| 225 |
+
# Non-streaming when we want response.parsed (the streaming API
|
| 226 |
+
# doesn't populate it). Streaming for free-text plain calls.
|
| 227 |
+
if response_schema is not None:
|
| 228 |
+
response = client.models.generate_content(
|
| 229 |
+
model=self.model_name,
|
| 230 |
+
contents=contents,
|
| 231 |
+
config=generate_content_config,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
)
|
| 233 |
+
response_text = response.text or ""
|
| 234 |
+
parsed = getattr(response, "parsed", None)
|
| 235 |
else:
|
| 236 |
+
response_text = ""
|
| 237 |
+
parsed = None
|
| 238 |
+
for chunk in client.models.generate_content_stream(
|
| 239 |
+
model=self.model_name,
|
| 240 |
+
contents=contents,
|
| 241 |
+
config=generate_content_config,
|
| 242 |
+
):
|
| 243 |
+
if chunk.text:
|
| 244 |
+
response_text += chunk.text
|
| 245 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
completion_time = time.time()
|
| 247 |
if self.manager.rate_limits is not None:
|
| 248 |
self.manager.record_usage(api_key, self.model_name, completion_time)
|
| 249 |
|
| 250 |
+
_llm_logger.debug(
|
| 251 |
+
"LLM call completed for %s using key …%s in %.2fs (schema=%s)",
|
| 252 |
+
self.model_name,
|
| 253 |
+
api_key[-4:],
|
| 254 |
+
completion_time - start_time,
|
| 255 |
+
response_schema.__name__ if response_schema else "none",
|
| 256 |
+
)
|
| 257 |
+
return response_text.strip(), parsed
|
| 258 |
+
|
| 259 |
+
except Exception as e: # noqa: BLE001 — narrow this in Phase 4 (per-error retries)
|
| 260 |
+
_llm_logger.warning(
|
| 261 |
+
"LLM call failed for %s using key …%s: %s",
|
| 262 |
+
self.model_name,
|
| 263 |
+
api_key[-4:],
|
| 264 |
+
str(e),
|
| 265 |
+
)
|
| 266 |
+
return f"Error: LLM call failed - {str(e)}", None
|
| 267 |
+
|
| 268 |
+
def _build_config(
|
| 269 |
+
self,
|
| 270 |
+
merged_kwargs: Dict[str, Any],
|
| 271 |
+
response_schema: Optional[Type[BaseModel]] = None,
|
| 272 |
+
) -> types.GenerateContentConfig:
|
| 273 |
+
max_tokens = merged_kwargs.get("max_tokens", 5120)
|
| 274 |
+
temperature = merged_kwargs.get("temperature", 0.3)
|
| 275 |
|
| 276 |
+
if self.is_gemma:
|
| 277 |
+
# Gemma can't do thinking_config or response_schema.
|
| 278 |
+
return types.GenerateContentConfig(
|
| 279 |
+
response_mime_type="text/plain",
|
| 280 |
+
max_output_tokens=max_tokens,
|
| 281 |
+
temperature=temperature,
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
thinking_cfg = types.ThinkingConfig(thinking_budget=self.thinking_budget)
|
| 285 |
+
if response_schema is not None:
|
| 286 |
+
return types.GenerateContentConfig(
|
| 287 |
+
thinking_config=thinking_cfg,
|
| 288 |
+
response_mime_type="application/json",
|
| 289 |
+
response_schema=response_schema,
|
| 290 |
+
max_output_tokens=max_tokens,
|
| 291 |
+
temperature=temperature,
|
| 292 |
+
)
|
| 293 |
+
mime = "application/json" if self.structured_output else "text/plain"
|
| 294 |
+
return types.GenerateContentConfig(
|
| 295 |
+
thinking_config=thinking_cfg,
|
| 296 |
+
response_mime_type=mime,
|
| 297 |
+
max_output_tokens=max_tokens,
|
| 298 |
+
temperature=temperature,
|
| 299 |
+
)
|
| 300 |
|
| 301 |
def format_prompt(self, messages: List[Dict[str, str]]) -> str:
|
| 302 |
prompt = ""
|
|
|
|
| 311 |
return prompt
|
| 312 |
|
| 313 |
|
| 314 |
+
# --- API key pool with optional rate limiting ----------------------------------
|
| 315 |
class APIPoolManager:
|
| 316 |
+
"""Round-robin Gemini API keys with per-key RPM/RPD enforcement.
|
| 317 |
+
|
| 318 |
+
``rate_limits`` is ``{model_name: (rpm, rpd)}``. When ``None``, the pool
|
| 319 |
+
just rotates keys without any throttling.
|
| 320 |
+
"""
|
| 321 |
+
|
| 322 |
+
def __init__(
|
| 323 |
+
self,
|
| 324 |
+
api_keys: List[str],
|
| 325 |
+
rate_limits: Optional[Dict[str, Tuple[int, int]]] = None,
|
| 326 |
+
) -> None:
|
| 327 |
self.api_keys = list(api_keys)
|
| 328 |
self.active_keys = list(api_keys)
|
| 329 |
self.rate_limits = rate_limits
|
|
|
|
| 334 |
if rate_limits is not None:
|
| 335 |
for key in api_keys:
|
| 336 |
self.usage[key] = {}
|
| 337 |
+
for model, (rpm, _rpd) in rate_limits.items():
|
| 338 |
self.usage[key][model] = {
|
| 339 |
"timestamps": deque(maxlen=max(1, rpm)),
|
| 340 |
"daily_requests": 0,
|
| 341 |
+
"last_day": date.today(),
|
| 342 |
}
|
|
|
|
|
|
|
| 343 |
|
| 344 |
+
# --- internal helpers ------------------------------------------------------
|
| 345 |
+
def _refresh_daily(self, key: str, model: str) -> None:
|
| 346 |
usage = self.usage[key][model]
|
| 347 |
today = date.today()
|
| 348 |
if usage["last_day"] < today:
|
|
|
|
| 355 |
self._refresh_daily(key, model)
|
| 356 |
_, rpd = self.rate_limits[model]
|
| 357 |
if self.usage[key][model]["daily_requests"] >= rpd:
|
|
|
|
| 358 |
if key in self.active_keys:
|
| 359 |
self.active_keys.remove(key)
|
| 360 |
return False
|
| 361 |
return True
|
| 362 |
|
| 363 |
def _key_wait_info(self, key: str, model: str) -> Tuple[float, float]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
if self.rate_limits is None:
|
| 365 |
return 0.0, 0.0
|
| 366 |
rpm, _ = self.rate_limits[model]
|
| 367 |
usage = self.usage[key][model]
|
| 368 |
now = time.time()
|
| 369 |
|
|
|
|
| 370 |
timestamps = usage["timestamps"]
|
| 371 |
while timestamps and now - timestamps[0] > 60:
|
| 372 |
timestamps.popleft()
|
|
|
|
| 377 |
wait_slot = max(0.0, 60.0 - (now - oldest))
|
| 378 |
|
| 379 |
wait_spacing = 0.0
|
| 380 |
+
if timestamps:
|
| 381 |
time_since_last = now - timestamps[-1]
|
| 382 |
min_interval = 60.0 / rpm if rpm > 0 else 0.0
|
| 383 |
wait_spacing = max(0.0, min_interval - time_since_last)
|
|
|
|
| 385 |
return wait_slot, wait_spacing
|
| 386 |
|
| 387 |
def can_use_now(self, key: str, model: str) -> bool:
|
|
|
|
|
|
|
|
|
|
| 388 |
if key not in self.active_keys:
|
| 389 |
return False
|
| 390 |
if not self._key_is_rpd_ok(key, model):
|
|
|
|
| 392 |
wait_slot, wait_spacing = self._key_wait_info(key, model)
|
| 393 |
return wait_slot <= 0.0 and wait_spacing <= 0.0
|
| 394 |
|
| 395 |
+
# --- public API ------------------------------------------------------------
|
| 396 |
def get_next_key(self, model: str, max_sleep_once: bool = True) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
with self.lock:
|
| 398 |
if not self.active_keys:
|
| 399 |
raise RuntimeError("No available API keys left due to rate limits.")
|
| 400 |
|
|
|
|
| 401 |
n = len(self.active_keys)
|
| 402 |
for i in range(n):
|
| 403 |
idx = (self.current_index + i) % n
|
| 404 |
key = self.active_keys[idx]
|
| 405 |
if self.can_use_now(key, model):
|
|
|
|
| 406 |
self.current_index = (idx + 1) % max(1, len(self.active_keys))
|
| 407 |
return key
|
| 408 |
|
| 409 |
+
min_wait: Optional[float] = None
|
| 410 |
+
for key in list(self.active_keys):
|
|
|
|
|
|
|
| 411 |
if not self._key_is_rpd_ok(key, model):
|
| 412 |
continue
|
| 413 |
wait_slot, wait_spacing = self._key_wait_info(key, model)
|
|
|
|
| 416 |
min_wait = wait
|
| 417 |
|
| 418 |
if min_wait is None:
|
|
|
|
| 419 |
raise RuntimeError("No available API keys left (RPD exhausted).")
|
| 420 |
|
| 421 |
if min_wait and min_wait > 0:
|
| 422 |
+
_pool_logger.debug("Waiting %.2fs for next API slot", min_wait)
|
| 423 |
+
time.sleep(min_wait)
|
|
|
|
|
|
|
| 424 |
return self.get_next_key(model, max_sleep_once=True)
|
| 425 |
|
| 426 |
+
def record_usage(self, key: str, model: str, timestamp: Optional[float] = None) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
if self.rate_limits is None:
|
| 428 |
return
|
| 429 |
t = timestamp or time.time()
|
| 430 |
with self.lock:
|
| 431 |
if key not in self.active_keys:
|
|
|
|
| 432 |
return
|
| 433 |
self._refresh_daily(key, model)
|
| 434 |
self.usage[key][model]["timestamps"].append(t)
|
| 435 |
self.usage[key][model]["daily_requests"] += 1
|
|
|
|
| 436 |
_, rpd = self.rate_limits[model]
|
| 437 |
if self.usage[key][model]["daily_requests"] >= rpd:
|
| 438 |
if key in self.active_keys:
|
| 439 |
self.active_keys.remove(key)
|
| 440 |
|
| 441 |
|
| 442 |
+
# --- Factory -------------------------------------------------------------------
|
| 443 |
+
def create_llm(config: dict, manager: APIPoolManager) -> LLM:
|
| 444 |
+
"""Instantiate an LLM from a config dict."""
|
| 445 |
if config["type"] == "gemini":
|
| 446 |
+
return GeminiLLM(
|
|
|
|
|
|
|
| 447 |
model_name=config["model_name"],
|
| 448 |
+
structured_output=config.get("structured_output", False),
|
| 449 |
+
thinking_budget=config.get("thinking_budget", 300),
|
| 450 |
manager=manager,
|
| 451 |
+
**config.get("params", {}),
|
| 452 |
)
|
| 453 |
+
raise ValueError(f"Unknown LLM type: {config['type']}")
|
| 454 |
+
|
|
|
|
| 455 |
|
| 456 |
+
# --- JSON helpers --------------------------------------------------------------
|
| 457 |
def extract_and_parse_json(text: str) -> Dict[str, Any]:
|
| 458 |
+
"""Best-effort JSON extraction with a chain of fallbacks.
|
| 459 |
+
|
| 460 |
+
Phase 1 makes this a measured *fallback* path only — agents will use
|
| 461 |
+
Gemini's native ``response_schema`` for guaranteed structure. Until then,
|
| 462 |
+
this remains the primary parser.
|
| 463 |
+
"""
|
| 464 |
try:
|
| 465 |
return json.loads(text.strip())
|
| 466 |
+
except Exception:
|
| 467 |
pass
|
| 468 |
|
| 469 |
+
fenced = re.search(r"```json\s*(.*?)\s*```", text, re.DOTALL)
|
| 470 |
+
if fenced:
|
| 471 |
try:
|
| 472 |
+
return json.loads(fenced.group(1))
|
| 473 |
+
except Exception:
|
| 474 |
pass
|
| 475 |
|
| 476 |
+
braces = re.search(r"\{.*\}", text, re.DOTALL)
|
| 477 |
+
if braces:
|
| 478 |
try:
|
| 479 |
+
return json.loads(repair_json(braces.group(0)))
|
| 480 |
+
except Exception:
|
|
|
|
| 481 |
pass
|
| 482 |
|
| 483 |
try:
|
| 484 |
+
return json.loads(repair_json(text))
|
|
|
|
| 485 |
except Exception as e:
|
| 486 |
+
_logger.warning("All JSON parsing strategies failed: %s", str(e))
|
| 487 |
return {
|
| 488 |
"thought": f"JSON parsing failed: {str(e)}",
|
| 489 |
"action": "compose_response",
|
| 490 |
+
"params": {"text": f"I encountered an error processing your request. Original response: {text[:200]}..."},
|
|
|
|
|
|
|
| 491 |
"_parse_error": True,
|
| 492 |
+
"_original_text": text,
|
| 493 |
}
|
| 494 |
|
| 495 |
+
|
| 496 |
+
def set_nested(d: Dict[str, Any], key: str, value: Any) -> None:
|
| 497 |
+
"""Assign ``value`` at a dotted-path key inside a nested dict."""
|
| 498 |
+
keys = key.split(".")
|
| 499 |
for k in keys[:-1]:
|
| 500 |
d = d.setdefault(k, {})
|
| 501 |
d[keys[-1]] = value
|
| 502 |
|
| 503 |
+
|
| 504 |
+
def get_memory_summary(memory: Dict[str, Any], partitions: Optional[List[str]] = None) -> str:
|
| 505 |
+
"""Format selected memory partitions as JSON for prompt embedding."""
|
| 506 |
if partitions is None:
|
| 507 |
partitions = ["user_profile", "medical_history", "flags_and_assessments", "plans"]
|
| 508 |
+
summary: Dict[str, Any] = {}
|
|
|
|
| 509 |
for partition in partitions:
|
| 510 |
+
summary[partition] = memory[partition] if partition in memory and memory[partition] else "empty"
|
| 511 |
+
return json.dumps(summary, indent=2, default=str)
|
|
|
|
|
|
|
| 512 |
|
|
|
|
| 513 |
|
| 514 |
def update_memory_partition(memory: Dict[str, Any], partition: str, data: Any) -> None:
|
| 515 |
+
"""Merge ``data`` into ``memory[partition]`` (or assign when types disagree)."""
|
| 516 |
if partition not in memory:
|
| 517 |
memory[partition] = {}
|
|
|
|
| 518 |
if isinstance(data, dict) and isinstance(memory[partition], dict):
|
| 519 |
memory[partition].update(data)
|
| 520 |
else:
|
| 521 |
memory[partition] = data
|
| 522 |
+
_logger.debug("Updated memory partition %r with new data", partition)
|
| 523 |
+
|
| 524 |
|
| 525 |
+
# --- Checkpointer --------------------------------------------------------------
|
| 526 |
class FileCheckpointSaver(BaseCheckpointSaver):
|
| 527 |
+
"""Pickle LangGraph checkpoints to ``directory/checkpoint_<thread_id>.pkl``."""
|
| 528 |
+
|
| 529 |
+
def __init__(self, directory: str) -> None:
|
| 530 |
self.directory = directory
|
| 531 |
os.makedirs(directory, exist_ok=True)
|
| 532 |
|
| 533 |
def put(self, config: Dict[str, Any], checkpoint: Dict[str, Any]) -> None:
|
|
|
|
| 534 |
thread_id = config.get("configurable", {}).get("thread_id", "default")
|
| 535 |
filepath = os.path.join(self.directory, f"checkpoint_{thread_id}.pkl")
|
| 536 |
+
with open(filepath, "wb") as f:
|
| 537 |
pickle.dump(checkpoint, f)
|
| 538 |
|
| 539 |
def get(self, config: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
|
|
| 540 |
thread_id = config.get("configurable", {}).get("thread_id", "default")
|
| 541 |
filepath = os.path.join(self.directory, f"checkpoint_{thread_id}.pkl")
|
| 542 |
if os.path.exists(filepath):
|
| 543 |
+
with open(filepath, "rb") as f:
|
| 544 |
return pickle.load(f)
|
| 545 |
return None
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
__all__ = [
|
| 549 |
+
"APIPoolManager",
|
| 550 |
+
"FileCheckpointSaver",
|
| 551 |
+
"GeminiLLM",
|
| 552 |
+
"LLM",
|
| 553 |
+
"create_llm",
|
| 554 |
+
"extract_and_parse_json",
|
| 555 |
+
"get_memory_summary",
|
| 556 |
+
"save_to_json",
|
| 557 |
+
"set_nested",
|
| 558 |
+
"should_debug",
|
| 559 |
+
"update_memory_partition",
|
| 560 |
+
]
|
validation.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ValidationAgent — the critic in the generator-critic loop.
|
| 2 |
+
|
| 3 |
+
Why this exists
|
| 4 |
+
----------------
|
| 5 |
+
The original README promised a ``ValidationAgent`` but it was never
|
| 6 |
+
implemented; the system shipped plans straight from the Planner to the user.
|
| 7 |
+
Modern multi-agent literature (Anthropic's research-system writeup, every
|
| 8 |
+
LangGraph reflection-pattern tutorial) is unanimous that a separate critic
|
| 9 |
+
node materially raises output quality on tasks with hard constraints.
|
| 10 |
+
|
| 11 |
+
Design
|
| 12 |
+
------
|
| 13 |
+
We combine two layers:
|
| 14 |
+
|
| 15 |
+
1. **Deterministic checks** (no LLM, no cost, instant):
|
| 16 |
+
* allergy violations,
|
| 17 |
+
* calorie deviation > 3 % of daily target,
|
| 18 |
+
* each macro deviation > 5 % of its target,
|
| 19 |
+
* disliked foods present (advisory),
|
| 20 |
+
* professional-consultation flag set without disclaimer.
|
| 21 |
+
|
| 22 |
+
2. **LLM-graded checks** (one Gemini round-trip, structured output):
|
| 23 |
+
* medical-flag respect (e.g., diabetes user should avoid high-GL meals),
|
| 24 |
+
* citation presence for clinical recommendations,
|
| 25 |
+
* cultural appropriateness against user's country/cuisine preference.
|
| 26 |
+
|
| 27 |
+
Verdict semantics
|
| 28 |
+
-----------------
|
| 29 |
+
* ``pass`` — Coach proceeds to ``compose_response``.
|
| 30 |
+
* ``revise`` — Issues are bundled into the next Planner task; Coach loops back
|
| 31 |
+
to ``call_agent('PlannerAgent', task=...)``. Capped at 2
|
| 32 |
+
revisions (enforced by Coach prompt) to avoid infinite loops.
|
| 33 |
+
* ``reject`` — Hard stop with ``severity='high'``. Coach must compose a
|
| 34 |
+
warning + HITL escalation chip (Phase 4 wires up the chip).
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from __future__ import annotations
|
| 38 |
+
|
| 39 |
+
import json
|
| 40 |
+
from datetime import datetime
|
| 41 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 42 |
+
|
| 43 |
+
from logging_setup import get_logger
|
| 44 |
+
from schemas import ValidationDecision, ValidationIssue
|
| 45 |
+
from utils import save_to_json
|
| 46 |
+
|
| 47 |
+
_logger = get_logger("agents.validation")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# Tolerances are class-level so tests/configs can override.
|
| 51 |
+
CALORIE_TOLERANCE = 0.03 # +/- 3 %
|
| 52 |
+
MACRO_TOLERANCE = 0.05 # +/- 5 %
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
_VALIDATION_SYSTEM_PROMPT = """\
|
| 56 |
+
You are the Validation Agent. You receive a meal plan and the medical
|
| 57 |
+
assessment context. Your job is to grade the plan, NOT redesign it.
|
| 58 |
+
|
| 59 |
+
Mandatory checks (in addition to the deterministic ones already supplied):
|
| 60 |
+
1. Medical-flag respect: for each flag in flags_and_assessments.flags
|
| 61 |
+
(e.g., "diabetes_risk", "high_ldl"), confirm the plan does not contain
|
| 62 |
+
foods that contraindicate the flag. Cite which food fails which flag.
|
| 63 |
+
2. Evidence: clinical recommendations in flags_and_assessments.recommendations
|
| 64 |
+
must be reflected in the plan or notes. Mention any unaddressed item.
|
| 65 |
+
3. Cultural appropriateness: if user_profile.country is set, confirm at
|
| 66 |
+
least 60 % of foods are commonly available / culturally familiar there.
|
| 67 |
+
Otherwise emit a low-severity issue suggesting substitutions.
|
| 68 |
+
|
| 69 |
+
Output JSON shape (enforced by schema):
|
| 70 |
+
{
|
| 71 |
+
"verdict": "pass" | "revise" | "reject",
|
| 72 |
+
"issues": [
|
| 73 |
+
{"code": "...", "description": "...",
|
| 74 |
+
"severity": "low" | "medium" | "high"}
|
| 75 |
+
],
|
| 76 |
+
"notes": "...",
|
| 77 |
+
"requires_human_review": false
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
Rules:
|
| 81 |
+
- Mark requires_human_review=true if any issue has severity="high" OR if
|
| 82 |
+
flags_and_assessments.requires_professional_consultation is true.
|
| 83 |
+
- Use verdict="reject" only for hard safety violations (allergy made it
|
| 84 |
+
through, food explicitly contraindicated by medication).
|
| 85 |
+
- Use verdict="revise" for fixable problems (over-budget calories, missing
|
| 86 |
+
guideline citation, monotonous menu).
|
| 87 |
+
- Use verdict="pass" only when issues is empty OR all issues are severity="low".
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class ValidationAgent:
|
| 92 |
+
"""Generator-critic gate for the Planner's output."""
|
| 93 |
+
|
| 94 |
+
def __init__(self, llm_instance):
|
| 95 |
+
self.llm = llm_instance
|
| 96 |
+
|
| 97 |
+
# ------------------------------------------------------------------
|
| 98 |
+
# Public API
|
| 99 |
+
# ------------------------------------------------------------------
|
| 100 |
+
def handle_task(self, task: str, memory: Dict[str, Any]) -> str:
|
| 101 |
+
"""Validate the current plan in ``memory.plans.current_plan``.
|
| 102 |
+
|
| 103 |
+
Returns a JSON string of ``ValidationDecision.model_dump()`` so the
|
| 104 |
+
Coach can read structured fields back out (``verdict``, ``issues``).
|
| 105 |
+
"""
|
| 106 |
+
_logger.info("\n🛡️ VALIDATION AGENT STARTED")
|
| 107 |
+
|
| 108 |
+
plan = memory.get("plans", {}).get("current_plan")
|
| 109 |
+
if plan is None:
|
| 110 |
+
_logger.warning("No current_plan in memory; nothing to validate.")
|
| 111 |
+
verdict = ValidationDecision(
|
| 112 |
+
verdict="reject",
|
| 113 |
+
issues=[
|
| 114 |
+
ValidationIssue(
|
| 115 |
+
code="missing_plan",
|
| 116 |
+
description="No current_plan in memory; Planner did not finalise.",
|
| 117 |
+
severity="high",
|
| 118 |
+
)
|
| 119 |
+
],
|
| 120 |
+
notes="Validator received no plan. Re-run PlannerAgent.",
|
| 121 |
+
requires_human_review=False,
|
| 122 |
+
)
|
| 123 |
+
return self._save_and_return(task, memory, verdict)
|
| 124 |
+
|
| 125 |
+
# 1. Deterministic checks
|
| 126 |
+
det_issues = self._deterministic_checks(plan, memory)
|
| 127 |
+
|
| 128 |
+
# 2. LLM-graded checks (only if deterministic ones don't already reject)
|
| 129 |
+
llm_decision: Optional[ValidationDecision] = None
|
| 130 |
+
hard_block = any(i.severity == "high" for i in det_issues)
|
| 131 |
+
if not hard_block:
|
| 132 |
+
llm_decision = self._llm_review(plan, memory, det_issues)
|
| 133 |
+
|
| 134 |
+
# 3. Merge
|
| 135 |
+
all_issues = list(det_issues)
|
| 136 |
+
notes_parts: List[str] = []
|
| 137 |
+
requires_hr = False
|
| 138 |
+
if llm_decision is not None:
|
| 139 |
+
all_issues.extend(llm_decision.issues)
|
| 140 |
+
if llm_decision.notes:
|
| 141 |
+
notes_parts.append(llm_decision.notes)
|
| 142 |
+
requires_hr |= llm_decision.requires_human_review
|
| 143 |
+
|
| 144 |
+
# Force human review when the medical assessment said so.
|
| 145 |
+
if memory.get("flags_and_assessments", {}).get("requires_professional_consultation"):
|
| 146 |
+
requires_hr = True
|
| 147 |
+
|
| 148 |
+
verdict = self._compute_verdict(all_issues)
|
| 149 |
+
decision = ValidationDecision(
|
| 150 |
+
verdict=verdict,
|
| 151 |
+
issues=all_issues,
|
| 152 |
+
notes=" | ".join(notes_parts) if notes_parts else "",
|
| 153 |
+
requires_human_review=requires_hr,
|
| 154 |
+
)
|
| 155 |
+
_logger.info("🛡️ Validation verdict: %s (%d issue(s))", verdict, len(all_issues))
|
| 156 |
+
return self._save_and_return(task, memory, decision)
|
| 157 |
+
|
| 158 |
+
# ------------------------------------------------------------------
|
| 159 |
+
# Deterministic layer
|
| 160 |
+
# ------------------------------------------------------------------
|
| 161 |
+
@staticmethod
|
| 162 |
+
def _deterministic_checks(plan: Dict[str, Any], memory: Dict[str, Any]) -> List[ValidationIssue]:
|
| 163 |
+
issues: List[ValidationIssue] = []
|
| 164 |
+
|
| 165 |
+
user_profile = memory.get("user_profile", {}) or {}
|
| 166 |
+
allergies = {a.strip().lower() for a in user_profile.get("allergies", []) or [] if a}
|
| 167 |
+
dislikes_raw = user_profile.get("food_dislikes", "") or ""
|
| 168 |
+
dislikes = {d.strip().lower() for d in dislikes_raw.split(",") if d.strip()}
|
| 169 |
+
|
| 170 |
+
flags = memory.get("flags_and_assessments", {}) or {}
|
| 171 |
+
calc = flags.get("calculations", {}) or {}
|
| 172 |
+
target_calories = calc.get("daily_target_calories")
|
| 173 |
+
macro_targets = calc.get("macro_targets") or {}
|
| 174 |
+
|
| 175 |
+
# Walk plan, accumulating foods and totals.
|
| 176 |
+
foods, totals = ValidationAgent._extract_foods_and_totals(plan)
|
| 177 |
+
|
| 178 |
+
# 1. Allergy violations (severity high — never let these through)
|
| 179 |
+
for food in foods:
|
| 180 |
+
name = (food.get("name") or "").lower()
|
| 181 |
+
for allergen in allergies:
|
| 182 |
+
if allergen and allergen in name:
|
| 183 |
+
issues.append(
|
| 184 |
+
ValidationIssue(
|
| 185 |
+
code="allergy_violation",
|
| 186 |
+
description=f"Food '{name}' matches allergen '{allergen}'.",
|
| 187 |
+
severity="high",
|
| 188 |
+
)
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# 2. Disliked foods (advisory)
|
| 192 |
+
for food in foods:
|
| 193 |
+
name = (food.get("name") or "").lower()
|
| 194 |
+
for d in dislikes:
|
| 195 |
+
if d and d in name:
|
| 196 |
+
issues.append(
|
| 197 |
+
ValidationIssue(
|
| 198 |
+
code="disliked_food",
|
| 199 |
+
description=f"Food '{name}' matches user dislike '{d}'.",
|
| 200 |
+
severity="low",
|
| 201 |
+
)
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# 3. Calorie tolerance
|
| 205 |
+
if target_calories and totals.get("calories"):
|
| 206 |
+
dev = abs(totals["calories"] - target_calories) / target_calories
|
| 207 |
+
if dev > CALORIE_TOLERANCE:
|
| 208 |
+
issues.append(
|
| 209 |
+
ValidationIssue(
|
| 210 |
+
code="calorie_deviation",
|
| 211 |
+
description=(
|
| 212 |
+
f"Plan total {totals['calories']:.0f} kcal vs target "
|
| 213 |
+
f"{target_calories} kcal ({dev*100:.1f}% deviation)."
|
| 214 |
+
),
|
| 215 |
+
severity="medium",
|
| 216 |
+
)
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
# 4. Macro tolerances
|
| 220 |
+
macro_map = {"protein_g": "protein", "fat_g": "fat", "carbohydrates_g": "carbohydrates"}
|
| 221 |
+
for tgt_key, plan_key in macro_map.items():
|
| 222 |
+
target = macro_targets.get(tgt_key)
|
| 223 |
+
actual = totals.get(plan_key)
|
| 224 |
+
if target and actual:
|
| 225 |
+
dev = abs(actual - target) / target
|
| 226 |
+
if dev > MACRO_TOLERANCE:
|
| 227 |
+
issues.append(
|
| 228 |
+
ValidationIssue(
|
| 229 |
+
code=f"{plan_key}_deviation",
|
| 230 |
+
description=(
|
| 231 |
+
f"{plan_key} total {actual:.0f}g vs target {target}g "
|
| 232 |
+
f"({dev*100:.1f}% deviation)."
|
| 233 |
+
),
|
| 234 |
+
severity="medium",
|
| 235 |
+
)
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
return issues
|
| 239 |
+
|
| 240 |
+
@staticmethod
|
| 241 |
+
def _extract_foods_and_totals(
|
| 242 |
+
plan: Dict[str, Any],
|
| 243 |
+
) -> Tuple[List[Dict[str, Any]], Dict[str, float]]:
|
| 244 |
+
"""Best-effort: support both 'days' shape and a flat dict-of-foods.
|
| 245 |
+
|
| 246 |
+
We tolerate the LLM's free-form ``drafted_plan`` shape too, since the
|
| 247 |
+
Planner's final_plan isn't yet strictly typed.
|
| 248 |
+
"""
|
| 249 |
+
foods: List[Dict[str, Any]] = []
|
| 250 |
+
totals: Dict[str, float] = {"calories": 0.0, "protein": 0.0, "fat": 0.0, "carbohydrates": 0.0}
|
| 251 |
+
|
| 252 |
+
def _walk(node: Any) -> None:
|
| 253 |
+
if isinstance(node, list):
|
| 254 |
+
for item in node:
|
| 255 |
+
_walk(item)
|
| 256 |
+
elif isinstance(node, dict):
|
| 257 |
+
if "name" in node and any(k in node for k in ("calories", "calories_g", "kcal")):
|
| 258 |
+
foods.append(node)
|
| 259 |
+
totals["calories"] += float(node.get("calories", node.get("kcal", 0)) or 0)
|
| 260 |
+
totals["protein"] += float(node.get("protein_g", node.get("protein", 0)) or 0)
|
| 261 |
+
totals["fat"] += float(node.get("fat_g", node.get("fat", 0)) or 0)
|
| 262 |
+
totals["carbohydrates"] += float(
|
| 263 |
+
node.get("carbohydrates_g", node.get("carbohydrates", 0)) or 0
|
| 264 |
+
)
|
| 265 |
+
else:
|
| 266 |
+
for v in node.values():
|
| 267 |
+
_walk(v)
|
| 268 |
+
|
| 269 |
+
_walk(plan)
|
| 270 |
+
|
| 271 |
+
# Plans may also surface daily_totals directly — prefer those when present.
|
| 272 |
+
if isinstance(plan, dict) and "daily_totals" in plan:
|
| 273 |
+
dt = plan["daily_totals"]
|
| 274 |
+
for k in ("calories", "protein", "fat", "carbohydrates"):
|
| 275 |
+
if k in dt:
|
| 276 |
+
totals[k] = float(dt[k])
|
| 277 |
+
return foods, totals
|
| 278 |
+
|
| 279 |
+
# ------------------------------------------------------------------
|
| 280 |
+
# LLM layer
|
| 281 |
+
# ------------------------------------------------------------------
|
| 282 |
+
def _llm_review(
|
| 283 |
+
self,
|
| 284 |
+
plan: Dict[str, Any],
|
| 285 |
+
memory: Dict[str, Any],
|
| 286 |
+
deterministic_issues: List[ValidationIssue],
|
| 287 |
+
) -> Optional[ValidationDecision]:
|
| 288 |
+
det_summary = "\n".join(f"- [{i.severity}] {i.code}: {i.description}" for i in deterministic_issues) or "None"
|
| 289 |
+
prompt = (
|
| 290 |
+
f"{_VALIDATION_SYSTEM_PROMPT}\n\n--- Plan ---\n{json.dumps(plan, indent=2, default=str)}\n\n"
|
| 291 |
+
f"--- User profile ---\n{json.dumps(memory.get('user_profile', {}), indent=2, default=str)}\n\n"
|
| 292 |
+
f"--- Medical assessment ---\n"
|
| 293 |
+
f"{json.dumps(memory.get('flags_and_assessments', {}), indent=2, default=str)}\n\n"
|
| 294 |
+
f"--- Deterministic findings already raised ---\n{det_summary}\n\n"
|
| 295 |
+
"Add only NEW issues. Do not repeat the deterministic ones."
|
| 296 |
+
)
|
| 297 |
+
decision = self.llm.call_typed(prompt, ValidationDecision)
|
| 298 |
+
if decision is None:
|
| 299 |
+
_logger.warning("Validator LLM call returned no parseable decision; skipping LLM layer.")
|
| 300 |
+
return decision
|
| 301 |
+
|
| 302 |
+
# ------------------------------------------------------------------
|
| 303 |
+
@staticmethod
|
| 304 |
+
def _compute_verdict(issues: List[ValidationIssue]) -> str:
|
| 305 |
+
if any(i.severity == "high" for i in issues):
|
| 306 |
+
return "reject"
|
| 307 |
+
if any(i.severity == "medium" for i in issues):
|
| 308 |
+
return "revise"
|
| 309 |
+
return "pass"
|
| 310 |
+
|
| 311 |
+
@staticmethod
|
| 312 |
+
def _save_and_return(task: str, memory: Dict[str, Any], decision: ValidationDecision) -> str:
|
| 313 |
+
# Persist to memory so the Coach can inspect the verdict next turn.
|
| 314 |
+
memory.setdefault("flags_and_assessments", {})
|
| 315 |
+
memory["flags_and_assessments"]["last_validation"] = decision.model_dump()
|
| 316 |
+
memory["flags_and_assessments"]["last_validation_at"] = datetime.now().isoformat()
|
| 317 |
+
|
| 318 |
+
save_to_json(
|
| 319 |
+
{
|
| 320 |
+
"task": task,
|
| 321 |
+
"decision": decision.model_dump(),
|
| 322 |
+
"timestamp": datetime.now().isoformat(),
|
| 323 |
+
},
|
| 324 |
+
f"validation_agent_{datetime.now().isoformat()}.json",
|
| 325 |
+
subdirectory="ValidationAgent",
|
| 326 |
+
)
|
| 327 |
+
return decision.model_dump_json()
|
workflow.py
CHANGED
|
@@ -1,121 +1,135 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from langgraph.checkpoint.memory import MemorySaver
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from state import NutritionState
|
| 4 |
-
from utils import
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
|
| 9 |
def should_continue(state: NutritionState) -> str:
|
| 10 |
-
|
|
|
|
|
|
|
| 11 |
return "end"
|
| 12 |
if state["num_turns"] >= state["max_turns"]:
|
| 13 |
return "end"
|
| 14 |
return "execute_action"
|
| 15 |
|
|
|
|
| 16 |
def coach_node(state: NutritionState, coach_agent) -> NutritionState:
|
| 17 |
return coach_agent.handle_task(state)
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
return state
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
if
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
task = action["params"]["task"]
|
| 34 |
-
elif action["action"] == "ask_user":
|
| 35 |
-
print(f"❓Asking user: {action['params']['prompt']}")
|
| 36 |
-
elif action["action"] == "write_memory":
|
| 37 |
-
print(f"Writing to memory partition: {action['params']['partition']}")
|
| 38 |
-
|
| 39 |
-
# Handle JSON parsing errors
|
| 40 |
if action.get("_parse_error"):
|
| 41 |
error_message = "I encountered an error processing the request. Let me try a different approach."
|
| 42 |
state["conversation_history"].append({"role": "assistant", "content": error_message})
|
| 43 |
return {**state, "agent_result": error_message}
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
state['previous_actions'] = []
|
| 48 |
|
| 49 |
try:
|
| 50 |
-
if
|
| 51 |
-
agent_name =
|
| 52 |
-
task =
|
| 53 |
agent_result = agents[agent_name].handle_task(task, state["memory"])
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
| 58 |
return {**state, "agent_result": success_message}
|
| 59 |
|
| 60 |
-
|
| 61 |
-
tool_name =
|
| 62 |
-
task =
|
| 63 |
-
tool_result =
|
| 64 |
-
|
| 65 |
-
|
|
|
|
| 66 |
return {**state, "agent_result": tool_result}
|
| 67 |
|
| 68 |
-
|
| 69 |
-
partition =
|
| 70 |
-
data =
|
| 71 |
updated_data = {**data, "last_updated": datetime.now().isoformat()}
|
| 72 |
set_nested(state["memory"], partition, updated_data)
|
| 73 |
-
|
| 74 |
-
state['previous_actions'].append(action_description)
|
| 75 |
return {**state, "agent_result": "Memory updated successfully"}
|
| 76 |
|
| 77 |
-
|
| 78 |
-
response_text =
|
| 79 |
if not response_text:
|
| 80 |
raise ValueError("Missing 'text' or 'response' in params for compose_response")
|
| 81 |
state["conversation_history"].append({"role": "assistant", "content": response_text})
|
| 82 |
-
|
| 83 |
-
state['previous_actions'].append(action_description)
|
| 84 |
return {**state, "agent_result": response_text}
|
| 85 |
|
| 86 |
-
|
| 87 |
-
prompt_text =
|
| 88 |
state["conversation_history"].append({"role": "assistant", "content": prompt_text})
|
| 89 |
-
|
| 90 |
-
state['previous_actions'].append(action_description)
|
| 91 |
return {**state, "agent_result": "User prompted for input"}
|
| 92 |
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
state['previous_actions'].append(action_description)
|
| 96 |
-
return {**state, "agent_result": f"Unknown action: {action['action']}"}
|
| 97 |
|
| 98 |
-
except Exception as e:
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
state
|
| 102 |
-
return {**state, "agent_result": error_result}
|
| 103 |
|
| 104 |
-
|
|
|
|
| 105 |
workflow = StateGraph(NutritionState)
|
| 106 |
workflow.add_node("coach", lambda state: coach_node(state, coach_agent))
|
| 107 |
workflow.add_node("execute_action", lambda state: execute_action_node(state, agents, tools))
|
| 108 |
workflow.set_entry_point("coach")
|
| 109 |
workflow.add_edge("coach", "execute_action")
|
| 110 |
-
workflow.add_conditional_edges(
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
if persistence_dir:
|
| 113 |
checkpointer = FileCheckpointSaver(persistence_dir)
|
| 114 |
-
|
| 115 |
else:
|
| 116 |
checkpointer = MemorySaver()
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
app = workflow.compile(checkpointer=checkpointer)
|
| 120 |
-
return app
|
| 121 |
|
|
|
|
|
|
| 1 |
+
"""LangGraph wiring for the Coach <-> action loop.
|
| 2 |
+
|
| 3 |
+
Phase 1 keeps the same two-node graph (``coach`` -> ``execute_action`` -> loop)
|
| 4 |
+
so the public contract is unchanged. Phase 2 will explode this into subgraphs
|
| 5 |
+
with parallel branches and a Validator critic loop.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
from typing import Any, Dict
|
| 12 |
+
|
| 13 |
from langgraph.checkpoint.memory import MemorySaver
|
| 14 |
+
from langgraph.graph import END, StateGraph
|
| 15 |
+
|
| 16 |
+
from config import get_settings
|
| 17 |
+
from logging_setup import get_logger
|
| 18 |
from state import NutritionState
|
| 19 |
+
from utils import FileCheckpointSaver, set_nested
|
| 20 |
+
|
| 21 |
+
_logger = get_logger("workflow")
|
| 22 |
+
|
| 23 |
|
| 24 |
def should_continue(state: NutritionState) -> str:
|
| 25 |
+
"""Edge predicate: stop on terminal action or when we hit max_turns."""
|
| 26 |
+
current = state.get("current_action") or {}
|
| 27 |
+
if current.get("action") in {"compose_response", "ask_user"}:
|
| 28 |
return "end"
|
| 29 |
if state["num_turns"] >= state["max_turns"]:
|
| 30 |
return "end"
|
| 31 |
return "execute_action"
|
| 32 |
|
| 33 |
+
|
| 34 |
def coach_node(state: NutritionState, coach_agent) -> NutritionState:
|
| 35 |
return coach_agent.handle_task(state)
|
| 36 |
|
| 37 |
+
|
| 38 |
+
def execute_action_node(state: NutritionState, agents: Dict[str, Any], tools: Dict[str, Any]) -> NutritionState:
|
| 39 |
+
action = state.get("current_action") or {}
|
| 40 |
+
action_name = action.get("action")
|
| 41 |
+
params = action.get("params", {}) or {}
|
| 42 |
+
|
| 43 |
+
if not action_name:
|
| 44 |
return state
|
| 45 |
+
|
| 46 |
+
settings = get_settings()
|
| 47 |
+
if settings.debug_mode:
|
| 48 |
+
_logger.debug("Executing Action: %s", action_name)
|
| 49 |
+
else:
|
| 50 |
+
if action_name == "ask_user":
|
| 51 |
+
_logger.info("❓ Asking user: %s", params.get("prompt"))
|
| 52 |
+
elif action_name == "write_memory":
|
| 53 |
+
_logger.info("Writing to memory partition: %s", params.get("partition"))
|
| 54 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
if action.get("_parse_error"):
|
| 56 |
error_message = "I encountered an error processing the request. Let me try a different approach."
|
| 57 |
state["conversation_history"].append({"role": "assistant", "content": error_message})
|
| 58 |
return {**state, "agent_result": error_message}
|
| 59 |
|
| 60 |
+
if "previous_actions" not in state:
|
| 61 |
+
state["previous_actions"] = []
|
|
|
|
| 62 |
|
| 63 |
try:
|
| 64 |
+
if action_name == "call_agent":
|
| 65 |
+
agent_name = params["agent_name"]
|
| 66 |
+
task = params["task"]
|
| 67 |
agent_result = agents[agent_name].handle_task(task, state["memory"])
|
| 68 |
+
success_message = (
|
| 69 |
+
f"{agent_name} task completed and stored in the memory successfully"
|
| 70 |
+
if agent_result
|
| 71 |
+
else f"{agent_name} task failed"
|
| 72 |
+
)
|
| 73 |
+
state["previous_actions"].append(f"Called agent {agent_name} with task: {task}")
|
| 74 |
return {**state, "agent_result": success_message}
|
| 75 |
|
| 76 |
+
if action_name == "call_tool":
|
| 77 |
+
tool_name = params["tool_name"]
|
| 78 |
+
task = params["task"]
|
| 79 |
+
tool_result = (
|
| 80 |
+
tools[tool_name].handle_task(task) if tool_name in tools else f"Unknown tool: {tool_name}"
|
| 81 |
+
)
|
| 82 |
+
state["previous_actions"].append(f"Called tool {tool_name} with task: {task}")
|
| 83 |
return {**state, "agent_result": tool_result}
|
| 84 |
|
| 85 |
+
if action_name == "write_memory":
|
| 86 |
+
partition = params["partition"]
|
| 87 |
+
data = params["data"]
|
| 88 |
updated_data = {**data, "last_updated": datetime.now().isoformat()}
|
| 89 |
set_nested(state["memory"], partition, updated_data)
|
| 90 |
+
state["previous_actions"].append(f"Wrote to memory partition: {partition}")
|
|
|
|
| 91 |
return {**state, "agent_result": "Memory updated successfully"}
|
| 92 |
|
| 93 |
+
if action_name == "compose_response":
|
| 94 |
+
response_text = params.get("text") or params.get("response")
|
| 95 |
if not response_text:
|
| 96 |
raise ValueError("Missing 'text' or 'response' in params for compose_response")
|
| 97 |
state["conversation_history"].append({"role": "assistant", "content": response_text})
|
| 98 |
+
state["previous_actions"].append("Composed response to user")
|
|
|
|
| 99 |
return {**state, "agent_result": response_text}
|
| 100 |
|
| 101 |
+
if action_name == "ask_user":
|
| 102 |
+
prompt_text = params["prompt"]
|
| 103 |
state["conversation_history"].append({"role": "assistant", "content": prompt_text})
|
| 104 |
+
state["previous_actions"].append(f"Asked user: {prompt_text}")
|
|
|
|
| 105 |
return {**state, "agent_result": "User prompted for input"}
|
| 106 |
|
| 107 |
+
state["previous_actions"].append(f"Executed {action_name} with params: {params}")
|
| 108 |
+
return {**state, "agent_result": f"Unknown action: {action_name}"}
|
|
|
|
|
|
|
| 109 |
|
| 110 |
+
except Exception as e: # noqa: BLE001
|
| 111 |
+
_logger.exception("Error executing %s", action_name)
|
| 112 |
+
state["previous_actions"].append(f"Attempted {action_name} with params: {params}")
|
| 113 |
+
return {**state, "agent_result": f"Error executing {action_name}: {str(e)}"}
|
|
|
|
| 114 |
|
| 115 |
+
|
| 116 |
+
def setup_workflow(coach_agent, agents: Dict[str, Any], tools: Dict[str, Any], persistence_dir: str | None = None):
|
| 117 |
workflow = StateGraph(NutritionState)
|
| 118 |
workflow.add_node("coach", lambda state: coach_node(state, coach_agent))
|
| 119 |
workflow.add_node("execute_action", lambda state: execute_action_node(state, agents, tools))
|
| 120 |
workflow.set_entry_point("coach")
|
| 121 |
workflow.add_edge("coach", "execute_action")
|
| 122 |
+
workflow.add_conditional_edges(
|
| 123 |
+
"execute_action",
|
| 124 |
+
should_continue,
|
| 125 |
+
{"execute_action": "coach", "end": END},
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
if persistence_dir:
|
| 129 |
checkpointer = FileCheckpointSaver(persistence_dir)
|
| 130 |
+
_logger.info("MAS workflow compiled with file-based persistence at %s.", persistence_dir)
|
| 131 |
else:
|
| 132 |
checkpointer = MemorySaver()
|
| 133 |
+
_logger.info("MAS workflow compiled with in-memory persistence.")
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
+
return workflow.compile(checkpointer=checkpointer)
|