iLOVE2D commited on
Commit
5374a2d
·
verified ·
1 Parent(s): 4cb1c2a

Upload 2846 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +14 -0
  2. evoagentx/.ipynb_checkpoints/test aflow-checkpoint.ipynb +427 -0
  3. evoagentx/__init__.py +2 -0
  4. evoagentx/__pycache__/__init__.cpython-311.pyc +0 -0
  5. evoagentx/actions/__init__.py +5 -0
  6. evoagentx/actions/__pycache__/__init__.cpython-311.pyc +0 -0
  7. evoagentx/actions/__pycache__/action.cpython-311.pyc +0 -0
  8. evoagentx/actions/__pycache__/agent_generation.cpython-311.pyc +0 -0
  9. evoagentx/actions/__pycache__/code_extraction.cpython-311.pyc +0 -0
  10. evoagentx/actions/__pycache__/code_verification.cpython-311.pyc +0 -0
  11. evoagentx/actions/__pycache__/customize_action.cpython-311.pyc +0 -0
  12. evoagentx/actions/__pycache__/task_planning.cpython-311.pyc +0 -0
  13. evoagentx/actions/action.py +256 -0
  14. evoagentx/actions/agent_generation.py +198 -0
  15. evoagentx/actions/code_extraction.py +276 -0
  16. evoagentx/actions/code_verification.py +63 -0
  17. evoagentx/actions/customize_action.py +559 -0
  18. evoagentx/actions/task_planning.py +80 -0
  19. evoagentx/agents/__init__.py +6 -0
  20. evoagentx/agents/__pycache__/__init__.cpython-311.pyc +0 -0
  21. evoagentx/agents/__pycache__/action_agent.cpython-311.pyc +0 -0
  22. evoagentx/agents/__pycache__/agent.cpython-311.pyc +0 -0
  23. evoagentx/agents/__pycache__/agent_generator.cpython-311.pyc +0 -0
  24. evoagentx/agents/__pycache__/agent_manager.cpython-311.pyc +0 -0
  25. evoagentx/agents/__pycache__/customize_agent.cpython-311.pyc +0 -0
  26. evoagentx/agents/__pycache__/task_planner.cpython-311.pyc +0 -0
  27. evoagentx/agents/__pycache__/workflow_reviewer.cpython-311.pyc +0 -0
  28. evoagentx/agents/action_agent.py +502 -0
  29. evoagentx/agents/agent.py +531 -0
  30. evoagentx/agents/agent_generator.py +23 -0
  31. evoagentx/agents/agent_manager.py +505 -0
  32. evoagentx/agents/customize_agent.py +522 -0
  33. evoagentx/agents/long_term_memory_agent.py +491 -0
  34. evoagentx/agents/task_planner.py +35 -0
  35. evoagentx/agents/workflow_reviewer.py +14 -0
  36. evoagentx/app/__init__.py +0 -0
  37. evoagentx/app/api.py +329 -0
  38. evoagentx/app/app.env +22 -0
  39. evoagentx/app/config.py +83 -0
  40. evoagentx/app/db.py +177 -0
  41. evoagentx/app/main.py +177 -0
  42. evoagentx/app/requirements.txt +23 -0
  43. evoagentx/app/schemas.py +168 -0
  44. evoagentx/app/security.py +172 -0
  45. evoagentx/app/services.py +463 -0
  46. evoagentx/benchmark/.ipynb_checkpoints/Untitled-checkpoint.ipynb +6 -0
  47. evoagentx/benchmark/.ipynb_checkpoints/test_load_json-checkpoint.ipynb +570 -0
  48. evoagentx/benchmark/README.md +178 -0
  49. evoagentx/benchmark/Untitled.ipynb +6 -0
  50. evoagentx/benchmark/WorfBench.py +155 -0
.gitattributes CHANGED
@@ -33,3 +33,17 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/antibiotic_pred/EC_antibiotic.xlsx filter=lfs diff=lfs merge=lfs -text
37
+ examples/antibiotic_pred/ec_train filter=lfs diff=lfs merge=lfs -text
38
+ examples/antibiotic_pred/ec_train.json filter=lfs diff=lfs merge=lfs -text
39
+ examples/hotpotqa/test[[:space:]]structopt[[:space:]]toolcall.ipynb filter=lfs diff=lfs merge=lfs -text
40
+ examples/pertqa/.ipynb_checkpoints/test[[:space:]]structual[[:space:]]ourloop-withsearch-checkpoint.ipynb filter=lfs diff=lfs merge=lfs -text
41
+ examples/pertqa/k562_processed_grn.csv filter=lfs diff=lfs merge=lfs -text
42
+ examples/pertqa/pert_folder/EGRET_K562.csv filter=lfs diff=lfs merge=lfs -text
43
+ examples/pertqa/reploge_train.json filter=lfs diff=lfs merge=lfs -text
44
+ examples/pertqa/replogle_update_train.csv filter=lfs diff=lfs merge=lfs -text
45
+ examples/pertqa/replogle_update_train.json filter=lfs diff=lfs merge=lfs -text
46
+ examples/pertqa/test[[:space:]]structual[[:space:]]ourloop-withsearch.ipynb filter=lfs diff=lfs merge=lfs -text
47
+ examples/pubmedqa/pubmedqa_train.json filter=lfs diff=lfs merge=lfs -text
48
+ examples/workflow/invest/300750/20250815/graphs/300750_candlestick_chart.png filter=lfs diff=lfs merge=lfs -text
49
+ examples/workflow/invest/300750/20250815/graphs/300750_technical_charts.png filter=lfs diff=lfs merge=lfs -text
evoagentx/.ipynb_checkpoints/test aflow-checkpoint.ipynb ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "e2d3caf8",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import pandas as pd\n",
11
+ "import matplotlib.pyplot as plt\n",
12
+ "import pickle\n",
13
+ "import glob\n",
14
+ "import pandas as pd\n",
15
+ "import glob\n",
16
+ "from tqdm import tqdm\n",
17
+ "import base64\n",
18
+ "import requests\n",
19
+ "# OpenAI API Key\n",
20
+ "api_key = \"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\"\n",
21
+ "# Function to encode the image\n",
22
+ "import os\n",
23
+ "os.environ[\"OPENAI_API_KEY\"] = api_key\n"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": 2,
29
+ "id": "f870b639",
30
+ "metadata": {},
31
+ "outputs": [
32
+ {
33
+ "name": "stderr",
34
+ "output_type": "stream",
35
+ "text": [
36
+ "/gpfs/radev/home/tl688/.conda/envs/evoagentx/lib/python3.11/site-packages/PyPDF2/__init__.py:21: DeprecationWarning: PyPDF2 is deprecated. Please move to the pypdf library instead.\n",
37
+ " warnings.warn(\n"
38
+ ]
39
+ }
40
+ ],
41
+ "source": [
42
+ "import os\n",
43
+ "from dotenv import load_dotenv\n",
44
+ "from evoagentx.optimizers import AFlowOptimizer\n",
45
+ "from evoagentx.models import LiteLLMConfig, LiteLLM, OpenAILLMConfig, OpenAILLM\n",
46
+ "from evoagentx.benchmark import AFlowHumanEval\n",
47
+ "\n",
48
+ "# Load environment variables\n",
49
+ "load_dotenv()\n",
50
+ "OPENAI_API_KEY = os.getenv(\"OPENAI_API_KEY\")\n",
51
+ "# ANTHROPIC_API_KEY = os.getenv(\"ANTHROPIC_API_KEY\")"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": 3,
57
+ "id": "1f3dd892",
58
+ "metadata": {},
59
+ "outputs": [],
60
+ "source": [
61
+ "# # Configure the optimizer LLM (Claude 3.5 Sonnet)\n",
62
+ "# claude_config = LiteLLMConfig(\n",
63
+ "# model=\"anthropic/claude-3-5-sonnet-20240620\", \n",
64
+ "# anthropic_key=ANTHROPIC_API_KEY\n",
65
+ "# )\n",
66
+ "# optimizer_llm = LiteLLM(config=claude_config)\n",
67
+ "\n",
68
+ "# Configure the executor LLM (GPT-4o-mini)\n",
69
+ "openai_config = OpenAILLMConfig(\n",
70
+ " model=\"gpt-4o-mini\", \n",
71
+ " openai_key=OPENAI_API_KEY\n",
72
+ ")\n",
73
+ "\n",
74
+ "claude_config = LiteLLMConfig(\n",
75
+ " model=\"gpt-4o-mini\", \n",
76
+ " openai_key=OPENAI_API_KEY\n",
77
+ ")\n",
78
+ "executor_llm = OpenAILLM(config=openai_config)\n",
79
+ "optimizer_llm = LiteLLM(config=claude_config)"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": 4,
85
+ "id": "a87feb08",
86
+ "metadata": {},
87
+ "outputs": [],
88
+ "source": [
89
+ "EXPERIMENTAL_CONFIG = {\n",
90
+ " \"humaneval\": {\n",
91
+ " \"question_type\": \"code\", \n",
92
+ " \"operators\": [\"Custom\", \"CustomCodeGenerate\", \"Test\", \"ScEnsemble\"] \n",
93
+ " }, \n",
94
+ " \"mbpp\": {\n",
95
+ " \"question_type\": \"code\", \n",
96
+ " \"operators\": [\"Custom\", \"CustomCodeGenerate\", \"Test\", \"ScEnsemble\"] \n",
97
+ " },\n",
98
+ " \"hotpotqa\": {\n",
99
+ " \"question_type\": \"qa\", \n",
100
+ " \"operators\": [\"Custom\", \"AnswerGenerate\", \"QAScEnsemble\"]\n",
101
+ " },\n",
102
+ " \"gsm8k\": {\n",
103
+ " \"question_type\": \"math\", \n",
104
+ " \"operators\": [\"Custom\", \"ScEnsemble\", \"Programmer\"]\n",
105
+ " },\n",
106
+ " \"math\": {\n",
107
+ " \"question_type\": \"math\", \n",
108
+ " \"operators\": [\"Custom\", \"ScEnsemble\", \"Programmer\"]\n",
109
+ " }\n",
110
+ "}"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "code",
115
+ "execution_count": 5,
116
+ "id": "b6054068",
117
+ "metadata": {},
118
+ "outputs": [],
119
+ "source": [
120
+ "import evoagentx.workflow.operators as operator\n",
121
+ "import examples.aflow.code_generation.prompt as prompt_custom # noqa: F401\n",
122
+ "from evoagentx.models.model_configs import LLMConfig\n",
123
+ "from evoagentx.benchmark.benchmark import Benchmark\n",
124
+ "from evoagentx.models.model_utils import create_llm_instance\n",
125
+ "\n",
126
+ "class Workflow:\n",
127
+ " \n",
128
+ " def __init__(\n",
129
+ " self,\n",
130
+ " name: str,\n",
131
+ " llm_config: LLMConfig,\n",
132
+ " benchmark: Benchmark\n",
133
+ " ):\n",
134
+ " self.name = name\n",
135
+ " self.llm = create_llm_instance(llm_config)\n",
136
+ " self.benchmark = benchmark \n",
137
+ " self.custom = operator.Custom(self.llm)\n",
138
+ " self.custom_code_generate = operator.CustomCodeGenerate(self.llm)\n",
139
+ "\n",
140
+ " async def __call__(self, problem: str, entry_point: str):\n",
141
+ " \"\"\"\n",
142
+ " Implementation of the workflow\n",
143
+ " Custom operator to generate anything you want.\n",
144
+ " But when you want to get standard code, you should use custom_code_generate operator.\n",
145
+ " \"\"\"\n",
146
+ " # await self.custom(input=, instruction=\"\")\n",
147
+ " solution = await self.custom_code_generate(problem=problem, entry_point=entry_point, instruction=prompt_custom.GENERATE_PYTHON_CODE_PROMPT) # But When you want to get standard code ,you should use customcodegenerator.\n",
148
+ " return solution['response']"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "execution_count": 6,
154
+ "id": "27e574ad",
155
+ "metadata": {},
156
+ "outputs": [
157
+ {
158
+ "name": "stderr",
159
+ "output_type": "stream",
160
+ "text": [
161
+ "\u001b[32m2025-10-12 15:04:04.523\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.benchmark.humaneval\u001b[0m:\u001b[36m_load_data\u001b[0m:\u001b[36m182\u001b[0m - \u001b[1mLoading train data from None\u001b[0m\n",
162
+ "\u001b[32m2025-10-12 15:04:04.524\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.benchmark.humaneval\u001b[0m:\u001b[36m_load_data\u001b[0m:\u001b[36m185\u001b[0m - \u001b[1mLoading dev data from humaneval_validate.jsonl\u001b[0m\n",
163
+ "\u001b[32m2025-10-12 15:04:04.525\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.benchmark.humaneval\u001b[0m:\u001b[36m_load_data\u001b[0m:\u001b[36m188\u001b[0m - \u001b[1mLoading test data from humaneval_test.jsonl\u001b[0m\n"
164
+ ]
165
+ }
166
+ ],
167
+ "source": [
168
+ "# Initialize the benchmark\n",
169
+ "humaneval = AFlowHumanEval()"
170
+ ]
171
+ },
172
+ {
173
+ "cell_type": "code",
174
+ "execution_count": 8,
175
+ "id": "2f8da181",
176
+ "metadata": {},
177
+ "outputs": [],
178
+ "source": [
179
+ "optimizer = AFlowOptimizer(\n",
180
+ " graph_path=\"../examples/aflow/code_generation\", # Path to the initial workflow graph\n",
181
+ " optimized_path=\"../examples/aflow/humaneval/optimized\", # Path to save optimized workflows\n",
182
+ " optimizer_llm=optimizer_llm, # LLM for optimization\n",
183
+ " executor_llm=executor_llm, # LLM for execution\n",
184
+ " validation_rounds=3, # Number of times to run validation on the development set during optimization\n",
185
+ " eval_rounds=3, # Number of times to run evaluation on the test set during testing\n",
186
+ " max_rounds=20, # Maximum optimization rounds\n",
187
+ " **EXPERIMENTAL_CONFIG[\"humaneval\"] # Task-specific configuration, used to specify the task type and available operators\n",
188
+ ")"
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "code",
193
+ "execution_count": 9,
194
+ "id": "74937699",
195
+ "metadata": {},
196
+ "outputs": [],
197
+ "source": [
198
+ "import nest_asyncio\n",
199
+ "nest_asyncio.apply()"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": null,
205
+ "id": "98ac4a63",
206
+ "metadata": {
207
+ "scrolled": true
208
+ },
209
+ "outputs": [
210
+ {
211
+ "name": "stderr",
212
+ "output_type": "stream",
213
+ "text": [
214
+ "\u001b[32m2025-10-12 15:04:50.304\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.utils.aflow_utils.graph_utils\u001b[0m:\u001b[36mload_graph\u001b[0m:\u001b[36m51\u001b[0m - \u001b[1mError loading graph for round 0: No module named '.'\u001b[0m\n",
215
+ "\u001b[32m2025-10-12 15:04:50.305\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.optimizers.aflow_optimizer\u001b[0m:\u001b[36m_execute_with_retry\u001b[0m:\u001b[36m147\u001b[0m - \u001b[1mError occurred: No module named '.'. Retrying... (Attempt 1/3)\u001b[0m\n",
216
+ "\u001b[32m2025-10-12 15:04:55.310\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.utils.aflow_utils.graph_utils\u001b[0m:\u001b[36mload_graph\u001b[0m:\u001b[36m51\u001b[0m - \u001b[1mError loading graph for round 0: No module named '.'\u001b[0m\n",
217
+ "\u001b[32m2025-10-12 15:04:55.311\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.optimizers.aflow_optimizer\u001b[0m:\u001b[36m_execute_with_retry\u001b[0m:\u001b[36m147\u001b[0m - \u001b[1mError occurred: No module named '.'. Retrying... (Attempt 2/3)\u001b[0m\n",
218
+ "\u001b[32m2025-10-12 15:05:05.322\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.utils.aflow_utils.graph_utils\u001b[0m:\u001b[36mload_graph\u001b[0m:\u001b[36m51\u001b[0m - \u001b[1mError loading graph for round 0: No module named '.'\u001b[0m\n",
219
+ "\u001b[32m2025-10-12 15:05:05.322\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.optimizers.aflow_optimizer\u001b[0m:\u001b[36m_execute_with_retry\u001b[0m:\u001b[36m147\u001b[0m - \u001b[1mError occurred: No module named '.'. Retrying... (Attempt 3/3)\u001b[0m\n",
220
+ "\u001b[32m2025-10-12 15:05:05.322\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.optimizers.aflow_optimizer\u001b[0m:\u001b[36m_execute_with_retry\u001b[0m:\u001b[36m149\u001b[0m - \u001b[1mMax retries reached.\u001b[0m\n",
221
+ "\u001b[32m2025-10-12 15:05:05.323\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.optimizers.aflow_optimizer\u001b[0m:\u001b[36moptimize\u001b[0m:\u001b[36m112\u001b[0m - \u001b[1mScore for round 1: None\u001b[0m\n",
222
+ "\u001b[32m2025-10-12 15:05:05.326\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.optimizers.aflow_optimizer\u001b[0m:\u001b[36m_execute_with_retry\u001b[0m:\u001b[36m147\u001b[0m - \u001b[1mError occurred: 'round'. Retrying... (Attempt 1/3)\u001b[0m\n",
223
+ "\u001b[32m2025-10-12 15:05:10.332\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.optimizers.aflow_optimizer\u001b[0m:\u001b[36m_execute_with_retry\u001b[0m:\u001b[36m147\u001b[0m - \u001b[1mError occurred: 'round'. Retrying... (Attempt 2/3)\u001b[0m\n",
224
+ "\u001b[32m2025-10-12 15:05:20.344\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.optimizers.aflow_optimizer\u001b[0m:\u001b[36m_execute_with_retry\u001b[0m:\u001b[36m147\u001b[0m - \u001b[1mError occurred: 'round'. Retrying... (Attempt 3/3)\u001b[0m\n",
225
+ "\u001b[32m2025-10-12 15:05:20.344\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.optimizers.aflow_optimizer\u001b[0m:\u001b[36m_execute_with_retry\u001b[0m:\u001b[36m149\u001b[0m - \u001b[1mMax retries reached.\u001b[0m\n",
226
+ "\u001b[32m2025-10-12 15:05:20.345\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.optimizers.aflow_optimizer\u001b[0m:\u001b[36moptimize\u001b[0m:\u001b[36m112\u001b[0m - \u001b[1mScore for round 2: None\u001b[0m\n",
227
+ "\u001b[32m2025-10-12 15:05:20.347\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.optimizers.aflow_optimizer\u001b[0m:\u001b[36m_execute_with_retry\u001b[0m:\u001b[36m147\u001b[0m - \u001b[1mError occurred: 'round'. Retrying... (Attempt 1/3)\u001b[0m\n"
228
+ ]
229
+ }
230
+ ],
231
+ "source": [
232
+ "# Optimize the workflow\n",
233
+ "optimizer.optimize(humaneval)"
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "code",
238
+ "execution_count": null,
239
+ "id": "1010d583",
240
+ "metadata": {
241
+ "scrolled": true
242
+ },
243
+ "outputs": [],
244
+ "source": [
245
+ "optimizer.test(humaneval)"
246
+ ]
247
+ },
248
+ {
249
+ "cell_type": "code",
250
+ "execution_count": null,
251
+ "id": "becb5a82",
252
+ "metadata": {},
253
+ "outputs": [],
254
+ "source": [
255
+ "import pandas as pd"
256
+ ]
257
+ },
258
+ {
259
+ "cell_type": "code",
260
+ "execution_count": 16,
261
+ "id": "5c076d29",
262
+ "metadata": {},
263
+ "outputs": [],
264
+ "source": [
265
+ "df = pd.read_json(\"/home/tl688/pitl688/selfevolve/AFlow/data/datasets/scicode_dev.jsonl\", lines=True)"
266
+ ]
267
+ },
268
+ {
269
+ "cell_type": "code",
270
+ "execution_count": 23,
271
+ "id": "481602a9",
272
+ "metadata": {},
273
+ "outputs": [
274
+ {
275
+ "data": {
276
+ "text/plain": [
277
+ "'def get_alpha(recvec, alpha_scaling=5):\\n \"\"\"\\n Calculate the alpha value for the Ewald summation, scaled by a specified factor.\\n Parameters:\\n recvec (np.ndarray): A 3x3 array representing the reciprocal lattice vectors.\\n alpha_scaling (float): A scaling factor applied to the alpha value. Default is 5.\\n Returns:\\n float: The calculated alpha value.\\n \"\"\"\\n alpha = alpha_scaling * np.max(np.linalg.norm(recvec, axis=1))\\n return alpha'"
278
+ ]
279
+ },
280
+ "execution_count": 23,
281
+ "metadata": {},
282
+ "output_type": "execute_result"
283
+ }
284
+ ],
285
+ "source": [
286
+ "df['ground_truth_code'].values[0]"
287
+ ]
288
+ },
289
+ {
290
+ "cell_type": "code",
291
+ "execution_count": 21,
292
+ "id": "ffb0be7e",
293
+ "metadata": {},
294
+ "outputs": [
295
+ {
296
+ "data": {
297
+ "text/plain": [
298
+ "\"def get_alpha(recvec, alpha_scaling=5):\\n '''Calculate the alpha value for the Ewald summation, scaled by a specified factor.\\n Parameters:\\n recvec (np.ndarray): A 3x3 array representing the reciprocal lattice vectors.\\n alpha_scaling (float): A scaling factor applied to the alpha value. Default is 5.\\n Returns:\\n float: The calculated alpha value.\\n '''\""
299
+ ]
300
+ },
301
+ "execution_count": 21,
302
+ "metadata": {},
303
+ "output_type": "execute_result"
304
+ }
305
+ ],
306
+ "source": [
307
+ "df['function_header'].values[0]"
308
+ ]
309
+ },
310
+ {
311
+ "cell_type": "code",
312
+ "execution_count": 24,
313
+ "id": "69acf613",
314
+ "metadata": {},
315
+ "outputs": [
316
+ {
317
+ "data": {
318
+ "text/plain": [
319
+ "'import numpy as np\\nfrom scipy.special import erfc'"
320
+ ]
321
+ },
322
+ "execution_count": 24,
323
+ "metadata": {},
324
+ "output_type": "execute_result"
325
+ }
326
+ ],
327
+ "source": [
328
+ "df['required_dependencies'].values[0]"
329
+ ]
330
+ },
331
+ {
332
+ "cell_type": "code",
333
+ "execution_count": 25,
334
+ "id": "b5696e0e",
335
+ "metadata": {},
336
+ "outputs": [
337
+ {
338
+ "data": {
339
+ "text/plain": [
340
+ "Index(['step_number', 'step_description_prompt', 'step_background',\n",
341
+ " 'ground_truth_code', 'function_header', 'test_cases', 'return_line',\n",
342
+ " 'required_dependencies'],\n",
343
+ " dtype='object')"
344
+ ]
345
+ },
346
+ "execution_count": 25,
347
+ "metadata": {},
348
+ "output_type": "execute_result"
349
+ }
350
+ ],
351
+ "source": [
352
+ "df.columns"
353
+ ]
354
+ },
355
+ {
356
+ "cell_type": "code",
357
+ "execution_count": 27,
358
+ "id": "0a3085a9",
359
+ "metadata": {},
360
+ "outputs": [
361
+ {
362
+ "data": {
363
+ "text/plain": [
364
+ "\"def get_alpha(recvec, alpha_scaling=5):\\n '''Calculate the alpha value for the Ewald summation, scaled by a specified factor.\\n Parameters:\\n recvec (np.ndarray): A 3x3 array representing the reciprocal lattice vectors.\\n alpha_scaling (float): A scaling factor applied to the alpha value. Default is 5.\\n Returns:\\n float: The calculated alpha value.\\n '''\""
365
+ ]
366
+ },
367
+ "execution_count": 27,
368
+ "metadata": {},
369
+ "output_type": "execute_result"
370
+ }
371
+ ],
372
+ "source": [
373
+ "df['function_header'].values[0]"
374
+ ]
375
+ },
376
+ {
377
+ "cell_type": "code",
378
+ "execution_count": 28,
379
+ "id": "e6a76c86",
380
+ "metadata": {},
381
+ "outputs": [
382
+ {
383
+ "data": {
384
+ "text/plain": [
385
+ "\"ref1 = -1.74756\\nEX1 = {\\n 'latvec': np.array([\\n [0.0, 1.0, 1.0],\\n [1.0, 0.0, 1.0],\\n [1.0, 1.0, 0.0]\\n ]),\\n 'atom_charges': np.array([1]),\\n 'atom_coords': np.array([\\n [0.0, 0.0, 0.0]\\n ]),\\n 'configs': np.array([\\n [1.0, 1.0, 1.0]\\n ]),\\n}\\nassert np.allclose(get_alpha(np.linalg.inv(EX1['latvec']).T), target)\""
386
+ ]
387
+ },
388
+ "execution_count": 28,
389
+ "metadata": {},
390
+ "output_type": "execute_result"
391
+ }
392
+ ],
393
+ "source": [
394
+ "df['test_cases'].values[0]"
395
+ ]
396
+ },
397
+ {
398
+ "cell_type": "code",
399
+ "execution_count": null,
400
+ "id": "99775141",
401
+ "metadata": {},
402
+ "outputs": [],
403
+ "source": []
404
+ }
405
+ ],
406
+ "metadata": {
407
+ "kernelspec": {
408
+ "display_name": "Python 3 (ipykernel)",
409
+ "language": "python",
410
+ "name": "python3"
411
+ },
412
+ "language_info": {
413
+ "codemirror_mode": {
414
+ "name": "ipython",
415
+ "version": 3
416
+ },
417
+ "file_extension": ".py",
418
+ "mimetype": "text/x-python",
419
+ "name": "python",
420
+ "nbconvert_exporter": "python",
421
+ "pygments_lexer": "ipython3",
422
+ "version": "3.11.13"
423
+ }
424
+ },
425
+ "nbformat": 4,
426
+ "nbformat_minor": 5
427
+ }
evoagentx/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+ __version__ = '0.1.0'
evoagentx/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (201 Bytes). View file
 
evoagentx/actions/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .action import Action, ActionInput, ActionOutput
2
+ from .code_verification import CodeVerification
3
+ from .code_extraction import CodeExtraction
4
+
5
+ __all__ = ["Action", "ActionInput", "ActionOutput", "CodeVerification", "CodeExtraction"]
evoagentx/actions/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (494 Bytes). View file
 
evoagentx/actions/__pycache__/action.cpython-311.pyc ADDED
Binary file (15.3 kB). View file
 
evoagentx/actions/__pycache__/agent_generation.cpython-311.pyc ADDED
Binary file (15.1 kB). View file
 
evoagentx/actions/__pycache__/code_extraction.cpython-311.pyc ADDED
Binary file (14.6 kB). View file
 
evoagentx/actions/__pycache__/code_verification.cpython-311.pyc ADDED
Binary file (6.33 kB). View file
 
evoagentx/actions/__pycache__/customize_action.cpython-311.pyc ADDED
Binary file (31.5 kB). View file
 
evoagentx/actions/__pycache__/task_planning.cpython-311.pyc ADDED
Binary file (5.56 kB). View file
 
evoagentx/actions/action.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pydantic import model_validator
3
+ from pydantic_core import PydanticUndefined
4
+ from typing import Optional, Type, Tuple, Union, List, Any
5
+
6
+ from ..core.module import BaseModule
7
+ from ..core.module_utils import get_type_name
8
+ from ..core.registry import MODULE_REGISTRY
9
+ # from ..core.base_config import Parameter
10
+ from ..core.parser import Parser
11
+ from ..core.message import Message
12
+ from ..models.base_model import BaseLLM, LLMOutputParser
13
+ from ..tools.tool import Toolkit
14
+ from ..prompts.context_extraction import CONTEXT_EXTRACTION
15
+ from ..prompts.template import PromptTemplate
16
+
17
+
18
+ class ActionInput(LLMOutputParser):
19
+ """Input specification and parsing for actions.
20
+
21
+ This class defines the input requirements for actions and provides methods
22
+ to generate structured input specifications. It inherits from LLMOutputParser
23
+ to allow parsing of LLM outputs into structured inputs for actions.
24
+
25
+ Notes:
26
+ Parameters in ActionInput should be defined in Pydantic Field format.
27
+ For optional variables, use format:
28
+ var: Optional[int] = Field(default=None, description="xxx")
29
+ Remember to add `default=None` for optional parameters.
30
+ """
31
+
32
+ @classmethod
33
+ def get_input_specification(cls, ignore_fields: List[str] = []) -> str:
34
+ """Generate a JSON specification of the input requirements.
35
+
36
+ Examines the class fields and produces a structured specification of
37
+ the input parameters, including their types, descriptions, and whether
38
+ they are required.
39
+
40
+ Args:
41
+ ignore_fields (List[str]): List of field names to exclude from the specification.
42
+
43
+ Returns:
44
+ A JSON string containing the input specification, or an empty string
45
+ if no fields are defined or all are ignored.
46
+ """
47
+ fields_info = {}
48
+ attrs = cls.get_attrs()
49
+ for field_name, field_info in cls.model_fields.items():
50
+ if field_name in ignore_fields:
51
+ continue
52
+ if field_name not in attrs:
53
+ continue
54
+ field_type = get_type_name(field_info.annotation)
55
+ field_desc = field_info.description if field_info.description is not None else None
56
+ # field_required = field_info.is_required()
57
+ field_default = str(field_info.default) if field_info.default is not PydanticUndefined else None
58
+ field_required = True if field_default is None else False
59
+ description = field_type + ", "
60
+ if field_desc is not None:
61
+ description += (field_desc.strip() + ", ")
62
+ description += ("required" if field_required else "optional")
63
+ if field_default is not None:
64
+ description += (", Default value: " + field_default)
65
+ fields_info[field_name] = description
66
+
67
+ if len(fields_info) == 0:
68
+ return ""
69
+ fields_info_str = json.dumps(fields_info, indent=4)
70
+ return fields_info_str
71
+
72
+ @classmethod
73
+ def get_required_input_names(cls) -> List[str]:
74
+ """Get a list of all required input parameter names.
75
+
76
+ Returns:
77
+ List[str]: Names of all parameters that are required (don't have default values).
78
+ """
79
+ required_fields = []
80
+ attrs = cls.get_attrs()
81
+ for field_name, field_info in cls.model_fields.items():
82
+ if field_name not in attrs:
83
+ continue
84
+ field_default = field_info.default
85
+ # A field is required if it doesn't have a default value
86
+ if field_default is PydanticUndefined:
87
+ required_fields.append(field_name)
88
+ return required_fields
89
+
90
+
91
+ class ActionOutput(LLMOutputParser):
92
+ """Output representation for actions.
93
+
94
+ This class handles the structured output of actions, providing methods
95
+ to convert the output to structured data. It inherits from LLMOutputParser
96
+ to support parsing of LLM outputs into structured action results.
97
+ """
98
+
99
+ def to_str(self) -> str:
100
+ """Convert the output to a formatted JSON string.
101
+
102
+ Returns:
103
+ A pretty-printed JSON string representation of the structured data.
104
+ """
105
+ return json.dumps(self.get_structured_data(), indent=4)
106
+
107
+
108
+ class Action(BaseModule):
109
+ """Base class for all actions in the EvoAgentX framework.
110
+
111
+ Actions represent discrete operations that can be performed by agents.
112
+ They define inputs, outputs, and execution behavior, and can optionally
113
+ use tools to accomplish their tasks.
114
+
115
+ Attributes:
116
+ name (str): Unique identifier for the action.
117
+ description (str): Human-readable description of what the action does.
118
+ prompt (Optional[str]): Optional prompt template for this action.
119
+ tools (Optional[List[Toolkit]]): Optional list of tools that can be used by this action.
120
+ inputs_format (Optional[Type[ActionInput]]): Optional class defining the expected input structure.
121
+ outputs_format (Optional[Type[Parser]]): Optional class defining the expected output structure.
122
+ """
123
+
124
+ name: str
125
+ description: str
126
+ prompt: Optional[str] = None
127
+ prompt_template: Optional[PromptTemplate] = None
128
+ tools: Optional[List[Toolkit]] = None # specify the possible tool for the action
129
+ inputs_format: Optional[Type[ActionInput]] = None # specify the input format of the action
130
+ outputs_format: Optional[Type[Parser]] = None # specify the possible structured output format
131
+
132
+ def init_module(self):
133
+ """Initialize the action module.
134
+
135
+ This method is called after the action is instantiated.
136
+ Subclasses can override this to perform custom initialization.
137
+ """
138
+ pass
139
+
140
+ def to_dict(self, exclude_none: bool = True, ignore: List[str] = [], **kwargs) -> dict:
141
+ """
142
+ Convert the action to a dictionary for saving.
143
+ """
144
+ data = super().to_dict(exclude_none=exclude_none, ignore=ignore, **kwargs)
145
+ if self.inputs_format:
146
+ data["inputs_format"] = self.inputs_format.__name__
147
+ if self.outputs_format:
148
+ data["outputs_format"] = self.outputs_format.__name__
149
+ # TODO: customize serialization for the tools
150
+ return data
151
+
152
+ @model_validator(mode="before")
153
+ @classmethod
154
+ def validate_data(cls, data: Any) -> Any:
155
+ if "inputs_format" in data and data["inputs_format"] and isinstance(data["inputs_format"], str):
156
+ # only used when loading from a file
157
+ data["inputs_format"] = MODULE_REGISTRY.get_module(data["inputs_format"])
158
+ if "outputs_format" in data and data["outputs_format"] and isinstance(data["outputs_format"], str):
159
+ # only used when loading from a file
160
+ data["outputs_format"] = MODULE_REGISTRY.get_module(data["outputs_format"])
161
+ # TODO: customize loading for the tools
162
+ return data
163
+
164
+ def execute(self, llm: Optional[BaseLLM] = None, inputs: Optional[dict] = None, sys_msg: Optional[str]=None, return_prompt: bool = False, **kwargs) -> Optional[Union[Parser, Tuple[Parser, str]]]:
165
+ """Execute the action to produce a result.
166
+
167
+ This is the main entry point for executing an action. Subclasses must
168
+ implement this method to define the action's behavior.
169
+
170
+ Args:
171
+ llm (Optional[BaseLLM]): The LLM used to execute the action.
172
+ inputs (Optional[dict]): Input data for the action execution. The input data should be a dictionary that matches the input format of the provided prompt.
173
+ For example, if the prompt contains a variable `{input_var}`, the `inputs` dictionary should have a key `input_var`, otherwise the variable will be set to empty string.
174
+ sys_msg (Optional[str]): Optional system message for the LLM.
175
+ return_prompt (bool): Whether to return the complete prompt passed to the LLM.
176
+ **kwargs (Any): Additional keyword arguments for the execution.
177
+
178
+ Returns:
179
+ If `return_prompt` is False, the method returns a Parser object containing the structured result of the action.
180
+ If `return_prompt` is True, the method returns a tuple containing the Parser object and the complete prompt passed to the LLM.
181
+ """
182
+ raise NotImplementedError(f"`execute` function of {type(self).__name__} is not implemented!")
183
+
184
+ async def async_execute(self, llm: Optional[BaseLLM] = None, inputs: Optional[dict] = None, sys_msg: Optional[str]=None, return_prompt: bool = False, **kwargs) -> Optional[Union[Parser, Tuple[Parser, str]]]:
185
+ """
186
+ Asynchronous execution of the action.
187
+
188
+ This method is the asynchronous counterpart of the `execute` method.
189
+ It allows the action to be executed asynchronously using an LLM.
190
+ """
191
+ raise NotImplementedError(f"`async_execute` function of {type(self).__name__} is not implemented!")
192
+
193
+ class ContextExtraction(Action):
194
+ """Action for extracting structured inputs from context.
195
+
196
+ This action analyzes a conversation context to extract relevant information
197
+ that can be used as inputs for other actions. It uses the LLM to interpret
198
+ unstructured contextual information and format it according to the target
199
+ action's input requirements.
200
+ """
201
+
202
+ def __init__(self, **kwargs):
203
+ name = kwargs.pop("name") if "name" in kwargs else CONTEXT_EXTRACTION["name"]
204
+ description = kwargs.pop("description") if "description" in kwargs else CONTEXT_EXTRACTION["description"]
205
+ super().__init__(name=name, description=description, **kwargs)
206
+
207
+ def get_context_from_messages(self, messages: List[Message]) -> str:
208
+ str_context = "\n\n".join([str(msg) for msg in messages])
209
+ return str_context
210
+
211
+ def execute(self, llm: Optional[BaseLLM] = None, action: Action = None, context: List[Message] = None, **kwargs) -> Union[dict, None]:
212
+ """Extract structured inputs for an action from conversation context.
213
+
214
+ This method uses the LLM to analyze the conversation context and extract
215
+ information that matches the input requirements of the target action.
216
+
217
+ Args:
218
+ llm: The language model to use for extraction.
219
+ action: The target action whose input requirements (`inputs_format`) define what to extract.
220
+ context: List of messages providing the conversation context.
221
+ **kwargs: Additional keyword arguments.
222
+
223
+ Returns:
224
+ A dictionary containing the extracted inputs for the target action,
225
+ or None if extraction is not possible (e.g., if the action doesn't
226
+ require inputs or if context is missing).
227
+ """
228
+ if action is None or context is None:
229
+ return None
230
+
231
+ action_inputs_cls: Type[ActionInput] = action.inputs_format
232
+ if action_inputs_cls is None:
233
+ # the action does not require inputs
234
+ return None
235
+
236
+ action_inputs_desc = action_inputs_cls.get_input_specification()
237
+ str_context = self.get_context_from_messages(messages=context)
238
+
239
+ if not action_inputs_desc or not str_context:
240
+ return None
241
+
242
+ prompt = CONTEXT_EXTRACTION["prompt"].format(
243
+ context=str_context,
244
+ action_name=action.name,
245
+ action_description=action.description,
246
+ action_inputs=action_inputs_desc
247
+ )
248
+
249
+ action_inputs = llm.generate(
250
+ prompt=prompt,
251
+ system_message=CONTEXT_EXTRACTION["system_prompt"],
252
+ parser=action_inputs_cls
253
+ )
254
+ action_inputs_data = action_inputs.get_structured_data()
255
+
256
+ return action_inputs_data
evoagentx/actions/agent_generation.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from pydantic import Field, model_validator
3
+ from typing import Optional, List
4
+
5
+ from ..core.logging import logger
6
+ from ..core.module import BaseModule
7
+ from ..core.base_config import Parameter
8
+ from ..models.base_model import BaseLLM
9
+ from .action import Action, ActionInput, ActionOutput
10
+ from ..prompts.agent_generator import AGENT_GENERATION_ACTION
11
+ from ..prompts.tool_calling import AGENT_GENERATION_TOOLS_PROMPT
12
+ from ..utils.utils import normalize_text
13
+
14
+ class AgentGenerationInput(ActionInput):
15
+ """
16
+ Input specification for the agent generation action.
17
+ """
18
+
19
+ goal: str = Field(description="A detailed statement of the workflow's goal, explaining the objectives the entire workflow aims to achieve")
20
+ workflow: str = Field(description="An overview of the entire workflow, detailing all sub-tasks with their respective names, descriptions, inputs, and outputs")
21
+ task: str = Field(description="A detailed JSON representation of the sub-task requiring agent generation. It should include the task's name, description, inputs, and outputs.")
22
+
23
+ history: Optional[str] = Field(default=None, description="Optional field containing previously selected or generated agents.")
24
+ suggestion: Optional[str] = Field(default=None, description="Optional suggestions to refine the generated agents.")
25
+ existing_agents: Optional[str] = Field(default=None, description="Optional field containing the description of predefined agents, including each agent's name, role, and available actions.")
26
+ tools: Optional[str] = Field(default=None, description="Optional field containing the description of tools that agents can use, including each tool's name and functionality.")
27
+
28
+
29
+ class GeneratedAgent(BaseModule):
30
+ """
31
+ Representation of a generated agent with validation capabilities.
32
+ """
33
+
34
+ name: str
35
+ description: str
36
+ inputs: List[Parameter]
37
+ outputs: List[Parameter]
38
+ prompt: str
39
+ tool_names: Optional[List[str]] = None
40
+
41
+ @classmethod
42
+ def find_output_name(cls, text: str, outputs: List[str]):
43
+ def sim(t1: str, t2: str):
44
+ t1_words = normalize_text(t1).split()
45
+ t2_words = normalize_text(t2).split()
46
+ return len(set(t1_words)&set(t2_words))
47
+
48
+ similarities = [sim(text, output) for output in outputs]
49
+ max_sim = max(similarities)
50
+ return outputs[similarities.index(max_sim)]
51
+
52
+ @model_validator(mode="after")
53
+ @classmethod
54
+ def validate_prompt(cls, agent: 'GeneratedAgent'):
55
+ """Validate and fix the agent's prompt template.
56
+
57
+ This validator ensures that:
58
+ 1. All input parameters are properly referenced in the prompt
59
+ 2. Input references use the correct format with braces
60
+ 3. All output sections match the defined output parameters
61
+
62
+ If there are mismatches in the output sections, it attempts to
63
+ fix them by finding the most similar output name.
64
+
65
+ Args:
66
+ agent: The GeneratedAgent instance to validate.
67
+
68
+ Returns:
69
+ The validated and potentially modified GeneratedAgent.
70
+
71
+ Raises:
72
+ ValueError: If inputs are missing from the prompt or output sections don't match the defined outputs.
73
+ """
74
+ # check whether all the inputs are present in the prompt
75
+ input_names = [inp.name for inp in agent.inputs]
76
+ prompt_has_inputs = [name in agent.prompt for name in input_names]
77
+ if not all(prompt_has_inputs):
78
+ missing_input_names = [name for name, has_input in zip(input_names, prompt_has_inputs) if not has_input]
79
+ raise ValueError(f'The prompt miss inputs: {missing_input_names}')
80
+
81
+ # check the format of the prompt to make sure it is wrapped in brackets.
82
+ pattern = r"### Instructions(.*?)### Output Format"
83
+ prompt = agent.prompt
84
+
85
+ def replace_with_braces(match):
86
+ instructions = match.group(1)
87
+ for name in input_names:
88
+ instructions = re.sub(fr'<input>{{*\b{re.escape(name)}\b}}*</input>', fr'<input>{{{name}}}</input>', instructions)
89
+ return "### Instructions" + instructions + "### Output Format"
90
+
91
+ modified_prompt = re.sub(pattern, replace_with_braces, prompt, flags=re.DOTALL)
92
+ agent.prompt = modified_prompt
93
+
94
+ # check whether all the outputs are present in the prompt
95
+ prompt = agent.prompt
96
+ pattern = r"### Output Format(.*)"
97
+ outputs_names = [out.name for out in agent.outputs]
98
+
99
+ def fix_output_names(match):
100
+ output_format = match.group(1)
101
+ matches = re.findall(r"## ([^\n#]+)", output_format, flags=re.DOTALL)
102
+ generated_outputs = [m.strip() for m in matches if m.strip() != "Thought"]
103
+ # check the number of generated outputs and agent outputs
104
+ if len(generated_outputs) != len(outputs_names):
105
+ raise ValueError(f"The number of outputs in the prompt is different from that defined in the `outputs` field of the agent. The outputs in the prompt are: {generated_outputs}, while the outputs from the agent's `outputs` field are: {outputs_names}")
106
+ # check whether the generated output names are the same as agent outputs
107
+ for generated_output in generated_outputs:
108
+ if generated_output not in outputs_names:
109
+ most_similar_output_name = cls.find_output_name(text=generated_output, outputs=outputs_names)
110
+ output_format = output_format.replace(generated_output, most_similar_output_name)
111
+ logger.warning(f"Couldn't find output name in prompt ('{generated_output}') in agent's outputs. Replace it with the most similar agent output: '{most_similar_output_name}'")
112
+ return "### Output Format" + output_format
113
+
114
+ modified_prompt = re.sub(pattern, fix_output_names, prompt, flags=re.DOTALL)
115
+ agent.prompt = modified_prompt
116
+
117
+ return agent
118
+
119
+
120
+ class AgentGenerationOutput(ActionOutput):
121
+
122
+ selected_agents: List[str] = Field(description="A list of selected agent's names")
123
+ generated_agents: List[GeneratedAgent] = Field(description="A list of generated agetns to address a sub-task")
124
+
125
+
126
+ class AgentGeneration(Action):
127
+ """
128
+ Action for generating agent specifications for workflow tasks.
129
+
130
+ This action analyzes task requirements and generates appropriate agent
131
+ specifications, including their prompts, inputs, and outputs. It can either
132
+ select from existing agents or create new ones tailored to the task.
133
+ """
134
+
135
+ def __init__(self, **kwargs):
136
+ name = kwargs.pop("name") if "name" in kwargs else AGENT_GENERATION_ACTION["name"]
137
+ description = kwargs.pop("description") if "description" in kwargs else AGENT_GENERATION_ACTION["description"]
138
+ prompt = kwargs.pop("prompt") if "prompt" in kwargs else AGENT_GENERATION_ACTION["prompt"]
139
+ # inputs_format = kwargs.pop("inputs_format") if "inputs_format" in kwargs else AgentGenerationInput
140
+ # outputs_format = kwargs.pop("outputs_format") if "outputs_format" in kwargs else AgentGenerationOutput
141
+ inputs_format = kwargs.pop("inputs_format", None) or AgentGenerationInput
142
+ outputs_format = kwargs.pop("outputs_format", None) or AgentGenerationOutput
143
+ tools = kwargs.pop("tools", None)
144
+ super().__init__(name=name, description=description, prompt=prompt, inputs_format=inputs_format, outputs_format=outputs_format, **kwargs)
145
+ self.tools = tools
146
+
147
+ def execute(self, llm: Optional[BaseLLM] = None, inputs: Optional[dict] = None, sys_msg: Optional[str]=None, return_prompt: bool = False, **kwargs) -> AgentGenerationOutput:
148
+ """Execute the agent generation process.
149
+
150
+ This method uses the provided language model to generate agent specifications
151
+ based on the workflow context and task requirements.
152
+
153
+ Args:
154
+ llm: The language model to use for generation.
155
+ inputs: Input data containing workflow and task information.
156
+ sys_msg: Optional system message for the language model.
157
+ return_prompt: Whether to return both the generated agents and the prompt used.
158
+ **kwargs: Additional keyword arguments.
159
+
160
+ Returns:
161
+ If return_prompt is False (default): The generated agents output.
162
+ If return_prompt is True: A tuple of (generated agents, prompt used).
163
+
164
+ Raises:
165
+ ValueError: If the inputs are None or empty.
166
+ """
167
+ if not inputs:
168
+ logger.error("AgentGeneration action received invalid `inputs`: None or empty.")
169
+ raise ValueError('The `inputs` to AgentGeneration action is None or empty.')
170
+
171
+ inputs_format: AgentGenerationInput = self.inputs_format
172
+ outputs_format: AgentGenerationOutput = self.outputs_format
173
+
174
+ prompt_params_names = inputs_format.get_attrs()
175
+ prompt_params_values = {param: inputs.get(param, "") for param in prompt_params_names}
176
+ if self.tools:
177
+ tool_description = [
178
+ {
179
+ tool.name: [
180
+ s["function"]["description"] for s in tool.get_tool_schemas()
181
+ ],
182
+ }
183
+ for tool in self.tools
184
+ ]
185
+ prompt_params_values["tools"] = AGENT_GENERATION_TOOLS_PROMPT.format(tools_description=tool_description)
186
+ prompt = self.prompt.format(**prompt_params_values)
187
+ agents = llm.generate(
188
+ prompt = prompt,
189
+ system_message = sys_msg,
190
+ parser=outputs_format,
191
+ parse_mode="json"
192
+ )
193
+
194
+ if return_prompt:
195
+ return agents, prompt
196
+
197
+ return agents
198
+
evoagentx/actions/code_extraction.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional, List, Dict
3
+ from pydantic import Field
4
+
5
+ from ..models.base_model import BaseLLM, LLMOutputParser
6
+ from .action import Action, ActionInput, ActionOutput
7
+ from ..prompts.code_extraction import CODE_EXTRACTION
8
+
9
+
10
+ class CodeExtractionInput(ActionInput):
11
+ """
12
+ Input parameters for the CodeExtraction action.
13
+ """
14
+ code_string: str = Field(description="The string containing code blocks to extract")
15
+ target_directory: str = Field(description="The directory path where extracted code files will be saved")
16
+ project_name: Optional[str] = Field(default=None, description="Optional name for the project folder")
17
+
18
+
19
+ class CodeExtractionOutput(ActionOutput):
20
+ """
21
+ Output of the CodeExtraction action.
22
+ """
23
+ extracted_files: Dict[str, str] = Field(description="Map of filename to file path of saved files")
24
+ main_file: Optional[str] = Field(default=None, description="Path to the main file if identified")
25
+ error: Optional[str] = Field(default=None, description="Error message if any operation failed")
26
+
27
+
28
+ class CodeBlockInfo(LLMOutputParser):
29
+ """
30
+ Information about an extracted code block.
31
+ """
32
+ language: str = Field(description="Programming language of the code block")
33
+ filename: str = Field(description="Suggested filename for the code block")
34
+ content: str = Field(description="The actual code content")
35
+
36
+
37
+ class CodeBlockList(LLMOutputParser):
38
+ """
39
+ List of code blocks extracted from text.
40
+ """
41
+ code_blocks: List[CodeBlockInfo] = Field(description="List of code blocks")
42
+
43
+
44
+ class CodeExtraction(Action):
45
+ """
46
+ An action that extracts and organizes code blocks from text.
47
+
48
+ This action uses an LLM to analyze text containing code blocks, extract them,
49
+ suggest appropriate filenames, and save them to a specified directory. It can
50
+ also identify which file is likely the main entry point based on heuristics.
51
+
52
+ Attributes:
53
+ name: The name of the action.
54
+ description: A description of what the action does.
55
+ prompt: The prompt template used by the action.
56
+ inputs_format: The expected format of inputs to this action.
57
+ outputs_format: The format of the action's output.
58
+ """
59
+
60
+ def __init__(self, **kwargs):
61
+
62
+ name = kwargs.pop("name") if "name" in kwargs else CODE_EXTRACTION["name"]
63
+ description = kwargs.pop("description") if "description" in kwargs else CODE_EXTRACTION["description"]
64
+ prompt = kwargs.pop("prompt") if "prompt" in kwargs else CODE_EXTRACTION["prompt"]
65
+ # inputs_format = kwargs.pop("inputs_format") if "inputs_format" in kwargs else CodeExtractionInput
66
+ # outputs_format = kwargs.pop("outputs_format") if "outputs_format" in kwargs else CodeExtractionOutput
67
+ inputs_format = kwargs.pop("inputs_format", None) or CodeExtractionInput
68
+ outputs_format = kwargs.pop("outputs_format", None) or CodeExtractionOutput
69
+ super().__init__(name=name, description=description, prompt=prompt, inputs_format=inputs_format, outputs_format=outputs_format, **kwargs)
70
+
71
+ def identify_main_file(self, saved_files: Dict[str, str]) -> Optional[str]:
72
+ """Identify the main file from the saved files based on content and file type.
73
+
74
+ This method uses a combination of common filename conventions and content
75
+ analysis to determine which file is likely the main entry point of a project.
76
+
77
+ Args:
78
+ saved_files: Dictionary mapping filenames to their full paths
79
+
80
+ Returns:
81
+ Path to the main file if found, None otherwise
82
+
83
+ """
84
+ # Priority lookup for common main files by language
85
+ main_file_priorities = [
86
+ # HTML files
87
+ "index.html",
88
+ # Python files
89
+ "main.py",
90
+ "app.py",
91
+ # JavaScript files
92
+ "index.js",
93
+ "main.js",
94
+ "app.js",
95
+ # Java files
96
+ "Main.java",
97
+ # C/C++ files
98
+ "main.cpp",
99
+ "main.c",
100
+ # Go files
101
+ "main.go",
102
+ # Other common entry points
103
+ "index.php",
104
+ "Program.cs"
105
+ ]
106
+
107
+ # First check priority list
108
+ for main_file in main_file_priorities:
109
+ if main_file in saved_files:
110
+ return saved_files[main_file]
111
+
112
+ # If no priority file found, use heuristics based on file extensions
113
+
114
+ # If we have HTML files, use the first one
115
+ html_files = {k: v for k, v in saved_files.items() if k.endswith('.html')}
116
+ if html_files:
117
+ return next(iter(html_files.values()))
118
+
119
+ # Check for Python files with "__main__" section
120
+ py_files = {k: v for k, v in saved_files.items() if k.endswith('.py')}
121
+ if py_files:
122
+ for filename, path in py_files.items():
123
+ with open(path, 'r', encoding='utf-8') as f:
124
+ content = f.read()
125
+ if "if __name__ == '__main__'" in content or 'if __name__ == "__main__"' in content:
126
+ return path
127
+ # If no main found, return the first Python file
128
+ if py_files:
129
+ return next(iter(py_files.values()))
130
+
131
+ # If we have Java files, look for one with a main method
132
+ java_files = {k: v for k, v in saved_files.items() if k.endswith('.java')}
133
+ if java_files:
134
+ for filename, path in java_files.items():
135
+ with open(path, 'r', encoding='utf-8') as f:
136
+ content = f.read()
137
+ if "public static void main" in content:
138
+ return path
139
+ # If no main found, return the first Java file
140
+ if java_files:
141
+ return next(iter(java_files.values()))
142
+
143
+ # For JavaScript applications
144
+ js_files = {k: v for k, v in saved_files.items() if k.endswith('.js')}
145
+ if js_files:
146
+ return next(iter(js_files.values()))
147
+
148
+ # If all else fails, return the first file
149
+ if saved_files:
150
+ return next(iter(saved_files.values()))
151
+
152
+ # No files found
153
+ return None
154
+
155
+ def save_code_blocks(self, code_blocks: List[Dict], target_directory: str) -> Dict[str, str]:
156
+ """Save code blocks to files in the target directory.
157
+
158
+ Creates the target directory if it doesn't exist and saves each code block
159
+ to a file with an appropriate name, handling filename conflicts.
160
+
161
+ Args:
162
+ code_blocks: List of dictionaries containing code block information
163
+ target_directory: Directory path where files should be saved
164
+
165
+ Returns:
166
+ Dictionary mapping filenames to their full paths
167
+ """
168
+ os.makedirs(target_directory, exist_ok=True)
169
+ saved_files = {}
170
+
171
+ for block in code_blocks:
172
+ filename = block.get("filename", "unknown.txt")
173
+ content = block.get("content", "")
174
+
175
+ # Skip empty blocks
176
+ if not content.strip():
177
+ continue
178
+
179
+ # Handle filename conflicts
180
+ base_filename = filename
181
+ counter = 1
182
+ while filename in saved_files:
183
+ name_parts = base_filename.split('.')
184
+ if len(name_parts) > 1:
185
+ filename = f"{'.'.join(name_parts[:-1])}_{counter}.{name_parts[-1]}"
186
+ else:
187
+ filename = f"{base_filename}_{counter}"
188
+ counter += 1
189
+
190
+ # Save to file
191
+ file_path = os.path.join(target_directory, filename)
192
+ with open(file_path, 'w', encoding='utf-8') as f:
193
+ f.write(content)
194
+
195
+ # Add to map
196
+ saved_files[filename] = file_path
197
+
198
+ return saved_files
199
+
200
+ def execute(self, llm: Optional[BaseLLM] = None, inputs: Optional[dict] = None, sys_msg: Optional[str]=None, return_prompt: bool = False, **kwargs) -> CodeExtractionOutput:
201
+ """Execute the CodeExtraction action.
202
+
203
+ Extracts code blocks from the provided text using the specified LLM,
204
+ saves them to the target directory, and identifies the main file.
205
+
206
+ Args:
207
+ llm: The LLM to use for code extraction
208
+ inputs: Dictionary containing:
209
+ - code_string: The string with code blocks to extract
210
+ - target_directory: Where to save the files
211
+ - project_name: Optional project folder name
212
+ sys_msg: Optional system message override for the LLM
213
+ return_prompt: Whether to return the prompt along with the result
214
+ **kwargs (Any): Additional keyword arguments
215
+
216
+ Returns:
217
+ CodeExtractionOutput with extracted file information
218
+ """
219
+ if not llm:
220
+ error_msg = "CodeExtraction action requires an LLM."
221
+ return CodeExtractionOutput(extracted_files={}, error=error_msg)
222
+
223
+ if not inputs:
224
+ error_msg = "CodeExtraction action received invalid `inputs`: None or empty."
225
+ return CodeExtractionOutput(extracted_files={}, error=error_msg)
226
+
227
+ code_string = inputs.get("code_string", "")
228
+ target_directory = inputs.get("target_directory", "")
229
+ project_name = inputs.get("project_name", None)
230
+
231
+ if not code_string:
232
+ error_msg = "No code string provided."
233
+ return CodeExtractionOutput(extracted_files={}, error=error_msg)
234
+
235
+ if not target_directory:
236
+ error_msg = "No target directory provided."
237
+ return CodeExtractionOutput(extracted_files={}, error=error_msg)
238
+
239
+ # Create project folder if name is provided
240
+ if project_name:
241
+ project_dir = os.path.join(target_directory, project_name)
242
+ else:
243
+ project_dir = target_directory
244
+
245
+ try:
246
+ # Use LLM to extract code blocks and suggest filenames
247
+ prompt_params = {"code_string": code_string}
248
+ system_message = CODE_EXTRACTION["system_prompt"] if sys_msg is None else sys_msg
249
+
250
+ llm_response: CodeBlockList = llm.generate(
251
+ prompt=self.prompt.format(**prompt_params),
252
+ system_message=system_message,
253
+ parser=CodeBlockList,
254
+ parse_mode="json"
255
+ )
256
+ code_blocks = llm_response.get_structured_data().get("code_blocks", [])
257
+
258
+ # Save code blocks to files
259
+ saved_files = self.save_code_blocks(code_blocks, project_dir)
260
+
261
+ # Identify main file
262
+ main_file = self.identify_main_file(saved_files)
263
+
264
+ result = CodeExtractionOutput(
265
+ extracted_files=saved_files,
266
+ main_file=main_file
267
+ )
268
+
269
+ if return_prompt:
270
+ return result, self.prompt.format(**prompt_params)
271
+
272
+ return result
273
+
274
+ except Exception as e:
275
+ error_msg = f"Error extracting code: {str(e)}"
276
+ return CodeExtractionOutput(extracted_files={}, error=error_msg)
evoagentx/actions/code_verification.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import Field
2
+ from typing import Optional
3
+
4
+ from ..core.logging import logger
5
+ from ..core.module_utils import extract_code_blocks
6
+ from ..models.base_model import BaseLLM
7
+ from .action import Action, ActionInput, ActionOutput
8
+ from ..prompts.code_verification import CODE_VERIFICATION_ACTION
9
+
10
+
11
+ class CodeVerificationInput(ActionInput):
12
+
13
+ code: str = Field(description="The code string to be verified for correctness and completeness.")
14
+ requirements: Optional[str] = Field(default=None, description="Optional field containing requirements or specifications for the code.")
15
+
16
+
17
+ class CodeVerificationOutput(ActionOutput):
18
+
19
+ analysis_summary: Optional[str] = Field(default=None, description="Brief summary of your findings, highlighting key issues or confirming overall quality.")
20
+ issues_identified: Optional[str] = Field(default=None, description="Categorized list of issues found, with explanation of impact and severity.")
21
+ thought_process: Optional[str] = Field(default=None, description="Detailed explanation of your verification reasoning and methodology applied.")
22
+ modification_strategy: Optional[str] = Field(default=None, description="Describe the changes you made (or will make) to address the issues. Include any assumptions, design choices, or additional components you decided to add to make the code complete and robust.")
23
+ verified_code: str = Field(description="The complete, corrected code if issues are found, or the original code if no issues are found.")
24
+
25
+
26
+ class CodeVerification(Action):
27
+
28
+ def __init__(self, **kwargs):
29
+
30
+ name = kwargs.pop("name") if "name" in kwargs else CODE_VERIFICATION_ACTION["name"]
31
+ description = kwargs.pop("description") if "description" in kwargs else CODE_VERIFICATION_ACTION["description"]
32
+ prompt = kwargs.pop("prompt") if "prompt" in kwargs else CODE_VERIFICATION_ACTION["prompt"]
33
+ # inputs_format = kwargs.pop("inputs_format") if "inputs_format" in kwargs else CodeVerificationInput
34
+ # outputs_format = kwargs.pop("outputs_format") if "outputs_format" in kwargs else CodeVerificationOutput
35
+ inputs_format = kwargs.pop("inputs_format", None) or CodeVerificationInput
36
+ outputs_format = kwargs.pop("outputs_format", None) or CodeVerificationOutput
37
+ super().__init__(name=name, description=description, prompt=prompt, inputs_format=inputs_format, outputs_format=outputs_format, **kwargs)
38
+
39
+ def execute(self, llm: Optional[BaseLLM] = None, inputs: Optional[dict] = None, sys_msg: Optional[str]=None, return_prompt: bool = False, **kwargs) -> CodeVerificationOutput:
40
+
41
+ if not inputs:
42
+ logger.error("CodeVerification action received invalid `inputs`: None or empty.")
43
+ raise ValueError('The `inputs` to CodeVerification action is None or empty.')
44
+
45
+ prompt_params_names = ["code", "requirements"]
46
+ prompt_params_values = {param: inputs.get(param, "Not Provided") for param in prompt_params_names}
47
+ prompt = self.prompt.format(**prompt_params_values)
48
+ response = llm.generate(prompt = prompt, system_message=sys_msg)
49
+
50
+ try:
51
+ verification_result = self.outputs_format.parse(response.content, parse_mode="title")
52
+ except Exception:
53
+ try:
54
+ code_blocks = extract_code_blocks(response.content, return_type=True)
55
+ code = "\n\n".join([f"```{code_type}\n{code}\n```" for code_type, code in code_blocks])
56
+ verification_result = self.outputs_format(verified_code=code)
57
+ except Exception:
58
+ raise ValueError(f"Failed to extract code blocks from the response: {response.content}")
59
+
60
+ if return_prompt:
61
+ return verification_result, prompt
62
+
63
+ return verification_result
evoagentx/actions/customize_action.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import Field
2
+ from typing import Optional, Any, Callable, List, Union
3
+ import re
4
+ import json
5
+ import asyncio
6
+ import inspect
7
+ import concurrent.futures
8
+
9
+ from ..core.logging import logger
10
+ from ..models.base_model import BaseLLM
11
+ from .action import Action
12
+ from ..core.message import Message
13
+ from ..prompts.template import StringTemplate, ChatTemplate
14
+ from ..prompts.tool_calling import OUTPUT_EXTRACTION_PROMPT, TOOL_CALLING_TEMPLATE, TOOL_CALLING_HISTORY_PROMPT, TOOL_CALLING_RETRY_PROMPT
15
+ from ..tools.tool import Toolkit
16
+ from ..core.registry import MODULE_REGISTRY
17
+ from ..models.base_model import LLMOutputParser
18
+ from ..core.module_utils import parse_json_from_llm_output, parse_json_from_text
19
+
20
+ class CustomizeAction(Action):
21
+
22
+ parse_mode: Optional[str] = Field(default="title", description="the parse mode of the action, must be one of: ['title', 'str', 'json', 'xml', 'custom']")
23
+ parse_func: Optional[Callable] = Field(default=None, exclude=True, description="the function to parse the LLM output. It receives the LLM output and returns a dict.")
24
+ title_format: Optional[str] = Field(default="## {title}", exclude=True, description="the format of the title. It is used when the `parse_mode` is 'title'.")
25
+ custom_output_format: Optional[str] = Field(default=None, exclude=True, description="the format of the output. It is used when the `prompt_template` is provided.")
26
+
27
+ tools: Optional[List[Toolkit]] = Field(default=None, description="The tools that the action can use")
28
+ conversation: Optional[Message] = Field(default=None, description="Current conversation state")
29
+
30
+ max_tool_try: int = Field(default=2, description="Maximum number of tool calling attempts allowed")
31
+
32
+ def __init__(self, **kwargs):
33
+
34
+ name = kwargs.pop("name", "CustomizeAction")
35
+ description = kwargs.pop("description", "Customized action that can use tools to accomplish its task")
36
+
37
+ super().__init__(name=name, description=description, **kwargs)
38
+
39
+ # Validate that at least one of prompt or prompt_template is provided
40
+ if not self.prompt and not self.prompt_template:
41
+ raise ValueError("`prompt` or `prompt_template` is required when creating CustomizeAction action")
42
+ # Prioritize template and give warning if both are provided
43
+ if self.prompt and self.prompt_template:
44
+ logger.warning("Both `prompt` and `prompt_template` are provided for CustomizeAction action. Prioritizing `prompt_template` and ignoring `prompt`.")
45
+ if self.tools:
46
+ self.tools_caller = {}
47
+ self.add_tools(self.tools)
48
+
49
+ def prepare_action_prompt(
50
+ self,
51
+ inputs: Optional[dict] = None,
52
+ system_prompt: Optional[str] = None,
53
+ **kwargs
54
+ ) -> Union[str, List[dict]]:
55
+ """Prepare prompt for action execution.
56
+
57
+ This helper function transforms the input dictionary into a formatted prompt
58
+ for the language model, handling different prompting modes.
59
+
60
+ Args:
61
+ inputs: Dictionary of input parameters
62
+ system_prompt: Optional system prompt to include
63
+
64
+ Returns:
65
+ Union[str, List[dict]]: Formatted prompt ready for LLM (string or chat messages)
66
+
67
+ Raises:
68
+ TypeError: If an input value type is not supported
69
+ ValueError: If neither prompt nor prompt_template is available
70
+ """
71
+ # Process inputs into prompt parameter values
72
+ if inputs is None:
73
+ inputs = {}
74
+
75
+ prompt_params_names = self.inputs_format.get_attrs()
76
+ prompt_params_values = {}
77
+ for param in prompt_params_names:
78
+ value = inputs.get(param, "")
79
+ if isinstance(value, str):
80
+ prompt_params_values[param] = value
81
+ elif isinstance(value, (dict, list)):
82
+ prompt_params_values[param] = json.dumps(value, indent=4)
83
+ else:
84
+ raise TypeError(f"The input type {type(value)} is invalid! Valid types: [str, dict, list].")
85
+
86
+ if self.prompt:
87
+ prompt = self.prompt.format(**prompt_params_values) if prompt_params_values else self.prompt
88
+ if self.tools:
89
+ tools_schemas = [j["function"] for i in [tool.get_tool_schemas() for tool in self.tools] for j in i]
90
+ prompt += "\n\n" + TOOL_CALLING_TEMPLATE.format(tools_description = tools_schemas)
91
+ return prompt
92
+ else:
93
+ # Use goal-based tool calling mode
94
+ if self.tools:
95
+ self.prompt_template.set_tools(self.tools)
96
+ return self.prompt_template.format(
97
+ system_prompt=system_prompt,
98
+ values=prompt_params_values,
99
+ inputs_format=self.inputs_format,
100
+ outputs_format=self.outputs_format,
101
+ parse_mode=self.parse_mode,
102
+ title_format=self.title_format,
103
+ custom_output_format=self.custom_output_format,
104
+ tools=self.tools
105
+ )
106
+
107
+ def prepare_extraction_prompt(self, llm_output_content: str) -> str:
108
+ """Prepare extraction prompt for fallback extraction when parsing fails.
109
+
110
+ Args:
111
+ self: The action instance
112
+ llm_output_content: Raw output content from LLM
113
+
114
+ Returns:
115
+ str: Formatted extraction prompt
116
+ """
117
+ attr_descriptions: dict = self.outputs_format.get_attr_descriptions()
118
+ output_description_list = []
119
+ for i, (name, desc) in enumerate(attr_descriptions.items()):
120
+ output_description_list.append(f"{i+1}. {name}\nDescription: {desc}")
121
+ output_description = "\n\n".join(output_description_list)
122
+ return OUTPUT_EXTRACTION_PROMPT.format(text=llm_output_content, output_description=output_description)
123
+
124
+ def _get_unique_class_name(self, candidate_name: str) -> str:
125
+ """
126
+ Get a unique class name by checking if it already exists in the registry.
127
+ If it does, append "Vx" to make it unique.
128
+ """
129
+ if not MODULE_REGISTRY.has_module(candidate_name):
130
+ return candidate_name
131
+
132
+ i = 1
133
+ while True:
134
+ unique_name = f"{candidate_name}V{i}"
135
+ if not MODULE_REGISTRY.has_module(unique_name):
136
+ break
137
+ i += 1
138
+ return unique_name
139
+
140
+ def add_tools(self, tools: Union[Toolkit, List[Toolkit]]):
141
+ if not tools:
142
+ return
143
+ if isinstance(tools,Toolkit):
144
+ tools = [tools]
145
+ if not all(isinstance(tool, Toolkit) for tool in tools):
146
+ raise TypeError("`tools` must be a Toolkit or list of Toolkit instances.")
147
+ if not self.tools:
148
+ self.tools_caller = {}
149
+ self.tools = []
150
+ # self.tools += tools
151
+ # tools_callers = [tool.get_tools() for tool in tools]
152
+ # tools_callers = [j for i in tools_callers for j in i]
153
+ # for tool_caller in tools_callers:
154
+ # self.tools_caller[tool_caller.name] = tool_caller
155
+
156
+ # avoid duplication & type checks
157
+ for toolkit in tools:
158
+ try:
159
+ tool_callers = toolkit.get_tools()
160
+ if not isinstance(tool_callers, list):
161
+ logger.warning(f"Expected list of tool functions from '{toolkit.name}.get_tools()', got {type(tool_callers)}.")
162
+ continue
163
+
164
+ # add tool callers to the tools_caller dictionary
165
+ valid_tools_count = 0
166
+ valid_tools_names, valid_tool_callers = [], []
167
+ for tool_caller in tool_callers:
168
+ tool_caller_name = getattr(tool_caller, "name", None)
169
+ if not tool_caller_name or not callable(tool_caller):
170
+ logger.warning(f"Invalid tool function in '{toolkit.name}': missing name or not callable.")
171
+ continue
172
+ if tool_caller_name in self.tools_caller:
173
+ logger.warning(f"Duplicate tool function '{tool_caller_name}' detected. Overwriting previous function.")
174
+ # self.tools_caller[tool_caller_name] = tool_caller
175
+ valid_tools_count += 1
176
+ valid_tools_names.append(tool_caller_name)
177
+ valid_tool_callers.append(tool_caller)
178
+
179
+ if valid_tools_count == 0:
180
+ logger.info(f"No valid tools found in toolkit '{toolkit.name}'. Skipping.")
181
+ continue
182
+
183
+ if valid_tools_count > 0 and all(name in self.tools_caller for name in valid_tools_names):
184
+ logger.info(f"All tools from toolkit '{toolkit.name}' are already added. Skipping.")
185
+ continue
186
+
187
+ if valid_tools_count > 0:
188
+ self.tools_caller.update({name: caller for name, caller in zip(valid_tools_names, valid_tool_callers)})
189
+
190
+ # only add toolkit if at least one valid tool is added and toolkit is not already added
191
+ existing_toolkit_names = {tkt.name for tkt in self.tools}
192
+ if valid_tools_count > 0 and toolkit.name not in existing_toolkit_names:
193
+ self.tools.append(toolkit)
194
+ if valid_tools_count > 0:
195
+ logger.info(f"Added toolkit '{toolkit.name}' with {valid_tools_count} valid tools in {self.name}: {valid_tools_names}.")
196
+
197
+ except Exception as e:
198
+ logger.error(f"Failed to load tools from toolkit '{toolkit.name}': {e}")
199
+
200
+
201
+ def _extract_tool_calls(self, llm_output: str, llm: Optional[BaseLLM] = None) -> List[dict]:
202
+ pattern = r"<ToolCalling>\s*(.*?)\s*</ToolCalling>"
203
+
204
+
205
+ # Find all ToolCalling blocks in the output
206
+ matches = re.findall(pattern, llm_output, re.DOTALL)
207
+
208
+ if not matches:
209
+ return []
210
+
211
+ parsed_tool_calls = []
212
+ for match_content in matches:
213
+ try:
214
+ json_content = match_content.strip()
215
+ json_list = parse_json_from_text(json_content)
216
+ if not json_list:
217
+ logger.warning("No valid JSON found in ToolCalling block")
218
+ continue
219
+ # Only use the first JSON string from each block
220
+ parsed_tool_call = json.loads(json_list[0])
221
+ if isinstance(parsed_tool_call, dict):
222
+ parsed_tool_calls.append(parsed_tool_call)
223
+ elif isinstance(parsed_tool_call, list):
224
+ parsed_tool_calls.extend(parsed_tool_call)
225
+ else:
226
+ logger.warning(f"Invalid tool call format: {parsed_tool_call}")
227
+ continue
228
+ except (json.JSONDecodeError, IndexError) as e:
229
+ logger.warning(f"Failed to parse tool calls from LLM output: {e}")
230
+ if llm is not None:
231
+ retry_prompt = TOOL_CALLING_RETRY_PROMPT.format(text=match_content)
232
+ try:
233
+ fixed_output = llm.generate(prompt=retry_prompt).content.strip()
234
+ logger.info(f"Retrying tool call parse with fixed output:\n{fixed_output}")
235
+
236
+ fixed_list = parse_json_from_text(fixed_output)
237
+ if fixed_list:
238
+ parsed_tool_call = json.loads(fixed_list[0])
239
+ if isinstance(parsed_tool_call, dict):
240
+ parsed_tool_calls.append(parsed_tool_call)
241
+ elif isinstance(parsed_tool_call, list):
242
+ parsed_tool_calls.extend(parsed_tool_call)
243
+ except Exception as retry_err:
244
+ logger.error(f"Retry failed: {retry_err}")
245
+ continue
246
+ else:
247
+ continue
248
+
249
+ return parsed_tool_calls
250
+
251
+ def _extract_output(self, llm_output: Any, llm: BaseLLM = None, **kwargs):
252
+
253
+ # Get the raw output content
254
+ llm_output_content = getattr(llm_output, "content", str(llm_output))
255
+
256
+ # Check if there are any defined output fields
257
+ output_attrs = self.outputs_format.get_attrs()
258
+
259
+ # If no output fields are defined, create a simple content-only output
260
+ if not output_attrs:
261
+ # Create output with just the content field
262
+ output = self.outputs_format.parse(content=llm_output_content)
263
+ # print("Created simple content output for agent with no defined outputs:")
264
+ # print(output)
265
+ return output
266
+
267
+ # Use the action's parse_mode and parse_func for parsing
268
+ try:
269
+ # Use the outputs_format's parse method with the action's parse settings
270
+ parsed_output = self.outputs_format.parse(
271
+ content=llm_output_content,
272
+ parse_mode=self.parse_mode,
273
+ parse_func=getattr(self, 'parse_func', None),
274
+ title_format=getattr(self, 'title_format', "## {title}")
275
+ )
276
+
277
+ # print("Successfully parsed output using action's parse settings:")
278
+ # print(parsed_output)
279
+ return parsed_output
280
+
281
+ except Exception as e:
282
+ logger.info(f"Failed to parse with action's parse settings: {e}")
283
+ logger.info("Falling back to using LLM to extract outputs...")
284
+
285
+ # Fall back to extraction prompt if direct parsing fails
286
+ extraction_prompt = self.prepare_extraction_prompt(llm_output_content)
287
+
288
+ llm_extracted_output: LLMOutputParser = llm.generate(prompt=extraction_prompt)
289
+ llm_extracted_data: dict = parse_json_from_llm_output(llm_extracted_output.content)
290
+ output = self.outputs_format.from_dict(llm_extracted_data)
291
+
292
+ # print("Extracted output using fallback:")
293
+ # print(output)
294
+ return output
295
+
296
+ async def _async_extract_output(self, llm_output: Any, llm: BaseLLM = None, **kwargs):
297
+
298
+ # Get the raw output content
299
+ llm_output_content = getattr(llm_output, "content", str(llm_output))
300
+
301
+ # Check if there are any defined output fields
302
+ output_attrs = self.outputs_format.get_attrs()
303
+
304
+ # If no output fields are defined, create a simple content-only output
305
+ if not output_attrs:
306
+ # Create output with just the content field
307
+ output = self.outputs_format.parse(content=llm_output_content)
308
+ # print("Created simple content output for agent with no defined outputs:")
309
+ # print(output)
310
+ return output
311
+
312
+ # Use the action's parse_mode and parse_func for parsing
313
+ try:
314
+ # Use the outputs_format's parse method with the action's parse settings
315
+ parsed_output = self.outputs_format.parse(
316
+ content=llm_output_content,
317
+ parse_mode=self.parse_mode,
318
+ parse_func=getattr(self, 'parse_func', None),
319
+ title_format=getattr(self, 'title_format', "## {title}")
320
+ )
321
+
322
+ # print("Successfully parsed output using action's parse settings:")
323
+ # print(parsed_output)
324
+ return parsed_output
325
+
326
+ except Exception as e:
327
+ logger.info(f"Failed to parse with action's parse settings: {e}")
328
+ logger.info("Falling back to using LLM to extract outputs...")
329
+
330
+ # Fall back to extraction prompt if direct parsing fails
331
+ extraction_prompt = self.prepare_extraction_prompt(llm_output_content)
332
+
333
+ llm_extracted_output = await llm.async_generate(prompt=extraction_prompt)
334
+ llm_extracted_data: dict = parse_json_from_llm_output(llm_extracted_output.content)
335
+ output = self.outputs_format.from_dict(llm_extracted_data)
336
+
337
+ # print("Extracted output using fallback:")
338
+ # print(output)
339
+ return output
340
+
341
+ def _call_single_tool(self, function_param: dict) -> tuple:
342
+ try:
343
+ function_name = function_param.get("function_name")
344
+ function_args = function_param.get("function_args") or {}
345
+
346
+ if not function_name:
347
+ return None, "No function name provided"
348
+
349
+ callable_fn = self.tools_caller.get(function_name)
350
+ if not callable(callable_fn):
351
+ return None, f"Function '{function_name}' not found or not callable"
352
+
353
+ print("_____________________ Start Function Calling _____________________")
354
+ print(f"Executing function calling: {function_name} with parameters: {function_args}")
355
+ result = callable_fn(**function_args)
356
+ return result, None
357
+
358
+ except Exception as e:
359
+ logger.error(f"Error executing tool {function_name}: {e}")
360
+ return None, f"Error executing tool {function_name}: {str(e)}"
361
+
362
+ def _calling_tools(self, tool_call_args: List[dict]) -> dict:
363
+ ## ___________ Call the tools in parallel___________
364
+ errors = []
365
+ results = []
366
+
367
+ with concurrent.futures.ThreadPoolExecutor() as executor:
368
+ future_to_tool = {executor.submit(self._call_single_tool, param): param for param in tool_call_args}
369
+
370
+ for future in concurrent.futures.as_completed(future_to_tool):
371
+ result, error = future.result()
372
+ if error:
373
+ errors.append(error)
374
+ if result is not None:
375
+ results.append(result)
376
+
377
+ return {"result": results, "error": errors}
378
+
379
+ async def _async_call_single_tool(self, function_param: dict) -> tuple:
380
+ try:
381
+ function_name = function_param.get("function_name")
382
+ function_args = function_param.get("function_args") or {}
383
+
384
+ if not function_name:
385
+ return None, "No function name provided"
386
+
387
+ callable_fn = self.tools_caller.get(function_name)
388
+ if not callable(callable_fn):
389
+ return None, f"Function '{function_name}' not found or not callable"
390
+
391
+ print("_____________________ Start Function Calling _____________________")
392
+ print(f"Executing function calling: {function_name} with parameters: {function_args}")
393
+
394
+ if inspect.iscoroutinefunction(callable_fn):
395
+ result = await callable_fn(**function_args)
396
+ else:
397
+ loop = asyncio.get_running_loop()
398
+ result = await loop.run_in_executor(None, lambda: callable_fn(**function_args))
399
+
400
+ return result, None
401
+
402
+ except Exception as e:
403
+ logger.error(f"Error executing tool {function_name}: {e}")
404
+ return None, f"Error executing tool {function_name}: {str(e)}"
405
+
406
+ async def _async_calling_tools(self, tool_call_args: List[dict]) -> dict:
407
+ ## ___________ Call the tools concurrently ___________
408
+ tasks = [self._async_call_single_tool(param) for param in tool_call_args]
409
+ results_with_errors = await asyncio.gather(*tasks)
410
+
411
+ results = [res for res, err in results_with_errors if err is None and res is not None]
412
+ errors = [err for _, err in results_with_errors if err is not None]
413
+
414
+ return {"result": results, "error": errors}
415
+
416
+ def execute(self, llm: Optional[BaseLLM] = None, inputs: Optional[dict] = None, sys_msg: Optional[str]=None, return_prompt: bool = False, time_out = 0, **kwargs):
417
+ # Allow empty inputs if the action has no required input attributes
418
+ input_attributes: dict = self.inputs_format.get_attr_descriptions()
419
+ if not inputs and input_attributes:
420
+ logger.error("CustomizeAction action received invalid `inputs`: None or empty.")
421
+ raise ValueError('The `inputs` to CustomizeAction action is None or empty.')
422
+ # Set inputs to empty dict if None and no inputs are required
423
+ if inputs is None:
424
+ inputs = {}
425
+ final_llm_response = None
426
+
427
+ if self.prompt_template:
428
+
429
+ if isinstance(self.prompt_template, ChatTemplate):
430
+ # must determine whether prompt_template is ChatTemplate first since ChatTemplate is a subclass of StringTemplate
431
+ conversation = self.prepare_action_prompt(inputs=inputs, system_prompt=sys_msg)
432
+ elif isinstance(self.prompt_template, StringTemplate):
433
+ conversation = [{"role": "system", "content": self.prepare_action_prompt(inputs=inputs, system_prompt=sys_msg)}]
434
+ else:
435
+ raise ValueError(f"`prompt_template` must be a StringTemplate or ChatTemplate instance, but got {type(self.prompt_template)}")
436
+ else:
437
+ conversation = [{"role": "system", "content": sys_msg}, {"role": "user", "content": self.prepare_action_prompt(inputs=inputs, system_prompt=sys_msg)}]
438
+
439
+
440
+ ## 1. get all the input parameters
441
+ prompt_params_values = {k: inputs.get(k, "") for k in input_attributes.keys()}
442
+ while True:
443
+ ### Generate response from LLM
444
+ if time_out > self.max_tool_try:
445
+ # Get the appropriate prompt for return
446
+ current_prompt = self.prepare_action_prompt(inputs=prompt_params_values or {})
447
+ # Use the final LLM response if available, otherwise fall back to execution history
448
+ content_to_extract = final_llm_response if final_llm_response is not None else "{content}".format(content = conversation)
449
+ if return_prompt:
450
+ return self._extract_output(content_to_extract, llm = llm), current_prompt
451
+ return self._extract_output(content_to_extract, llm = llm)
452
+ time_out += 1
453
+
454
+ # Handle both string prompts and chat message lists
455
+ llm_response = llm.generate(messages=conversation)
456
+ conversation.append({"role": "assistant", "content": llm_response.content})
457
+
458
+ # Store the final LLM response
459
+ final_llm_response = llm_response
460
+
461
+ tool_call_args = self._extract_tool_calls(llm_response.content)
462
+ if not tool_call_args:
463
+ break
464
+
465
+ logger.info("Extracted tool call args:")
466
+ logger.info(json.dumps(tool_call_args, indent=4))
467
+
468
+ results = self._calling_tools(tool_call_args)
469
+
470
+ logger.info("Tool call results:")
471
+ logger.info(json.dumps(results, indent=4))
472
+
473
+ conversation.append({"role": "assistant", "content": TOOL_CALLING_HISTORY_PROMPT.format(
474
+ iteration_number=time_out,
475
+ tool_call_args=f"{tool_call_args}",
476
+ results=f"{results}"
477
+ )})
478
+
479
+ # Get the appropriate prompt for return
480
+ current_prompt = self.prepare_action_prompt(inputs=prompt_params_values or {})
481
+ # Use the final LLM response if available, otherwise fall back to execution history
482
+ content_to_extract = final_llm_response if final_llm_response is not None else "{content}".format(content = conversation)
483
+ if return_prompt:
484
+ return self._extract_output(content_to_extract, llm = llm), current_prompt
485
+ return self._extract_output(content_to_extract, llm = llm)
486
+
487
+
488
+ async def async_execute(self, llm: Optional[BaseLLM] = None, inputs: Optional[dict] = None, sys_msg: Optional[str]=None, return_prompt: bool = False, time_out = 0, **kwargs):
489
+ # Allow empty inputs if the action has no required input attributes
490
+ input_attributes: dict = self.inputs_format.get_attr_descriptions()
491
+ if not inputs and input_attributes:
492
+ logger.error("CustomizeAction action received invalid `inputs`: None or empty.")
493
+ raise ValueError('The `inputs` to CustomizeAction action is None or empty.')
494
+ # Set inputs to empty dict if None and no inputs are required
495
+ if inputs is None:
496
+ inputs = {}
497
+ final_llm_response = None
498
+
499
+ if self.prompt_template:
500
+ if isinstance(self.prompt_template, ChatTemplate):
501
+ # must determine whether prompt_template is ChatTemplate first since ChatTemplate is a subclass of StringTemplate
502
+ conversation = self.prepare_action_prompt(inputs=inputs, system_prompt=sys_msg)
503
+ elif isinstance(self.prompt_template, StringTemplate):
504
+ conversation = [{"role": "system", "content": self.prepare_action_prompt(inputs=inputs, system_prompt=sys_msg)}]
505
+ else:
506
+ raise ValueError(f"`prompt_template` must be a StringTemplate or ChatTemplate instance, but got {type(self.prompt_template)}")
507
+ else:
508
+ conversation = [{"role": "system", "content": sys_msg}, {"role": "user", "content": self.prepare_action_prompt(inputs=inputs, system_prompt=sys_msg)}]
509
+
510
+
511
+ ## 1. get all the input parameters
512
+ prompt_params_values = {k: inputs.get(k, "") for k in input_attributes.keys()}
513
+ while True:
514
+ ### Generate response from LLM
515
+ if time_out > self.max_tool_try:
516
+ # Get the appropriate prompt for return
517
+ current_prompt = self.prepare_action_prompt(inputs=prompt_params_values or {})
518
+ # Use the final LLM response if available, otherwise fall back to execution history
519
+ content_to_extract = final_llm_response if final_llm_response is not None else "{content}".format(content = conversation)
520
+ if return_prompt:
521
+ return await self._async_extract_output(content_to_extract, llm = llm), current_prompt
522
+ return await self._async_extract_output(content_to_extract, llm = llm)
523
+ time_out += 1
524
+
525
+ # Handle both string prompts and chat message lists
526
+ llm_response = await llm.async_generate(messages=conversation)
527
+ conversation.append({"role": "assistant", "content": llm_response.content})
528
+
529
+ # Store the final LLM response
530
+ final_llm_response = llm_response
531
+
532
+ tool_call_args = self._extract_tool_calls(llm_response.content)
533
+ if not tool_call_args:
534
+ break
535
+
536
+ logger.info("Extracted tool call args:")
537
+ logger.info(json.dumps(tool_call_args, indent=4))
538
+
539
+ results = self._calling_tools(tool_call_args)
540
+
541
+ logger.info("Tool call results:")
542
+ try:
543
+ logger.info(json.dumps(results, indent=4))
544
+ except Exception:
545
+ logger.info(str(results))
546
+
547
+ conversation.append({"role": "assistant", "content": TOOL_CALLING_HISTORY_PROMPT.format(
548
+ iteration_number=time_out,
549
+ tool_call_args=f"{tool_call_args}",
550
+ results=f"{results}"
551
+ )})
552
+
553
+ # Get the appropriate prompt for return
554
+ current_prompt = self.prepare_action_prompt(inputs=prompt_params_values or {})
555
+ # Use the final LLM response if available, otherwise fall back to execution history
556
+ content_to_extract = final_llm_response if final_llm_response is not None else "{content}".format(content = conversation)
557
+ if return_prompt:
558
+ return await self._async_extract_output(content_to_extract, llm = llm), current_prompt
559
+ return await self._async_extract_output(content_to_extract, llm = llm)
evoagentx/actions/task_planning.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import Field
2
+ from typing import Optional, List
3
+
4
+ from ..core.logging import logger
5
+ from ..models.base_model import BaseLLM
6
+ from .action import Action, ActionInput, ActionOutput
7
+ from ..prompts.task_planner import TASK_PLANNING_ACTION
8
+ from ..workflow.workflow_graph import WorkFlowNode
9
+
10
+
11
+ class TaskPlanningInput(ActionInput):
12
+ """
13
+ Input specification for the task planning action.
14
+ """
15
+ goal: str = Field(description="A clear and detailed description of the user's goal, specifying what needs to be achieved.")
16
+ history: Optional[str] = Field(default=None, description="Optional field containing previously generated task plan.")
17
+ suggestion: Optional[str] = Field(default=None, description="Optional suggestions or ideas to guide the planning process.")
18
+
19
+
20
+ class TaskPlanningOutput(ActionOutput):
21
+ """
22
+ Output structure for the task planning action.
23
+ """
24
+ sub_tasks: List[WorkFlowNode] = Field(description="A list of sub-tasks that collectively achieve user's goal.")
25
+
26
+
27
+ class TaskPlanning(Action):
28
+ """
29
+ Action for planning a series of tasks to achieve a goal.
30
+ """
31
+
32
+ def __init__(self, **kwargs):
33
+
34
+ name = kwargs.pop("name") if "name" in kwargs else TASK_PLANNING_ACTION["name"]
35
+ description = kwargs.pop("description") if "description" in kwargs else TASK_PLANNING_ACTION["description"]
36
+ prompt = kwargs.pop("prompt") if "prompt" in kwargs else TASK_PLANNING_ACTION["prompt"]
37
+ # inputs_format = kwargs.pop("inputs_format") if "inputs_format" in kwargs else TaskPlanningInput
38
+ # outputs_format = kwargs.pop("outputs_format") if "outputs_format" in kwargs else TaskPlanningOutput
39
+ inputs_format = kwargs.pop("inputs_format", None) or TaskPlanningInput
40
+ outputs_format = kwargs.pop("outputs_format", None) or TaskPlanningOutput
41
+ super().__init__(name=name, description=description, prompt=prompt, inputs_format=inputs_format, outputs_format=outputs_format, **kwargs)
42
+
43
+ def execute(self, llm: Optional[BaseLLM] = None, inputs: Optional[dict] = None, sys_msg: Optional[str]=None, return_prompt: bool = False, **kwargs) -> TaskPlanningOutput:
44
+ """Execute the task planning process.
45
+
46
+ This method uses the provided language model to generate a structured
47
+ plan of sub-tasks based on the user's goal and any additional context.
48
+
49
+ Args:
50
+ llm: The language model to use for planning.
51
+ inputs: Input data containing the goal and optional context.
52
+ sys_msg: Optional system message for the language model.
53
+ return_prompt: Whether to return both the task plan and the prompt used.
54
+ **kwargs: Additional keyword arguments.
55
+
56
+ Returns:
57
+ If return_prompt is False (default): The generated task plan.
58
+ If return_prompt is True: A tuple of (task plan, prompt used).
59
+
60
+ Raises:
61
+ ValueError: If the inputs are None or empty.
62
+ """
63
+ if not inputs:
64
+ logger.error("TaskPlanning action received invalid `inputs`: None or empty.")
65
+ raise ValueError('The `inputs` to TaskPlanning action is None or empty.')
66
+
67
+ prompt_params_names = ["goal", "history", "suggestion"]
68
+ prompt_params_values = {param: inputs.get(param, "") for param in prompt_params_names}
69
+ prompt = self.prompt.format(**prompt_params_values)
70
+ task_plan = llm.generate(
71
+ prompt = prompt,
72
+ system_message = sys_msg,
73
+ parser=self.outputs_format,
74
+ parse_mode="json"
75
+ )
76
+
77
+ if return_prompt:
78
+ return task_plan, prompt
79
+
80
+ return task_plan
evoagentx/agents/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .agent import Agent
2
+ from .customize_agent import CustomizeAgent
3
+ from .action_agent import ActionAgent
4
+ from .agent_manager import AgentManager
5
+
6
+ __all__ = ["Agent", "CustomizeAgent", "ActionAgent", "AgentManager"]
evoagentx/agents/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (485 Bytes). View file
 
evoagentx/agents/__pycache__/action_agent.cpython-311.pyc ADDED
Binary file (25.6 kB). View file
 
evoagentx/agents/__pycache__/agent.cpython-311.pyc ADDED
Binary file (25.3 kB). View file
 
evoagentx/agents/__pycache__/agent_generator.cpython-311.pyc ADDED
Binary file (2.06 kB). View file
 
evoagentx/agents/__pycache__/agent_manager.cpython-311.pyc ADDED
Binary file (26.7 kB). View file
 
evoagentx/agents/__pycache__/customize_agent.cpython-311.pyc ADDED
Binary file (28.2 kB). View file
 
evoagentx/agents/__pycache__/task_planner.cpython-311.pyc ADDED
Binary file (2.85 kB). View file
 
evoagentx/agents/__pycache__/workflow_reviewer.cpython-311.pyc ADDED
Binary file (1.23 kB). View file
 
evoagentx/agents/action_agent.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ from pydantic import create_model, Field
4
+ from typing import Optional, Callable, Type, List, Any
5
+
6
+ from .agent import Agent
7
+ from ..core.logging import logger
8
+ from ..core.registry import MODULE_REGISTRY, ACTION_FUNCTION_REGISTRY
9
+ from ..models.model_configs import LLMConfig
10
+ from ..actions.action import Action, ActionOutput, ActionInput
11
+ from ..utils.utils import generate_dynamic_class_name, make_parent_folder
12
+ from ..core.message import Message, MessageType
13
+
14
+
15
+ class ActionAgent(Agent):
16
+ """
17
+ ActionAgent is a specialized agent that executes a provided function directly without LLM.
18
+ It creates an action that uses the provided function as the execution backbone.
19
+
20
+ Attributes:
21
+ name (str): The name of the agent.
22
+ description (str): A description of the agent's purpose and capabilities.
23
+ inputs (List[dict]): List of input specifications, where each dict contains:
24
+ - name (str): Name of the input parameter
25
+ - type (str): Type of the input
26
+ - description (str): Description of what the input represents
27
+ - required (bool, optional): Whether this input is required (default: True)
28
+ outputs (List[dict]): List of output specifications, where each dict contains:
29
+ - name (str): Name of the output field
30
+ - type (str): Type of the output
31
+ - description (str): Description of what the output represents
32
+ - required (bool, optional): Whether this output is required (default: True)
33
+ execute_func (Callable): The function to execute the agent.
34
+ async_execute_func (Callable, Optional): Async version of the function. If not provided,
35
+ an async wrapper will be automatically created around execute_func.
36
+ llm_config (LLMConfig, optional): Configuration for the language model (minimal usage).
37
+ """
38
+
39
+
40
+ def __init__(
41
+ self,
42
+ name: str,
43
+ description: str,
44
+ inputs: List[dict],
45
+ outputs: List[dict],
46
+ execute_func: Callable,
47
+ async_execute_func: Optional[Callable] = None,
48
+ llm_config: Optional[LLMConfig] = None,
49
+ **kwargs
50
+ ):
51
+ # Validate inputs
52
+ if not callable(execute_func):
53
+ raise ValueError("execute_func must be callable")
54
+
55
+ if async_execute_func is not None and not callable(async_execute_func):
56
+ raise ValueError("async_execute_func must be callable")
57
+
58
+ # Validate inputs and outputs
59
+ self._validate_inputs_outputs(inputs, outputs)
60
+
61
+ # Set is_human based on LLM availability
62
+ is_human = llm_config is None
63
+
64
+ # Initialize parent directly
65
+ super().__init__(
66
+ name=name,
67
+ description=description,
68
+ llm_config=llm_config,
69
+ is_human=is_human,
70
+ **kwargs
71
+ )
72
+
73
+ # Store function references and metadata
74
+ self.execute_func = execute_func
75
+ self.async_execute_func = async_execute_func
76
+ self.inputs = inputs
77
+ self.outputs = outputs
78
+
79
+ # Create and add the function-based action
80
+ action = self._create_function_action_with_params(
81
+ name, execute_func, async_execute_func, inputs, outputs
82
+ )
83
+ self.add_action(action)
84
+
85
+ def init_llm(self):
86
+ pass
87
+
88
+ def _validate_inputs_outputs(self, inputs: List[dict], outputs: List[dict]):
89
+ """Validate the structure of inputs and outputs."""
90
+ # Allow empty inputs for functions that don't require any inputs
91
+ if inputs is None:
92
+ inputs = []
93
+
94
+ if outputs is None:
95
+ outputs = []
96
+
97
+ # Validate inputs structure
98
+ for i, input_field in enumerate(inputs):
99
+ if not isinstance(input_field, dict):
100
+ raise ValueError(f"Input field {i} must be a dictionary, got {type(input_field)}")
101
+
102
+ required_keys = ["name", "type", "description"]
103
+ for key in required_keys:
104
+ if key not in input_field:
105
+ raise ValueError(f"Input field {i} missing required key '{key}'")
106
+
107
+ if not isinstance(input_field["name"], str):
108
+ raise ValueError(f"Input field {i} 'name' must be a string, got {type(input_field['name'])}")
109
+
110
+ if not isinstance(input_field["type"], str):
111
+ raise ValueError(f"Input field {i} 'type' must be a string, got {type(input_field['type'])}")
112
+
113
+ if not isinstance(input_field["description"], str):
114
+ raise ValueError(f"Input field {i} 'description' must be a string, got {type(input_field['description'])}")
115
+
116
+ # Check for duplicate input names
117
+ input_names = [field["name"] for field in inputs]
118
+ if len(input_names) != len(set(input_names)):
119
+ raise ValueError(f"Duplicate input names found: {[name for name in input_names if input_names.count(name) > 1]}")
120
+
121
+ # Validate outputs structure
122
+ for i, output_field in enumerate(outputs):
123
+ if not isinstance(output_field, dict):
124
+ raise ValueError(f"Output field {i} must be a dictionary, got {type(output_field)}")
125
+
126
+ required_keys = ["name", "type", "description"]
127
+ for key in required_keys:
128
+ if key not in output_field:
129
+ raise ValueError(f"Output field {i} missing required key '{key}'")
130
+
131
+ if not isinstance(output_field["name"], str):
132
+ raise ValueError(f"Output field {i} 'name' must be a string, got {type(output_field['name'])}")
133
+
134
+ if not isinstance(output_field["type"], str):
135
+ raise ValueError(f"Output field {i} 'type' must be a string, got {type(output_field['type'])}")
136
+
137
+ if not isinstance(output_field["description"], str):
138
+ raise ValueError(f"Output field {i} 'description' must be a string, got {type(output_field['description'])}")
139
+
140
+ # Check for duplicate output names
141
+ output_names = [field["name"] for field in outputs]
142
+ if len(output_names) != len(set(output_names)):
143
+ raise ValueError(f"Duplicate output names found: {[name for name in output_names if output_names.count(name) > 1]}")
144
+
145
+ def _create_function_action_input_type(self, name: str, inputs: List[dict]) -> Type[ActionInput]:
146
+ """Create ActionInput type from input specifications."""
147
+ action_input_fields = {}
148
+ for field in inputs:
149
+ required = field.get("required", True)
150
+ if required:
151
+ action_input_fields[field["name"]] = (str, Field(description=field["description"]))
152
+ else:
153
+ action_input_fields[field["name"]] = (Optional[str], Field(default=None, description=field["description"]))
154
+
155
+ action_input_type = create_model(
156
+ self._get_unique_class_name(
157
+ generate_dynamic_class_name(f"{name} action_input")
158
+ ),
159
+ **action_input_fields,
160
+ __base__=ActionInput
161
+ )
162
+ return action_input_type
163
+
164
+ def _create_function_action_output_type(self, name: str, outputs: List[dict]) -> Type[ActionOutput]:
165
+ """Create ActionOutput type from output specifications."""
166
+ action_output_fields = {}
167
+ for field in outputs:
168
+ required = field.get("required", True)
169
+ if required:
170
+ action_output_fields[field["name"]] = (Any, Field(description=field["description"]))
171
+ else:
172
+ action_output_fields[field["name"]] = (Optional[Any], Field(default=None, description=field["description"]))
173
+
174
+ action_output_type = create_model(
175
+ self._get_unique_class_name(
176
+ generate_dynamic_class_name(f"{name} action_output")
177
+ ),
178
+ **action_output_fields,
179
+ __base__=ActionOutput
180
+ )
181
+ return action_output_type
182
+
183
+ def _create_execute_method(self, execute_func: Callable):
184
+ """Create the execute method for the action."""
185
+ def execute_method(action_self, llm=None, inputs=None, sys_msg=None, return_prompt=False, **kwargs):
186
+ # Validate inputs
187
+ if inputs is None:
188
+ inputs = {}
189
+
190
+ # Validate that all required inputs are provided
191
+ required_inputs = action_self.inputs_format.get_required_input_names()
192
+ missing_inputs = [input_name for input_name in required_inputs if input_name not in inputs]
193
+ if missing_inputs:
194
+ raise ValueError(f"Missing required inputs: {missing_inputs}")
195
+
196
+ # Validate input types (basic validation)
197
+ filtered_inputs = {}
198
+ for input_name, input_value in inputs.items():
199
+ if input_name in [field["name"] for field in self.inputs]:
200
+ filtered_inputs[input_name] = input_value
201
+ else:
202
+ logger.warning(f"Unexpected input '{input_name}' provided")
203
+
204
+ # Execute function
205
+ try:
206
+ result = execute_func(**filtered_inputs)
207
+ except Exception as e:
208
+ # Create error output - try to use error field if it exists, otherwise use first available field
209
+ try:
210
+ # Check if output format has an error field
211
+ output_fields = action_self.outputs_format.get_attrs()
212
+ if "error" in output_fields:
213
+ error_output = action_self.outputs_format(
214
+ error=f"Function execution failed: {str(e)}"
215
+ )
216
+ elif len(output_fields) > 0:
217
+ # Use the first field as error field
218
+ first_field = output_fields[0]
219
+ error_output = action_self.outputs_format(**{first_field: f"Error: {str(e)}"})
220
+ else:
221
+ # Fallback to creating a simple output with error message
222
+ error_output = action_self.outputs_format()
223
+ except Exception as create_error:
224
+ # If all else fails, create a minimal output
225
+ logger.error(f"Failed to create error output: {create_error}")
226
+ error_output = action_self.outputs_format()
227
+ return error_output, "Function execution"
228
+
229
+ # Create success output using the parse method
230
+ if isinstance(result, dict):
231
+ # For dict results, create output directly
232
+ output = action_self.outputs_format(**result)
233
+ else:
234
+ # For simple values, create output with the first field
235
+ output_fields = action_self.outputs_format.get_attrs()
236
+ if len(output_fields) > 0:
237
+ first_field = output_fields[0]
238
+ output = action_self.outputs_format(**{first_field: result})
239
+ else:
240
+ # Fallback to creating empty output
241
+ output = action_self.outputs_format()
242
+
243
+ return output, "Function execution"
244
+
245
+ return execute_method
246
+
247
+ def _create_async_execute_method(self, async_execute_func: Callable, execute_func: Callable):
248
+ """Create the async execute method for the action."""
249
+ async def async_execute_method(action_self, llm=None, inputs=None, sys_msg=None, return_prompt=False, **kwargs):
250
+ # Validate inputs
251
+ if inputs is None:
252
+ inputs = {}
253
+
254
+ # Validate that all required inputs are provided
255
+ required_inputs = action_self.inputs_format.get_required_input_names()
256
+ missing_inputs = [input_name for input_name in required_inputs if input_name not in inputs]
257
+ if missing_inputs:
258
+ raise ValueError(f"Missing required inputs: {missing_inputs}")
259
+
260
+ # Validate input types (basic validation)
261
+ filtered_inputs = {}
262
+ for input_name, input_value in inputs.items():
263
+ if input_name in [field["name"] for field in self.inputs]:
264
+ filtered_inputs[input_name] = input_value
265
+ else:
266
+ logger.warning(f"Unexpected input '{input_name}' provided")
267
+
268
+ # Execute async function
269
+ try:
270
+ if async_execute_func is not None:
271
+ result = await async_execute_func(**filtered_inputs)
272
+ else:
273
+ # Use sync function in async context
274
+ loop = asyncio.get_event_loop()
275
+ result = await loop.run_in_executor(None, lambda: execute_func(**filtered_inputs))
276
+ except Exception as e:
277
+ # Create error output - try to use error field if it exists, otherwise use first available field
278
+ try:
279
+ # Check if output format has an error field
280
+ output_fields = action_self.outputs_format.get_attrs()
281
+ if "error" in output_fields:
282
+ error_output = action_self.outputs_format(
283
+ error=f"Async function execution failed: {str(e)}"
284
+ )
285
+ elif len(output_fields) > 0:
286
+ # Use the first field as error field
287
+ first_field = list(output_fields.keys())[0]
288
+ error_output = action_self.outputs_format(**{first_field: f"Error: {str(e)}"})
289
+ else:
290
+ # Fallback to creating a simple output with error message
291
+ error_output = action_self.outputs_format()
292
+ except Exception as create_error:
293
+ # If all else fails, create a minimal output
294
+ logger.error(f"Failed to create error output: {create_error}")
295
+ error_output = action_self.outputs_format()
296
+ return error_output, "Async function execution"
297
+
298
+ # Create success output using the parse method
299
+ if isinstance(result, dict):
300
+ # For dict results, create output directly
301
+ output = action_self.outputs_format(**result)
302
+ else:
303
+ # For simple values, create output with the first field
304
+ output_fields = action_self.outputs_format.get_attrs()
305
+ if len(output_fields) > 0:
306
+ first_field = output_fields[0]
307
+ output = action_self.outputs_format(**{first_field: result})
308
+ else:
309
+ # Fallback to creating empty output
310
+ output = action_self.outputs_format()
311
+
312
+ return output, "Async function execution"
313
+
314
+ return async_execute_method
315
+
316
+ def _create_function_action_with_params(self, name: str, execute_func: Callable, async_execute_func: Callable, inputs: List[dict], outputs: List[dict]) -> Action:
317
+ """Create an action that executes the provided function with given parameters."""
318
+
319
+ # Create input/output types
320
+ action_input_type = self._create_function_action_input_type(name, inputs)
321
+ action_output_type = self._create_function_action_output_type(name, outputs)
322
+
323
+ # Create custom action class
324
+ action_cls_name = self._get_unique_class_name(
325
+ generate_dynamic_class_name(f"{name} function action")
326
+ )
327
+
328
+ # Create action class with function execution
329
+ function_action_cls = create_model(
330
+ action_cls_name,
331
+ __base__=Action
332
+ )
333
+
334
+ # Create action instance
335
+ function_action = function_action_cls(
336
+ name=action_cls_name,
337
+ description=f"Executes {execute_func.__name__} function",
338
+ inputs_format=action_input_type,
339
+ outputs_format=action_output_type
340
+ )
341
+
342
+ # Override execute methods - bind them properly to the action instance
343
+ execute_method = self._create_execute_method(execute_func)
344
+ async_execute_method = self._create_async_execute_method(async_execute_func, execute_func)
345
+
346
+ # Bind the methods to the action instance
347
+ function_action.execute = execute_method.__get__(function_action, type(function_action))
348
+ function_action.async_execute = async_execute_method.__get__(function_action, type(function_action))
349
+
350
+ return function_action
351
+
352
+ def _create_function_action(self, name: str, execute_func: Callable, async_execute_func: Callable, inputs: List[dict], outputs: List[dict]) -> Action:
353
+ """Create an action that executes the provided function."""
354
+ return self._create_function_action_with_params(
355
+ name,
356
+ execute_func,
357
+ async_execute_func,
358
+ inputs,
359
+ outputs
360
+ )
361
+
362
+ def get_config(self) -> dict:
363
+ """Get configuration for the ActionAgent."""
364
+ # Get base config from Agent
365
+ config = super().get_config()
366
+
367
+ # Add ActionAgent-specific information
368
+ config.update({
369
+ "class_name": "ActionAgent",
370
+ "execute_func_name": self.execute_func.__name__ if self.execute_func else None,
371
+ "async_execute_func_name": self.async_execute_func.__name__ if self.async_execute_func else None,
372
+ "inputs": self.inputs,
373
+ "outputs": self.outputs
374
+ })
375
+ return config
376
+
377
+ def save_module(self, path: str, ignore: List[str] = [], **kwargs) -> str:
378
+ """Save the ActionAgent configuration to a JSON file.
379
+
380
+ Args:
381
+ path: File path where the configuration should be saved
382
+ ignore: List of keys to exclude from the saved configuration
383
+ **kwargs (Any): Additional parameters for the save operation
384
+
385
+ Returns:
386
+ The path where the configuration was saved
387
+ """
388
+ config = self.get_config()
389
+
390
+ # Add ActionAgent-specific information
391
+ config.update({
392
+ "class_name": "ActionAgent",
393
+ "execute_func_name": self.execute_func.__name__ if self.execute_func else None,
394
+ "async_execute_func_name": self.async_execute_func.__name__ if self.async_execute_func else None,
395
+ "inputs": self.inputs,
396
+ "outputs": self.outputs
397
+ })
398
+
399
+ # Remove non-serializable items
400
+ for ignore_key in ignore:
401
+ config.pop(ignore_key, None)
402
+
403
+ # Save to JSON file
404
+ make_parent_folder(path)
405
+ with open(path, 'w', encoding='utf-8') as f:
406
+ json.dump(config, f, indent=4, ensure_ascii=False)
407
+
408
+ return path
409
+
410
+ @classmethod
411
+ def load_module(cls, path: str, llm_config: LLMConfig = None, **kwargs) -> "ActionAgent":
412
+ """Load the ActionAgent from a JSON file.
413
+
414
+ Args:
415
+ path: The path of the file
416
+ llm_config: The LLMConfig instance (optional)
417
+ **kwargs: Additional keyword arguments
418
+
419
+ Returns:
420
+ ActionAgent: The loaded agent instance
421
+
422
+ Raises:
423
+ KeyError: If required functions are not found in the registry
424
+ """
425
+ # Load configuration
426
+ with open(path, 'r', encoding='utf-8') as f:
427
+ config = json.load(f)
428
+
429
+ # Extract function names
430
+ execute_func_name = config.get("execute_func_name")
431
+ async_execute_func_name = config.get("async_execute_func_name")
432
+
433
+ # Retrieve functions from registry
434
+ execute_func = None
435
+ async_execute_func = None
436
+
437
+ if execute_func_name:
438
+ if not ACTION_FUNCTION_REGISTRY.has_function(execute_func_name):
439
+ raise KeyError(f"Function '{execute_func_name}' not found in registry. Please register it first.")
440
+ execute_func = ACTION_FUNCTION_REGISTRY.get_function(execute_func_name)
441
+
442
+ if async_execute_func_name:
443
+ if not ACTION_FUNCTION_REGISTRY.has_function(async_execute_func_name):
444
+ raise KeyError(f"Function '{async_execute_func_name}' not found in registry. Please register it first.")
445
+ async_execute_func = ACTION_FUNCTION_REGISTRY.get_function(async_execute_func_name)
446
+
447
+ # Create agent
448
+ agent = cls(
449
+ name=config["name"],
450
+ description=config["description"],
451
+ inputs=config["inputs"],
452
+ outputs=config["outputs"],
453
+ execute_func=execute_func,
454
+ async_execute_func=async_execute_func,
455
+ llm_config=llm_config,
456
+ **kwargs
457
+ )
458
+
459
+ return agent
460
+
461
+ def __call__(self, inputs: dict = None, return_msg_type: MessageType = MessageType.UNKNOWN, **kwargs) -> Message:
462
+ """
463
+ Call the main function action.
464
+
465
+ Args:
466
+ inputs (dict): The inputs to the function action.
467
+ return_msg_type (MessageType): The type of message to return.
468
+ **kwargs (Any): Additional keyword arguments.
469
+
470
+ Returns:
471
+ Message: The output of the function action.
472
+ """
473
+ inputs = inputs or {}
474
+ return super().__call__(action_name=self.main_action_name, action_input_data=inputs, return_msg_type=return_msg_type, **kwargs)
475
+
476
+ @property
477
+ def main_action_name(self) -> str:
478
+ """
479
+ Get the name of the main function action for this agent.
480
+
481
+ Returns:
482
+ The name of the main function action
483
+ """
484
+ for action in self.actions:
485
+ if action.name != self.cext_action_name:
486
+ return action.name
487
+ raise ValueError("Couldn't find the main action name!")
488
+
489
+ def _get_unique_class_name(self, candidate_name: str) -> str:
490
+ """
491
+ Get a unique class name by checking if it already exists in the registry.
492
+ If it does, append "Vx" to make it unique.
493
+ """
494
+ if not MODULE_REGISTRY.has_module(candidate_name):
495
+ return candidate_name
496
+
497
+ counter = 1
498
+ while True:
499
+ new_name = f"{candidate_name}V{counter}"
500
+ if not MODULE_REGISTRY.has_module(new_name):
501
+ return new_name
502
+ counter += 1
evoagentx/agents/agent.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import inspect
3
+ from pydantic import Field
4
+ from typing import Type, Optional, Union, Tuple, List, Any, Coroutine
5
+
6
+ from ..core.module import BaseModule
7
+ from ..core.module_utils import generate_id
8
+ from ..core.message import Message, MessageType
9
+ from ..core.registry import MODEL_REGISTRY
10
+ from ..models.model_configs import LLMConfig
11
+ from ..models.base_model import BaseLLM
12
+ from ..memory.memory import ShortTermMemory
13
+ from ..memory.long_term_memory import LongTermMemory
14
+ from ..memory.memory_manager import MemoryManager
15
+ from ..storages.base import StorageHandler
16
+ from ..actions.action import Action
17
+ from ..actions.action import ContextExtraction
18
+
19
+
20
+ class Agent(BaseModule):
21
+ """
22
+ Base class for all agents.
23
+
24
+ Attributes:
25
+ name (str): Unique identifier for the agent
26
+ description (str): Human-readable description of the agent's purpose
27
+ llm_config (Optional[LLMConfig]): Configuration for the language model. If provided, a new LLM instance will be created.
28
+ Otherwise, the existing LLM instance specified in the `llm` field will be used.
29
+ llm (Optional[BaseLLM]): Language model instance. If provided, the existing LLM instance will be used.
30
+ agent_id (Optional[str]): Unique ID for the agent, auto-generated if not provided
31
+ system_prompt (Optional[str]): System prompt for the Agent.
32
+ actions (List[Action]): List of available actions
33
+ n (Optional[int]): Number of latest messages used to provide context for action execution. It uses all the messages in short term memory by default.
34
+ is_human (bool): Whether this agent represents a human user
35
+ version (int): Version number of the agent, default is 0.
36
+ """
37
+
38
+ name: str # should be unique
39
+ description: str
40
+ llm_config: Optional[LLMConfig] = None
41
+ llm: Optional[BaseLLM] = None
42
+ agent_id: Optional[str] = Field(default_factory=generate_id)
43
+ system_prompt: Optional[str] = None
44
+ short_term_memory: Optional[ShortTermMemory] = Field(default_factory=ShortTermMemory) # store short term memory for a single workflow.
45
+ use_long_term_memory: Optional[bool] = False
46
+ storage_handler: Optional[StorageHandler] = None
47
+ long_term_memory: Optional[LongTermMemory] = None
48
+ long_term_memory_manager: Optional[MemoryManager] = None
49
+ actions: List[Action] = Field(default=None)
50
+ n: int = Field(default=None, description="number of latest messages used to provide context for action execution. It uses all the messages in short term memory by default.")
51
+ is_human: bool = Field(default=False)
52
+ version: int = 0
53
+
54
+ def init_module(self):
55
+ if not self.is_human:
56
+ self.init_llm()
57
+ if self.use_long_term_memory:
58
+ self.init_long_term_memory()
59
+ self.actions = [] if self.actions is None else self.actions
60
+ self._action_map = {action.name: action for action in self.actions} if self.actions else dict()
61
+ self._save_ignore_fields = ["llm", "llm_config"]
62
+ self.init_context_extractor()
63
+
64
+ # def __call__(self, *args, **kwargs) -> Message:
65
+ # """Make the agent callable and automatically choose between sync and async execution"""
66
+ # if asyncio.iscoroutinefunction(self.async_execute) and asyncio.get_event_loop().is_running():
67
+ # # If the operator is in an asynchronous environment and has an execute_async method, return a coroutine
68
+ # return self.async_execute(*args, **kwargs)
69
+ # # Otherwise, use the synchronous method
70
+ # return self.execute(*args, **kwargs)
71
+
72
+ def __call__(self, *args: Any, **kwargs: Any) -> Union[dict, Coroutine[Any, Any, dict]]:
73
+ """Make the operator callable and automatically choose between sync and async execution."""
74
+ try:
75
+ # Safe way to check if we're inside an async environment
76
+ asyncio.get_running_loop()
77
+ return self.async_execute(*args, **kwargs)
78
+ except RuntimeError:
79
+ # No running loop — likely in sync context or worker thread
80
+ return self.execute(*args, **kwargs)
81
+
82
+ def _prepare_execution(
83
+ self,
84
+ action_name: str,
85
+ msgs: Optional[List[Message]] = None,
86
+ action_input_data: Optional[dict] = None,
87
+ **kwargs
88
+ ) -> Tuple[Action, dict]:
89
+ """Prepare for action execution by updating memory and getting inputs.
90
+
91
+ Helper method used by both execute and aexecute methods.
92
+
93
+ Args:
94
+ action_name: The name of the action to execute
95
+ msgs: Optional list of messages providing context for the action
96
+ action_input_data: Optional pre-extracted input data for the action
97
+ **kwargs: Additional workflow parameters
98
+
99
+ Returns:
100
+ Tuple containing the action object and input data
101
+
102
+ Raises:
103
+ AssertionError: If neither msgs nor action_input_data is provided
104
+ """
105
+ assert msgs is not None or action_input_data is not None, "must provide either `msgs` or `action_input_data`"
106
+ action = self.get_action(action_name=action_name)
107
+
108
+ # update short-term memory
109
+ if msgs is not None:
110
+ # directly add messages to short-term memory
111
+ self.short_term_memory.add_messages(msgs)
112
+ if action_input_data is not None:
113
+ # create a message from action_input_data and add it to short-term memory
114
+ input_message = Message(
115
+ content = action_input_data,
116
+ next_actions = [action_name],
117
+ msg_type = MessageType.INPUT,
118
+ wf_goal = kwargs.get("wf_goal", None),
119
+ wf_task = kwargs.get("wf_task", None),
120
+ wf_task_desc = kwargs.get("wf_task_desc", None)
121
+ )
122
+ self.short_term_memory.add_message(input_message)
123
+
124
+ # obtain action input data from short term memory if not provided
125
+ action_input_data = action_input_data or self.get_action_inputs(action=action)
126
+
127
+ return action, action_input_data
128
+
129
+ def _create_output_message(
130
+ self,
131
+ action_output,
132
+ prompt: str,
133
+ action_name: str,
134
+ return_msg_type: Optional[MessageType] = MessageType.UNKNOWN,
135
+ **kwargs
136
+ ) -> Message:
137
+ """Create a message from execution results and update memory.
138
+
139
+ Helper method used by both execute and aexecute methods.
140
+
141
+ Args:
142
+ action_output: The output from action execution
143
+ prompt: The prompt used for execution
144
+ action_name: The name of the executed action
145
+ return_msg_type: Message type for the return message
146
+ **kwargs: Additional workflow parameters
147
+
148
+ Returns:
149
+ Message object containing execution results
150
+ """
151
+ # formulate a message
152
+ message = Message(
153
+ content=action_output,
154
+ agent=self.name,
155
+ action=action_name,
156
+ prompt=prompt,
157
+ msg_type=return_msg_type,
158
+ wf_goal = kwargs.get("wf_goal", None),
159
+ wf_task = kwargs.get("wf_task", None),
160
+ wf_task_desc = kwargs.get("wf_task_desc", None)
161
+ )
162
+
163
+ # update short-term memory
164
+ self.short_term_memory.add_message(message)
165
+
166
+ return message
167
+
168
+ async def async_execute(
169
+ self,
170
+ action_name: str,
171
+ msgs: Optional[List[Message]] = None,
172
+ action_input_data: Optional[dict] = None,
173
+ return_msg_type: Optional[MessageType] = MessageType.UNKNOWN,
174
+ return_action_input_data: Optional[bool] = False,
175
+ **kwargs
176
+ ) -> Union[Message, Tuple[Message, dict]]:
177
+ """Execute an action asynchronously with the given context and return results.
178
+
179
+ This is the async version of the execute method, allowing it to perform actions
180
+ based on the current conversation context.
181
+
182
+ Args:
183
+ action_name: The name of the action to execute
184
+ msgs: Optional list of messages providing context for the action
185
+ action_input_data: Optional pre-extracted input data for the action
186
+ return_msg_type: Message type for the return message
187
+ **kwargs (Any): Additional parameters, may include workflow information
188
+
189
+ Returns:
190
+ Message: A message containing the execution results
191
+ """
192
+ action, action_input_data = self._prepare_execution(
193
+ action_name=action_name,
194
+ msgs=msgs,
195
+ action_input_data=action_input_data,
196
+ **kwargs
197
+ )
198
+
199
+ # execute action asynchronously
200
+ async_execute_source = inspect.getsource(action.async_execute)
201
+ if "NotImplementedError" in async_execute_source:
202
+ # if the async_execute method is not implemented, use the execute method instead
203
+ execution_results = action.execute(
204
+ llm=self.llm,
205
+ inputs=action_input_data,
206
+ sys_msg=self.system_prompt,
207
+ return_prompt=True,
208
+ **kwargs
209
+ )
210
+ else:
211
+ execution_results = await action.async_execute(
212
+ llm=self.llm,
213
+ inputs=action_input_data,
214
+ sys_msg=self.system_prompt,
215
+ return_prompt=True,
216
+ **kwargs
217
+ )
218
+ action_output, prompt = execution_results
219
+
220
+ message = self._create_output_message(
221
+ action_output=action_output,
222
+ prompt=prompt,
223
+ action_name=action_name,
224
+ return_msg_type=return_msg_type,
225
+ **kwargs
226
+ )
227
+ if return_action_input_data:
228
+ return message, action_input_data
229
+ return message
230
+
231
+ def execute(
232
+ self,
233
+ action_name: str,
234
+ msgs: Optional[List[Message]] = None,
235
+ action_input_data: Optional[dict] = None,
236
+ return_msg_type: Optional[MessageType] = MessageType.UNKNOWN,
237
+ return_action_input_data: Optional[bool] = False,
238
+ **kwargs
239
+ ) -> Union[Message, Tuple[Message, dict]]:
240
+ """Execute an action with the given context and return results.
241
+
242
+ This is the core method for agent functionality, allowing it to perform actions
243
+ based on the current conversation context.
244
+
245
+ Args:
246
+ action_name: The name of the action to execute
247
+ msgs: Optional list of messages providing context for the action
248
+ action_input_data: Optional pre-extracted input data for the action
249
+ return_msg_type: Message type for the return message
250
+ **kwargs (Any): Additional parameters, may include workflow information
251
+
252
+ Returns:
253
+ Message: A message containing the execution results
254
+ """
255
+ action, action_input_data = self._prepare_execution(
256
+ action_name=action_name,
257
+ msgs=msgs,
258
+ action_input_data=action_input_data,
259
+ **kwargs
260
+ )
261
+
262
+ # execute action
263
+ execution_results = action.execute(
264
+ llm=self.llm,
265
+ inputs=action_input_data,
266
+ sys_msg=self.system_prompt,
267
+ return_prompt=True,
268
+ **kwargs
269
+ )
270
+ action_output, prompt = execution_results
271
+
272
+ message = self._create_output_message(
273
+ action_output=action_output,
274
+ prompt=prompt,
275
+ action_name=action_name,
276
+ return_msg_type=return_msg_type,
277
+ **kwargs
278
+ )
279
+ if return_action_input_data:
280
+ return message, action_input_data
281
+ return message
282
+
283
+ def init_llm(self):
284
+ """
285
+ Initialize the language model for the agent.
286
+ """
287
+ # Only initialize LLM if not human and LLM is provided
288
+ if not self.is_human and (not self.llm_config and not self.llm):
289
+ raise ValueError("must provide `llm_config` or `llm` when `is_human` is False")
290
+ if not self.is_human and (self.llm_config or self.llm):
291
+ if self.llm_config and not self.llm:
292
+ llm_cls = MODEL_REGISTRY.get_model(self.llm_config.llm_type)
293
+ self.llm = llm_cls(config=self.llm_config)
294
+ if self.llm:
295
+ self.llm_config = self.llm.config
296
+ # If is_human=True or no LLM provided, self.llm remains None
297
+
298
+ def init_long_term_memory(self):
299
+ """
300
+ Initialize long-term memory components.
301
+ """
302
+ assert self.storage_handler is not None, "must provide ``storage_handler`` when use_long_term_memory=True"
303
+ # TODO revise the initialisation of long_term_memory and long_term_memory_manager
304
+ if not self.long_term_memory:
305
+ self.long_term_memory = LongTermMemory()
306
+ if not self.long_term_memory_manager:
307
+ self.long_term_memory_manager = MemoryManager(
308
+ storage_handler=self.storage_handler,
309
+ memory=self.long_term_memory
310
+ )
311
+
312
+ def init_context_extractor(self):
313
+ """
314
+ Initialize the context extraction action.
315
+ """
316
+ cext_action = ContextExtraction()
317
+ self.cext_action_name = cext_action.name
318
+ self.add_action(cext_action)
319
+
320
+ def add_action(self, action: Type[Action]):
321
+ """
322
+ Add a new action to the agent's available actions.
323
+
324
+ Args:
325
+ action: The action instance to add
326
+ """
327
+ action_name = action.name
328
+ if action_name in self._action_map:
329
+ return
330
+ self.actions.append(action)
331
+ self._action_map[action_name] = action
332
+
333
+ def check_action_name(self, action_name: str):
334
+ """
335
+ Check if an action name is valid for this agent.
336
+
337
+ Args:
338
+ action_name: Name of the action to check
339
+ """
340
+ if action_name not in self._action_map:
341
+ raise KeyError(f"'{action_name}' is an invalid action for {self.name}! Available action names: {list(self._action_map.keys())}")
342
+
343
+ def get_action(self, action_name: str) -> Action:
344
+ """
345
+ Retrieves the Action instance associated with the given name.
346
+
347
+ Args:
348
+ action_name: Name of the action to retrieve
349
+
350
+ Returns:
351
+ The Action instance with the specified name
352
+ """
353
+ self.check_action_name(action_name=action_name)
354
+ return self._action_map[action_name]
355
+
356
+ def get_action_name(self, action_cls: Type[Action]) -> str:
357
+ """
358
+ Searches through the agent's actions to find one matching the specified type.
359
+
360
+ Args:
361
+ action_cls: The Action class type to search for
362
+
363
+ Returns:
364
+ The name of the matching action
365
+ """
366
+ for name, action in self._action_map.items():
367
+ if isinstance(action, action_cls):
368
+ return name
369
+ raise ValueError(f"Couldn't find an action that matches Type '{action_cls.__name__}'")
370
+
371
+ def get_action_inputs(self, action: Action) -> Union[dict, None]:
372
+ """
373
+ Uses the context extraction action to determine appropriate inputs
374
+ for the specified action based on the conversation history.
375
+
376
+ Args:
377
+ action: The action for which to extract inputs
378
+
379
+ Returns:
380
+ Dictionary of extracted input data, or None if extraction fails
381
+ """
382
+ # return the input data of an action.
383
+ context = self.short_term_memory.get(n=self.n)
384
+ cext_action = self.get_action(self.cext_action_name)
385
+ action_inputs = cext_action.execute(llm=self.llm, action=action, context=context)
386
+ return action_inputs
387
+
388
+ def get_all_actions(self) -> List[Action]:
389
+ """Get all actions except the context extraction action.
390
+
391
+ Returns:
392
+ List of Action instances available for execution
393
+ """
394
+ actions = [action for action in self.actions if action.name != self.cext_action_name]
395
+ return actions
396
+
397
+ def get_agent_profile(self, action_names: List[str] = None) -> str:
398
+ """Generate a human-readable profile of the agent and its capabilities.
399
+
400
+ Args:
401
+ action_names: Optional list of action names to include in the profile.
402
+ If None, all actions are included.
403
+
404
+ Returns:
405
+ A formatted string containing the agent profile
406
+ """
407
+ all_actions = self.get_all_actions()
408
+ if action_names is None:
409
+ # if `action_names` is None, return description of all actions
410
+ action_descriptions = "\n".join([f" - {action.name}: {action.description}" for action in all_actions])
411
+ else:
412
+ # otherwise, only return description of actions that matches `action_names`
413
+ action_descriptions = "\n".join([f" - {action.name}: {action.description}" for action in all_actions if action.name in action_names])
414
+ profile = f"Agent Name: {self.name}\nDescription: {self.description}\nAvailable Actions:\n{action_descriptions}"
415
+ return profile
416
+
417
+ def clear_short_term_memory(self):
418
+ """
419
+ Remove all content from the agent's short-term memory.
420
+ """
421
+ pass
422
+
423
+ def __eq__(self, other: "Agent"):
424
+ return self.agent_id == other.agent_id
425
+
426
+ def __hash__(self):
427
+ return self.agent_id
428
+
429
+ def get_prompts(self) -> dict:
430
+ """
431
+ Get all the prompts of the agent.
432
+
433
+ Returns:
434
+ dict: A dictionary with keys in the format 'agent_name::action_name' and values
435
+ containing the system_prompt and action prompt.
436
+ """
437
+ prompts = {}
438
+ for action in self.get_all_actions():
439
+ prompts[action.name] = {
440
+ "system_prompt": self.system_prompt,
441
+ "prompt": action.prompt
442
+ }
443
+ return prompts
444
+
445
+ def set_prompt(self, action_name: str, prompt: str, system_prompt: Optional[str] = None) -> bool:
446
+ """
447
+ Set the prompt for a specific action of this agent.
448
+
449
+ Args:
450
+ action_name: Name of the action whose prompt should be updated
451
+ prompt: New prompt text to set for the action
452
+ system_prompt: Optional new system prompt to set for the agent
453
+
454
+ Returns:
455
+ bool: True if the prompt was successfully updated, False otherwise
456
+
457
+ Raises:
458
+ KeyError: If the action_name does not exist for this agent
459
+ """
460
+ try:
461
+ action = self.get_action(action_name)
462
+ action.prompt = prompt
463
+
464
+ if system_prompt is not None:
465
+ self.system_prompt = system_prompt
466
+
467
+ return True
468
+ except KeyError:
469
+ raise KeyError(f"Action '{action_name}' not found in agent '{self.name}'")
470
+
471
+ def set_prompts(self, prompts: dict) -> bool:
472
+ """
473
+ Set the prompts for all actions of this agent.
474
+
475
+ Args:
476
+ prompts: A dictionary with keys in the format 'action_name' and values
477
+ containing the system_prompt and action prompt.
478
+
479
+ Returns:
480
+ bool: True if the prompts were successfully updated, False otherwise
481
+ """
482
+ for action_name, prompt_data in prompts.items():
483
+ # self.set_prompt(action_name, prompt_data["prompt"], prompt_data["system_prompt"])
484
+ if not isinstance(prompt_data, dict):
485
+ raise ValueError(f"Invalid prompt data for action '{action_name}'. Expected a dictionary with 'prompt' and 'system_prompt' (optional) keys.")
486
+ if "prompt" not in prompt_data:
487
+ raise ValueError(f"Missing 'prompt' key in prompt data for action '{action_name}'.")
488
+ self.set_prompt(action_name, prompt_data["prompt"], prompt_data.get("system_prompt", None))
489
+ return True
490
+
491
+ def save_module(self, path: str, ignore: List[str] = [], **kwargs)-> str:
492
+ """Save the agent to persistent storage.
493
+
494
+ Args:
495
+ path: Path where the agent should be saved
496
+ ignore: List of field names to exclude from serialization
497
+ **kwargs (Any): Additional parameters for the save operation
498
+
499
+ Returns:
500
+ The path where the agent was saved
501
+ """
502
+ ignore_fields = self._save_ignore_fields + ignore
503
+ super().save_module(path=path, ignore=ignore_fields, **kwargs)
504
+
505
+ @classmethod
506
+ def load_module(cls, path: str, llm_config: LLMConfig = None, **kwargs) -> "Agent":
507
+ """
508
+ load the agent from local storage. Must provide `llm_config` when loading the agent from local storage.
509
+
510
+ Args:
511
+ path: The path of the file
512
+ llm_config: The LLMConfig instance
513
+
514
+ Returns:
515
+ Agent: The loaded agent instance
516
+ """
517
+ agent = super().load_module(path=path, **kwargs)
518
+ if llm_config is not None:
519
+ agent["llm_config"] = llm_config.to_dict()
520
+ return agent
521
+
522
+ def get_config(self) -> dict:
523
+ """
524
+ Get a dictionary containing all necessary configuration to recreate this agent.
525
+
526
+ Returns:
527
+ dict: A configuration dictionary that can be used to initialize a new Agent instance
528
+ with the same properties as this one.
529
+ """
530
+ config = self.to_dict()
531
+ return config
evoagentx/agents/agent_generator.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .agent import Agent
2
+ from ..actions.agent_generation import AgentGeneration
3
+ from ..prompts.agent_generator import AGENT_GENERATOR
4
+
5
+
6
+ class AgentGenerator(Agent):
7
+
8
+ """
9
+ An agent responsible for generating agents for a task.
10
+ """
11
+
12
+ def __init__(self, **kwargs):
13
+
14
+ name = kwargs.pop("name") if "name" in kwargs else AGENT_GENERATOR["name"]
15
+ description = kwargs.pop("description") if "description" in kwargs else AGENT_GENERATOR["description"]
16
+ system_prompt = kwargs.pop("system_prompt") if "system_prompt" in kwargs else AGENT_GENERATOR["system_prompt"]
17
+ actions = kwargs.pop("actions") if "actions" in kwargs else [AgentGeneration(tools=kwargs.pop("tools", []))]
18
+ super().__init__(name=name, description=description, system_prompt=system_prompt, actions=actions, **kwargs)
19
+
20
+ @property
21
+ def agent_generation_action_name(self):
22
+ return self.get_action_name(action_cls=AgentGeneration)
23
+
evoagentx/agents/agent_manager.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ from enum import Enum
3
+ from typing import Union, Optional, Dict, List
4
+ from pydantic import Field
5
+ from copy import deepcopy
6
+
7
+ from .agent import Agent
8
+ # from .agent_generator import AgentGenerator
9
+ from .customize_agent import CustomizeAgent
10
+ from ..core.module import BaseModule
11
+ from ..core.decorators import atomic_method
12
+ from ..storages.base import StorageHandler
13
+ from ..models.model_configs import LLMConfig
14
+ from ..tools.tool import Toolkit, Tool
15
+ class AgentState(str, Enum):
16
+ AVAILABLE = "available"
17
+ RUNNING = "running"
18
+
19
+
20
+ class AgentManager(BaseModule):
21
+ """
22
+ Responsible for creating and managing all Agent objects required for workflow operation.
23
+
24
+ Attributes:
25
+ storage_handler (StorageHandler): Used to load and save agents from/to storage.
26
+ agents (List[Agent]): A list to keep track of all managed Agent instances.
27
+ agent_states (Dict[str, AgentState]): A dictionary to track the state of each Agent by name.
28
+ """
29
+ agents: List[Agent] = Field(default_factory=list)
30
+ agent_states: Dict[str, AgentState] = Field(default_factory=dict) # agent_name to AgentState mapping
31
+ storage_handler: Optional[StorageHandler] = None # used to load and save agent from storage.
32
+ # agent_generator: Optional[AgentGenerator] = None # used to generate agents for a specific subtask
33
+ tools: Optional[List[Union[Toolkit, Tool]]] = None
34
+
35
+ def init_module(self):
36
+ self._lock = threading.Lock()
37
+ self._state_conditions = {}
38
+ if self.agents:
39
+ for agent in self.agents:
40
+ self.agent_states[agent.name] = self.agent_states.get(agent.name, AgentState.AVAILABLE)
41
+ if agent.name not in self._state_conditions:
42
+ self._state_conditions[agent.name] = threading.Condition()
43
+ self.check_agents()
44
+
45
+ def check_agents(self):
46
+ """Validate agent list integrity and state consistency.
47
+
48
+ Performs thorough validation of the agent manager's internal state:
49
+ 1. Checks for duplicate agent names
50
+ 2. Verifies that agent states exist for all agents
51
+ 3. Ensures agent list and state dictionary sizes match
52
+ """
53
+ # check that the names of self.agents should be unique
54
+ duplicate_agent_names = self.find_duplicate_agents(self.agents)
55
+ if duplicate_agent_names:
56
+ raise ValueError(f"The agents should be unique. Found duplicate agent names: {duplicate_agent_names}!")
57
+ # check agent states
58
+ if len(self.agents) != len(self.agent_states):
59
+ raise ValueError(f"The lengths of self.agents ({len(self.agents)}) and self.agent_states ({len(self.agent_states)}) are different!")
60
+ missing_agents = self.find_missing_agent_states()
61
+ if missing_agents:
62
+ raise ValueError(f"The following agents' states were not found: {missing_agents}")
63
+
64
+ def find_duplicate_agents(self, agents: List[Agent]) -> List[str]:
65
+ # return the names of duplicate agents based on agent.name
66
+ unique_agent_names = set()
67
+ duplicate_agent_names = set()
68
+ for agent in agents:
69
+ agent_name = agent.name
70
+ if agent_name in unique_agent_names:
71
+ duplicate_agent_names.add(agent_name)
72
+ unique_agent_names.add(agent_name)
73
+ return list(duplicate_agent_names)
74
+
75
+ def find_missing_agent_states(self):
76
+ missing_agents = [agent.name for agent in self.agents if agent.name not in self.agent_states]
77
+ return missing_agents
78
+
79
+ def list_agents(self) -> List[str]:
80
+ return [agent.name for agent in self.agents]
81
+
82
+ def has_agent(self, agent_name: str) -> bool:
83
+ """Check if an agent with the given name exists in the manager.
84
+
85
+ Args:
86
+ agent_name: The name of the agent to check
87
+
88
+ Returns:
89
+ True if an agent with the given name exists, False otherwise
90
+ """
91
+ all_agent_names = self.list_agents()
92
+ return agent_name in all_agent_names
93
+
94
+ @property
95
+ def size(self):
96
+ """
97
+ Get the total number of agents managed by this manager.
98
+ """
99
+ return len(self.agents)
100
+
101
+ def load_agent(self, agent_name: str, **kwargs) -> Agent:
102
+ """Load an agent from local storage through storage_handler.
103
+
104
+ Retrieves agent data from storage and creates an Agent instance.
105
+
106
+ Args:
107
+ agent_name: The name of the agent to load
108
+ **kwargs (Any): Additional parameters for agent creation
109
+
110
+ Returns:
111
+ Agent instance with data loaded from storage
112
+ """
113
+ if not self.storage_handler:
114
+ raise ValueError("must provide ``self.storage_handler`` to use ``load_agent``")
115
+ agent_data = self.storage_handler.load_agent(agent_name=agent_name)
116
+ agent: Agent = self.create_customize_agent(agent_data=agent_data)
117
+ return agent
118
+
119
+ def load_all_agents(self, **kwargs):
120
+ """Load all agents from storage and add them to the manager.
121
+
122
+ Retrieves all available agents from storage and adds them to the
123
+ managed agents collection.
124
+
125
+ Args:
126
+ **kwargs (Any): Additional parameters passed to storage handler
127
+ """
128
+ pass
129
+
130
+ def update_tools(self, agent_data: dict) -> None:
131
+ """
132
+ Update agent_data with tools based on tool_names.
133
+
134
+ Handles four scenarios:
135
+ 1. Neither tool_names nor tools exist: return directly
136
+ 2. Only tool_names exists: resolve tool_names to tools and set tools field
137
+ 3. Only tools exists: return directly (no action needed)
138
+ 4. Both exist: merge tool_names into existing tools (skip duplicates)
139
+
140
+ Args:
141
+ agent_data (dict): Agent configuration dictionary that may contain 'tool_names' and/or 'tools'
142
+
143
+ Raises:
144
+ ValueError: If tool_names exist but self.tools is None, or if requested tools are not found
145
+ """
146
+ tool_names = agent_data.get("tool_names", None)
147
+ existing_tools = agent_data.get("tools", None)
148
+
149
+ # Case 1: Neither tool_names nor tools exist
150
+ if not tool_names and not existing_tools:
151
+ return
152
+
153
+ # Case 3: Only tools exist (no tool_names)
154
+ if not tool_names and existing_tools:
155
+ return
156
+
157
+ # For cases 2 and 4: tool_names exists, need to resolve
158
+ if self.tools is None:
159
+ raise ValueError(
160
+ f"Agent requires tools {tool_names}, but no tools are available in AgentManager. "
161
+ f"Please set self.tools before creating agents with tool_names."
162
+ )
163
+
164
+ # Create tool mapping from available tools
165
+ tool_mapping = {}
166
+ for tool in self.tools:
167
+ tool_mapping[tool.name] = tool
168
+
169
+ # Case 2: Only tool_names exists - initialize empty tools list
170
+ if tool_names and not existing_tools:
171
+ existing_tools = []
172
+
173
+ # Case 2 & 4: Process tool_names (either with empty or existing tools list)
174
+ if tool_names:
175
+ # Create a set of existing tool names for quick lookup
176
+ existing_tool_names = {tool.name for tool in existing_tools}
177
+
178
+ tools_to_add = []
179
+ missing_tools = []
180
+
181
+ for tool_name in tool_names:
182
+ # Skip if tool already exists in tools
183
+ if tool_name in existing_tool_names:
184
+ continue
185
+
186
+ # Try to resolve new tool
187
+ if tool_name in tool_mapping:
188
+ tools_to_add.append(tool_mapping[tool_name])
189
+ else:
190
+ missing_tools.append(tool_name)
191
+
192
+ if missing_tools:
193
+ available_tools = list(tool_mapping.keys())
194
+ raise ValueError(
195
+ f"The following tools are not available: {missing_tools}. "
196
+ f"Available tools: {available_tools}"
197
+ )
198
+
199
+ # Merge new tools with existing ones
200
+ if tools_to_add:
201
+ agent_data["tools"] = list(existing_tools) + tools_to_add
202
+
203
+ def create_customize_agent(self, agent_data: dict, llm_config: Optional[Union[LLMConfig, dict]]=None, **kwargs) -> CustomizeAgent:
204
+ """
205
+ create a customized agent from the provided `agent_data`.
206
+
207
+ Args:
208
+ agent_data: The data used to create an Agent instance, must contain 'name', 'description' and 'prompt' keys.
209
+ llm_config (Optional[LLMConfig]): The LLM configuration to be used for the agent.
210
+ It will be used as the default LLM for agents without a `llm_config` key.
211
+ If not provided, the `agent_data` should contain a `llm_config` key.
212
+ If provided and `agent_data` contains a `llm_config` key, the `llm_config` in `agent_data` will be used.
213
+ **kwargs (Any): Additional parameters for agent creation
214
+
215
+ Returns:
216
+ Agent: the instantiated agent instance.
217
+ """
218
+
219
+ agent_data = deepcopy(agent_data)
220
+ agent_llm_config = agent_data.get("llm_config", llm_config)
221
+ if not agent_data.get("is_human", False) and not agent_llm_config:
222
+ raise ValueError("`agent_data` should contain a `llm_config` key or `llm_config` should be provided.")
223
+
224
+ if agent_llm_config:
225
+ if isinstance(agent_llm_config, dict):
226
+ agent_data["llm_config"] = agent_llm_config
227
+ elif isinstance(agent_llm_config, LLMConfig):
228
+ agent_data["llm_config"] = agent_llm_config.to_dict()
229
+
230
+ # tool_mapping = {}
231
+ # if self.tools is not None:
232
+ # for tool in self.tools:
233
+ # tool_mapping[tool.name] = tool
234
+ # if agent_data.get("tool_names", None):
235
+ # agent_data["tools"] = [tool_mapping[tool_name] for tool_name in agent_data["tool_names"]]
236
+ self.update_tools(agent_data=agent_data) # add `tools` field if needed
237
+ return CustomizeAgent.from_dict(data=agent_data)
238
+
239
+ def get_agent_name(self, agent: Union[str, dict, Agent]) -> str:
240
+ """Extract agent name from different agent representations.
241
+
242
+ Handles different ways to specify an agent (string name, dictionary, or
243
+ Agent instance) and extracts the agent name.
244
+
245
+ Args:
246
+ agent: Agent specified as a string name, dictionary with 'name' key,
247
+ or Agent instance
248
+
249
+ Returns:
250
+ The extracted agent name as a string
251
+ """
252
+ if isinstance(agent, str):
253
+ agent_name = agent
254
+ elif isinstance(agent, dict):
255
+ agent_name = agent["name"]
256
+ elif isinstance(agent, Agent):
257
+ agent_name = agent.name
258
+ else:
259
+ raise ValueError(f"{type(agent)} is not a supported type for ``get_agent_name``. Supported types: [str, dict, Agent].")
260
+ return agent_name
261
+
262
+ def create_agent(self, agent: Union[str, dict, Agent], llm_config: Optional[LLMConfig]=None, **kwargs) -> Agent:
263
+
264
+ if isinstance(agent, str):
265
+ if self.storage_handler is None:
266
+ # if self.storage_handler is None, the agent (str) must exist in self.agents. Otherwise, a dictionary or an Agent instance should be provided.
267
+ if not self.has_agent(agent_name=agent):
268
+ raise ValueError(f"Agent ``{agent}`` does not exist! You should provide a dictionary or an Agent instance when ``self.storage_handler`` is not provided.")
269
+ return self.get_agent(agent_name=agent)
270
+ else:
271
+ # if self.storage_handler is not None, the agent (str) must exist in the storage and will be loaded from the storage.
272
+ agent_instance = self.load_agent(agent_name=agent)
273
+ elif isinstance(agent, dict):
274
+ if not agent.get("is_human", False) and (llm_config is None and "llm_config" not in agent):
275
+ raise ValueError("When providing an agent as a dictionary, you must either include 'llm_config' in the dictionary or provide it as a parameter.")
276
+ agent_instance = self.create_customize_agent(agent_data=agent, llm_config=llm_config, **kwargs)
277
+ elif isinstance(agent, Agent):
278
+ agent_instance = agent
279
+ else:
280
+ raise ValueError(f"{type(agent)} is not a supported input type of ``create_agent``. Supported types: [str, dict, Agent].")
281
+ return agent_instance
282
+
283
+ @atomic_method
284
+ def add_agent(self, agent: Union[str, dict, Agent], llm_config: Optional[LLMConfig]=None, **kwargs):
285
+ """
286
+ add a single agent, ignore if the agent already exists (judged by the name of an agent).
287
+
288
+ Args:
289
+ agent: The agent to be added, specified as:
290
+ - String: Agent name to load from storage
291
+ - Dictionary: Agent specification to create a CustomizeAgent
292
+ - Agent: Existing Agent instance to add directly
293
+ llm_config (Optional[LLMConfig]): The LLM configuration to be used for the agent. Only used when the `agent` is a dictionary, used to create a CustomizeAgent.
294
+ **kwargs (Any): Additional parameters for agent creation
295
+ """
296
+ # Check for 'tool' key and convert it to 'tools' if needed
297
+ # if isinstance(agent, dict) and "tool_names" in agent:
298
+ # tools_mapping = {}
299
+ # if self.tools is not None:
300
+ # for tool in self.tools:
301
+ # tools_mapping[tool.name] = tool
302
+ # agent["tools"] = [tools_mapping[tool_name] for tool_name in agent["tool_names"]]
303
+ # agent["tools"] = [tool if isinstance(tool, Toolkit) else Toolkit(name=tool.name, tools=[tool]) for tool in agent["tools"]]
304
+
305
+ agent_name = self.get_agent_name(agent=agent)
306
+ if self.has_agent(agent_name=agent_name):
307
+ return
308
+ agent_instance = self.create_agent(agent=agent, llm_config=llm_config, **kwargs)
309
+ self.agents.append(agent_instance)
310
+ self.agent_states[agent_instance.name] = AgentState.AVAILABLE
311
+ if agent_instance.name not in self._state_conditions:
312
+ self._state_conditions[agent_instance.name] = threading.Condition()
313
+ self.check_agents()
314
+
315
+ def add_agents(self, agents: List[Union[str, dict, Agent]], llm_config: Optional[LLMConfig]=None, **kwargs):
316
+ """
317
+ add several agents by using self.add_agent().
318
+ """
319
+ for agent in agents:
320
+ self.add_agent(agent=agent, llm_config=llm_config, **kwargs)
321
+
322
+ def add_agents_from_workflow(self, workflow_graph, llm_config: Optional[LLMConfig]=None, **kwargs):
323
+ """
324
+ Initialize agents from the nodes of a given WorkFlowGraph and add these agents to self.agents.
325
+
326
+ Args:
327
+ workflow_graph (WorkFlowGraph): The workflow graph containing nodes with agents information.
328
+ llm_config (Optional[LLMConfig]): The LLM configuration to be used for the agents.
329
+ **kwargs (Any): Additional parameters passed to add_agent
330
+ """
331
+ from ..workflow.workflow_graph import WorkFlowGraph
332
+ if not isinstance(workflow_graph, WorkFlowGraph):
333
+ raise TypeError("workflow_graph must be an instance of WorkFlowGraph")
334
+ for node in workflow_graph.nodes:
335
+ if node.agents:
336
+ for agent in node.agents:
337
+ self.add_agent(agent=agent, llm_config=llm_config, **kwargs)
338
+
339
+ def update_agents_from_workflow(self, workflow_graph, llm_config: Optional[LLMConfig]=None, **kwargs):
340
+ """
341
+ Update agents from a given WorkFlowGraph.
342
+
343
+ Args:
344
+ workflow_graph (WorkFlowGraph): The workflow graph containing nodes with agents information.
345
+ llm_config (Optional[LLMConfig]): The LLM configuration to be used for the agents.
346
+ **kwargs: Additional parameters passed to update_agent
347
+ """
348
+ from ..workflow.workflow_graph import WorkFlowGraph
349
+ if not isinstance(workflow_graph, WorkFlowGraph):
350
+ raise TypeError("workflow_graph must be an instance of WorkFlowGraph")
351
+ for node in workflow_graph.nodes:
352
+ if node.agents:
353
+ for agent in node.agents:
354
+ agent_name = self.get_agent_name(agent=agent)
355
+ if self.has_agent(agent_name=agent_name):
356
+ # use the llm_config of the existing agent
357
+ agent_llm_config = self.get_agent(agent_name).llm_config
358
+ self.update_agent(agent=agent, llm_config=agent_llm_config, **kwargs)
359
+ else:
360
+ self.add_agent(agent=agent, llm_config=llm_config, **kwargs)
361
+
362
+ def get_agent(self, agent_name: str, **kwargs) -> Agent:
363
+ """Retrieve an agent by its name from managed agents.
364
+
365
+ Searches the list of managed agents for an agent with the specified name.
366
+
367
+ Args:
368
+ agent_name: The name of the agent to retrieve
369
+ **kwargs (Any): Additional parameters (unused)
370
+
371
+ Returns:
372
+ The Agent instance with the specified name
373
+ """
374
+ for agent in self.agents:
375
+ if agent.name == agent_name:
376
+ return agent
377
+ raise ValueError(f"Agent ``{agent_name}`` does not exists!")
378
+
379
+ def update_agent(self, agent: Union[dict, Agent], llm_config: Optional[LLMConfig]=None, **kwargs):
380
+ """
381
+ Update an agent in the manager.
382
+
383
+ Args:
384
+ agent: The agent to be updated, specified as:
385
+ - Dictionary: Agent specification to update a CustomizeAgent
386
+ - Agent: Existing Agent instance to update
387
+ llm_config (Optional[LLMConfig]): The LLM configuration to be used for the agent.
388
+ """
389
+ agent_name = self.get_agent_name(agent=agent)
390
+ self.remove_agent(agent_name=agent_name)
391
+ self.add_agent(agent=agent, llm_config=llm_config, **kwargs)
392
+
393
+ @atomic_method
394
+ def remove_agent(self, agent_name: str, remove_from_storage: bool=False, **kwargs):
395
+ """
396
+ Remove an agent from the manager and optionally from storage.
397
+
398
+ Args:
399
+ agent_name: The name of the agent to remove
400
+ remove_from_storage: If True, also remove the agent from storage
401
+ **kwargs (Any): Additional parameters passed to storage_handler.remove_agent
402
+ """
403
+ self.agents = [agent for agent in self.agents if agent.name != agent_name]
404
+ self.agent_states.pop(agent_name, None)
405
+ self._state_conditions.pop(agent_name, None)
406
+ if remove_from_storage:
407
+ self.storage_handler.remove_agent(agent_name=agent_name, **kwargs)
408
+ self.check_agents()
409
+
410
+ def get_agent_state(self, agent_name: str) -> AgentState:
411
+ """
412
+ Get the state of a specific agent by its name.
413
+
414
+ Args:
415
+ agent_name: The name of the agent.
416
+
417
+ Returns:
418
+ AgentState: The current state of the agent.
419
+ """
420
+ return self.agent_states[agent_name]
421
+
422
+ @atomic_method
423
+ def set_agent_state(self, agent_name: str, new_state: AgentState) -> bool:
424
+ """
425
+ Changes an agent's state and notifies any threads waiting on that agent's state.
426
+ Thread-safe operation for coordinating multi-threaded agent execution.
427
+
428
+ Args:
429
+ agent_name: The name of the agent
430
+ new_state: The new state to set
431
+
432
+ Returns:
433
+ True if the state was updated successfully, False otherwise
434
+ """
435
+
436
+ # if agent_name in self.agent_states and isinstance(new_state, AgentState):
437
+ # # self.agent_states[agent_name] = new_state
438
+ # with self._state_conditions[agent_name]:
439
+ # self.agent_states[agent_name] = new_state
440
+ # self._state_conditions[agent_name].notify_all()
441
+ # self.check_agents()
442
+ # return True
443
+ # else:
444
+ # return False
445
+ if agent_name in self.agent_states and isinstance(new_state, AgentState):
446
+ if agent_name not in self._state_conditions:
447
+ self._state_conditions[agent_name] = threading.Condition()
448
+ with self._state_conditions[agent_name]:
449
+ self.agent_states[agent_name] = new_state
450
+ self._state_conditions[agent_name].notify_all()
451
+ return True
452
+ return False
453
+
454
+ def get_all_agent_states(self) -> Dict[str, AgentState]:
455
+ """Get the states of all managed agents.
456
+
457
+ Returns:
458
+ Dict[str, AgentState]: A dictionary mapping agent names to their states.
459
+ """
460
+ return self.agent_states
461
+
462
+ @atomic_method
463
+ def save_all_agents(self, **kwargs):
464
+ """Save all managed agents to persistent storage.
465
+
466
+ Args:
467
+ **kwargs (Any): Additional parameters passed to the storage handler
468
+ """
469
+ pass
470
+
471
+ @atomic_method
472
+ def clear_agents(self):
473
+ """
474
+ Remove all agents from the manager.
475
+ """
476
+ self.agents = []
477
+ self.agent_states = {}
478
+ self._state_conditions = {}
479
+ self.check_agents()
480
+
481
+ def wait_for_agent_available(self, agent_name: str, timeout: Optional[float] = None) -> bool:
482
+ """Wait for an agent to be available.
483
+
484
+ Args:
485
+ agent_name: The name of the agent to wait for
486
+ timeout: Maximum time to wait in seconds, or None to wait indefinitely
487
+
488
+ Returns:
489
+ True if the agent became available, False if timed out
490
+ """
491
+ if agent_name not in self._state_conditions:
492
+ self._state_conditions[agent_name] = threading.Condition()
493
+ condition = self._state_conditions[agent_name]
494
+
495
+ with condition:
496
+ return condition.wait_for(
497
+ lambda: self.agent_states.get(agent_name) == AgentState.AVAILABLE,
498
+ timeout=timeout
499
+ )
500
+
501
+ def copy(self) -> "AgentManager":
502
+ """
503
+ Create a shallow copy of the AgentManager.
504
+ """
505
+ return AgentManager(agents=self.agents, storage_handler=self.storage_handler)
evoagentx/agents/customize_agent.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import inspect
3
+ from pydantic import create_model, Field
4
+ from typing import Optional, Callable, Type, List, Any, Union, Dict
5
+
6
+ from .agent import Agent
7
+ from ..core.logging import logger
8
+ from ..core.registry import MODULE_REGISTRY, PARSE_FUNCTION_REGISTRY
9
+ from ..core.message import Message, MessageType
10
+ from ..models.model_configs import LLMConfig
11
+ from ..models.base_model import PARSER_VALID_MODE
12
+ from ..prompts.utils import DEFAULT_SYSTEM_PROMPT
13
+ from ..prompts.template import PromptTemplate
14
+ from ..actions.action import Action, ActionOutput
15
+ from ..utils.utils import generate_dynamic_class_name, make_parent_folder
16
+ from ..actions.customize_action import CustomizeAction
17
+ from ..actions.action import ActionInput
18
+ from ..tools.tool import Toolkit, Tool
19
+
20
+
21
+ class CustomizeAgent(Agent):
22
+
23
+ """
24
+ CustomizeAgent provides a flexible framework for creating specialized LLM-powered agents without
25
+ writing custom code. It enables the creation of agents with well-defined inputs and outputs,
26
+ custom prompt templates, and configurable parsing strategies.
27
+
28
+ Attributes:
29
+ name (str): The name of the agent.
30
+ description (str): A description of the agent's purpose and capabilities.
31
+ prompt_template (PromptTemplate, optional): The prompt template that will be used for the agent's primary action.
32
+ prompt (str, optional): The prompt template that will be used for the agent's primary action.
33
+ Should contain placeholders in the format `{input_name}` for each input parameter.
34
+ llm_config (LLMConfig, optional): Configuration for the language model.
35
+ inputs (List[dict], optional): List of input specifications, where each dict (e.g., `{"name": str, "type": str, "description": str, ["required": bool]}`) contains:
36
+ - name (str): Name of the input parameter
37
+ - type (str): Type of the input
38
+ - description (str): Description of what the input represents
39
+ - required (bool, optional): Whether this input is required (default: True)
40
+ outputs (List[dict], optional): List of output specifications, where each dict (e.g., `{"name": str, "type": str, "description": str, ["required": bool]}`) contains:
41
+ - name (str): Name of the output field
42
+ - type (str): Type of the output
43
+ - description (str): Description of what the output represents
44
+ - required (bool, optional): Whether this output is required (default: True)
45
+ system_prompt (str, optional): The system prompt for the LLM. Defaults to DEFAULT_SYSTEM_PROMPT.
46
+ output_parser (Type[ActionOutput], optional): A custom class for parsing the LLM's output.
47
+ Must be a subclass of ActionOutput.
48
+ parse_mode (str, optional): Mode for parsing LLM output. Options are:
49
+ - "title": Parse outputs using section titles (default)
50
+ - "str": Parse as plain text
51
+ - "json": Parse as JSON
52
+ - "xml": Parse as XML
53
+ - "custom": Use a custom parsing function
54
+ parse_func (Callable, optional): Custom function for parsing LLM output when parse_mode is "custom".
55
+ Must accept a "content" parameter and return a dictionary.
56
+ title_format (str, optional): Format string for title parsing mode with {title} placeholder.
57
+ Default is "## {title}".
58
+ tools (list[Toolkit], optional): List of tools to be used by the agent.
59
+ max_tool_calls (int, optional): Maximum number of tool calls. Defaults to 5.
60
+ custom_output_format (str, optional): Specify the output format. Only used when `prompt_template` is used.
61
+ If not provided, the output format will be constructed from the `outputs` specification and `parse_mode`.
62
+ """
63
+ def __init__(
64
+ self,
65
+ name: str,
66
+ description: str,
67
+ prompt: Optional[str] = None,
68
+ prompt_template: Optional[PromptTemplate] = None,
69
+ llm_config: Optional[LLMConfig] = None,
70
+ inputs: Optional[List[dict]] = None,
71
+ outputs: Optional[List[dict]] = None,
72
+ system_prompt: Optional[str] = None,
73
+ output_parser: Optional[Type[ActionOutput]] = None,
74
+ parse_mode: Optional[str] = "title",
75
+ parse_func: Optional[Callable] = None,
76
+ title_format: Optional[str] = None,
77
+ tools: Optional[List[Union[Toolkit, Tool]]] = None,
78
+ max_tool_calls: Optional[int] = 5,
79
+ custom_output_format: Optional[str] = None,
80
+ **kwargs
81
+ ):
82
+ system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
83
+ inputs = inputs or []
84
+ outputs = outputs or []
85
+ if tools is not None:
86
+ raw_tool_map = {tool.name: tool for tool in tools}
87
+ tools = [tool if isinstance(tool, Toolkit) else Toolkit(name=tool.name, tools=[tool]) for tool in tools]
88
+ else:
89
+ raw_tool_map = None
90
+
91
+ if prompt is not None and prompt_template is not None:
92
+ logger.warning("Both `prompt` and `prompt_template` are provided in `CustomizeAgent`. `prompt_template` will be used.")
93
+ prompt = None
94
+
95
+ if isinstance(parse_func, str):
96
+ if not PARSE_FUNCTION_REGISTRY.has_function(parse_func):
97
+ raise ValueError(f"parse function `{parse_func}` is not registered! To instantiate a CustomizeAgent from a file, you should use decorator `@register_parse_function` to register the parse function.")
98
+ parse_func = PARSE_FUNCTION_REGISTRY.get_function(parse_func)
99
+
100
+ if isinstance(output_parser, str):
101
+ output_parser = MODULE_REGISTRY.get_module(output_parser)
102
+
103
+ # set default title format
104
+ if parse_mode == "title" and title_format is None:
105
+ title_format = "## {title}"
106
+
107
+ # validate the data
108
+ self.validate_data(
109
+ prompt = prompt,
110
+ prompt_template = prompt_template,
111
+ inputs = inputs,
112
+ outputs = outputs,
113
+ output_parser = output_parser,
114
+ parse_mode = parse_mode,
115
+ parse_func = parse_func,
116
+ title_format = title_format
117
+ )
118
+
119
+ customize_action = self.create_customize_action(
120
+ name=name,
121
+ desc=description,
122
+ prompt=prompt,
123
+ prompt_template=prompt_template,
124
+ inputs=inputs,
125
+ outputs=outputs,
126
+ parse_mode=parse_mode,
127
+ parse_func=parse_func,
128
+ output_parser=output_parser,
129
+ title_format=title_format,
130
+ custom_output_format=custom_output_format ,
131
+ tools=tools,
132
+ max_tool_calls=max_tool_calls
133
+ )
134
+ super().__init__(
135
+ name=name,
136
+ description=description,
137
+ llm_config=llm_config,
138
+ system_prompt=system_prompt,
139
+ actions=[customize_action],
140
+ **kwargs
141
+ )
142
+ self._store_inputs_outputs_info(inputs, outputs, raw_tool_map)
143
+ self.output_parser = output_parser
144
+ self.parse_mode = parse_mode
145
+ self.parse_func = parse_func
146
+ self.title_format = title_format
147
+ self.tools = tools
148
+ self.max_tool_calls = max_tool_calls
149
+ self.custom_output_format = custom_output_format
150
+
151
+ def _add_tools(self, tools: List[Toolkit]):
152
+ self.get_action(self.customize_action_name).add_tools(tools)
153
+
154
+ @property
155
+ def customize_action_name(self) -> str:
156
+ """
157
+ Get the name of the primary custom action for this agent.
158
+
159
+ Returns:
160
+ The name of the primary custom action
161
+ """
162
+ for action in self.actions:
163
+ if action.name != self.cext_action_name:
164
+ return action.name
165
+ raise ValueError("Couldn't find the customize action name!")
166
+
167
+ @property
168
+ def action(self) -> Action:
169
+ """
170
+ Get the primary custom action for this agent.
171
+
172
+ Returns:
173
+ The primary custom action
174
+ """
175
+ return self.get_action(self.customize_action_name)
176
+
177
+ @property
178
+ def prompt(self) -> str:
179
+ """
180
+ Get the prompt for the primary custom action.
181
+
182
+ Returns:
183
+ The prompt for the primary custom action
184
+ """
185
+ return self.action.prompt
186
+
187
+ @property
188
+ def prompt_template(self) -> PromptTemplate:
189
+ """
190
+ Get the prompt template for the primary custom action.
191
+
192
+ Returns:
193
+ The prompt template for the primary custom action
194
+ """
195
+ return self.action.prompt_template
196
+
197
+ def validate_data(self, prompt: str, prompt_template: PromptTemplate, inputs: List[dict], outputs: List[dict], output_parser: Type[ActionOutput], parse_mode: str, parse_func: Callable, title_format: str):
198
+
199
+ # check if the prompt is provided
200
+ if prompt is None and prompt_template is None:
201
+ raise ValueError("`prompt` or `prompt_template` is required when creating a CustomizeAgent.")
202
+
203
+ # check if all the inputs are in the prompt (only used when prompt_template is not provided)
204
+ if prompt_template is None and inputs:
205
+ all_input_names = [input_item["name"] for input_item in inputs]
206
+ inputs_names_not_in_prompt = [name for name in all_input_names if f'{{{name}}}' not in prompt]
207
+ if inputs_names_not_in_prompt:
208
+ raise KeyError(f"The following inputs are not found in the prompt: {inputs_names_not_in_prompt}.")
209
+
210
+ # check if the output_parser is valid
211
+ if output_parser is not None:
212
+ self._check_output_parser(outputs, output_parser)
213
+
214
+ # check the parse_mode, parse_func, and title_format
215
+ if parse_mode not in PARSER_VALID_MODE:
216
+ raise ValueError(f"'{parse_mode}' is an invalid value for `parse_mode`. Available choices: {PARSER_VALID_MODE}.")
217
+
218
+ if parse_mode == "custom":
219
+ if parse_func is None:
220
+ raise ValueError("`parse_func` (a callable function with an input argument `content`) must be provided when `parse_mode` is 'custom'.")
221
+
222
+ if parse_func is not None:
223
+ if not callable(parse_func):
224
+ raise ValueError("`parse_func` must be a callable function with an input argument `content`.")
225
+ signature = inspect.signature(parse_func)
226
+ if "content" not in signature.parameters:
227
+ raise ValueError("`parse_func` must have an input argument `content`.")
228
+ if not PARSE_FUNCTION_REGISTRY.has_function(parse_func.__name__):
229
+ logger.warning(
230
+ f"parse function `{parse_func.__name__}` is not registered. This can cause issues when loading the agent from a file. "
231
+ f"It is recommended to register the parse function using `register_parse_function`:\n"
232
+ f"from evoagentx.core.registry import register_parse_function\n"
233
+ f"@register_parse_function\n"
234
+ f"def {parse_func.__name__}(content: str) -> dict:\n"
235
+ r" return {'output_name': output_value}"
236
+ )
237
+
238
+ if title_format is not None:
239
+ if parse_mode != "title":
240
+ logger.warning(f"`title_format` will not be used because `parse_mode` is '{parse_mode}', not 'title'. Set `parse_mode='title'` to use title formatting.")
241
+ if r'{title}' not in title_format:
242
+ raise ValueError(r"`title_format` must contain the placeholder `{title}`.")
243
+
244
+ def create_customize_action(
245
+ self,
246
+ name: str,
247
+ desc: str,
248
+ prompt: str,
249
+ prompt_template: PromptTemplate,
250
+ inputs: List[dict],
251
+ outputs: List[dict],
252
+ parse_mode: str,
253
+ parse_func: Optional[Callable] = None,
254
+ output_parser: Optional[ActionOutput] = None,
255
+ title_format: Optional[str] = "## {title}",
256
+ custom_output_format: Optional[str] = None,
257
+ tools: Optional[List[Toolkit]] = None,
258
+ max_tool_calls: Optional[int] = 5
259
+ ) -> Action:
260
+ """Create a custom action based on the provided specifications.
261
+
262
+ This method dynamically generates an Action class and instance with:
263
+ - Input parameters defined by the inputs specification
264
+ - Output format defined by the outputs specification
265
+ - Custom execution logic using the customize_action_execute function
266
+ - If tools is provided, returns a CustomizeAction action instead
267
+
268
+ Args:
269
+ name: Base name for the action
270
+ desc: Description of the action
271
+ prompt: Prompt template for the action
272
+ prompt_template: Prompt template for the action
273
+ inputs: List of input field specifications
274
+ outputs: List of output field specifications
275
+ parse_mode: Mode to use for parsing LLM output
276
+ parse_func: Optional custom parsing function
277
+ output_parser: Optional custom output parser class
278
+ tools: Optional list of tools
279
+
280
+ Returns:
281
+ A newly created Action instance
282
+ """
283
+ assert prompt is not None or prompt_template is not None, "must provide `prompt` or `prompt_template` when creating CustomizeAgent"
284
+
285
+ # create the action input type
286
+ action_input_fields = {}
287
+ for field in inputs:
288
+ required = field.get("required", True)
289
+ if required:
290
+ action_input_fields[field["name"]] = (str, Field(description=field["description"]))
291
+ else:
292
+ action_input_fields[field["name"]] = (Optional[str], Field(default=None, description=field["description"]))
293
+
294
+ action_input_type = create_model(
295
+ self._get_unique_class_name(
296
+ generate_dynamic_class_name(name+" action_input")
297
+ ),
298
+ **action_input_fields,
299
+ __base__=ActionInput
300
+ )
301
+
302
+ # create the action output type
303
+ if output_parser is None:
304
+ action_output_fields = {}
305
+ for field in outputs:
306
+ required = field.get("required", True)
307
+ if required:
308
+ action_output_fields[field["name"]] = (Any, Field(description=field["description"]))
309
+ else:
310
+ action_output_fields[field["name"]] = (Optional[Any], Field(default=None, description=field["description"]))
311
+ action_output_type = create_model(
312
+ self._get_unique_class_name(
313
+ generate_dynamic_class_name(name+" action_output")
314
+ ),
315
+ **action_output_fields,
316
+ __base__=ActionOutput,
317
+ # get_content_data=customize_get_content_data,
318
+ # to_str=customize_to_str
319
+ )
320
+ else:
321
+ # self._check_output_parser(outputs, output_parser)
322
+ action_output_type = output_parser
323
+
324
+ action_cls_name = self._get_unique_class_name(
325
+ generate_dynamic_class_name(name+" action")
326
+ )
327
+
328
+ # Create CustomizeAction-based action with parsing properties only
329
+ customize_action_cls = create_model(
330
+ action_cls_name,
331
+ __base__=CustomizeAction
332
+ )
333
+
334
+ customize_action = customize_action_cls(
335
+ name=action_cls_name,
336
+ description=desc,
337
+ prompt=prompt,
338
+ prompt_template=prompt_template,
339
+ inputs_format=action_input_type,
340
+ outputs_format=action_output_type,
341
+ parse_mode=parse_mode,
342
+ parse_func=parse_func,
343
+ title_format=title_format,
344
+ custom_output_format=custom_output_format,
345
+ max_tool_try=max_tool_calls,
346
+ tools=tools
347
+ )
348
+
349
+ return customize_action
350
+
351
+ def _check_output_parser(self, outputs: List[dict], output_parser: Type[ActionOutput]):
352
+
353
+ if output_parser is not None:
354
+ if not isinstance(output_parser, type):
355
+ raise TypeError(f"output_parser must be a class, but got {type(output_parser).__name__}")
356
+ if not issubclass(output_parser, ActionOutput):
357
+ raise ValueError(f"`output_parser` must be a class and a subclass of `ActionOutput`, but got `{output_parser.__name__}`.")
358
+
359
+ # check if the output parser is compatible with the outputs
360
+ output_parser_fields = output_parser.get_attrs()
361
+ all_output_names = [output_item["name"] for output_item in outputs]
362
+ for field in output_parser_fields:
363
+ if field not in all_output_names:
364
+ raise ValueError(
365
+ f"The output parser `{output_parser.__name__}` is not compatible with the `outputs`.\n"
366
+ f"The output parser fields: {output_parser_fields}.\n"
367
+ f"The outputs: {all_output_names}.\n"
368
+ f"All the fields in the output parser must be present in the outputs."
369
+ )
370
+
371
+ def _store_inputs_outputs_info(self, inputs: List[dict], outputs: List[dict], tool_map: Dict[str, Union[Toolkit, Tool]]):
372
+
373
+ self._action_input_types, self._action_input_required = {}, {}
374
+ for field in inputs:
375
+ required = field.get("required", True)
376
+ self._action_input_types[field["name"]] = field["type"]
377
+ self._action_input_required[field["name"]] = required
378
+ self._action_output_types, self._action_output_required = {}, {}
379
+ for field in outputs:
380
+ required = field.get("required", True)
381
+ self._action_output_types[field["name"]] = field["type"]
382
+ self._action_output_required[field["name"]] = required
383
+ self._raw_tool_map = tool_map
384
+
385
+ def __call__(self, inputs: dict = None, return_msg_type: MessageType = MessageType.UNKNOWN, **kwargs) -> Message:
386
+ """
387
+ Call the customize action.
388
+
389
+ Args:
390
+ inputs (dict): The inputs to the customize action.
391
+ **kwargs (Any): Additional keyword arguments.
392
+
393
+ Returns:
394
+ ActionOutput: The output of the customize action.
395
+ """
396
+ # return self.execute(action_name=self.customize_action_name, action_input_data=inputs, **kwargs)
397
+ inputs = inputs or {}
398
+ return super().__call__(action_name=self.customize_action_name, action_input_data=inputs, return_msg_type=return_msg_type, **kwargs)
399
+
400
+ def get_customize_agent_info(self) -> dict:
401
+ """
402
+ Get the information of the customize agent.
403
+ """
404
+ customize_action = self.get_action(self.customize_action_name)
405
+ action_input_params = customize_action.inputs_format.get_attrs()
406
+ action_output_params = customize_action.outputs_format.get_attrs()
407
+
408
+ config = {
409
+ "class_name": "CustomizeAgent",
410
+ "name": self.name,
411
+ "description": self.description,
412
+ "prompt": customize_action.prompt,
413
+ "prompt_template": customize_action.prompt_template.to_dict() if customize_action.prompt_template is not None else None,
414
+ # "llm_config": self.llm_config.to_dict(exclude_none=True),
415
+ "inputs": [
416
+ {
417
+ "name": field,
418
+ "type": self._action_input_types[field],
419
+ "description": field_info.description,
420
+ "required": self._action_input_required[field]
421
+ }
422
+ for field, field_info in customize_action.inputs_format.model_fields.items() if field in action_input_params
423
+ ],
424
+ "outputs": [
425
+ {
426
+ "name": field,
427
+ "type": self._action_output_types[field],
428
+ "description": field_info.description,
429
+ "required": self._action_output_required[field]
430
+ }
431
+ for field, field_info in customize_action.outputs_format.model_fields.items() if field in action_output_params
432
+ ],
433
+ "system_prompt": self.system_prompt,
434
+ "output_parser": self.output_parser.__name__ if self.output_parser is not None else None,
435
+ "parse_mode": self.parse_mode,
436
+ "parse_func": self.parse_func.__name__ if self.parse_func is not None else None,
437
+ "title_format": self.title_format,
438
+ "tool_names": [tool.name for tool in customize_action.tools] if customize_action.tools else [],
439
+ "max_tool_calls": self.max_tool_calls,
440
+ "custom_output_format": self.custom_output_format
441
+ }
442
+ return config
443
+
444
+ @classmethod
445
+ def load_module(cls, path: str, llm_config: LLMConfig = None, tools: List[Union[Toolkit, Tool]] = None, **kwargs) -> "CustomizeAgent":
446
+ """
447
+ load the agent from local storage. Must provide `llm_config` when loading the agent from local storage.
448
+ If tools is provided, tool_names must also be provided.
449
+
450
+ Args:
451
+ path: The path of the file
452
+ llm_config: The LLMConfig instance
453
+ tool_names: List of tool names to be used by the agent. If provided,
454
+ tool_dict: Dictionary mapping tool names to Tool instances. Required when tool_names is provided.
455
+
456
+ Returns:
457
+ CustomizeAgent: The loaded agent instance
458
+ """
459
+ match_dict = {}
460
+ agent = super().load_module(path=path, llm_config=llm_config, **kwargs)
461
+ if tools:
462
+ match_dict = {tool.name:tool for tool in tools}
463
+ if agent.get("tool_names", None):
464
+ assert tools is not None, "must provide `tools: List[Union[Toolkit, Tool]]` when using `load_module` or `from_file` to load the agent from local storage and `tool_names` is not None or empty"
465
+ added_tools = [match_dict[tool_name] for tool_name in agent["tool_names"]]
466
+ agent["tools"] = [tool if isinstance(tool, Toolkit) else Toolkit(name=tool.name, tools=[tool]) for tool in added_tools]
467
+ return agent
468
+
469
+ def save_module(self, path: str, ignore: List[str] = [], **kwargs)-> str:
470
+ """Save the customize agent's configuration to a JSON file.
471
+
472
+ Args:
473
+ path: File path where the configuration should be saved
474
+ ignore: List of keys to exclude from the saved configuration
475
+ **kwargs (Any): Additional parameters for the save operation
476
+
477
+ Returns:
478
+ The path where the configuration was saved
479
+ """
480
+ config = self.get_customize_agent_info()
481
+
482
+ for ignore_key in ignore:
483
+ config.pop(ignore_key, None)
484
+
485
+ # Save to JSON file
486
+ make_parent_folder(path)
487
+ with open(path, 'w', encoding='utf-8') as f:
488
+ json.dump(config, f, indent=4, ensure_ascii=False)
489
+
490
+ return path
491
+
492
+ def _get_unique_class_name(self, candidate_name: str) -> str:
493
+ """
494
+ Get a unique class name by checking if it already exists in the registry.
495
+ If it does, append "Vx" to make it unique.
496
+ """
497
+ if not MODULE_REGISTRY.has_module(candidate_name):
498
+ return candidate_name
499
+
500
+ i = 1
501
+ while True:
502
+ unique_name = f"{candidate_name}V{i}"
503
+ if not MODULE_REGISTRY.has_module(unique_name):
504
+ break
505
+ i += 1
506
+ return unique_name
507
+
508
+ def get_config(self) -> dict:
509
+ """
510
+ Get a dictionary containing all necessary configuration to recreate this agent.
511
+
512
+ Returns:
513
+ dict: A configuration dictionary that can be used to initialize a new Agent instance
514
+ with the same properties as this one.
515
+ """
516
+ config = self.get_customize_agent_info()
517
+ config["llm_config"] = self.llm_config.to_dict()
518
+ tool_names = config.pop("tool_names", None)
519
+ if tool_names:
520
+ config["tools"] = [self._raw_tool_map[name] for name in tool_names]
521
+ return config
522
+
evoagentx/agents/long_term_memory_agent.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import asyncio
3
+ from uuid import uuid4
4
+ from pydantic import Field
5
+ from datetime import datetime
6
+ from typing import Optional, List, Tuple, Dict, Union
7
+
8
+ from evoagentx.agents import Agent
9
+ from evoagentx.core.parser import Parser
10
+ from evoagentx.models import BaseLLM
11
+ from evoagentx.core.logging import logger
12
+ from evoagentx.models import OpenAILLMConfig
13
+ from evoagentx.storages.base import StorageHandler
14
+ from evoagentx.core.message import Message, MessageType
15
+ from evoagentx.memory.memory_manager import MemoryManager
16
+ from evoagentx.memory.long_term_memory import LongTermMemory
17
+ from evoagentx.actions.action import Action, ActionInput, ActionOutput
18
+ from evoagentx.rag.rag_config import RAGConfig
19
+
20
+
21
+ class MemoryActionInput(ActionInput):
22
+ user_prompt: str = Field(description="The user's input prompt")
23
+ conversation_id: Optional[str] = Field(default=None, description="ID for tracking conversation")
24
+ top_k: Optional[int] = Field(default=5, description="Number of memory results to retrieve")
25
+ metadata_filters: Optional[Dict] = Field(default=None, description="Filters for memory retrieval")
26
+
27
+
28
+ class MemoryActionOutput(ActionOutput):
29
+ response: str = Field(description="The agent's response based on memory and prompt")
30
+
31
+
32
+ class MemoryAction(Action):
33
+ def __init__(
34
+ self,
35
+ name: str = "MemoryAction",
36
+ description: str = "Action that processes user input with long-term memory context",
37
+ prompt: str = "Based on the following context and user prompt, provide a relevant response:\n\nContext: {context}\n\nUser Prompt: {user_prompt}\n\n",
38
+ inputs_format: ActionInput = None,
39
+ outputs_format: ActionOutput = None,
40
+ **kwargs
41
+ ):
42
+ inputs_format = inputs_format or MemoryActionInput
43
+ outputs_format = outputs_format or MemoryActionOutput
44
+ super().__init__(
45
+ name=name,
46
+ description=description,
47
+ prompt=prompt,
48
+ inputs_format=inputs_format,
49
+ outputs_format=outputs_format,
50
+ **kwargs
51
+ )
52
+
53
+ def execute(self, llm: BaseLLM | None = None,
54
+ inputs: Dict | None = None,
55
+ sys_msg: str | None = None,
56
+ return_prompt: bool = False,
57
+ memory_manager: Optional[MemoryManager] = None,
58
+ **kwargs
59
+ ) -> Parser | Tuple[Parser | str] | None:
60
+ return asyncio.run(self.async_execute(llm, inputs, sys_msg, return_prompt, memory_manager, **kwargs))
61
+
62
+ async def async_execute(
63
+ self,
64
+ llm: Optional["BaseLLM"] = None,
65
+ inputs: Optional[Dict] = None,
66
+ sys_msg: Optional[str] = None,
67
+ return_prompt: bool = False,
68
+ memory_manager: Optional[MemoryManager] = None,
69
+ **kwargs
70
+ ) -> Union[MemoryActionOutput, tuple]:
71
+ if not memory_manager:
72
+ logger.error("MemoryManager is required for MemoryAction execution")
73
+ raise ValueError("MemoryManager is required for MemoryAction")
74
+
75
+ action_input = self.inputs_format(**inputs)
76
+ user_prompt = action_input.user_prompt
77
+ conversation_id = action_input.conversation_id
78
+ if not conversation_id:
79
+ conversation_id = str(uuid4())
80
+ logger.warning("No conversation_id provided; generated a new UUID4 for this session")
81
+ top_k = action_input.top_k
82
+ metadata_filters = action_input.metadata_filters
83
+
84
+ message = await memory_manager.create_conversation_message(
85
+ user_prompt=user_prompt,
86
+ conversation_id=conversation_id,
87
+ top_k=top_k,
88
+ metadata_filters=metadata_filters
89
+ )
90
+
91
+ action_input_attrs = self.inputs_format.get_attrs()
92
+ action_input_data = {attr: getattr(action_input, attr, "undefined") for attr in action_input_attrs}
93
+ action_input_data["context"] = message.content
94
+ prompt = self.prompt.format(**action_input_data)
95
+ logger.info(f"The New Created Message by LongTermMemory:\n\n{prompt}")
96
+
97
+ output = await llm.async_generate(
98
+ prompt=prompt,
99
+ system_message=sys_msg,
100
+ parser=self.outputs_format,
101
+ parse_mode='str'
102
+ )
103
+
104
+ response_message = Message(
105
+ content=output.content,
106
+ msg_type=MessageType.RESPONSE,
107
+ timestamp=datetime.now().isoformat(),
108
+ conversation_id=conversation_id,
109
+ memory_ids=message.memory_ids
110
+ )
111
+ memory_ids = await memory_manager.handle_memory(
112
+ action="add",
113
+ data=response_message,
114
+ )
115
+
116
+ # Prepare the final output
117
+ final_output = self.outputs_format(
118
+ response=output.content,
119
+ memory_ids=memory_ids
120
+ )
121
+
122
+ if return_prompt:
123
+ return final_output, prompt
124
+ return final_output
125
+
126
+
127
+ class MemoryAgent(Agent):
128
+ memory_manager: Optional[MemoryManager] = Field(default=None, description="Manager for long-term memory operations")
129
+ inputs: List[Dict] = Field(default_factory=list, description="Input specifications for the memory action")
130
+ outputs: List[Dict] = Field(default_factory=list, description="Output specifications for the memory action")
131
+
132
+ def __init__(
133
+ self,
134
+ name: str = "MemoryAgent",
135
+ description: str = "An agent that uses long-term memory to provide context-aware responses",
136
+ inputs: Optional[List[Dict]] = None,
137
+ outputs: Optional[List[Dict]] = None,
138
+ llm_config: Optional[OpenAILLMConfig] = None,
139
+ storage_handler: Optional[StorageHandler] = None,
140
+ rag_config: Optional[RAGConfig] = None,
141
+ conversation_id: Optional[str] = None,
142
+ system_prompt: Optional[str] = None,
143
+ prompt: str = "Based on the following context and user prompt, provide a relevant response:\n\nContext: {context}\n\nUser Prompt: {user_prompt}",
144
+ **kwargs
145
+ ):
146
+ # Define inputs and outputs inspired by CustomizeAgent
147
+ inputs = inputs or []
148
+ outputs = outputs or []
149
+
150
+ # Initialize base Agent with provided parameters
151
+ super().__init__(
152
+ name=name,
153
+ description=description,
154
+ llm_config=llm_config,
155
+ system_prompt=system_prompt,
156
+ storage_handler=storage_handler,
157
+ inputs=inputs,
158
+ outputs=outputs,
159
+ **kwargs
160
+ )
161
+
162
+ self.long_term_memory = LongTermMemory(
163
+ storage_handler=storage_handler,
164
+ rag_config=rag_config,
165
+ default_corpus_id=conversation_id
166
+ )
167
+ self.memory_manager = MemoryManager(
168
+ memory=self.long_term_memory,
169
+ llm=llm_config.get_llm() if llm_config else None,
170
+ use_llm_management=True
171
+ )
172
+
173
+ # Initialize inputs and outputs
174
+ self.inputs = inputs
175
+ self.outputs = outputs
176
+
177
+ # Initialize actions list and add MemoryAction
178
+ self.actions = []
179
+ self._action_map = {}
180
+ memory_action = MemoryAction(
181
+ name="MemoryAction",
182
+ description="Action that processes user input with long-term memory context",
183
+ prompt=prompt,
184
+ inputs_format=MemoryActionInput,
185
+ outputs_format=MemoryActionOutput
186
+ )
187
+ self.add_action(memory_action)
188
+
189
+ def _create_output_message(
190
+ self,
191
+ action_output,
192
+ action_name: str,
193
+ action_input_data: Optional[Dict],
194
+ prompt: str,
195
+ return_msg_type: MessageType = MessageType.RESPONSE,
196
+ **kwargs
197
+ ) -> Message:
198
+ msg = super()._create_output_message(
199
+ action_output=action_output,
200
+ action_name=action_name,
201
+ action_input_data=action_input_data,
202
+ prompt=prompt,
203
+ return_msg_type=return_msg_type,
204
+ **kwargs
205
+ )
206
+
207
+ if action_input_data and "user_prompt" in action_input_data:
208
+ user_msg = Message(
209
+ content=action_input_data["user_prompt"],
210
+ msg_type=MessageType.REQUEST,
211
+ conversation_id=msg.conversation_id
212
+ )
213
+ asyncio.create_task(self.memory_manager.handle_memory(action="add", data=user_msg))
214
+
215
+ response_msg = Message(
216
+ content=action_output.response if hasattr(action_output, "response") else str(action_output),
217
+ msg_type=MessageType.RESPONSE,
218
+ conversation_id=msg.conversation_id
219
+ )
220
+ asyncio.create_task(self.memory_manager.handle_memory(action="add", data=response_msg))
221
+
222
+ return msg
223
+
224
+ async def async_execute(
225
+ self,
226
+ action_name: str,
227
+ msgs: Optional[List[Message]] = None,
228
+ action_input_data: Optional[Dict] = None,
229
+ return_msg_type: Optional[MessageType] = MessageType.RESPONSE,
230
+ return_action_input_data: Optional[bool] = False,
231
+ **kwargs
232
+ ) -> Union[Message, Tuple[Message, Dict]]:
233
+ """
234
+ Execute an action asynchronously with memory management.
235
+
236
+ Args:
237
+ action_name: Name of the action to execute
238
+ msgs: Optional list of messages providing context
239
+ action_input_data: Optional input data for the action
240
+ return_msg_type: Message type for the return message
241
+ return_action_input_data: Whether to return the action input data
242
+ **kwargs: Additional parameters
243
+
244
+ Returns:
245
+ Message or tuple: The execution result, optionally with input data
246
+ """
247
+ action, action_input_data = self._prepare_execution(
248
+ action_name=action_name,
249
+ msgs=msgs,
250
+ action_input_data=action_input_data,
251
+ **kwargs
252
+ )
253
+
254
+ # Execute action with memory_manager
255
+ execution_results = await action.async_execute(
256
+ llm=self.llm,
257
+ inputs=action_input_data,
258
+ sys_msg=self.system_prompt,
259
+ return_prompt=True,
260
+ memory_manager=self.memory_manager,
261
+ **kwargs
262
+ )
263
+ action_output, prompt = execution_results
264
+
265
+ message = self._create_output_message(
266
+ action_output=action_output,
267
+ prompt=prompt,
268
+ action_name=action_name,
269
+ return_msg_type=return_msg_type,
270
+ action_input_data=action_input_data,
271
+ **kwargs
272
+ )
273
+ if return_action_input_data:
274
+ return message, action_input_data
275
+ return message
276
+
277
+ def execute(
278
+ self,
279
+ action_name: str,
280
+ msgs: Optional[List[Message]] = None,
281
+ action_input_data: Optional[Dict] = None,
282
+ return_msg_type: Optional[MessageType] = MessageType.RESPONSE,
283
+ return_action_input_data: Optional[bool] = False,
284
+ **kwargs
285
+ ) -> Union[Message, Tuple[Message, Dict]]:
286
+ """
287
+ Execute an action synchronously with memory management.
288
+
289
+ Args:
290
+ action_name: Name of the action to execute
291
+ msgs: Optional list of messages providing context
292
+ action_input_data: Optional input data for the action
293
+ return_msg_type: Message type for the return message
294
+ return_action_input_data: Whether to return the action input data
295
+ **kwargs: Additional parameters
296
+
297
+ Returns:
298
+ Message or tuple: The execution result, optionally with input data
299
+ """
300
+ action, action_input_data = self._prepare_execution(
301
+ action_name=action_name,
302
+ msgs=msgs,
303
+ action_input_data=action_input_data,
304
+ **kwargs
305
+ )
306
+
307
+ # Execute action with memory_manager
308
+ execution_results = action.execute(
309
+ llm=self.llm,
310
+ inputs=action_input_data,
311
+ sys_msg=self.system_prompt,
312
+ return_prompt=True,
313
+ memory_manager=self.memory_manager,
314
+ **kwargs
315
+ )
316
+ action_output, prompt = execution_results
317
+
318
+ message = self._create_output_message(
319
+ action_output=action_output,
320
+ prompt=prompt,
321
+ action_name=action_name,
322
+ return_msg_type=return_msg_type,
323
+ action_input_data=action_input_data,
324
+ **kwargs
325
+ )
326
+ if return_action_input_data:
327
+ return message, action_input_data
328
+ return message
329
+
330
+ def chat(
331
+ self,
332
+ user_prompt: str,
333
+ *,
334
+ conversation_id: Optional[str] = None,
335
+ top_k: Optional[int] = None,
336
+ metadata_filters: Optional[dict] = None,
337
+ return_message: bool = True,
338
+ **kwargs
339
+ ):
340
+ action_input_data = {
341
+ "user_prompt": user_prompt,
342
+ "conversation_id": conversation_id or self._default_conversation_id(),
343
+ "top_k": top_k if top_k is not None else 3,
344
+ "metadata_filters": metadata_filters or {},
345
+ }
346
+ msg = self.execute(
347
+ action_name="MemoryAction",
348
+ action_input_data=action_input_data,
349
+ return_msg_type=MessageType.RESPONSE,
350
+ **kwargs
351
+ )
352
+ return msg if return_message else (getattr(msg, "content", None) or str(msg))
353
+
354
+
355
+ async def async_chat(
356
+ self,
357
+ user_prompt: str,
358
+ *,
359
+ conversation_id: Optional[str] = None,
360
+ top_k: Optional[int] = None,
361
+ metadata_filters: Optional[dict] = None,
362
+ return_message: bool = True,
363
+ **kwargs
364
+ ):
365
+ action_input_data = {
366
+ "user_prompt": user_prompt,
367
+ "conversation_id": conversation_id or self._default_conversation_id(),
368
+ "top_k": top_k if top_k is not None else 3,
369
+ "metadata_filters": metadata_filters or {},
370
+ }
371
+ msg = await self.async_execute(
372
+ action_name="MemoryAction",
373
+ action_input_data=action_input_data,
374
+ return_msg_type=MessageType.RESPONSE,
375
+ **kwargs
376
+ )
377
+ return msg if return_message else (getattr(msg, "content", None) or str(msg))
378
+
379
+
380
+ def _default_conversation_id(self) -> str:
381
+ """
382
+ Session scope: By default, a new uuid4() is returned (new session).
383
+ User/global scope: Reuse LongTermMemory.default_corpus_id (stable namespace).
384
+ Note: The final ID is still uniformly managed by MemoryAgent._prepare_execution() (which will override based on the scope).
385
+ """
386
+ scope = getattr(self, "conversation_scope", "session")
387
+ if scope == "session":
388
+ return str(uuid4())
389
+ return getattr(getattr(self, "long_term_memory", None), "default_corpus_id", None) or "global_corpus"
390
+
391
+ async def interactive_chat(
392
+ self,
393
+ conversation_id: Optional[str] = None,
394
+ top_k: int = 3,
395
+ metadata_filters: Optional[dict] = None
396
+ ):
397
+ """
398
+ In interactive chat, each round of input will:
399
+ 1. Retrieve from memory
400
+ 2. Generate a response based on historical context
401
+ 3. Write the input/output to long-term memory and refresh the index
402
+ """
403
+ conversation_id = conversation_id or self._default_conversation_id()
404
+ metadata_filters = metadata_filters or {}
405
+
406
+ print("💬 MemoryAgent has been started (type 'exit' to quit)\n")
407
+
408
+ while True:
409
+ user_prompt = input("You: ").strip()
410
+ if user_prompt.lower() in ["exit", "quit"]:
411
+ print("🔚 Conversation ended")
412
+ break
413
+
414
+ # Retrieve historical context
415
+ retrieved_memories = await self.memory_manager.handle_memory(
416
+ action="search",
417
+ user_prompt=user_prompt,
418
+ top_k=top_k,
419
+ metadata_filters=metadata_filters
420
+ )
421
+
422
+ context_texts = []
423
+ for msg, _ in retrieved_memories:
424
+ if hasattr(msg, "content") and msg.content:
425
+ context_texts.append(msg.content)
426
+ context_str = "\n".join(context_texts)
427
+
428
+ # if context_str:
429
+ # print(f"📖 Retrieved context from memory:\n{context_str}\n")
430
+
431
+ # Concatenate the historical context into the user input and invoke async_chat
432
+ full_prompt = f"Context:\n{context_str}\n\nUser: {user_prompt}" if context_str else user_prompt
433
+ msg = await self.async_chat(
434
+ user_prompt=full_prompt,
435
+ conversation_id=conversation_id,
436
+ top_k=top_k,
437
+ metadata_filters=metadata_filters
438
+ )
439
+
440
+ print(f"Agent: {msg.content}\n")
441
+
442
+ # Refresh the index to ensure it can be retrieved in the next round
443
+ if hasattr(self.memory_manager, "handle_memory_flush"):
444
+ await self.memory_manager.handle_memory_flush()
445
+ else:
446
+ await asyncio.sleep(0.1)
447
+
448
+
449
+
450
+ def save_module(self, path: str, ignore: List[str] = ["llm", "llm_config", "memory_manager"], **kwargs) -> str:
451
+ """
452
+ Save the agent's configuration to a JSON file, excluding memory_manager by default.
453
+
454
+ Args:
455
+ path: File path to save the configuration
456
+ ignore: List of keys to exclude from the saved configuration
457
+ **kwargs: Additional parameters for saving
458
+
459
+ Returns:
460
+ str: The path where the configuration was saved
461
+ """
462
+ return super().save_module(path=path, ignore=ignore, **kwargs)
463
+
464
+ @classmethod
465
+ def from_file(cls, path: str, llm_config: OpenAILLMConfig, storage_handler: Optional[StorageHandler] = None, rag_config: Optional[RAGConfig] = None, **kwargs) -> "MemoryAgent":
466
+ """
467
+ Load a MemoryAgent from a JSON configuration file.
468
+
469
+ Args:
470
+ path: Path to the JSON configuration file
471
+ llm_config: LLM configuration
472
+ storage_handler: Optional storage handler
473
+ rag_config: Optional RAG configuration
474
+ **kwargs: Additional parameters
475
+
476
+ Returns:
477
+ MemoryAgent: The loaded agent instance
478
+ """
479
+ with open(path, 'r', encoding='utf-8') as f:
480
+ config = json.load(f)
481
+ return cls(
482
+ name=config.get("name", "MemoryAgent"),
483
+ description=config.get("description", "An agent that uses long-term memory"),
484
+ llm_config=llm_config,
485
+ storage_handler=storage_handler,
486
+ rag_config=rag_config,
487
+ system_prompt=config.get("system_prompt"),
488
+ prompt=config.get("prompt"),
489
+ use_long_term_memory=config.get("use_long_term_memory", True),
490
+ **kwargs
491
+ )
evoagentx/agents/task_planner.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .agent import Agent
2
+ from ..actions.task_planning import TaskPlanning
3
+ from ..prompts.task_planner import TASK_PLANNER
4
+
5
+
6
+ class TaskPlanner(Agent):
7
+ """An agent responsible for planning and decomposing high-level tasks into smaller sub-tasks.
8
+
9
+ The TaskPlanner agent analyzes complex goals and breaks them down into a structured
10
+ sequence of smaller, more manageable tasks. It serves as a critical component in the
11
+ workflow by creating execution plans that other specialized agents can follow.
12
+
13
+ Attributes:
14
+ name (str): Name of the task planner agent, defaults to the value in TASK_PLANNER
15
+ description (str): Description of the agent's purpose and capabilities, defaults to the value in TASK_PLANNER
16
+ system_prompt (str): System prompt guiding the agent's behavior, defaults to the value in TASK_PLANNER
17
+ actions (List[Action]): List of actions the agent can perform, defaults to [TaskPlanning()]
18
+ """
19
+ def __init__(self, **kwargs):
20
+
21
+ name = kwargs.pop("name") if "name" in kwargs else TASK_PLANNER["name"]
22
+ description = kwargs.pop("description") if "description" in kwargs else TASK_PLANNER["description"]
23
+ system_prompt = kwargs.pop("system_prompt") if "system_prompt" in kwargs else TASK_PLANNER["system_prompt"]
24
+ actions = kwargs.pop("actions") if "actions" in kwargs else [TaskPlanning()]
25
+ super().__init__(name=name, description=description, system_prompt=system_prompt, actions=actions, **kwargs)
26
+
27
+ @property
28
+ def task_planning_action_name(self):
29
+ """Get the name of the TaskPlanning action associated with this agent.
30
+
31
+ Returns:
32
+ The name of the TaskPlanning action in this agent's action registry
33
+ """
34
+ return self.get_action_name(action_cls=TaskPlanning)
35
+
evoagentx/agents/workflow_reviewer.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+ from .agent import Agent
3
+ from ..core.message import Message # MessageType
4
+
5
+
6
+ class WorkFlowReviewer(Agent):
7
+
8
+ """
9
+ Placeholder for the Agent that is responsible for reviewing workflow plans and agents.
10
+ """
11
+
12
+ def execute(self, action_name: str, msgs: Optional[List[Message]] = None, action_input_data: Optional[dict] = None, **kwargs) -> Message:
13
+
14
+ raise NotImplementedError("WorkflowReviewer is not implemented yet.")
evoagentx/app/__init__.py ADDED
File without changes
evoagentx/app/api.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API routes for EvoAgentX application.
3
+ """
4
+ from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks
5
+ from fastapi.security import OAuth2PasswordRequestForm
6
+ from typing import List, Dict, Any # , Optional
7
+ from fastapi import Response
8
+
9
+ from datetime import timedelta
10
+
11
+ from evoagentx.app.config import settings
12
+ from evoagentx.app.schemas import (
13
+ AgentCreate, AgentUpdate, AgentResponse,
14
+ WorkflowCreate, WorkflowUpdate, WorkflowResponse,
15
+ ExecutionCreate, ExecutionResponse,
16
+ PaginationParams, SearchParams,
17
+ Token, UserCreate, UserResponse, # UserLogin,
18
+ )
19
+ from evoagentx.app.services import AgentService, WorkflowService, WorkflowExecutionService
20
+ from evoagentx.app.security import (
21
+ create_access_token,
22
+ authenticate_user,
23
+ create_user,
24
+ get_current_active_user,
25
+ get_current_admin_user
26
+ )
27
+ from evoagentx.app.db import Database, ExecutionStatus
28
+
29
+ # Create routers for different route groups
30
+ auth_router = APIRouter(prefix=settings.API_PREFIX)
31
+ agents_router = APIRouter(prefix=settings.API_PREFIX)
32
+ workflows_router = APIRouter(prefix=settings.API_PREFIX)
33
+ executions_router = APIRouter(prefix=settings.API_PREFIX)
34
+ system_router = APIRouter(prefix=settings.API_PREFIX)
35
+
36
+ # Authentication Routes
37
+ @auth_router.post("/auth/register", response_model=UserResponse, tags=["Authentication"])
38
+ async def register_user(user: UserCreate):
39
+ """Register a new user."""
40
+ return await create_user(user)
41
+
42
+ @auth_router.post("/auth/login", response_model=Token, tags=["Authentication"])
43
+ async def login(form_data: OAuth2PasswordRequestForm = Depends()):
44
+ """Login and return access token."""
45
+ user = await authenticate_user(form_data.username, form_data.password)
46
+ if not user:
47
+ raise HTTPException(
48
+ status_code=status.HTTP_401_UNAUTHORIZED,
49
+ detail="Incorrect username or password",
50
+ headers={"WWW-Authenticate": "Bearer"},
51
+ )
52
+
53
+ access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
54
+ access_token = create_access_token(
55
+ subject=user['email'],
56
+ expires_delta=access_token_expires
57
+ )
58
+
59
+ return {
60
+ "access_token": access_token,
61
+ "token_type": "bearer"
62
+ }
63
+
64
+ # Agent Routes
65
+ @agents_router.post("/agents", response_model=AgentResponse, tags=["Agents"])
66
+ async def create_agent(
67
+ agent: AgentCreate,
68
+ current_user: Dict[str, Any] = Depends(get_current_active_user)
69
+ ):
70
+ """Create a new agent."""
71
+ try:
72
+ created_agent = await AgentService.create_agent(
73
+ agent,
74
+ user_id=str(current_user['_id'])
75
+ )
76
+ # Convert the ObjectId to string before creating the response model
77
+ created_agent["_id"] = str(created_agent["_id"])
78
+ return AgentResponse(**created_agent)
79
+ except ValueError as e:
80
+ raise HTTPException(status_code=400, detail=str(e))
81
+
82
+ @agents_router.get("/agents/{agent_id}", response_model=AgentResponse, tags=["Agents"])
83
+ async def get_agent(
84
+ agent_id: str,
85
+ current_user: Dict[str, Any] = Depends(get_current_active_user)
86
+ ):
87
+ """Retrieve a specific agent by ID."""
88
+ agent = await AgentService.get_agent(agent_id)
89
+ if not agent:
90
+ raise HTTPException(status_code=404, detail="Agent not found")
91
+ agent["_id"] = str(agent["_id"])
92
+ return AgentResponse(**agent)
93
+
94
+ @agents_router.put("/agents/{agent_id}", response_model=AgentResponse, tags=["Agents"])
95
+ async def update_agent(
96
+ agent_id: str,
97
+ agent_update: AgentUpdate,
98
+ current_user: Dict[str, Any] = Depends(get_current_active_user)
99
+ ):
100
+ """Update an existing agent."""
101
+ try:
102
+ updated_agent = await AgentService.update_agent(agent_id, agent_update)
103
+ if not updated_agent:
104
+ raise HTTPException(status_code=404, detail="Agent not found")
105
+ updated_agent["_id"] = str(updated_agent["_id"])
106
+ return AgentResponse(**updated_agent)
107
+ except ValueError as e:
108
+ raise HTTPException(status_code=400, detail=str(e))
109
+
110
+ @agents_router.get("/agents", response_model=List[AgentResponse], tags=["Agents"])
111
+ async def list_agents(
112
+ pagination: PaginationParams = Depends(),
113
+ search: SearchParams = Depends(),
114
+ current_user: Dict[str, Any] = Depends(get_current_active_user)
115
+ ):
116
+ """List agents with optional pagination and search."""
117
+ agents, total = await AgentService.list_agents(pagination, search)
118
+ # Convert _id to string for each agent in the list
119
+ for agent in agents:
120
+ agent["_id"] = str(agent["_id"])
121
+ return [AgentResponse(**agent) for agent in agents]
122
+
123
+ @agents_router.delete("/agents/{agent_id}", status_code=204, tags=["Agents"])
124
+ async def delete_agent(
125
+ agent_id: str,
126
+ current_user: Dict[str, Any] = Depends(get_current_admin_user)
127
+ ):
128
+ """Delete an agent (admin-only)."""
129
+ try:
130
+ success = await AgentService.delete_agent(agent_id)
131
+ if not success:
132
+ raise HTTPException(status_code=404, detail="Agent not found")
133
+ return # With 204, no content is returned
134
+ except ValueError as e:
135
+ raise HTTPException(status_code=400, detail=str(e))
136
+
137
+
138
+
139
+ # Workflow Routes
140
+ @workflows_router.post("/workflows", response_model=WorkflowResponse,status_code=201, tags=["Workflows"])
141
+ async def create_workflow(
142
+ workflow: WorkflowCreate,
143
+ current_user: Dict[str, Any] = Depends(get_current_active_user)
144
+ ):
145
+ """Create a new workflow."""
146
+ try:
147
+ created_workflow = await WorkflowService.create_workflow(
148
+ workflow,
149
+ user_id=str(current_user['_id'])
150
+ )
151
+ # Convert the ObjectId to string for consistency
152
+ created_workflow["_id"] = str(created_workflow["_id"])
153
+ return WorkflowResponse(**created_workflow)
154
+ except ValueError as e:
155
+ raise HTTPException(status_code=400, detail=str(e))
156
+
157
+
158
+
159
+ @workflows_router.get("/workflows/{workflow_id}", response_model=WorkflowResponse, tags=["Workflows"])
160
+ async def get_workflow(
161
+ workflow_id: str,
162
+ current_user: Dict[str, Any] = Depends(get_current_active_user)
163
+ ):
164
+ """Retrieve a specific workflow by ID."""
165
+ workflow = await WorkflowService.get_workflow(workflow_id)
166
+ if not workflow:
167
+ raise HTTPException(status_code=404, detail="Workflow not found")
168
+ # Convert ObjectId to string
169
+ workflow["_id"] = str(workflow["_id"])
170
+ return WorkflowResponse(**workflow)
171
+
172
+ @workflows_router.put("/workflows/{workflow_id}", response_model=WorkflowResponse, tags=["Workflows"])
173
+ async def update_workflow(
174
+ workflow_id: str,
175
+ workflow_update: WorkflowUpdate,
176
+ current_user: Dict[str, Any] = Depends(get_current_active_user)
177
+ ):
178
+ """Update an existing workflow."""
179
+ try:
180
+ updated_workflow = await WorkflowService.update_workflow(workflow_id, workflow_update)
181
+ if not updated_workflow:
182
+ raise HTTPException(status_code=404, detail="Workflow not found")
183
+
184
+ updated_workflow["_id"] = str(updated_workflow["_id"])
185
+ return WorkflowResponse(**updated_workflow)
186
+ except ValueError as e:
187
+ raise HTTPException(status_code=400, detail=str(e))
188
+
189
+ @workflows_router.delete("/workflows/{workflow_id}", status_code=204, tags=["Workflows"])
190
+ async def delete_workflow(
191
+ workflow_id: str,
192
+ current_user: Dict[str, Any] = Depends(get_current_admin_user)
193
+ ):
194
+ """Delete a workflow (admin-only)."""
195
+ try:
196
+ success = await WorkflowService.delete_workflow(workflow_id)
197
+ if not success:
198
+ raise HTTPException(status_code=404, detail="Workflow not found")
199
+ return Response(status_code=204)
200
+ except ValueError as e:
201
+ raise HTTPException(status_code=400, detail=str(e))
202
+
203
+
204
+ @workflows_router.get("/workflows", response_model=List[WorkflowResponse], tags=["Workflows"])
205
+ async def list_workflows(
206
+ pagination: PaginationParams = Depends(),
207
+ search: SearchParams = Depends(),
208
+ current_user: Dict[str, Any] = Depends(get_current_active_user)
209
+ ):
210
+ """List workflows with optional pagination and search."""
211
+ workflows, total = await WorkflowService.list_workflows(pagination, search)
212
+
213
+ # Convert ObjectId to string for each workflow
214
+ converted_workflows = [
215
+ {**workflow, "_id": str(workflow["_id"])}
216
+ for workflow in workflows
217
+ ]
218
+
219
+ return [WorkflowResponse(**workflow) for workflow in converted_workflows]
220
+
221
+
222
+ # Workflow Execution Routes
223
+ @executions_router.post("/executions", response_model=ExecutionResponse, status_code=202)
224
+ async def create_execution(
225
+ execution: ExecutionCreate,
226
+ background_tasks: BackgroundTasks,
227
+ current_user: Dict[str, Any] = Depends(get_current_active_user)
228
+ ):
229
+ """Create and start a workflow execution."""
230
+ try:
231
+ execution_result = await WorkflowExecutionService.create_execution(
232
+ execution_data=execution,
233
+ user_id=str(current_user['_id'])
234
+ )
235
+ # Convert _id to string for consistency
236
+ execution_result["_id"] = str(execution_result["_id"])
237
+ return ExecutionResponse(**execution_result)
238
+ except ValueError as e:
239
+ raise HTTPException(status_code=400, detail=str(e))
240
+
241
+
242
+ @executions_router.get("/executions/{execution_id}", response_model=ExecutionResponse)
243
+ async def get_execution(
244
+ execution_id: str,
245
+ current_user: Dict[str, Any] = Depends(get_current_active_user)
246
+ ):
247
+ """Retrieve a specific workflow execution by ID."""
248
+ try:
249
+ execution = await WorkflowExecutionService.get_execution(execution_id)
250
+ if not execution:
251
+ raise HTTPException(status_code=404, detail="Execution not found")
252
+ execution["_id"] = str(execution["_id"])
253
+ return ExecutionResponse(**execution)
254
+ except ValueError as e:
255
+ raise HTTPException(status_code=400, detail=str(e))
256
+
257
+
258
+ @executions_router.post("/executions/{execution_id}/stop", response_model=ExecutionResponse)
259
+ async def stop_execution(
260
+ execution_id: str,
261
+ current_user: Dict[str, Any] = Depends(get_current_active_user)
262
+ ):
263
+ """Stop (cancel) a workflow execution."""
264
+ try:
265
+ updated_execution = await WorkflowExecutionService.update_execution_status(
266
+ execution_id=execution_id,
267
+ status=ExecutionStatus.CANCELLED
268
+ )
269
+ if not updated_execution:
270
+ raise HTTPException(status_code=404, detail="Execution not found")
271
+ # Convert ObjectId to string for consistency
272
+ updated_execution["_id"] = str(updated_execution["_id"])
273
+ return ExecutionResponse(**updated_execution)
274
+ except ValueError as e:
275
+ raise HTTPException(status_code=400, detail=str(e))
276
+
277
+
278
+ @executions_router.get("/executions", response_model=List[ExecutionResponse])
279
+ async def list_executions(
280
+ pagination: PaginationParams = Depends(),
281
+ search: SearchParams = Depends(),
282
+ current_user: Dict[str, Any] = Depends(get_current_active_user)
283
+ ):
284
+ """List workflow executions with optional pagination and search."""
285
+ executions, total = await WorkflowExecutionService.list_executions(
286
+ params=pagination,
287
+ search=search
288
+ )
289
+ # Convert _id to string for each execution
290
+ for exec_item in executions:
291
+ exec_item["_id"] = str(exec_item["_id"])
292
+ return [ExecutionResponse(**exec_item) for exec_item in executions]
293
+
294
+
295
+ @executions_router.get("/executions/{execution_id}/logs", response_model=List[Dict[str, Any]])
296
+ async def get_execution_logs(
297
+ execution_id: str,
298
+ pagination: PaginationParams = Depends(),
299
+ current_user: Dict[str, Any] = Depends(get_current_active_user)
300
+ ):
301
+ """Retrieve logs for a specific execution."""
302
+ logs, total = await WorkflowExecutionService.get_execution_logs(execution_id, params=pagination)
303
+ # Convert _id in each log entry to string
304
+ for log in logs:
305
+ log["_id"] = str(log["_id"])
306
+ return logs
307
+
308
+ # Health Check Route
309
+ @system_router.get("/health", tags=["System"])
310
+ async def health_check():
311
+ """Simple health check endpoint."""
312
+ try:
313
+ # You can add more comprehensive health checks here
314
+ await Database.db.command('ping')
315
+ return {
316
+ "status": "healthy",
317
+ "version": "1.0.0"
318
+ }
319
+ except Exception as e:
320
+ raise HTTPException(status_code=500, detail=f"Database connection error: {str(e)}")
321
+
322
+ # Export the routers
323
+ __all__ = [
324
+ 'auth_router',
325
+ 'agents_router',
326
+ 'workflows_router',
327
+ 'executions_router',
328
+ 'system_router'
329
+ ]
evoagentx/app/app.env ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # .env file
2
+ APP_NAME=EvoAgentX
3
+ DEBUG=True
4
+ API_PREFIX=/api/v1
5
+ HOST=0.0.0.0
6
+ PORT=8000
7
+
8
+ # MongoDB settings
9
+ MONGODB_URL=mongodb+srv://eax:eax@cluster0.1lkbi0y.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0
10
+ MONGODB_DB_NAME=evoagentx
11
+
12
+ # JWT Authentication
13
+ SECRET_KEY=your-secret-key
14
+ ACCESS_TOKEN_EXPIRE_MINUTES=30
15
+ ALGORITHM=HS256
16
+
17
+ # Logging
18
+ LOG_LEVEL=INFO
19
+
20
+ # CORS settings
21
+ ALLOWED_HOSTS: List[str]
22
+ CORS_ORIGINS: List[str]
evoagentx/app/config.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration settings for the EvoAgentX application.
3
+ """
4
+ # import os
5
+ from pydantic import BaseModel, Field, validator
6
+ from pydantic_settings import BaseSettings
7
+ from typing import Optional, Dict, Any, List
8
+
9
+ class Settings(BaseSettings):
10
+ # Application settings
11
+ APP_NAME: str
12
+ DEBUG: bool
13
+ API_PREFIX: str
14
+ HOST: str
15
+ PORT: int
16
+
17
+ # MongoDB settings
18
+ MONGODB_URL: str
19
+ MONGODB_DB_NAME: str
20
+
21
+ # JWT Authentication
22
+ SECRET_KEY: str
23
+ ACCESS_TOKEN_EXPIRE_MINUTES: int
24
+ ALGORITHM: str
25
+
26
+ # Logging configuration
27
+ LOG_LEVEL: str
28
+
29
+ # Add CORS settings
30
+ CORS_ORIGINS: List[str] = ["http://localhost:3000", "http://localhost:8000"]
31
+ CORS_ALLOW_CREDENTIALS: bool = True
32
+
33
+ class Config:
34
+ env_file = ".env"
35
+ case_sensitive = True
36
+ env_delimiter = ","
37
+
38
+
39
+
40
+ # Global settings instance
41
+ settings = Settings()
42
+
43
+ # Agent and Workflow configuration
44
+ class AgentConfig(BaseModel):
45
+ """Base configuration for an LLM agent."""
46
+ model_name: str
47
+ temperature: float = 0.7
48
+ max_tokens: int = 2048
49
+ api_key_env_var: Optional[str] = None
50
+ system_prompt: Optional[str] = None
51
+ extra_params: Dict[str, Any] = Field(default_factory=dict)
52
+
53
+ @validator('temperature')
54
+ def validate_temperature(cls, v):
55
+ if v < 0 or v > 1:
56
+ raise ValueError('Temperature must be between 0 and 1')
57
+ return v
58
+
59
+ class WorkflowStepConfig(BaseModel):
60
+ """Configuration for a single step in a workflow."""
61
+ step_id: str
62
+ agent_id: str
63
+ action: str
64
+ input_mapping: Dict[str, str] = Field(default_factory=dict)
65
+ output_mapping: Dict[str, str] = Field(default_factory=dict)
66
+ timeout_seconds: int = 300
67
+ retry_count: int = 3
68
+
69
+ class WorkflowConfig(BaseModel):
70
+ """Configuration for a workflow composed of agent steps."""
71
+ name: str
72
+ description: Optional[str] = None
73
+ steps: List[WorkflowStepConfig]
74
+ parallel_execution: bool = False
75
+ timeout_seconds: int = 3600 # Default to 1 hour total timeout
76
+
77
+ class ExecutionConfig(BaseModel):
78
+ """Configuration for a workflow execution."""
79
+ workflow_id: str
80
+ input_params: Dict[str, Any] = Field(default_factory=dict)
81
+ user_id: Optional[str] = None
82
+ priority: int = 1 # Higher number means higher priority
83
+ callback_url: Optional[str] = None
evoagentx/app/db.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Database connection and models for EvoAgentX.
3
+ """
4
+ # import asyncio
5
+ import logging
6
+ from datetime import datetime
7
+ from enum import Enum
8
+ from typing import Optional, List, Dict, Any # , Union
9
+ from motor.motor_asyncio import AsyncIOMotorClient
10
+ from pymongo import ASCENDING, TEXT
11
+ from pydantic_core import core_schema
12
+ from bson import ObjectId
13
+ from pydantic import GetCoreSchemaHandler
14
+ from pydantic import Field, BaseModel
15
+ from evoagentx.app.config import settings
16
+
17
+ # Setup logger
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # Custom PyObjectId for MongoDB ObjectId compatibility with Pydantic
21
+ class PyObjectId(ObjectId):
22
+ @classmethod
23
+ def __get_pydantic_core_schema__(cls, source_type, handler: GetCoreSchemaHandler):
24
+ return core_schema.no_info_after_validator_function(cls.validate, core_schema.str_schema())
25
+
26
+ @classmethod
27
+ def validate(cls, v):
28
+ if not ObjectId.is_valid(v):
29
+ raise ValueError("Invalid ObjectId")
30
+ return ObjectId(v)
31
+
32
+ # Base model with ObjectId handling
33
+ class MongoBaseModel(BaseModel):
34
+ id: Optional[PyObjectId] = Field(alias="_id", default=None)
35
+
36
+ model_config = {
37
+ "protected_namespaces": (),
38
+ "populate_by_name": True, # Replace `allow_population_by_field_name`
39
+ "arbitrary_types_allowed": True, # Keep custom types like ObjectId
40
+ "json_encoders": {
41
+ ObjectId: str # Ensure ObjectId is serialized as a string
42
+ }
43
+ }
44
+
45
+ # Status Enums
46
+ class AgentStatus(str, Enum):
47
+ CREATED = "created"
48
+ ACTIVE = "active"
49
+ INACTIVE = "inactive"
50
+ ERROR = "error"
51
+
52
+ class WorkflowStatus(str, Enum):
53
+ CREATED = "created"
54
+ RUNNING = "running"
55
+ COMPLETED = "completed"
56
+ FAILED = "failed"
57
+ CANCELLED = "cancelled"
58
+
59
+ class ExecutionStatus(str, Enum):
60
+ PENDING = "pending"
61
+ RUNNING = "running"
62
+ COMPLETED = "completed"
63
+ FAILED = "failed"
64
+ TIMEOUT = "timeout"
65
+ CANCELLED = "cancelled"
66
+
67
+ # Database Models
68
+ class Agent(MongoBaseModel):
69
+ id: str = Field(..., alias="_id")
70
+ name: str
71
+ description: Optional[str] = None
72
+ config: Dict[str, Any]
73
+ state: Dict[str, Any] = Field(default_factory=dict)
74
+ runtime_params: Dict[str, Any] = Field(default_factory=dict)
75
+ status: AgentStatus = AgentStatus.CREATED
76
+ created_at: datetime = Field(default_factory=datetime.utcnow)
77
+ updated_at: datetime = Field(default_factory=datetime.utcnow)
78
+ created_by: Optional[str] = None
79
+ tags: List[str] = Field(default_factory=list)
80
+
81
+ class Workflow(MongoBaseModel):
82
+ id: str = Field(..., alias="_id")
83
+ name: str
84
+ description: Optional[str] = None
85
+ definition: Dict[str, Any]
86
+ agent_ids: List[str] = Field(default_factory=list)
87
+ status: WorkflowStatus = WorkflowStatus.CREATED
88
+ created_at: datetime = Field(default_factory=datetime.utcnow)
89
+ updated_at: datetime = Field(default_factory=datetime.utcnow)
90
+ created_by: Optional[str] = None
91
+ tags: List[str] = Field(default_factory=list)
92
+ version: int = 1
93
+
94
+ class ExecutionLog(MongoBaseModel):
95
+ workflow_id: str
96
+ execution_id: str
97
+ step_id: Optional[str] = None
98
+ agent_id: Optional[str] = None
99
+ timestamp: datetime = Field(default_factory=datetime.utcnow)
100
+ level: str = "INFO"
101
+ message: str
102
+ details: Dict[str, Any] = Field(default_factory=dict)
103
+
104
+ class WorkflowExecution(MongoBaseModel):
105
+ workflow_id: str
106
+ status: ExecutionStatus = ExecutionStatus.PENDING
107
+ start_time: Optional[datetime] = None
108
+ end_time: Optional[datetime] = None
109
+ input_params: Dict[str, Any] = Field(default_factory=dict)
110
+ results: Dict[str, Any] = Field(default_factory=dict)
111
+ created_by: Optional[str] = None
112
+ step_results: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
113
+ current_step: Optional[str] = None
114
+ error_message: Optional[str] = None
115
+ created_at: datetime = Field(default_factory=datetime.utcnow)
116
+
117
+ # Database client
118
+ class Database:
119
+ client: AsyncIOMotorClient = None
120
+ db = None
121
+
122
+ # Collections
123
+ agents = None
124
+ workflows = None
125
+ executions = None
126
+ logs = None
127
+
128
+ @classmethod
129
+ async def connect(cls):
130
+ """Connect to MongoDB"""
131
+ logger.info(f"Connecting to MongoDB at {settings.MONGODB_URL}...")
132
+ cls.client = AsyncIOMotorClient(settings.MONGODB_URL)
133
+ cls.db = cls.client[settings.MONGODB_DB_NAME]
134
+
135
+ # Set up collections
136
+ cls.agents = cls.db.agents
137
+ cls.workflows = cls.db.workflows
138
+ cls.executions = cls.db.workflow_executions
139
+ cls.logs = cls.db.execution_logs
140
+
141
+ # Create indexes
142
+ await cls._create_indexes()
143
+
144
+ logger.info("Connected to MongoDB successfully")
145
+
146
+ @classmethod
147
+ async def disconnect(cls):
148
+ """Disconnect from MongoDB"""
149
+ if cls.client:
150
+ cls.client.close()
151
+ logger.info("Disconnected from MongoDB")
152
+
153
+ @classmethod
154
+ async def _create_indexes(cls):
155
+ """Create indexes for collections"""
156
+ # Agent indexes
157
+ await cls.agents.create_index([("name", ASCENDING)], unique=True)
158
+ await cls.agents.create_index([("name", TEXT), ("description", TEXT)])
159
+ await cls.agents.create_index([("created_at", ASCENDING)])
160
+ await cls.agents.create_index([("tags", ASCENDING)])
161
+
162
+ # Workflow indexes
163
+ await cls.workflows.create_index([("name", ASCENDING)])
164
+ await cls.workflows.create_index([("name", TEXT), ("description", TEXT)])
165
+ await cls.workflows.create_index([("created_at", ASCENDING)])
166
+ await cls.workflows.create_index([("agent_ids", ASCENDING)])
167
+ await cls.workflows.create_index([("tags", ASCENDING)])
168
+
169
+ # Execution indexes
170
+ await cls.executions.create_index([("workflow_id", ASCENDING)])
171
+ await cls.executions.create_index([("created_at", ASCENDING)])
172
+ await cls.executions.create_index([("status", ASCENDING)])
173
+
174
+ # Log indexes
175
+ await cls.logs.create_index([("execution_id", ASCENDING)])
176
+ await cls.logs.create_index([("timestamp", ASCENDING)])
177
+ await cls.logs.create_index([("workflow_id", ASCENDING), ("execution_id", ASCENDING)])
evoagentx/app/main.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main application entry point for EvoAgentX.
3
+ """
4
+ import logging
5
+ # import asyncio
6
+ from contextlib import asynccontextmanager
7
+
8
+ import uvicorn
9
+ from fastapi import FastAPI, Request
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from fastapi.responses import JSONResponse
12
+ from fastapi.exceptions import RequestValidationError, HTTPException
13
+
14
+ from evoagentx.app.config import settings
15
+ from evoagentx.app.db import Database
16
+ from evoagentx.app.security import init_users_collection
17
+ from evoagentx.app.api import (
18
+ auth_router,
19
+ agents_router,
20
+ workflows_router,
21
+ executions_router,
22
+ system_router
23
+ )
24
+
25
+ # Configure logging
26
+ logging.basicConfig(
27
+ level=getattr(logging, settings.LOG_LEVEL.upper()),
28
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
29
+ )
30
+ logger = logging.getLogger(__name__)
31
+
32
+ # Lifespan context manager for startup and shutdown events
33
+ @asynccontextmanager
34
+ async def lifespan(app: FastAPI):
35
+ """
36
+ Async context manager to handle application startup and shutdown events.
37
+ """
38
+ # Startup tasks
39
+ try:
40
+ # Connect to database
41
+ await Database.connect()
42
+
43
+ # Initialize users collection and create admin user if not exists
44
+ await init_users_collection()
45
+
46
+ logger.info("Application startup completed successfully")
47
+ yield
48
+ except Exception as e:
49
+ logger.error(f"Error during application startup: {e}")
50
+ raise
51
+ finally:
52
+ # Shutdown tasks
53
+ try:
54
+ await Database.disconnect()
55
+ logger.info("Application shutdown completed successfully")
56
+ except Exception as e:
57
+ logger.error(f"Error during application shutdown: {e}")
58
+
59
+ # Create FastAPI application
60
+ app = FastAPI(
61
+ title="EvoAgentX API",
62
+ description="API for EvoAgentX platform",
63
+ version="1.0.0",
64
+ lifespan=lifespan,
65
+ docs_url="/docs",
66
+ redoc_url="/redoc"
67
+ )
68
+
69
+ # Configure CORS
70
+ app.add_middleware(
71
+ CORSMiddleware,
72
+ allow_origins=settings.CORS_ORIGINS,
73
+ allow_credentials=True,
74
+ allow_methods=["*"],
75
+ allow_headers=["*"],
76
+ )
77
+
78
+ # Include routers
79
+ app.include_router(auth_router)
80
+ app.include_router(agents_router)
81
+ app.include_router(workflows_router)
82
+ app.include_router(executions_router)
83
+ app.include_router(system_router)
84
+
85
+ # Global exception handlers
86
+ @app.exception_handler(RequestValidationError)
87
+ async def validation_exception_handler(request: Request, exc: RequestValidationError):
88
+ """
89
+ Custom validation error handler to provide more detailed error responses.
90
+ """
91
+ return JSONResponse(
92
+ status_code=422,
93
+ content={
94
+ "status": "error",
95
+ "message": "Validation error",
96
+ "errors": exc.errors()
97
+ }
98
+ )
99
+
100
+ @app.exception_handler(HTTPException)
101
+ async def http_exception_handler(request: Request, exc: HTTPException):
102
+ """
103
+ Custom HTTP exception handler to provide consistent error responses.
104
+ """
105
+ return JSONResponse(
106
+ status_code=exc.status_code,
107
+ content={
108
+ "status": "error",
109
+ "message": exc.detail
110
+ }
111
+ )
112
+
113
+ # Root endpoint for health check
114
+ @app.get("/")
115
+ async def root():
116
+ """
117
+ Root endpoint for application health check.
118
+ """
119
+ return {
120
+ "app_name": settings.APP_NAME,
121
+ "status": "healthy",
122
+ "version": "0.1.0"
123
+ }
124
+
125
+ # Workflow logging and monitoring endpoint
126
+ @app.get("/metrics")
127
+ async def get_metrics():
128
+ """
129
+ Endpoint to retrieve system metrics and stats.
130
+ """
131
+ # Collect metrics from different services
132
+ try:
133
+ # Collect agent metrics
134
+ total_agents = await Database.agents.count_documents({})
135
+ active_agents = await Database.agents.count_documents({"status": "active"})
136
+
137
+ # Collect workflow metrics
138
+ total_workflows = await Database.workflows.count_documents({})
139
+ running_workflows = await Database.workflows.count_documents({"status": "running"})
140
+
141
+ # Collect execution metrics
142
+ total_executions = await Database.executions.count_documents({})
143
+ failed_executions = await Database.executions.count_documents({"status": "failed"})
144
+
145
+ return {
146
+ "agents": {
147
+ "total": total_agents,
148
+ "active": active_agents
149
+ },
150
+ "workflows": {
151
+ "total": total_workflows,
152
+ "running": running_workflows
153
+ },
154
+ "executions": {
155
+ "total": total_executions,
156
+ "failed": failed_executions
157
+ }
158
+ }
159
+ except Exception as e:
160
+ logger.error(f"Error retrieving metrics: {e}")
161
+ return {
162
+ "status": "error",
163
+ "message": "Unable to retrieve metrics"
164
+ }
165
+
166
+ # Run the application if this script is executed directly
167
+ if __name__ == "__main__":
168
+ # Configuration for running the server
169
+ uvicorn_config = {
170
+ "host": settings.HOST,
171
+ "port": settings.PORT,
172
+ "reload": settings.DEBUG,
173
+ "log_level": settings.LOG_LEVEL.lower()
174
+ }
175
+
176
+ # Start the server
177
+ uvicorn.run("evoagentx.app.main:app", **uvicorn_config)
evoagentx/app/requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FastAPI and ASGI server
2
+ fastapi==0.115.10
3
+ uvicorn==0.22.0
4
+ pydantic==2.7.0
5
+ pydantic-settings==2.8.1
6
+
7
+ # MongoDB ODM
8
+ motor==3.3.1
9
+ pymongo==4.6.0
10
+ sqlalchemy-2.0.38
11
+
12
+ python-jose==3.3.0
13
+ passlib==1.7.4
14
+ python-multipart==0.0.6
15
+ bcrypt==4.0.1
16
+ celery==5.3.4
17
+ redis==5.0.0
18
+ pytest==7.4.2
19
+ pytest-asyncio==0.21.0
20
+ httpx==0.24.1
21
+ asgi-lifespan==1.0.1
22
+ python-dotenv==1.0.0
23
+ loguru==0.7.3
evoagentx/app/schemas.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pydantic models for request/response validation in the EvoAgentX API.
3
+ """
4
+ from datetime import datetime
5
+ from typing import Optional, List, Dict, Any # , Union
6
+ from pydantic import BaseModel, Field # , validator
7
+ from bson import ObjectId
8
+ from evoagentx.app.db import AgentStatus, WorkflowStatus, ExecutionStatus
9
+
10
+ # Helper for ObjectId validation
11
+ class PyObjectId(ObjectId):
12
+ @classmethod
13
+ def __get_validators__(cls):
14
+ yield cls.validate
15
+
16
+ @classmethod
17
+ def validate(cls, v):
18
+ if not ObjectId.is_valid(v):
19
+ raise ValueError("Invalid ObjectId")
20
+ return ObjectId(v)
21
+
22
+ @classmethod
23
+ def __modify_schema__(cls, field_schema):
24
+ field_schema.update(type="string")
25
+
26
+ # Base Schema Models
27
+ class BaseSchema(BaseModel):
28
+ class Config:
29
+ allow_population_by_field_name = True
30
+ arbitrary_types_allowed = True
31
+ json_encoders = {
32
+ ObjectId: str,
33
+ datetime: lambda dt: dt.isoformat()
34
+ }
35
+
36
+ # Agent Schemas
37
+ class AgentCreate(BaseSchema):
38
+ name: str
39
+ description: Optional[str] = None
40
+ config: Dict[str, Any]
41
+ runtime_params: Dict[str, Any] = Field(default_factory=dict)
42
+ tags: List[str] = Field(default_factory=list)
43
+
44
+ class AgentUpdate(BaseSchema):
45
+ name: Optional[str] = None
46
+ description: Optional[str] = None
47
+ config: Optional[Dict[str, Any]] = None
48
+ runtime_params: Optional[Dict[str, Any]] = None
49
+ status: Optional[AgentStatus] = None
50
+ tags: Optional[List[str]] = None
51
+
52
+ class AgentResponse(BaseSchema):
53
+ id: str = Field(..., alias="_id")
54
+ name: str
55
+ description: Optional[str] = None
56
+ config: Dict[str, Any]
57
+ status: AgentStatus
58
+ runtime_params: Dict[str, Any]
59
+ created_at: datetime
60
+ updated_at: datetime
61
+ created_by: Optional[str] = None
62
+ tags: List[str]
63
+
64
+ # Workflow Schemas
65
+ class WorkflowStepDefinition(BaseSchema):
66
+ step_id: str
67
+ agent_id: str
68
+ action: str
69
+ input_mapping: Dict[str, str] = Field(default_factory=dict)
70
+ output_mapping: Dict[str, str] = Field(default_factory=dict)
71
+ timeout_seconds: int = 300
72
+ retry_count: int = 3
73
+ depends_on: List[str] = Field(default_factory=list)
74
+
75
+ class WorkflowCreate(BaseSchema):
76
+ name: str
77
+ description: Optional[str] = None
78
+ definition: Dict[str, Any]
79
+ tags: List[str] = Field(default_factory=list)
80
+
81
+ class WorkflowUpdate(BaseSchema):
82
+ name: Optional[str] = None
83
+ description: Optional[str] = None
84
+ definition: Optional[Dict[str, Any]] = None
85
+ status: Optional[WorkflowStatus] = None
86
+ tags: Optional[List[str]] = None
87
+
88
+ class WorkflowResponse(BaseSchema):
89
+ id: str = Field(..., alias="_id")
90
+ name: str
91
+ description: Optional[str] = None
92
+ definition: Dict[str, Any]
93
+ agent_ids: List[str]
94
+ status: WorkflowStatus
95
+ created_at: datetime
96
+ updated_at: datetime
97
+ created_by: Optional[str] = None
98
+ tags: List[str]
99
+ version: int
100
+
101
+ # Execution Schemas
102
+ class ExecutionCreate(BaseSchema):
103
+ workflow_id: str
104
+ input_params: Dict[str, Any] = Field(default_factory=dict)
105
+ callback_url: Optional[str] = None
106
+
107
+ class ExecutionResponse(BaseSchema):
108
+ id: str = Field(..., alias="_id")
109
+ workflow_id: str
110
+ status: ExecutionStatus
111
+ start_time: Optional[datetime] = None
112
+ end_time: Optional[datetime] = None
113
+ input_params: Dict[str, Any]
114
+ results: Dict[str, Any]
115
+ created_by: Optional[str] = None
116
+ step_results: Dict[str, Dict[str, Any]]
117
+ current_step: Optional[str] = None
118
+ error_message: Optional[str] = None
119
+ created_at: datetime
120
+
121
+ class ExecutionLogResponse(BaseSchema):
122
+ id: str = Field(..., alias="_id")
123
+ workflow_id: str
124
+ execution_id: str
125
+ step_id: Optional[str] = None
126
+ agent_id: Optional[str] = None
127
+ timestamp: datetime
128
+ level: str
129
+ message: str
130
+ details: Dict[str, Any]
131
+
132
+ # User auth schemas
133
+ class Token(BaseSchema):
134
+ access_token: str
135
+ token_type: str
136
+
137
+ class TokenPayload(BaseSchema):
138
+ sub: Optional[str] = None
139
+ exp: Optional[int] = None
140
+
141
+ class UserCreate(BaseSchema):
142
+ email: str
143
+ password: str
144
+ full_name: Optional[str] = None
145
+
146
+ class UserLogin(BaseSchema):
147
+ email: str
148
+ password: str
149
+
150
+ class UserResponse(BaseSchema):
151
+ id: str = Field(..., alias="_id")
152
+ email: str
153
+ full_name: Optional[str] = None
154
+ is_active: bool
155
+ is_admin: bool
156
+ created_at: datetime
157
+
158
+ # Query parameters
159
+ class PaginationParams(BaseSchema):
160
+ skip: int = 0
161
+ limit: int = 100
162
+
163
+ class SearchParams(BaseSchema):
164
+ query: Optional[str] = None
165
+ tags: Optional[List[str]] = None
166
+ status: Optional[str] = None
167
+ start_date: Optional[datetime] = None
168
+ end_date: Optional[datetime] = None
evoagentx/app/security.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Security components for authentication and authorization.
3
+ """
4
+ import jwt
5
+ from datetime import datetime, timedelta
6
+ from typing import Optional, Dict, Any # , List
7
+ from passlib.context import CryptContext
8
+ from fastapi import Depends, HTTPException, status
9
+ from fastapi.security import OAuth2PasswordBearer
10
+ from pydantic import BaseModel, ValidationError
11
+ from pymongo.errors import DuplicateKeyError
12
+ from bson import ObjectId
13
+
14
+ from evoagentx.app.config import settings
15
+ from evoagentx.app.db import Database
16
+ from evoagentx.app.schemas import TokenPayload, UserCreate, UserResponse
17
+
18
+ # Password hashing
19
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
20
+
21
+ # OAuth2 scheme for token authentication
22
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_PREFIX}/auth/login")
23
+
24
+ # User model for database
25
+ class UserInDB(BaseModel):
26
+ _id: Optional[ObjectId] = None
27
+ email: str
28
+ hashed_password: str
29
+ full_name: Optional[str] = None
30
+ is_active: bool = True
31
+ is_admin: bool = False
32
+ created_at: datetime = datetime.utcnow()
33
+
34
+ def verify_password(plain_password: str, hashed_password: str) -> bool:
35
+ """Verify a password against a hash."""
36
+ return pwd_context.verify(plain_password, hashed_password)
37
+
38
+ def get_password_hash(password: str) -> str:
39
+ """Hash a password for storing."""
40
+ return pwd_context.hash(password)
41
+
42
+ async def get_user_by_email(email: str) -> Optional[Dict[str, Any]]:
43
+ """Get a user by email."""
44
+ return await Database.db.users.find_one({"email": email})
45
+
46
+ async def authenticate_user(email: str, password: str) -> Optional[Dict[str, Any]]:
47
+ """Authenticate a user by email and password."""
48
+ user = await get_user_by_email(email)
49
+ if not user:
50
+ return None
51
+ if not verify_password(password, user["hashed_password"]):
52
+ return None
53
+ if not user.get("is_active", True):
54
+ return None
55
+ return user
56
+
57
+ async def create_user(user_create: UserCreate) -> UserResponse:
58
+ """Create a new user."""
59
+ # Check if user already exists
60
+ existing_user = await get_user_by_email(user_create.email)
61
+ if existing_user:
62
+ raise HTTPException(
63
+ status_code=status.HTTP_400_BAD_REQUEST,
64
+ detail="Email already registered"
65
+ )
66
+
67
+ # Create new user
68
+ user_dict = user_create.dict()
69
+ hashed_password = get_password_hash(user_dict.pop("password"))
70
+
71
+ new_user = {
72
+ "email": user_dict["email"],
73
+ "hashed_password": hashed_password,
74
+ "full_name": user_dict.get("full_name"),
75
+ "is_active": True,
76
+ "is_admin": False,
77
+ "created_at": datetime.utcnow()
78
+ }
79
+
80
+ try:
81
+ insert_result = await Database.db.users.insert_one(new_user)
82
+ new_user["_id"] = insert_result.inserted_id
83
+ return UserResponse(**new_user)
84
+ except DuplicateKeyError:
85
+ raise HTTPException(
86
+ status_code=status.HTTP_400_BAD_REQUEST,
87
+ detail="Email already registered"
88
+ )
89
+
90
+ def create_access_token(subject: str, expires_delta: Optional[timedelta] = None) -> str:
91
+ """Create a new JWT access token."""
92
+ if expires_delta:
93
+ expire = datetime.utcnow() + expires_delta
94
+ else:
95
+ expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
96
+
97
+ to_encode = {"exp": expire, "sub": subject}
98
+ encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
99
+ return encoded_jwt
100
+
101
+ async def get_current_user(token: str = Depends(oauth2_scheme)) -> Dict[str, Any]:
102
+ """Get the current user from a JWT token."""
103
+ try:
104
+ payload = jwt.decode(
105
+ token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
106
+ )
107
+ token_data = TokenPayload(**payload)
108
+
109
+ if datetime.fromtimestamp(token_data.exp) < datetime.utcnow():
110
+ raise HTTPException(
111
+ status_code=status.HTTP_401_UNAUTHORIZED,
112
+ detail="Token expired",
113
+ headers={"WWW-Authenticate": "Bearer"},
114
+ )
115
+ except (jwt.PyJWTError, ValidationError):
116
+ raise HTTPException(
117
+ status_code=status.HTTP_403_FORBIDDEN,
118
+ detail="Could not validate credentials",
119
+ headers={"WWW-Authenticate": "Bearer"},
120
+ )
121
+
122
+ user = await get_user_by_email(token_data.sub)
123
+ if user is None:
124
+ raise HTTPException(
125
+ status_code=status.HTTP_404_NOT_FOUND,
126
+ detail="User not found"
127
+ )
128
+
129
+ return user
130
+
131
+ async def get_current_active_user(current_user: Dict[str, Any] = Depends(get_current_user)) -> Dict[str, Any]:
132
+ """Get the current active user."""
133
+ if not current_user.get("is_active", True):
134
+ raise HTTPException(status_code=400, detail="Inactive user")
135
+ return current_user
136
+
137
+ async def get_current_admin_user(current_user: Dict[str, Any] = Depends(get_current_user)) -> Dict[str, Any]:
138
+ """Get the current admin user."""
139
+ if not current_user.get("is_admin", False):
140
+ raise HTTPException(
141
+ status_code=status.HTTP_403_FORBIDDEN,
142
+ detail="Not enough permissions"
143
+ )
144
+ return current_user
145
+
146
+ # Initialize the users collection
147
+ async def init_users_collection():
148
+ """Initialize the users collection with indexes."""
149
+ await Database.db.users.create_index("email", unique=True)
150
+
151
+ # Create admin user if it doesn't exist
152
+ admin_email = "admin@clayx.ai"
153
+ admin = await get_user_by_email(admin_email)
154
+ if not admin:
155
+ admin_user = UserCreate(
156
+ email=admin_email,
157
+ password="adminpassword", # Change this in production!
158
+ full_name="Admin User"
159
+ )
160
+ user_dict = admin_user.dict()
161
+ hashed_password = get_password_hash(user_dict["password"])
162
+
163
+ new_admin = {
164
+ "email": admin_email,
165
+ "hashed_password": hashed_password,
166
+ "full_name": "Admin User",
167
+ "is_active": True,
168
+ "is_admin": True,
169
+ "created_at": datetime.utcnow()
170
+ }
171
+
172
+ await Database.db.users.insert_one(new_admin)
evoagentx/app/services.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Business logic for agents, workflows, and executions.
3
+ """
4
+ import logging
5
+ # import asyncio
6
+ from datetime import datetime
7
+ from typing import List, Dict, Any, Optional, Tuple
8
+ from bson import ObjectId
9
+
10
+ from evoagentx.app.db import (
11
+ Database, # Agent, Workflow, WorkflowExecution, ExecutionLog,
12
+ AgentStatus, WorkflowStatus, ExecutionStatus
13
+ )
14
+ from evoagentx.app.schemas import (
15
+ AgentCreate, AgentUpdate, WorkflowCreate, WorkflowUpdate,
16
+ ExecutionCreate, PaginationParams, SearchParams
17
+ )
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # Agent Service
22
+ class AgentService:
23
+ @staticmethod
24
+ async def create_agent(agent_data: AgentCreate, user_id: Optional[str] = None) -> Dict[str, Any]:
25
+ """Create a new agent."""
26
+ agent_dict = agent_data.dict()
27
+ agent_dict["created_by"] = user_id
28
+ agent_dict["created_at"] = datetime.utcnow()
29
+ agent_dict["updated_at"] = agent_dict["created_at"]
30
+ agent_dict["status"] = AgentStatus.CREATED
31
+
32
+ # Validate agent exists with the same name
33
+ existing_agent = await Database.agents.find_one({"name": agent_dict["name"]})
34
+ if existing_agent:
35
+ raise ValueError(f"Agent with name '{agent_dict['name']}' already exists")
36
+
37
+ result = await Database.agents.insert_one(agent_dict)
38
+ agent_dict["_id"] = result.inserted_id
39
+
40
+ logger.info(f"Created agent {agent_dict['name']} with ID {result.inserted_id}")
41
+
42
+ return agent_dict
43
+
44
+ @staticmethod
45
+ async def get_agent(agent_id: str) -> Optional[Dict[str, Any]]:
46
+ """Get an agent by ID."""
47
+ if not ObjectId.is_valid(agent_id):
48
+ raise ValueError(f"Invalid agent ID: {agent_id}")
49
+
50
+ agent = await Database.agents.find_one({"_id": ObjectId(agent_id)})
51
+ return agent
52
+
53
+ @staticmethod
54
+ async def get_agent_by_name(name: str) -> Optional[Dict[str, Any]]:
55
+ """Get an agent by name."""
56
+ return await Database.agents.find_one({"name": name})
57
+
58
+ @staticmethod
59
+ async def update_agent(agent_id: str, agent_data: AgentUpdate) -> Optional[Dict[str, Any]]:
60
+ """Update an agent."""
61
+ if not ObjectId.is_valid(agent_id):
62
+ raise ValueError(f"Invalid agent ID: {agent_id}")
63
+
64
+ agent = await Database.agents.find_one({"_id": ObjectId(agent_id)})
65
+ if not agent:
66
+ return None
67
+
68
+ update_data = agent_data.dict(exclude_unset=True)
69
+ update_data["updated_at"] = datetime.utcnow()
70
+
71
+ if "name" in update_data:
72
+ # Check if the new name already exists
73
+ existing = await Database.agents.find_one({
74
+ "name": update_data["name"],
75
+ "_id": {"$ne": ObjectId(agent_id)}
76
+ })
77
+ if existing:
78
+ raise ValueError(f"Agent with name '{update_data['name']}' already exists")
79
+
80
+ await Database.agents.update_one(
81
+ {"_id": ObjectId(agent_id)},
82
+ {"$set": update_data}
83
+ )
84
+
85
+ updated_agent = await Database.agents.find_one({"_id": ObjectId(agent_id)})
86
+ logger.info(f"Updated agent {agent_id}")
87
+
88
+ return updated_agent
89
+
90
+ @staticmethod
91
+ async def delete_agent(agent_id: str) -> bool:
92
+ """Delete an agent."""
93
+ if not ObjectId.is_valid(agent_id):
94
+ raise ValueError(f"Invalid agent ID: {agent_id}")
95
+
96
+ # Check if agent is used in any workflows
97
+ workflow_count = await Database.workflows.count_documents({"agent_ids": agent_id})
98
+ if workflow_count > 0:
99
+ raise ValueError(f"Cannot delete agent {agent_id} as it is used in {workflow_count} workflows")
100
+
101
+ result = await Database.agents.delete_one({"_id": ObjectId(agent_id)})
102
+ if result.deleted_count:
103
+ logger.info(f"Deleted agent {agent_id}")
104
+ return True
105
+ return False
106
+
107
+ @staticmethod
108
+ async def list_agents(
109
+ params: PaginationParams,
110
+ search: Optional[SearchParams] = None
111
+ ) -> Tuple[List[Dict[str, Any]], int]:
112
+ """List agents with pagination and search."""
113
+ query = {}
114
+
115
+ if search:
116
+ if search.query:
117
+ query["$text"] = {"$search": search.query}
118
+
119
+ if search.tags:
120
+ query["tags"] = {"$all": search.tags}
121
+
122
+ if search.status:
123
+ query["status"] = search.status
124
+
125
+ if search.start_date and search.end_date:
126
+ query["created_at"] = {
127
+ "$gte": search.start_date,
128
+ "$lte": search.end_date
129
+ }
130
+ elif search.start_date:
131
+ query["created_at"] = {"$gte": search.start_date}
132
+ elif search.end_date:
133
+ query["created_at"] = {"$lte": search.end_date}
134
+
135
+ total = await Database.agents.count_documents(query)
136
+
137
+ cursor = Database.agents.find(query)\
138
+ .sort("created_at", -1)\
139
+ .skip(params.skip)\
140
+ .limit(params.limit)
141
+
142
+ agents = await cursor.to_list(length=params.limit)
143
+ return agents, total
144
+
145
+ # Workflow Service
146
+ class WorkflowService:
147
+ @staticmethod
148
+ async def create_workflow(workflow_data: WorkflowCreate, user_id: Optional[str] = None) -> Dict[str, Any]:
149
+ """Create a new workflow."""
150
+ workflow_dict = workflow_data.dict()
151
+ workflow_dict["created_by"] = user_id
152
+ workflow_dict["created_at"] = datetime.utcnow()
153
+ workflow_dict["updated_at"] = workflow_dict["created_at"]
154
+ workflow_dict["status"] = WorkflowStatus.CREATED
155
+ workflow_dict["version"] = 1
156
+
157
+ # Extract agent IDs from the workflow definition
158
+ agent_ids = set()
159
+
160
+ # Extract agent IDs from steps
161
+ steps = workflow_dict["definition"].get("steps", [])
162
+ for step in steps:
163
+ if "agent_id" in step:
164
+ agent_id = step["agent_id"]
165
+ # Validate agent exists
166
+ agent = await AgentService.get_agent(agent_id)
167
+ if not agent:
168
+ raise ValueError(f"Agent with ID {agent_id} does not exist")
169
+ agent_ids.add(agent_id)
170
+
171
+ workflow_dict["agent_ids"] = list(agent_ids)
172
+
173
+ # Check for existing workflow with the same name
174
+ existing = await Database.workflows.find_one({"name": workflow_dict["name"]})
175
+ if existing:
176
+ raise ValueError(f"Workflow with name '{workflow_dict['name']}' already exists")
177
+
178
+ result = await Database.workflows.insert_one(workflow_dict)
179
+ workflow_dict["_id"] = result.inserted_id
180
+
181
+ logger.info(f"Created workflow {workflow_dict['name']} with ID {result.inserted_id}")
182
+
183
+ return workflow_dict
184
+
185
+ @staticmethod
186
+ async def get_workflow(workflow_id: str) -> Optional[Dict[str, Any]]:
187
+ """Get a workflow by ID."""
188
+ if not ObjectId.is_valid(workflow_id):
189
+ raise ValueError(f"Invalid workflow ID: {workflow_id}")
190
+ workflow = await Database.workflows.find_one({"_id": ObjectId(workflow_id)})
191
+ return workflow
192
+
193
+ @staticmethod
194
+ async def get_workflow_by_name(name: str) -> Optional[Dict[str, Any]]:
195
+ """Get a workflow by name."""
196
+ return await Database.workflows.find_one({"name": name})
197
+
198
+ @staticmethod
199
+ async def update_workflow(workflow_id: str, workflow_data: WorkflowUpdate) -> Optional[Dict[str, Any]]:
200
+ """Update a workflow."""
201
+ if not ObjectId.is_valid(workflow_id):
202
+ raise ValueError(f"Invalid workflow ID: {workflow_id}")
203
+
204
+ workflow = await Database.workflows.find_one({"_id": ObjectId(workflow_id)})
205
+ if not workflow:
206
+ return None
207
+
208
+ update_data = workflow_data.dict(exclude_unset=True)
209
+ update_data["updated_at"] = datetime.utcnow()
210
+
211
+ # Update version if definition changes
212
+ if "definition" in update_data:
213
+ update_data["version"] = workflow.get("version", 1) + 1
214
+
215
+ # Extract agent IDs from the updated workflow definition
216
+ agent_ids = set()
217
+ steps = update_data["definition"].get("steps", [])
218
+ for step in steps:
219
+ if "agent_id" in step:
220
+ agent_id = step["agent_id"]
221
+ # Validate agent exists
222
+ agent = await AgentService.get_agent(agent_id)
223
+ if not agent:
224
+ raise ValueError(f"Agent with ID {agent_id} does not exist")
225
+ agent_ids.add(agent_id)
226
+
227
+ update_data["agent_ids"] = list(agent_ids)
228
+
229
+ # Check for name conflict if name is being updated
230
+ if "name" in update_data:
231
+ existing = await Database.workflows.find_one({
232
+ "name": update_data["name"],
233
+ "_id": {"$ne": ObjectId(workflow_id)}
234
+ })
235
+ if existing:
236
+ raise ValueError(f"Workflow with name '{update_data['name']}' already exists")
237
+
238
+ await Database.workflows.update_one(
239
+ {"_id": ObjectId(workflow_id)},
240
+ {"$set": update_data}
241
+ )
242
+
243
+ updated_workflow = await Database.workflows.find_one({"_id": ObjectId(workflow_id)})
244
+ logger.info(f"Updated workflow {workflow_id}")
245
+
246
+ return updated_workflow
247
+
248
+ @staticmethod
249
+ async def delete_workflow(workflow_id: str) -> bool:
250
+ """Delete a workflow."""
251
+ if not ObjectId.is_valid(workflow_id):
252
+ raise ValueError(f"Invalid workflow ID: {workflow_id}")
253
+
254
+ # Check if workflow has any ongoing or recent executions
255
+ recent_executions = await Database.executions.count_documents({
256
+ "workflow_id": workflow_id,
257
+ "status": {"$in": [
258
+ ExecutionStatus.PENDING,
259
+ ExecutionStatus.RUNNING
260
+ ]}
261
+ })
262
+
263
+ if recent_executions > 0:
264
+ raise ValueError(f"Cannot delete workflow {workflow_id} with {recent_executions} active executions")
265
+
266
+ result = await Database.workflows.delete_one({"_id": ObjectId(workflow_id)})
267
+ if result.deleted_count:
268
+ # Delete associated execution logs
269
+ await Database.logs.delete_many({"workflow_id": workflow_id})
270
+ await Database.executions.delete_many({"workflow_id": workflow_id})
271
+
272
+ logger.info(f"Deleted workflow {workflow_id}")
273
+ return True
274
+ return False
275
+
276
+ @staticmethod
277
+ async def list_workflows(
278
+ params: PaginationParams,
279
+ search: Optional[SearchParams] = None
280
+ ) -> Tuple[List[Dict[str, Any]], int]:
281
+ """List workflows with pagination and search."""
282
+ query = {}
283
+
284
+ if search:
285
+ if search.query:
286
+ query["$text"] = {"$search": search.query}
287
+
288
+ if search.tags:
289
+ query["tags"] = {"$all": search.tags}
290
+
291
+ if search.status:
292
+ query["status"] = search.status
293
+
294
+ if search.start_date and search.end_date:
295
+ query["created_at"] = {
296
+ "$gte": search.start_date,
297
+ "$lte": search.end_date
298
+ }
299
+ elif search.start_date:
300
+ query["created_at"] = {"$gte": search.start_date}
301
+ elif search.end_date:
302
+ query["created_at"] = {"$lte": search.end_date}
303
+
304
+ total = await Database.workflows.count_documents(query)
305
+
306
+ cursor = Database.workflows.find(query)\
307
+ .sort("created_at", -1)\
308
+ .skip(params.skip)\
309
+ .limit(params.limit)
310
+
311
+ workflows = await cursor.to_list(length=params.limit)
312
+ return workflows, total
313
+
314
+ # Workflow Execution Service
315
+ class WorkflowExecutionService:
316
+ @staticmethod
317
+ async def create_execution(execution_data: ExecutionCreate, user_id: Optional[str] = None) -> Dict[str, Any]:
318
+ """Create a new workflow execution."""
319
+ # Validate workflow exists
320
+ workflow = await WorkflowService.get_workflow(execution_data.workflow_id)
321
+ if not workflow:
322
+ raise ValueError(f"Workflow {execution_data.workflow_id} not found")
323
+
324
+ # Prepare execution document
325
+ execution_dict = {
326
+ "workflow_id": execution_data.workflow_id,
327
+ "status": ExecutionStatus.PENDING,
328
+ "start_time": datetime.utcnow(),
329
+ "input_params": execution_data.input_params,
330
+ "created_by": user_id,
331
+ "created_at": datetime.utcnow(),
332
+ "step_results": {},
333
+ "current_step": None,
334
+ "results": {},
335
+ "error_message": None
336
+ }
337
+
338
+ # Insert execution record
339
+ result = await Database.executions.insert_one(execution_dict)
340
+ execution_dict["_id"] = result.inserted_id
341
+
342
+ logger.info(f"Created workflow execution {result.inserted_id}")
343
+
344
+ # Optional: Queue execution for async processing
345
+ # This would typically use a task queue like Celery
346
+ # await execute_workflow_async.delay(execution_dict)
347
+
348
+ return execution_dict
349
+
350
+ @staticmethod
351
+ async def get_execution(execution_id: str) -> Optional[Dict[str, Any]]:
352
+ """Get a workflow execution by ID."""
353
+ if not ObjectId.is_valid(execution_id):
354
+ raise ValueError(f"Invalid execution ID: {execution_id}")
355
+
356
+ execution = await Database.executions.find_one({"_id": ObjectId(execution_id)})
357
+ return execution
358
+
359
+ @staticmethod
360
+ async def update_execution_status(execution_id: str, status: ExecutionStatus, error_message: Optional[str] = None) -> Optional[Dict[str, Any]]:
361
+ """Update execution status."""
362
+ if not ObjectId.is_valid(execution_id):
363
+ raise ValueError(f"Invalid execution ID: {execution_id}")
364
+
365
+ update_data = {
366
+ "status": status,
367
+ "updated_at": datetime.utcnow()
368
+ }
369
+
370
+ if status in [ExecutionStatus.COMPLETED, ExecutionStatus.FAILED, ExecutionStatus.CANCELLED]:
371
+ update_data["end_time"] = datetime.utcnow()
372
+
373
+ if error_message:
374
+ update_data["error_message"] = error_message
375
+
376
+ result = await Database.executions.find_one_and_update(
377
+ {"_id": ObjectId(execution_id)},
378
+ {"$set": update_data},
379
+ return_document=True
380
+ )
381
+
382
+ return result
383
+
384
+ @staticmethod
385
+ async def list_executions(
386
+ workflow_id: Optional[str] = None,
387
+ params: PaginationParams = PaginationParams(),
388
+ search: Optional[SearchParams] = None
389
+ ) -> Tuple[List[Dict[str, Any]], int]:
390
+ """List workflow executions with pagination and search."""
391
+ query = {}
392
+
393
+ if workflow_id:
394
+ query["workflow_id"] = workflow_id
395
+
396
+ if search:
397
+ if search.status:
398
+ query["status"] = search.status
399
+
400
+ if search.start_date and search.end_date:
401
+ query["created_at"] = {
402
+ "$gte": search.start_date,
403
+ "$lte": search.end_date
404
+ }
405
+ elif search.start_date:
406
+ query["created_at"] = {"$gte": search.start_date}
407
+ elif search.end_date:
408
+ query["created_at"] = {"$lte": search.end_date}
409
+
410
+ total = await Database.executions.count_documents(query)
411
+
412
+ cursor = Database.executions.find(query)\
413
+ .sort("created_at", -1)\
414
+ .skip(params.skip)\
415
+ .limit(params.limit)
416
+
417
+ executions = await cursor.to_list(length=params.limit)
418
+ return executions, total
419
+
420
+ @staticmethod
421
+ async def log_execution_event(
422
+ workflow_id: str,
423
+ execution_id: str,
424
+ message: str,
425
+ step_id: Optional[str] = None,
426
+ agent_id: Optional[str] = None,
427
+ level: str = "INFO",
428
+ details: Optional[Dict[str, Any]] = None
429
+ ) -> Dict[str, Any]:
430
+ """Log an event in a workflow execution."""
431
+ log_entry = {
432
+ "workflow_id": workflow_id,
433
+ "execution_id": execution_id,
434
+ "step_id": step_id,
435
+ "agent_id": agent_id,
436
+ "timestamp": datetime.utcnow(),
437
+ "level": level,
438
+ "message": message,
439
+ "details": details or {}
440
+ }
441
+
442
+ result = await Database.logs.insert_one(log_entry)
443
+ log_entry["_id"] = result.inserted_id
444
+
445
+ return log_entry
446
+
447
+ @staticmethod
448
+ async def get_execution_logs(
449
+ execution_id: str,
450
+ params: PaginationParams = PaginationParams()
451
+ ) -> Tuple[List[Dict[str, Any]], int]:
452
+ """Retrieve logs for a specific execution."""
453
+ query = {"execution_id": execution_id}
454
+
455
+ total = await Database.logs.count_documents(query)
456
+
457
+ cursor = Database.logs.find(query)\
458
+ .sort("timestamp", 1)\
459
+ .skip(params.skip)\
460
+ .limit(params.limit)
461
+
462
+ logs = await cursor.to_list(length=params.limit)
463
+ return logs, total
evoagentx/benchmark/.ipynb_checkpoints/Untitled-checkpoint.ipynb ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [],
3
+ "metadata": {},
4
+ "nbformat": 4,
5
+ "nbformat_minor": 5
6
+ }
evoagentx/benchmark/.ipynb_checkpoints/test_load_json-checkpoint.ipynb ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "385cefbd",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "ename": "ImportError",
11
+ "evalue": "attempted relative import with no known parent package",
12
+ "output_type": "error",
13
+ "traceback": [
14
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
15
+ "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
16
+ "Cell \u001b[0;32mIn[1], line 7\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mshutil\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Union, Any, Callable, List, Dict, Tuple\n\u001b[0;32m----> 7\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mscicode\u001b[39;00m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m \u001b[38;5;66;03m# Many SciCode tests use numpy\u001b[39;00m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbenchmark\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m CodingBenchmark\n",
17
+ "File \u001b[0;32m/gpfs/radev/pi/ying_rex/tl688/selfevolve/EvoAgentX/evoagentx/benchmark/scicode.py:10\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mscicode\u001b[39;00m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m \u001b[38;5;66;03m# Many SciCode tests use numpy\u001b[39;00m\n\u001b[0;32m---> 10\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbenchmark\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m CodingBenchmark\n\u001b[1;32m 11\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcore\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mlogging\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m logger\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m download_file\n",
18
+ "\u001b[0;31mImportError\u001b[0m: attempted relative import with no known parent package"
19
+ ]
20
+ }
21
+ ],
22
+ "source": [
23
+ "import os\n",
24
+ "import re\n",
25
+ "import gzip\n",
26
+ "import shutil\n",
27
+ "from typing import Union, Any, Callable, List, Dict, Tuple\n",
28
+ "\n",
29
+ "import scicode\n",
30
+ "\n",
31
+ "import numpy as np # Many SciCode tests use numpy\n",
32
+ "from .benchmark import CodingBenchmark\n",
33
+ "from ..core.logging import logger\n",
34
+ "from ..utils.utils import download_file\n",
35
+ "from ..core.module_utils import load_json\n",
36
+ "from ..utils.aflow_utils.data_utils import AFLOW_DATASET_FILES_MAP, download_aflow_benchmark_data\n",
37
+ "\n",
38
+ "\n",
39
+ "# ----------------------------\n",
40
+ "# Raw SciCode (community) data\n",
41
+ "# ----------------------------\n",
42
+ "\n",
43
+ "SCICODE_DEFAULT_URL = \"https://raw.githubusercontent.com/scicode-bench/scicode/main/data/scicode.jsonl.gz\" # If you mirror elsewhere, update here.\n",
44
+ "\n",
45
+ "\n",
46
+ "def download_raw_scicode_data(save_folder: str, url: str = SCICODE_DEFAULT_URL) -> str:\n",
47
+ " \"\"\"\n",
48
+ " Download and unzip the raw SciCode jsonl(.gz) to `save_folder`.\n",
49
+ "\n",
50
+ " Returns:\n",
51
+ " str: Path to the unzipped jsonl file.\n",
52
+ " \"\"\"\n",
53
+ " os.makedirs(save_folder, exist_ok=True)\n",
54
+ " gz_path = os.path.join(save_folder, \"scicode.jsonl.gz\")\n",
55
+ " jsonl_path = os.path.join(save_folder, \"scicode.jsonl\")\n",
56
+ "\n",
57
+ " logger.info(f\"Downloading SciCode data from {url} ...\")\n",
58
+ " download_file(url=url, save_file=gz_path)\n",
59
+ "\n",
60
+ " logger.info(\"Unzipping SciCode data ...\")\n",
61
+ " with gzip.open(gz_path, \"rb\") as f_in, open(jsonl_path, \"wb\") as f_out:\n",
62
+ " shutil.copyfileobj(f_in, f_out)\n",
63
+ " if os.path.exists(gz_path):\n",
64
+ " os.remove(gz_path)\n",
65
+ "\n",
66
+ " return jsonl_path\n",
67
+ "\n",
68
+ "\n",
69
+ "# ----------------------------\n",
70
+ "# Schema helpers\n",
71
+ "# ----------------------------\n",
72
+ "\n",
73
+ "def _extract_entry_point_from_header(header: str) -> str:\n",
74
+ " \"\"\"\n",
75
+ " Given a SciCode 'function_header' string like:\n",
76
+ " \"def get_alpha(recvec, alpha_scaling=5):\\n '''...'''\"\n",
77
+ " return \"get_alpha\".\n",
78
+ " \"\"\"\n",
79
+ " m = re.search(r\"def\\s+([A-Za-z_][A-Za-z0-9_]*)\\s*\\(\", header)\n",
80
+ " if not m:\n",
81
+ " raise ValueError(\"Could not parse entry point from function_header\")\n",
82
+ " return m.group(1)\n",
83
+ "\n",
84
+ "\n",
85
+ "def _coerce_scicode_row_to_examples(row: Dict[str, Any]) -> List[Dict[str, Any]]:\n",
86
+ " \"\"\"\n",
87
+ " SciCode rows may contain a single task or multiple step tasks.\n",
88
+ " We normalize them to a list of examples with a unified structure:\n",
89
+ " {\n",
90
+ " \"task_id\": \"SciCode/<name>#<sub_id>\",\n",
91
+ " \"prompt\": <function_header + optional docstring block>,\n",
92
+ " \"entry_point\": <func_name>,\n",
93
+ " \"canonical_solution\": <ground_truth_code>,\n",
94
+ " \"tests\": List[str], # list of python test snippets\n",
95
+ " \"imports\": str # optional import prelude (e.g., 'import numpy as np')\n",
96
+ " }\n",
97
+ " \"\"\"\n",
98
+ " examples: List[Dict[str, Any]] = []\n",
99
+ "\n",
100
+ " name = str(row[0]) if 0 in row or isinstance(row, list) else str(row.get(\"name\", \"unknown\"))\n",
101
+ " # Different dumps can be list-based or dict-based; support both:\n",
102
+ " if isinstance(row, list):\n",
103
+ " # Heuristic index layout (based on the example provided by the user):\n",
104
+ " # [name, <maybe_int>, description, <maybe empty>, docstring, imports, steps(list[dict]) or code, tests(list[str]) or None]\n",
105
+ " # We will try to find keys by semantic type\n",
106
+ " description = None\n",
107
+ " doc_or_header = None\n",
108
+ " imports_block = None\n",
109
+ " steps_or_code = None\n",
110
+ " tests = None\n",
111
+ "\n",
112
+ " # Try assigning by scanning\n",
113
+ " for item in row:\n",
114
+ " if isinstance(item, str) and item.strip().startswith('\"\"\"'):\n",
115
+ " # docstring/prompt block for the top-level task\n",
116
+ " doc_or_header = item\n",
117
+ " elif isinstance(item, str) and (item.startswith(\"import \") or \"from \" in item):\n",
118
+ " imports_block = item\n",
119
+ " elif isinstance(item, list):\n",
120
+ " # Could be steps OR tests\n",
121
+ " if item and isinstance(item[0], dict) and \"function_header\" in item[0]:\n",
122
+ " steps_or_code = item\n",
123
+ " elif item and isinstance(item[0], str) and item[0].strip().startswith((\"ref\", \"assert\", \"from \")):\n",
124
+ " tests = item\n",
125
+ " elif isinstance(item, dict):\n",
126
+ " # Some SciCode variants may directly be dicts per step; treat as steps\n",
127
+ " steps_or_code = [item]\n",
128
+ "\n",
129
+ " # If we have step dictionaries, produce one example per step\n",
130
+ " if isinstance(steps_or_code, list) and steps_or_code and isinstance(steps_or_code[0], dict):\n",
131
+ " for idx, step in enumerate(steps_or_code):\n",
132
+ " header = step.get(\"function_header\") or step.get(\"header\") or \"\"\n",
133
+ " code = step.get(\"ground_truth_code\") or step.get(\"solution\") or \"\"\n",
134
+ " step_tests = step.get(\"test_cases\") or []\n",
135
+ " entry_point = _extract_entry_point_from_header(header)\n",
136
+ " prompt = header # keep header as the model prompt (header + docstring already embedded)\n",
137
+ " examples.append(\n",
138
+ " {\n",
139
+ " \"task_id\": f\"SciCode/{name}#step{idx+1}\",\n",
140
+ " \"prompt\": prompt,\n",
141
+ " \"entry_point\": entry_point,\n",
142
+ " \"canonical_solution\": code,\n",
143
+ " \"tests\": step_tests,\n",
144
+ " \"imports\": imports_block or \"\",\n",
145
+ " }\n",
146
+ " )\n",
147
+ " else:\n",
148
+ " # Single task variant: expect a combined \"function_header\" + \"ground_truth_code\" + \"test_cases\" in the row\n",
149
+ " # Try to detect them from the large code string block if present.\n",
150
+ " # Fall back to no-op if missing.\n",
151
+ " # NOTE: The user’s example shows a consolidated block near the end; we’ll try to parse it.\n",
152
+ " code_blob = None\n",
153
+ " for item in row:\n",
154
+ " if isinstance(item, str) and \"def \" in item and \"return\" in item:\n",
155
+ " code_blob = item\n",
156
+ " break\n",
157
+ " # Try to split the big blob into multiple functions; evaluate the last one as the main if we cannot find header separately.\n",
158
+ " if code_blob:\n",
159
+ " # Heuristic: the last \"def ...\" in the blob is the target entry point\n",
160
+ " headers = list(re.finditer(r\"(?ms)^(def\\s+[A-Za-z_][A-Za-z0-9_]*\\s*\\(.*?\\):\\s*\\n)\", code_blob))\n",
161
+ " if headers:\n",
162
+ " last_header = headers[-1].group(1)\n",
163
+ " entry_point = _extract_entry_point_from_header(last_header)\n",
164
+ " else:\n",
165
+ " entry_point = \"solution\"\n",
166
+ "\n",
167
+ " # We will treat entire blob as canonical_solution and create a minimal prompt from the docstring if present\n",
168
+ " prompt = doc_or_header or f\"def {entry_point}(*args, **kwargs):\\n '''Fill in the function body.'''\\n ...\"\n",
169
+ " examples.append(\n",
170
+ " {\n",
171
+ " \"task_id\": f\"SciCode/{name}\",\n",
172
+ " \"prompt\": prompt,\n",
173
+ " \"entry_point\": entry_point,\n",
174
+ " \"canonical_solution\": code_blob,\n",
175
+ " \"tests\": tests or [],\n",
176
+ " \"imports\": imports_block or \"\",\n",
177
+ " }\n",
178
+ " )\n",
179
+ "\n",
180
+ " else:\n",
181
+ " # Dict-style row (fallback): expect keys by name\n",
182
+ " steps = row.get(\"steps\", [])\n",
183
+ " imports_block = row.get(\"imports\", \"\")\n",
184
+ " task_name = row.get(\"name\", \"unknown\")\n",
185
+ "\n",
186
+ " if steps:\n",
187
+ " for idx, step in enumerate(steps):\n",
188
+ " header = step.get(\"function_header\", \"\")\n",
189
+ " code = step.get(\"ground_truth_code\", \"\")\n",
190
+ " step_tests = step.get(\"test_cases\", [])\n",
191
+ " entry_point = _extract_entry_point_from_header(header)\n",
192
+ " examples.append(\n",
193
+ " {\n",
194
+ " \"task_id\": f\"SciCode/{task_name}#step{idx+1}\",\n",
195
+ " \"prompt\": header,\n",
196
+ " \"entry_point\": entry_point,\n",
197
+ " \"canonical_solution\": code,\n",
198
+ " \"tests\": step_tests,\n",
199
+ " \"imports\": imports_block or \"\",\n",
200
+ " }\n",
201
+ " )\n",
202
+ " else:\n",
203
+ " header = row.get(\"function_header\", \"\")\n",
204
+ " code = row.get(\"ground_truth_code\", \"\")\n",
205
+ " tests = row.get(\"test_cases\", [])\n",
206
+ " entry_point = _extract_entry_point_from_header(header) if header else \"solution\"\n",
207
+ " prompt = header or f\"def {entry_point}(*args, **kwargs):\\n pass\"\n",
208
+ " examples.append(\n",
209
+ " {\n",
210
+ " \"task_id\": f\"SciCode/{task_name}\",\n",
211
+ " \"prompt\": prompt,\n",
212
+ " \"entry_point\": entry_point,\n",
213
+ " \"canonical_solution\": code,\n",
214
+ " \"tests\": tests,\n",
215
+ " \"imports\": imports_block or \"\",\n",
216
+ " }\n",
217
+ " )\n",
218
+ "\n",
219
+ " return examples\n",
220
+ "\n",
221
+ "\n",
222
+ "def load_scicode_data(jsonl_path: str) -> List[Dict[str, Any]]:\n",
223
+ " \"\"\"\n",
224
+ " Load SciCode jsonl and expand into normalized examples.\n",
225
+ " \"\"\"\n",
226
+ " raw = load_json(jsonl_path, type=\"jsonl\")\n",
227
+ " all_examples: List[Dict[str, Any]] = []\n",
228
+ " for row in raw:\n",
229
+ " try:\n",
230
+ " all_examples.extend(_coerce_scicode_row_to_examples(row))\n",
231
+ " except Exception as e:\n",
232
+ " logger.warning(f\"[SciCode] Skipping a malformed row due to: {e}\")\n",
233
+ " return all_examples\n",
234
+ "\n",
235
+ "\n",
236
+ "# ----------------------------\n",
237
+ "# Benchmark classes\n",
238
+ "# ----------------------------\n",
239
+ "\n",
240
+ "class SciCode(CodingBenchmark):\n",
241
+ " \"\"\"\n",
242
+ " Benchmark class for evaluating code generation on SciCode.\n",
243
+ "\n",
244
+ " SciCode problems provide:\n",
245
+ " - function_header (prompt stub)\n",
246
+ " - ground_truth_code (reference implementation)\n",
247
+ " - test_cases (list[str] of python asserts)\n",
248
+ "\n",
249
+ " We normalize each item and evaluate by executing the candidate implementation\n",
250
+ " against the provided test cases. Since many SciCode tests reference a variable\n",
251
+ " named `target`, we heuristically pre-compute `target` from the reference\n",
252
+ " implementation when necessary, or set it to True for boolean-allclose tests.\n",
253
+ " \"\"\"\n",
254
+ "\n",
255
+ " def __init__(self, path: str = None, mode: str = \"all\", timeout: int = 60, k: Union[int, list] = 1, **kwargs):\n",
256
+ " path = os.path.expanduser(path or \"~/.evoagentx/data/scicode\")\n",
257
+ " self.k = k\n",
258
+ " super().__init__(name=type(self).__name__, path=path, mode=mode, timeout=timeout, **kwargs)\n",
259
+ "\n",
260
+ " # ---------- Data loading ----------\n",
261
+ "\n",
262
+ " def _load_data(self):\n",
263
+ " data_path = os.path.join(self.path, \"scicode.jsonl\")\n",
264
+ " if not os.path.exists(data_path):\n",
265
+ " data_path = download_raw_scicode_data(self.path)\n",
266
+ "\n",
267
+ " # For SciCode, we place everything into \"test\" split by default.\n",
268
+ "\n",
269
+ " if self.mode in (\"dev\", \"all\"):\n",
270
+ " self._dev_data = load_scicode_data(\"/home/tl688/pitl688/selfevolve/SciCode/eval/data/subproblems_dev.jsonl\")\n",
271
+ " if self.mode in (\"test\", \"all\"):\n",
272
+ " self._test_data = load_scicode_data(\"/home/tl688/pitl688/selfevolve/SciCode/eval/data/subproblems_test.jsonl\")\n",
273
+ "\n",
274
+ " def _get_label(self, example: Any):\n",
275
+ " \"\"\"\n",
276
+ " For SciCode we treat the label as the full test suite plus metadata.\n",
277
+ " \"\"\"\n",
278
+ " return {\n",
279
+ " \"task_id\": example[\"task_id\"],\n",
280
+ " \"entry_point\": example[\"entry_point\"],\n",
281
+ " \"tests\": example.get(\"tests\", []),\n",
282
+ " \"canonical_solution\": example.get(\"canonical_solution\", \"\"),\n",
283
+ " \"imports\": example.get(\"imports\", \"\"),\n",
284
+ " }\n",
285
+ "\n",
286
+ " def _get_id(self, example: Any):\n",
287
+ " return example[\"task_id\"]\n",
288
+ "\n",
289
+ " # ---------- Evaluation ----------\n",
290
+ "\n",
291
+ " @staticmethod\n",
292
+ " def _build_reference_namespace(imports: str, canonical_solution: str) -> Dict[str, Any]:\n",
293
+ " \"\"\"\n",
294
+ " Build an execution namespace that defines the reference function.\n",
295
+ " \"\"\"\n",
296
+ " ns: Dict[str, Any] = {\"np\": np, \"scicode\":scicode}\n",
297
+ " if imports:\n",
298
+ " exec(imports, ns, ns) # e.g., \"import numpy as np\\nfrom scipy.special import erfc\"\n",
299
+ " if canonical_solution:\n",
300
+ " exec(canonical_solution, ns, ns)\n",
301
+ " return ns\n",
302
+ "\n",
303
+ " @staticmethod\n",
304
+ " def _extract_candidate_exprs_from_test(test_src: str) -> List[str]:\n",
305
+ " \"\"\"\n",
306
+ " Heuristically extract expressions that are compared against `target` inside np.allclose(..., target)\n",
307
+ " or equality checks like \"== target\" / \", target)\" etc. Returns a list of python expressions (as strings)\n",
308
+ " that we should evaluate with the *reference* implementation to generate `target`.\n",
309
+ "\n",
310
+ " This is a pragmatic parser covering the most common SciCode patterns.\n",
311
+ " \"\"\"\n",
312
+ " exprs: List[str] = []\n",
313
+ "\n",
314
+ " # Pattern A: np.allclose( <expr>, target )\n",
315
+ " for m in re.finditer(r\"np\\.allclose\\s*\\(\\s*(?P<expr>.+?)\\s*,\\s*target\\s*\\)\", test_src, flags=re.DOTALL):\n",
316
+ " exprs.append(m.group(\"expr\"))\n",
317
+ "\n",
318
+ " # Pattern B: assert <expr> == target\n",
319
+ " for m in re.finditer(r\"assert\\s+(?P<expr>.+?)\\s*==\\s*target\", test_src):\n",
320
+ " exprs.append(m.group(\"expr\"))\n",
321
+ "\n",
322
+ " # Pattern C: assert <expr>, target (when the first arg should be True)\n",
323
+ " # In this case, target is expected to be True; no need to compute it.\n",
324
+ " # We'll handle by leaving exprs empty and later default target=True.\n",
325
+ "\n",
326
+ " # Pattern D: Using slices like target[0], target[1] — we try to recover by\n",
327
+ " # extracting both left-hand expressions in the same line in order:\n",
328
+ " # np.allclose(func(...)[0], target[0]) and np.allclose(func(...)[1], target[1])\n",
329
+ " # Already captured by Pattern A; expr may include \"[0]\" or \"[1]\".\n",
330
+ " return exprs\n",
331
+ "\n",
332
+ " @staticmethod\n",
333
+ " def _compute_target_list(exprs: List[str], ref_ns: Dict[str, Any]) -> Any:\n",
334
+ " \"\"\"\n",
335
+ " Given a list of expressions (strings), evaluate them in the reference namespace.\n",
336
+ " If multiple expressions are found, we pack them into a tuple in the same order.\n",
337
+ " If no expression found, return True (to support tests of the form `assert <bool>, target`).\n",
338
+ " \"\"\"\n",
339
+ " if not exprs:\n",
340
+ " return True\n",
341
+ " values = []\n",
342
+ " for ex in exprs:\n",
343
+ " # Safety: limit builtins\n",
344
+ " local_ns: Dict[str, Any] = {}\n",
345
+ " val = eval(ex, ref_ns, local_ns)\n",
346
+ " values.append(val)\n",
347
+ " if len(values) == 1:\n",
348
+ " return values[0]\n",
349
+ " return tuple(values)\n",
350
+ "\n",
351
+ " def _make_harness(self, task_id: str, entry_point: str, imports: str, canonical_solution: str, tests: List[str], candidate_src: str) -> str:\n",
352
+ " \"\"\"\n",
353
+ " Construct an executable harness that:\n",
354
+ " 1) Defines imports\n",
355
+ " 2) Defines candidate implementation (prompt + candidate completion)\n",
356
+ " 3) Pre-computes `target` using the reference implementation for each test (heuristics)\n",
357
+ " 4) Executes the original test snippet with `target` bound.\n",
358
+ " We run each test independently within the same process, stopping on first failure.\n",
359
+ " \"\"\"\n",
360
+ " # We'll build a block that iterates tests in Python.\n",
361
+ " # We cannot dynamically pass `target` into a raw `assert` snippet without executing it;\n",
362
+ " # so for each test, we will:\n",
363
+ " # a) compute target in a separate namespace using reference function,\n",
364
+ " # b) then execute the original test with the candidate function and that target.\n",
365
+ " # This is orchestrated by the benchmark runtime (not inside the user env).\n",
366
+ "\n",
367
+ " # NOTE: actual orchestration happens in `evaluate()` by repeated calls to `check_solution`;\n",
368
+ " # here we only prepare the body (candidate code). The unit tests are executed by the\n",
369
+ " # framework’s sand-boxed executor using `test` passed in.\n",
370
+ "\n",
371
+ " # We keep the candidate_src as-is. The imports are prepended at runtime via the test body.\n",
372
+ " return candidate_src\n",
373
+ "\n",
374
+ " def handle_special_cases(self, task_id: str, solution: str, test: str) -> Tuple[str, str]:\n",
375
+ " \"\"\"\n",
376
+ " Hook: adjust solution/test for edge cases in SciCode, if needed.\n",
377
+ " Currently, we leave as-is and fallback to the base handler.\n",
378
+ " \"\"\"\n",
379
+ " return super().handle_special_cases(task_id=task_id, solution=solution, test=test)\n",
380
+ "\n",
381
+ " def evaluate(self, prediction: Any, label: Any) -> dict:\n",
382
+ " \"\"\"\n",
383
+ " Evaluate candidate solution(s) against SciCode test cases.\n",
384
+ "\n",
385
+ " Strategy:\n",
386
+ " - For each candidate solution:\n",
387
+ " - For each test snippet:\n",
388
+ " 1) Build reference namespace; compute `target` (heuristics).\n",
389
+ " 2) Build candidate code by concatenating example['prompt'] + candidate solution.\n",
390
+ " 3) Execute the test with `target` and candidate in the sandbox via `check_solution`.\n",
391
+ "\n",
392
+ " - Aggregate per-test pass/fail into a single boolean for the example.\n",
393
+ " - Compute pass@k across candidates.\n",
394
+ " \"\"\"\n",
395
+ " prediction, label = self._check_evaluation_inputs(prediction, label)\n",
396
+ "\n",
397
+ " results = []\n",
398
+ " for solution in prediction:\n",
399
+ " # Each `label` item corresponds to the SAME example in our usage (benchmark runs per example),\n",
400
+ " # but we preserve the structure consistent with the base class.\n",
401
+ " solution_states = []\n",
402
+ " for label_data in label:\n",
403
+ " task_id = label_data[\"task_id\"]\n",
404
+ " entry_point = label_data[\"entry_point\"]\n",
405
+ " tests = label_data.get(\"tests\", [])\n",
406
+ " imports = label_data.get(\"imports\", \"\")\n",
407
+ " canonical_solution = label_data.get(\"canonical_solution\", \"\")\n",
408
+ "\n",
409
+ " # Build reference env for computing `target`\n",
410
+ " ref_ns = self._build_reference_namespace(imports=imports, canonical_solution=canonical_solution)\n",
411
+ "\n",
412
+ " # Build candidate code (prompt + solution)\n",
413
+ " prompt = self.get_example_by_id(task_id)[\"prompt\"]\n",
414
+ " candidate_code = prompt + \"\\n\" + solution\n",
415
+ "\n",
416
+ " # Run each test individually; any failure => whole example fails\n",
417
+ " all_ok = True\n",
418
+ " for raw_test in tests if tests else [\"# no tests provided\\nassert True, True\"]:\n",
419
+ " # Heuristically precompute `target`\n",
420
+ " exprs = self._extract_candidate_exprs_from_test(raw_test)\n",
421
+ " try:\n",
422
+ " target_value = self._compute_target_list(exprs, ref_ns)\n",
423
+ " except Exception as e:\n",
424
+ " # If we cannot compute target from the reference, fall back to True\n",
425
+ " logger.warning(f\"[SciCode] Fallback target=True for {task_id} due to: {e}\")\n",
426
+ " target_value = True\n",
427
+ "\n",
428
+ " # Compose a runnable unit-test block:\n",
429
+ " # We inject `imports`, bind `target`, then execute the original test code.\n",
430
+ " unit_test = (\n",
431
+ " (imports or \"\")\n",
432
+ " + \"\\n\"\n",
433
+ " + \"target = __TARGET_VALUE__\\n\"\n",
434
+ " + raw_test\n",
435
+ " )\n",
436
+ "\n",
437
+ " # Because `check_solution` runs code in separate exec, we stringify the target safely.\n",
438
+ " # We'll register a placeholder and pass the real object via the executor's globals.\n",
439
+ " # Our base framework doesn't support direct object injection; so we serialize small types.\n",
440
+ " # For numpy arrays/tuples we rely on repr + eval. If that fails, we degrade to boolean.\n",
441
+ " try:\n",
442
+ " # Light-weight serializer for numpy arrays / tuples / lists / scalars\n",
443
+ " def _pyrepr(obj):\n",
444
+ " if isinstance(obj, np.ndarray):\n",
445
+ " return f\"np.array({repr(obj.tolist())})\"\n",
446
+ " return repr(obj)\n",
447
+ "\n",
448
+ " unit_test = unit_test.replace(\n",
449
+ " \"__TARGET_VALUE__\", _pyrepr(target_value)\n",
450
+ " )\n",
451
+ " except Exception:\n",
452
+ " unit_test = unit_test.replace(\"__TARGET_VALUE__\", \"True\")\n",
453
+ "\n",
454
+ " # Optional special-case patching hook\n",
455
+ " candidate_code_patched, unit_test_patched = self.handle_special_cases(\n",
456
+ " task_id=task_id, solution=candidate_code, test=unit_test\n",
457
+ " )\n",
458
+ "\n",
459
+ " # Execute\n",
460
+ " state, message = self.check_solution(\n",
461
+ " task_id=task_id,\n",
462
+ " solution=candidate_code_patched,\n",
463
+ " test=unit_test_patched,\n",
464
+ " entry_point=entry_point,\n",
465
+ " )\n",
466
+ " if state != self.SUCCESS:\n",
467
+ " all_ok = False\n",
468
+ " break\n",
469
+ "\n",
470
+ " solution_states.append(self.SUCCESS if all_ok else self.FAILURE)\n",
471
+ " results.append(len(solution_states) == len(label) and all(s == self.SUCCESS for s in solution_states))\n",
472
+ "\n",
473
+ " k_list = [self.k] if isinstance(self.k, int) else self.k\n",
474
+ " pass_at_k = self.compute_pass_at_k(results, k_list)\n",
475
+ " return pass_at_k\n",
476
+ "\n",
477
+ "\n",
478
+ "class AFlowSciCode(SciCode):\n",
479
+ " \"\"\"\n",
480
+ " AFlow-specific implementation of SciCode benchmark.\n",
481
+ " Uses AFLOW_DATASET_FILES_MAP['scicode'] for split files (if provided by your distribution).\n",
482
+ " \"\"\"\n",
483
+ "\n",
484
+ " def __init__(self, path: str = None, mode: str = \"all\", timeout: int = 60, k: Union[int, list] = 1, **kwargs):\n",
485
+ " path = os.path.expanduser(path or \"~/.evoagentx/data/aflow/scicode\")\n",
486
+ " super().__init__(path=path, mode=mode, timeout=timeout, k=k, **kwargs)\n",
487
+ "\n",
488
+ " def _load_data_from_file(self, file_name: str):\n",
489
+ " if file_name is None:\n",
490
+ " return None\n",
491
+ " file_path = os.path.join(self.path, file_name)\n",
492
+ " if not os.path.exists(file_path):\n",
493
+ " logger.info(\"Downloading AFlow SciCode split files ...\")\n",
494
+ " download_aflow_benchmark_data(dataset=\"scicode\", save_folder=self.path)\n",
495
+ " return load_json(path=file_path, type=\"jsonl\")\n",
496
+ "\n",
497
+ " def _load_data(self):\n",
498
+ " # Prefer AFLOW split files when available; otherwise fall back to raw download.\n",
499
+ " if \"scicode\" not in AFLOW_DATASET_FILES_MAP:\n",
500
+ " logger.warning(\"AFLOW_DATASET_FILES_MAP has no entry for 'scicode'; falling back to raw SciCode jsonl.\")\n",
501
+ " return super()._load_data()\n",
502
+ "\n",
503
+ " splits = AFLOW_DATASET_FILES_MAP[\"scicode\"]\n",
504
+ " data_all: Dict[str, List[Dict[str, Any]]] = {}\n",
505
+ "\n",
506
+ " for split in (\"train\", \"dev\", \"test\"):\n",
507
+ " fname = splits.get(split)\n",
508
+ " if fname:\n",
509
+ " logger.info(f\"Loading {split} data from {fname}\")\n",
510
+ " raw_split = self._load_data_from_file(file_name=fname)\n",
511
+ " # Normalize rows to examples\n",
512
+ " examples: List[Dict[str, Any]] = []\n",
513
+ " for row in raw_split or []:\n",
514
+ " try:\n",
515
+ " examples.extend(_coerce_scicode_row_to_examples(row))\n",
516
+ " except Exception as e:\n",
517
+ " logger.warning(f\"[AFlowSciCode] Skipping a malformed row in {split} due to: {e}\")\n",
518
+ " data_all[split] = examples\n",
519
+ " else:\n",
520
+ " data_all[split] = None\n",
521
+ "\n",
522
+ " if self.mode in (\"train\", \"all\"):\n",
523
+ " self._train_data = data_all.get(\"train\")\n",
524
+ " if self.mode in (\"dev\", \"all\"):\n",
525
+ " self._dev_data = data_all.get(\"dev\")\n",
526
+ " if self.mode in (\"test\", \"all\"):\n",
527
+ " self._test_data = data_all.get(\"test\")\n",
528
+ "\n",
529
+ " async def async_evaluate(self, graph: Callable, example: Any) -> float:\n",
530
+ " \"\"\"\n",
531
+ " Generate a solution asynchronously and return pass@1 for the example.\n",
532
+ " \"\"\"\n",
533
+ " prompt, entry_point = example[\"prompt\"], example[\"entry_point\"]\n",
534
+ " solution = await graph(prompt, entry_point)\n",
535
+ " label = self._get_label(example)\n",
536
+ " metrics = await super().async_evaluate(prediction=solution, label=label)\n",
537
+ " return metrics.get(\"pass@1\", 0.0)\n"
538
+ ]
539
+ },
540
+ {
541
+ "cell_type": "code",
542
+ "execution_count": null,
543
+ "id": "a2bca001",
544
+ "metadata": {},
545
+ "outputs": [],
546
+ "source": []
547
+ }
548
+ ],
549
+ "metadata": {
550
+ "kernelspec": {
551
+ "display_name": "Python 3 (ipykernel)",
552
+ "language": "python",
553
+ "name": "python3"
554
+ },
555
+ "language_info": {
556
+ "codemirror_mode": {
557
+ "name": "ipython",
558
+ "version": 3
559
+ },
560
+ "file_extension": ".py",
561
+ "mimetype": "text/x-python",
562
+ "name": "python",
563
+ "nbconvert_exporter": "python",
564
+ "pygments_lexer": "ipython3",
565
+ "version": "3.11.0"
566
+ }
567
+ },
568
+ "nbformat": 4,
569
+ "nbformat_minor": 5
570
+ }
evoagentx/benchmark/README.md ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Benchmark
2
+
3
+ ## Benchmark Overview
4
+
5
+ This repository provides a set of benchmarks to facilitate the evaluation of different agent-based systems. Below is a summary of the benchmarks currently included, along with basic dataset statistics:
6
+
7
+
8
+ | Task | Dataset Name | # Train | # Dev | # Test |
9
+ | ------------------------- | --------------- | --------- | ------- | ------ |
10
+ | QA | NQ | 79,168 | 8,757 | 3,610 |
11
+ | Multi-Hop QA | HotPotQA | 90,447 | 7,405 | / |
12
+ | Math | GSM8K | 7,473 | / | 1,319 |
13
+ | Math | MATH | 7,500 | / | 5,000 |
14
+ | Code Generation | HumanEval | / | / | 164 |
15
+ | Code Generation | MBPP | / | / | 427 |
16
+ | Code Generation | LiveCodeBench(v1~v5) | / | / | 400~880 |
17
+ | Code Execution | LiveCodeBench | / | / | 479 |
18
+ | Test Output Prediction | LiveCodeBench | / | / | 442 |
19
+
20
+
21
+ Below, we introduce the preprocessing steps and evaluation metrics for each benchmark.
22
+
23
+ - [Question Answering](#question-answering)
24
+ - [NQ](#nq)
25
+ - [HotPotQA](#hotpotqa)
26
+ - [Math](#math)
27
+ - [GSM8K](#gsm8k)
28
+ - [MATH](#math)
29
+ - [Code Generation](#code-generation)
30
+ - [HumanEval](#humaneval)
31
+ - [MBPP](#mbpp)
32
+ - [LiveCodeBench](#livecodebench)
33
+
34
+ ## Preprocessing and Evaluation Metrics
35
+
36
+ ### Question Answering
37
+
38
+ For the QA datasets, we use Exact Match (EM), F1, and Accuracy (ACC) as evaluation metrics by default. EM requires the predicted answer to be exactly the same as the ground truth answer. ACC requires that the predicted answer contains the ground-truth answer, which is useful when the LLM is used to generate the answer.
39
+
40
+ #### NQ
41
+ [Natural Questions (NQ)](https://github.com/google-research-datasets/natural-questions) contains questions from the Google search engine and the answers, annotated by human annotators, are paragraphs or entities in the Wikipedia page of the top 5 search results. We use the dataset splits provided by the [DPR](https://github.com/facebookresearch/DPR) repository, which contains 79,168 training, 8,757 development, and 3,610 test examples.
42
+
43
+ You can load the dataset using the following code:
44
+ ```python
45
+ from evoagentx.benchmark import NQ
46
+ nq_dataset = NQ() # optional: path="/path/to/save_data"
47
+ test_data = nq_dataset.get_test_data()
48
+ ```
49
+ Each example in the dataset is in the following format:
50
+ ```json
51
+ {
52
+ "id": "test-1",
53
+ "question": "the question",
54
+ "answers": ["possible answers"]
55
+ }
56
+ ```
57
+
58
+
59
+ #### HotPotQA
60
+ [HotPotQA](https://hotpotqa.github.io/) is a multi-hop QA dataset that requires multi-step reasoning to answer the question. We use the distractor setting of the dataset. Each example contains a question, an answer, some context that contians both supporting and distractor information, and supporting facts. We only include the training and development sets, as the test set is not publicly available.
61
+
62
+ You can load the dataset using the following code:
63
+ ```python
64
+ from evoagentx.benchmark import HotPotQA
65
+ hotpotqa_dataset = HotPotQA() # optional: path="/path/to/save_data"
66
+ test_data = hotpotqa_dataset.get_test_data()
67
+ ```
68
+ Each example in the dataset is in the following format, where the second element (int) of a supporting_fact is the index of the sentence in the context that supports the answer.
69
+ ```json
70
+ {
71
+ "_id": "the id of the example",
72
+ "question": "the question",
73
+ "answer": "the answer",
74
+ "context": [["context_title", ["context_sentence", "another_sentence"]]],
75
+ "supporting_facts": [["supporting_title", 0]]
76
+ }
77
+ ```
78
+
79
+
80
+ ### Math
81
+
82
+ For match datasets, we use the solve rate as the evaluation metric. The solve rate is the ratio of the number of examples that are solved correctly to the total number of examples.
83
+
84
+ #### GSM8K
85
+ [GSM8K](https://github.com/openai/grade-school-math) consists of high quality grade school math problems created by human problem writers. These problems require multi-step mathematical reasoning to solve. We use the dataset splits provided by the original repository, which contains 7.5K training problems and 1K test problems.
86
+
87
+ You can load the dataset using the following code:
88
+ ```python
89
+ from evoagentx.benchmark import GSM8K
90
+ gsm8k_dataset = GSM8K() # optional: path="/path/to/save_data"
91
+ test_data = gsm8k_dataset.get_test_data()
92
+ ```
93
+ Each example in the dataset is in the following format:
94
+ ```json
95
+ {
96
+ "id": "test-1",
97
+ "question": "the question",
98
+ "answer": "the answer"
99
+ }
100
+ ```
101
+
102
+ #### MATH
103
+ The [Mathematics Aptitude Test of Heuristics (MATH)](https://github.com/hendrycks/math) dataset consists of problems from mathematics competitions, including the AMC 10, AMC 12, AIME, etc. Each problem in MATH has a step-by-step solution. We use the dataset splits provided by the original repository, which contains 7.5K training problems and 5K test problems.
104
+
105
+ You can load the dataset using the following code:
106
+ ```python
107
+ from evoagentx.benchmark import MATH
108
+ math_dataset = MATH() # optional: path="/path/to/save_data"
109
+ test_data = math_dataset.get_test_data()
110
+ ```
111
+ Each example in the dataset is in the following format. For the `level` field, valid values are: "Level 1", "Level 2", "Level 3", "Level 4", "Level 5", and "Level ?". The `type` field can be one of: "Geometry", "Algebra", "Intermediate Algebra", "Counting & Probability", "Precalculus", "Number Theory", or "Prealgebra".
112
+
113
+ ```json
114
+ {
115
+ "id": "test-1",
116
+ "problem": "the problem",
117
+ "solution": "the solution",
118
+ "level": "Level 1",
119
+ "type": "Algebra"
120
+ }
121
+ ```
122
+
123
+ ### Code Generation
124
+ For the code generation benchmarks, we use pass@k as the evaluation metric, where k is the number of solutions for each problem. By default, k is set to 1.
125
+
126
+ #### HumanEval
127
+ [HumanEval](https://github.com/openai/human-eval) is a dataset of 164 coding problems from the HumanEval benchmark. Each problem contains a function signature, a canonical solution, and a set of unit tests.
128
+
129
+ You can load the dataset using the following code:
130
+ ```python
131
+ from evoagentx.benchmark import HumanEval
132
+ humaneval_dataset = HumanEval() # optional: path="/path/to/save_data"
133
+ test_data = humaneval_dataset.get_test_data()
134
+ ```
135
+ Each example in the dataset is in the following format:
136
+ ```json
137
+ {
138
+ "task_id": "HumanEval/0",
139
+ "prompt": "the prompt of the problem",
140
+ "entry_point": "the name of the function to be tested",
141
+ "canonical_solution": "the canonical solution of the problem",
142
+ "test": "the unit tests of the problem"
143
+ }
144
+ ```
145
+
146
+ #### MBPP
147
+ [Mostly Basic Python Problems (MBPP)](https://github.com/google-research/google-research/tree/master/mbpp) consists of hundreds of entry-level Python programming problems. Each problem consists of a task description, code solution and 3 automated test cases. We use the [sanitized subset](https://github.com/google-research/google-research/blob/master/mbpp/sanitized-mbpp.json) of the MBPP dataset, which consists of 427 problems with data that are hand-verfied by the authors. To facilitate the evaluation, we convert the MBPP dataset into the HumanEval format.
148
+
149
+ You can load the dataset using the following code:
150
+ ```python
151
+ from evoagentx.benchmark import MBPP
152
+ mbpp_dataset = MBPP() # optional: path="/path/to/save_data"
153
+ test_data = mbpp_dataset.get_test_data()
154
+ ```
155
+ Each example in the dataset is in the following format, where we keep the original MBPP `task_id`.
156
+ ```json
157
+ {
158
+ "task_id": 2,
159
+ "prompt": "the prompt of the problem",
160
+ "entry_point": "the name of the function to be tested",
161
+ "canonical_solution": "the canonical solution of the problem",
162
+ "test": "the unit tests of the problem"
163
+ }
164
+ ```
165
+ You can also access the original MBPP attributes such as "code", "test_list" in the example by using `example["code"]`.
166
+
167
+
168
+ #### LiveCodeBench
169
+ [LiveCodeBench](https://livecodebench.github.io/) is a contamination-free evaluation benchmark of LLMs for code that continuously collects new problems over time. Particularly, LiveCodeBench also focuses on broader code-related capabilities, such as code execution, and test output prediction, beyond mere code generation. Currently, LiveCodeBench hosts over three hundred high-quality coding problems published between May 2023 and February 2024.
170
+
171
+ You can load the dataset using the following code, where `scenario` can be one of [`code_generation`, `test_output_prediction`, `code_execution`] indicating different tasks. `version` denotes different versions of the code generation datasets, which is only available for `code_generation` scenario, and can be one of `["release_v1", "release_v2", "release_v3", "release_v4", "release_v5", "release_latest"]`. Please refer to the [LiveCodeBench](https://livecodebench.github.io/) repository for more details.
172
+
173
+ ```python
174
+ from evoagentx.benchmark import LiveCodeBench
175
+ livecodebench_dataset = LiveCodeBench(scenario="code_generation", version="release_v1") # optional: path="/path/to/save_data"
176
+ test_data = livecodebench_dataset.get_test_data()
177
+ ```
178
+
evoagentx/benchmark/Untitled.ipynb ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [],
3
+ "metadata": {},
4
+ "nbformat": 4,
5
+ "nbformat_minor": 5
6
+ }
evoagentx/benchmark/WorfBench.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ from typing import Any, Dict, Callable, List
5
+ from .benchmark import Benchmark
6
+ from .measures import exact_match_score, f1_score, acc_score
7
+ from ..core.logging import logger
8
+ from ..core.module_utils import load_json
9
+ from datasets import load_dataset
10
+
11
+ # WorfBench dataset file mapping
12
+ WORFBENCH_FILES_MAP = {
13
+ "train": "worfbench_train.json",
14
+ "test": "worfbench_test.json"
15
+ }
16
+ VALID_WORFBENCH_FILES = list(WORFBENCH_FILES_MAP.values())
17
+
18
+ def evaluate_workflow_sequence(prediction: List[Any], ground_truth: List[Any]) -> float:
19
+ """Evaluate F1 score for sequence workflow."""
20
+ from .measures import f1_score
21
+ return f1_score(prediction=prediction, ground_truth=ground_truth)
22
+
23
+ def evaluate_workflow_graph(prediction: Dict[str, Any], ground_truth: Dict[str, Any]) -> float:
24
+ """Evaluate F1 score for graph workflow."""
25
+ pred_nodes = set(prediction.get("nodes", []))
26
+ true_nodes = set(ground_truth.get("nodes", []))
27
+ pred_edges = set(tuple(edge) for edge in prediction.get("edges", []))
28
+ true_edges = set(tuple(edge) for edge in ground_truth.get("edges", []))
29
+
30
+ node_precision = len(pred_nodes & true_nodes) / len(pred_nodes) if pred_nodes else 0
31
+ node_recall = len(pred_nodes & true_nodes) / len(true_nodes) if true_nodes else 0
32
+ edge_precision = len(pred_edges & true_edges) / len(pred_edges) if pred_edges else 0
33
+ edge_recall = len(pred_edges & true_edges) / len(true_edges) if true_edges else 0
34
+
35
+ node_f1 = 2 * (node_precision * node_recall) / (node_precision + node_recall) if (node_precision + node_recall) > 0 else 0
36
+ edge_f1 = 2 * (edge_precision * edge_recall) / (edge_precision + edge_recall) if (edge_precision + edge_recall) > 0 else 0
37
+
38
+ return (node_f1 + edge_f1) / 2
39
+
40
+ def download_worfbench_data(dataset: str, save_folder: str) -> None:
41
+ """
42
+ Download WorfBench dataset from Hugging Face.
43
+
44
+ Args:
45
+ dataset (str): Dataset name ("worfbench").
46
+ save_folder (str): Directory to save data.
47
+ """
48
+ datasets_map = {
49
+ "train": {"repo_id": "zjunlp/WorFBench_train", "filename": "worfbench_train.json", "split": "train"},
50
+ "test": {"repo_id": "zjunlp/WorFBench_test", "filename": "worfbench_test.json", "split": "test"}
51
+ }
52
+
53
+ os.makedirs(save_folder, exist_ok=True)
54
+ for split, info in datasets_map.items():
55
+ repo_id = info["repo_id"]
56
+ filename = info["filename"]
57
+ dataset_split = info["split"]
58
+ save_path = os.path.join(save_folder, filename)
59
+
60
+ if not os.path.exists(save_path):
61
+ logger.info(f"Downloading {split} split of {dataset} from {repo_id}...")
62
+ try:
63
+ # Load dataset
64
+ ds = load_dataset(repo_id, split=dataset_split)
65
+ # Convert dataset to list and save as JSON
66
+ data = [item for item in ds]
67
+ with open(save_path, 'w', encoding='utf-8') as f:
68
+ json.dump(data, f, ensure_ascii=False, indent=2)
69
+ logger.info(f"Successfully downloaded and saved {filename} to {save_path}")
70
+ except Exception as e:
71
+ logger.error(f"Failed to download or save {filename}: {e}")
72
+ raise
73
+ else:
74
+ logger.info(f"File {save_path} already exists, skipping download.")
75
+
76
+ class WorfBench(Benchmark):
77
+ """
78
+ WorfBench evaluation class for assessing LLM agents on complex workflow generation tasks.
79
+ Assumed data structure:
80
+ {
81
+ "id": str,
82
+ "task": str,
83
+ "context": list of dicts (e.g., [{"title": str, "content": list of str}]),
84
+ "expected_output": str or dict (sequence or graph),
85
+ "type": str,
86
+ "level": str
87
+ }
88
+ """
89
+ def __init__(self, path: str = None, mode: str = "test", **kwargs):
90
+ path = os.path.expanduser(path or "~/.worfbench/data")
91
+ super().__init__(name=type(self).__name__, path=path, mode=mode, **kwargs)
92
+
93
+ def _load_data_from_file(self, file_name: str) -> Dict:
94
+ if file_name is None:
95
+ return None
96
+ file_path = os.path.join(self.path, file_name)
97
+ if not os.path.exists(file_path):
98
+ download_worfbench_data(dataset="worfbench", save_folder=self.path)
99
+ if not os.path.exists(file_path):
100
+ logger.error(f"File {file_path} still does not exist after download attempt!")
101
+ return None
102
+ logger.info(f"Loading WorfBench data from {file_path} ...")
103
+ data = load_json(path=file_path, type="json")
104
+ if data is None:
105
+ logger.error(f"Failed to load data from {file_path}")
106
+ return None
107
+ return data
108
+
109
+ def _load_data(self) -> None:
110
+ if self.mode in ["train", "dev"]:
111
+ self._train_data = self._load_data_from_file(file_name=WORFBENCH_FILES_MAP["train"])
112
+ if self.mode == "dev":
113
+ if self._train_data:
114
+ random.seed(42)
115
+ keys = list(self._train_data.keys())
116
+ n_dev = len(self._train_data[keys[0]]) // 10 or 1
117
+ indices = list(range(len(self._train_data[keys[0]])))
118
+ random.shuffle(indices)
119
+ self._train_data = {k: [v[i] for i in indices[:n_dev]] for k, v in self._train_data.items()}
120
+ if self.mode == "test":
121
+ self._test_data = self._load_data_from_file(file_name=WORFBENCH_FILES_MAP["test"])
122
+
123
+ def _get_label(self, example: Dict) -> Any:
124
+ return example.get("expected_output", "")
125
+
126
+ def _get_id(self, example: Dict) -> Any:
127
+ return example.get("id", "")
128
+
129
+ def evaluate(self, prediction: Any, label: Any) -> Dict:
130
+ if isinstance(prediction, list) and isinstance(label, list):
131
+ f1 = evaluate_workflow_sequence(prediction, label)
132
+ elif isinstance(prediction, dict) and isinstance(label, dict):
133
+ f1 = evaluate_workflow_graph(prediction, label)
134
+ else:
135
+ f1 = f1_score(prediction=str(prediction), ground_truth=str(label))
136
+ em = exact_match_score(prediction=prediction, ground_truth=label)
137
+ acc = acc_score(prediction=prediction, ground_truths=[label])
138
+ return {"em": em, "f1": f1, "acc": acc}
139
+
140
+ async def async_evaluate(self, graph: Callable, example: Dict) -> float:
141
+ task = example.get("task", "")
142
+ context = "\n".join(
143
+ f"{ctx.get('title', '')}: {' '.join(ctx.get('content', []))}"
144
+ for ctx in example.get("context", [])
145
+ if isinstance(ctx, dict)
146
+ )
147
+ inputs = f"Task: {task}\nContext: {context}\nGenerate workflow:\nAnswer:"
148
+ try:
149
+ generated_workflow = await graph(inputs)
150
+ except Exception as e:
151
+ logger.error(f"Error generating workflow: {e}")
152
+ generated_workflow = ""
153
+ label = self._get_label(example)
154
+ metrics = self.evaluate(prediction=generated_workflow, label=label)
155
+ return metrics["f1"]