Upload 2846 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +14 -0
- evoagentx/.ipynb_checkpoints/test aflow-checkpoint.ipynb +427 -0
- evoagentx/__init__.py +2 -0
- evoagentx/__pycache__/__init__.cpython-311.pyc +0 -0
- evoagentx/actions/__init__.py +5 -0
- evoagentx/actions/__pycache__/__init__.cpython-311.pyc +0 -0
- evoagentx/actions/__pycache__/action.cpython-311.pyc +0 -0
- evoagentx/actions/__pycache__/agent_generation.cpython-311.pyc +0 -0
- evoagentx/actions/__pycache__/code_extraction.cpython-311.pyc +0 -0
- evoagentx/actions/__pycache__/code_verification.cpython-311.pyc +0 -0
- evoagentx/actions/__pycache__/customize_action.cpython-311.pyc +0 -0
- evoagentx/actions/__pycache__/task_planning.cpython-311.pyc +0 -0
- evoagentx/actions/action.py +256 -0
- evoagentx/actions/agent_generation.py +198 -0
- evoagentx/actions/code_extraction.py +276 -0
- evoagentx/actions/code_verification.py +63 -0
- evoagentx/actions/customize_action.py +559 -0
- evoagentx/actions/task_planning.py +80 -0
- evoagentx/agents/__init__.py +6 -0
- evoagentx/agents/__pycache__/__init__.cpython-311.pyc +0 -0
- evoagentx/agents/__pycache__/action_agent.cpython-311.pyc +0 -0
- evoagentx/agents/__pycache__/agent.cpython-311.pyc +0 -0
- evoagentx/agents/__pycache__/agent_generator.cpython-311.pyc +0 -0
- evoagentx/agents/__pycache__/agent_manager.cpython-311.pyc +0 -0
- evoagentx/agents/__pycache__/customize_agent.cpython-311.pyc +0 -0
- evoagentx/agents/__pycache__/task_planner.cpython-311.pyc +0 -0
- evoagentx/agents/__pycache__/workflow_reviewer.cpython-311.pyc +0 -0
- evoagentx/agents/action_agent.py +502 -0
- evoagentx/agents/agent.py +531 -0
- evoagentx/agents/agent_generator.py +23 -0
- evoagentx/agents/agent_manager.py +505 -0
- evoagentx/agents/customize_agent.py +522 -0
- evoagentx/agents/long_term_memory_agent.py +491 -0
- evoagentx/agents/task_planner.py +35 -0
- evoagentx/agents/workflow_reviewer.py +14 -0
- evoagentx/app/__init__.py +0 -0
- evoagentx/app/api.py +329 -0
- evoagentx/app/app.env +22 -0
- evoagentx/app/config.py +83 -0
- evoagentx/app/db.py +177 -0
- evoagentx/app/main.py +177 -0
- evoagentx/app/requirements.txt +23 -0
- evoagentx/app/schemas.py +168 -0
- evoagentx/app/security.py +172 -0
- evoagentx/app/services.py +463 -0
- evoagentx/benchmark/.ipynb_checkpoints/Untitled-checkpoint.ipynb +6 -0
- evoagentx/benchmark/.ipynb_checkpoints/test_load_json-checkpoint.ipynb +570 -0
- evoagentx/benchmark/README.md +178 -0
- evoagentx/benchmark/Untitled.ipynb +6 -0
- evoagentx/benchmark/WorfBench.py +155 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,17 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
examples/antibiotic_pred/EC_antibiotic.xlsx filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
examples/antibiotic_pred/ec_train filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
examples/antibiotic_pred/ec_train.json filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
examples/hotpotqa/test[[:space:]]structopt[[:space:]]toolcall.ipynb filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
examples/pertqa/.ipynb_checkpoints/test[[:space:]]structual[[:space:]]ourloop-withsearch-checkpoint.ipynb filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
examples/pertqa/k562_processed_grn.csv filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
examples/pertqa/pert_folder/EGRET_K562.csv filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
examples/pertqa/reploge_train.json filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
examples/pertqa/replogle_update_train.csv filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
examples/pertqa/replogle_update_train.json filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
examples/pertqa/test[[:space:]]structual[[:space:]]ourloop-withsearch.ipynb filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
examples/pubmedqa/pubmedqa_train.json filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
examples/workflow/invest/300750/20250815/graphs/300750_candlestick_chart.png filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
examples/workflow/invest/300750/20250815/graphs/300750_technical_charts.png filter=lfs diff=lfs merge=lfs -text
|
evoagentx/.ipynb_checkpoints/test aflow-checkpoint.ipynb
ADDED
|
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "e2d3caf8",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"import pandas as pd\n",
|
| 11 |
+
"import matplotlib.pyplot as plt\n",
|
| 12 |
+
"import pickle\n",
|
| 13 |
+
"import glob\n",
|
| 14 |
+
"import pandas as pd\n",
|
| 15 |
+
"import glob\n",
|
| 16 |
+
"from tqdm import tqdm\n",
|
| 17 |
+
"import base64\n",
|
| 18 |
+
"import requests\n",
|
| 19 |
+
"# OpenAI API Key\n",
|
| 20 |
+
"api_key = \"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\"\n",
|
| 21 |
+
"# Function to encode the image\n",
|
| 22 |
+
"import os\n",
|
| 23 |
+
"os.environ[\"OPENAI_API_KEY\"] = api_key\n"
|
| 24 |
+
]
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"cell_type": "code",
|
| 28 |
+
"execution_count": 2,
|
| 29 |
+
"id": "f870b639",
|
| 30 |
+
"metadata": {},
|
| 31 |
+
"outputs": [
|
| 32 |
+
{
|
| 33 |
+
"name": "stderr",
|
| 34 |
+
"output_type": "stream",
|
| 35 |
+
"text": [
|
| 36 |
+
"/gpfs/radev/home/tl688/.conda/envs/evoagentx/lib/python3.11/site-packages/PyPDF2/__init__.py:21: DeprecationWarning: PyPDF2 is deprecated. Please move to the pypdf library instead.\n",
|
| 37 |
+
" warnings.warn(\n"
|
| 38 |
+
]
|
| 39 |
+
}
|
| 40 |
+
],
|
| 41 |
+
"source": [
|
| 42 |
+
"import os\n",
|
| 43 |
+
"from dotenv import load_dotenv\n",
|
| 44 |
+
"from evoagentx.optimizers import AFlowOptimizer\n",
|
| 45 |
+
"from evoagentx.models import LiteLLMConfig, LiteLLM, OpenAILLMConfig, OpenAILLM\n",
|
| 46 |
+
"from evoagentx.benchmark import AFlowHumanEval\n",
|
| 47 |
+
"\n",
|
| 48 |
+
"# Load environment variables\n",
|
| 49 |
+
"load_dotenv()\n",
|
| 50 |
+
"OPENAI_API_KEY = os.getenv(\"OPENAI_API_KEY\")\n",
|
| 51 |
+
"# ANTHROPIC_API_KEY = os.getenv(\"ANTHROPIC_API_KEY\")"
|
| 52 |
+
]
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"cell_type": "code",
|
| 56 |
+
"execution_count": 3,
|
| 57 |
+
"id": "1f3dd892",
|
| 58 |
+
"metadata": {},
|
| 59 |
+
"outputs": [],
|
| 60 |
+
"source": [
|
| 61 |
+
"# # Configure the optimizer LLM (Claude 3.5 Sonnet)\n",
|
| 62 |
+
"# claude_config = LiteLLMConfig(\n",
|
| 63 |
+
"# model=\"anthropic/claude-3-5-sonnet-20240620\", \n",
|
| 64 |
+
"# anthropic_key=ANTHROPIC_API_KEY\n",
|
| 65 |
+
"# )\n",
|
| 66 |
+
"# optimizer_llm = LiteLLM(config=claude_config)\n",
|
| 67 |
+
"\n",
|
| 68 |
+
"# Configure the executor LLM (GPT-4o-mini)\n",
|
| 69 |
+
"openai_config = OpenAILLMConfig(\n",
|
| 70 |
+
" model=\"gpt-4o-mini\", \n",
|
| 71 |
+
" openai_key=OPENAI_API_KEY\n",
|
| 72 |
+
")\n",
|
| 73 |
+
"\n",
|
| 74 |
+
"claude_config = LiteLLMConfig(\n",
|
| 75 |
+
" model=\"gpt-4o-mini\", \n",
|
| 76 |
+
" openai_key=OPENAI_API_KEY\n",
|
| 77 |
+
")\n",
|
| 78 |
+
"executor_llm = OpenAILLM(config=openai_config)\n",
|
| 79 |
+
"optimizer_llm = LiteLLM(config=claude_config)"
|
| 80 |
+
]
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"cell_type": "code",
|
| 84 |
+
"execution_count": 4,
|
| 85 |
+
"id": "a87feb08",
|
| 86 |
+
"metadata": {},
|
| 87 |
+
"outputs": [],
|
| 88 |
+
"source": [
|
| 89 |
+
"EXPERIMENTAL_CONFIG = {\n",
|
| 90 |
+
" \"humaneval\": {\n",
|
| 91 |
+
" \"question_type\": \"code\", \n",
|
| 92 |
+
" \"operators\": [\"Custom\", \"CustomCodeGenerate\", \"Test\", \"ScEnsemble\"] \n",
|
| 93 |
+
" }, \n",
|
| 94 |
+
" \"mbpp\": {\n",
|
| 95 |
+
" \"question_type\": \"code\", \n",
|
| 96 |
+
" \"operators\": [\"Custom\", \"CustomCodeGenerate\", \"Test\", \"ScEnsemble\"] \n",
|
| 97 |
+
" },\n",
|
| 98 |
+
" \"hotpotqa\": {\n",
|
| 99 |
+
" \"question_type\": \"qa\", \n",
|
| 100 |
+
" \"operators\": [\"Custom\", \"AnswerGenerate\", \"QAScEnsemble\"]\n",
|
| 101 |
+
" },\n",
|
| 102 |
+
" \"gsm8k\": {\n",
|
| 103 |
+
" \"question_type\": \"math\", \n",
|
| 104 |
+
" \"operators\": [\"Custom\", \"ScEnsemble\", \"Programmer\"]\n",
|
| 105 |
+
" },\n",
|
| 106 |
+
" \"math\": {\n",
|
| 107 |
+
" \"question_type\": \"math\", \n",
|
| 108 |
+
" \"operators\": [\"Custom\", \"ScEnsemble\", \"Programmer\"]\n",
|
| 109 |
+
" }\n",
|
| 110 |
+
"}"
|
| 111 |
+
]
|
| 112 |
+
},
|
| 113 |
+
{
|
| 114 |
+
"cell_type": "code",
|
| 115 |
+
"execution_count": 5,
|
| 116 |
+
"id": "b6054068",
|
| 117 |
+
"metadata": {},
|
| 118 |
+
"outputs": [],
|
| 119 |
+
"source": [
|
| 120 |
+
"import evoagentx.workflow.operators as operator\n",
|
| 121 |
+
"import examples.aflow.code_generation.prompt as prompt_custom # noqa: F401\n",
|
| 122 |
+
"from evoagentx.models.model_configs import LLMConfig\n",
|
| 123 |
+
"from evoagentx.benchmark.benchmark import Benchmark\n",
|
| 124 |
+
"from evoagentx.models.model_utils import create_llm_instance\n",
|
| 125 |
+
"\n",
|
| 126 |
+
"class Workflow:\n",
|
| 127 |
+
" \n",
|
| 128 |
+
" def __init__(\n",
|
| 129 |
+
" self,\n",
|
| 130 |
+
" name: str,\n",
|
| 131 |
+
" llm_config: LLMConfig,\n",
|
| 132 |
+
" benchmark: Benchmark\n",
|
| 133 |
+
" ):\n",
|
| 134 |
+
" self.name = name\n",
|
| 135 |
+
" self.llm = create_llm_instance(llm_config)\n",
|
| 136 |
+
" self.benchmark = benchmark \n",
|
| 137 |
+
" self.custom = operator.Custom(self.llm)\n",
|
| 138 |
+
" self.custom_code_generate = operator.CustomCodeGenerate(self.llm)\n",
|
| 139 |
+
"\n",
|
| 140 |
+
" async def __call__(self, problem: str, entry_point: str):\n",
|
| 141 |
+
" \"\"\"\n",
|
| 142 |
+
" Implementation of the workflow\n",
|
| 143 |
+
" Custom operator to generate anything you want.\n",
|
| 144 |
+
" But when you want to get standard code, you should use custom_code_generate operator.\n",
|
| 145 |
+
" \"\"\"\n",
|
| 146 |
+
" # await self.custom(input=, instruction=\"\")\n",
|
| 147 |
+
" solution = await self.custom_code_generate(problem=problem, entry_point=entry_point, instruction=prompt_custom.GENERATE_PYTHON_CODE_PROMPT) # But When you want to get standard code ,you should use customcodegenerator.\n",
|
| 148 |
+
" return solution['response']"
|
| 149 |
+
]
|
| 150 |
+
},
|
| 151 |
+
{
|
| 152 |
+
"cell_type": "code",
|
| 153 |
+
"execution_count": 6,
|
| 154 |
+
"id": "27e574ad",
|
| 155 |
+
"metadata": {},
|
| 156 |
+
"outputs": [
|
| 157 |
+
{
|
| 158 |
+
"name": "stderr",
|
| 159 |
+
"output_type": "stream",
|
| 160 |
+
"text": [
|
| 161 |
+
"\u001b[32m2025-10-12 15:04:04.523\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.benchmark.humaneval\u001b[0m:\u001b[36m_load_data\u001b[0m:\u001b[36m182\u001b[0m - \u001b[1mLoading train data from None\u001b[0m\n",
|
| 162 |
+
"\u001b[32m2025-10-12 15:04:04.524\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.benchmark.humaneval\u001b[0m:\u001b[36m_load_data\u001b[0m:\u001b[36m185\u001b[0m - \u001b[1mLoading dev data from humaneval_validate.jsonl\u001b[0m\n",
|
| 163 |
+
"\u001b[32m2025-10-12 15:04:04.525\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.benchmark.humaneval\u001b[0m:\u001b[36m_load_data\u001b[0m:\u001b[36m188\u001b[0m - \u001b[1mLoading test data from humaneval_test.jsonl\u001b[0m\n"
|
| 164 |
+
]
|
| 165 |
+
}
|
| 166 |
+
],
|
| 167 |
+
"source": [
|
| 168 |
+
"# Initialize the benchmark\n",
|
| 169 |
+
"humaneval = AFlowHumanEval()"
|
| 170 |
+
]
|
| 171 |
+
},
|
| 172 |
+
{
|
| 173 |
+
"cell_type": "code",
|
| 174 |
+
"execution_count": 8,
|
| 175 |
+
"id": "2f8da181",
|
| 176 |
+
"metadata": {},
|
| 177 |
+
"outputs": [],
|
| 178 |
+
"source": [
|
| 179 |
+
"optimizer = AFlowOptimizer(\n",
|
| 180 |
+
" graph_path=\"../examples/aflow/code_generation\", # Path to the initial workflow graph\n",
|
| 181 |
+
" optimized_path=\"../examples/aflow/humaneval/optimized\", # Path to save optimized workflows\n",
|
| 182 |
+
" optimizer_llm=optimizer_llm, # LLM for optimization\n",
|
| 183 |
+
" executor_llm=executor_llm, # LLM for execution\n",
|
| 184 |
+
" validation_rounds=3, # Number of times to run validation on the development set during optimization\n",
|
| 185 |
+
" eval_rounds=3, # Number of times to run evaluation on the test set during testing\n",
|
| 186 |
+
" max_rounds=20, # Maximum optimization rounds\n",
|
| 187 |
+
" **EXPERIMENTAL_CONFIG[\"humaneval\"] # Task-specific configuration, used to specify the task type and available operators\n",
|
| 188 |
+
")"
|
| 189 |
+
]
|
| 190 |
+
},
|
| 191 |
+
{
|
| 192 |
+
"cell_type": "code",
|
| 193 |
+
"execution_count": 9,
|
| 194 |
+
"id": "74937699",
|
| 195 |
+
"metadata": {},
|
| 196 |
+
"outputs": [],
|
| 197 |
+
"source": [
|
| 198 |
+
"import nest_asyncio\n",
|
| 199 |
+
"nest_asyncio.apply()"
|
| 200 |
+
]
|
| 201 |
+
},
|
| 202 |
+
{
|
| 203 |
+
"cell_type": "code",
|
| 204 |
+
"execution_count": null,
|
| 205 |
+
"id": "98ac4a63",
|
| 206 |
+
"metadata": {
|
| 207 |
+
"scrolled": true
|
| 208 |
+
},
|
| 209 |
+
"outputs": [
|
| 210 |
+
{
|
| 211 |
+
"name": "stderr",
|
| 212 |
+
"output_type": "stream",
|
| 213 |
+
"text": [
|
| 214 |
+
"\u001b[32m2025-10-12 15:04:50.304\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.utils.aflow_utils.graph_utils\u001b[0m:\u001b[36mload_graph\u001b[0m:\u001b[36m51\u001b[0m - \u001b[1mError loading graph for round 0: No module named '.'\u001b[0m\n",
|
| 215 |
+
"\u001b[32m2025-10-12 15:04:50.305\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.optimizers.aflow_optimizer\u001b[0m:\u001b[36m_execute_with_retry\u001b[0m:\u001b[36m147\u001b[0m - \u001b[1mError occurred: No module named '.'. Retrying... (Attempt 1/3)\u001b[0m\n",
|
| 216 |
+
"\u001b[32m2025-10-12 15:04:55.310\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.utils.aflow_utils.graph_utils\u001b[0m:\u001b[36mload_graph\u001b[0m:\u001b[36m51\u001b[0m - \u001b[1mError loading graph for round 0: No module named '.'\u001b[0m\n",
|
| 217 |
+
"\u001b[32m2025-10-12 15:04:55.311\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.optimizers.aflow_optimizer\u001b[0m:\u001b[36m_execute_with_retry\u001b[0m:\u001b[36m147\u001b[0m - \u001b[1mError occurred: No module named '.'. Retrying... (Attempt 2/3)\u001b[0m\n",
|
| 218 |
+
"\u001b[32m2025-10-12 15:05:05.322\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.utils.aflow_utils.graph_utils\u001b[0m:\u001b[36mload_graph\u001b[0m:\u001b[36m51\u001b[0m - \u001b[1mError loading graph for round 0: No module named '.'\u001b[0m\n",
|
| 219 |
+
"\u001b[32m2025-10-12 15:05:05.322\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.optimizers.aflow_optimizer\u001b[0m:\u001b[36m_execute_with_retry\u001b[0m:\u001b[36m147\u001b[0m - \u001b[1mError occurred: No module named '.'. Retrying... (Attempt 3/3)\u001b[0m\n",
|
| 220 |
+
"\u001b[32m2025-10-12 15:05:05.322\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.optimizers.aflow_optimizer\u001b[0m:\u001b[36m_execute_with_retry\u001b[0m:\u001b[36m149\u001b[0m - \u001b[1mMax retries reached.\u001b[0m\n",
|
| 221 |
+
"\u001b[32m2025-10-12 15:05:05.323\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.optimizers.aflow_optimizer\u001b[0m:\u001b[36moptimize\u001b[0m:\u001b[36m112\u001b[0m - \u001b[1mScore for round 1: None\u001b[0m\n",
|
| 222 |
+
"\u001b[32m2025-10-12 15:05:05.326\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.optimizers.aflow_optimizer\u001b[0m:\u001b[36m_execute_with_retry\u001b[0m:\u001b[36m147\u001b[0m - \u001b[1mError occurred: 'round'. Retrying... (Attempt 1/3)\u001b[0m\n",
|
| 223 |
+
"\u001b[32m2025-10-12 15:05:10.332\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.optimizers.aflow_optimizer\u001b[0m:\u001b[36m_execute_with_retry\u001b[0m:\u001b[36m147\u001b[0m - \u001b[1mError occurred: 'round'. Retrying... (Attempt 2/3)\u001b[0m\n",
|
| 224 |
+
"\u001b[32m2025-10-12 15:05:20.344\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.optimizers.aflow_optimizer\u001b[0m:\u001b[36m_execute_with_retry\u001b[0m:\u001b[36m147\u001b[0m - \u001b[1mError occurred: 'round'. Retrying... (Attempt 3/3)\u001b[0m\n",
|
| 225 |
+
"\u001b[32m2025-10-12 15:05:20.344\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.optimizers.aflow_optimizer\u001b[0m:\u001b[36m_execute_with_retry\u001b[0m:\u001b[36m149\u001b[0m - \u001b[1mMax retries reached.\u001b[0m\n",
|
| 226 |
+
"\u001b[32m2025-10-12 15:05:20.345\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.optimizers.aflow_optimizer\u001b[0m:\u001b[36moptimize\u001b[0m:\u001b[36m112\u001b[0m - \u001b[1mScore for round 2: None\u001b[0m\n",
|
| 227 |
+
"\u001b[32m2025-10-12 15:05:20.347\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mevoagentx.optimizers.aflow_optimizer\u001b[0m:\u001b[36m_execute_with_retry\u001b[0m:\u001b[36m147\u001b[0m - \u001b[1mError occurred: 'round'. Retrying... (Attempt 1/3)\u001b[0m\n"
|
| 228 |
+
]
|
| 229 |
+
}
|
| 230 |
+
],
|
| 231 |
+
"source": [
|
| 232 |
+
"# Optimize the workflow\n",
|
| 233 |
+
"optimizer.optimize(humaneval)"
|
| 234 |
+
]
|
| 235 |
+
},
|
| 236 |
+
{
|
| 237 |
+
"cell_type": "code",
|
| 238 |
+
"execution_count": null,
|
| 239 |
+
"id": "1010d583",
|
| 240 |
+
"metadata": {
|
| 241 |
+
"scrolled": true
|
| 242 |
+
},
|
| 243 |
+
"outputs": [],
|
| 244 |
+
"source": [
|
| 245 |
+
"optimizer.test(humaneval)"
|
| 246 |
+
]
|
| 247 |
+
},
|
| 248 |
+
{
|
| 249 |
+
"cell_type": "code",
|
| 250 |
+
"execution_count": null,
|
| 251 |
+
"id": "becb5a82",
|
| 252 |
+
"metadata": {},
|
| 253 |
+
"outputs": [],
|
| 254 |
+
"source": [
|
| 255 |
+
"import pandas as pd"
|
| 256 |
+
]
|
| 257 |
+
},
|
| 258 |
+
{
|
| 259 |
+
"cell_type": "code",
|
| 260 |
+
"execution_count": 16,
|
| 261 |
+
"id": "5c076d29",
|
| 262 |
+
"metadata": {},
|
| 263 |
+
"outputs": [],
|
| 264 |
+
"source": [
|
| 265 |
+
"df = pd.read_json(\"/home/tl688/pitl688/selfevolve/AFlow/data/datasets/scicode_dev.jsonl\", lines=True)"
|
| 266 |
+
]
|
| 267 |
+
},
|
| 268 |
+
{
|
| 269 |
+
"cell_type": "code",
|
| 270 |
+
"execution_count": 23,
|
| 271 |
+
"id": "481602a9",
|
| 272 |
+
"metadata": {},
|
| 273 |
+
"outputs": [
|
| 274 |
+
{
|
| 275 |
+
"data": {
|
| 276 |
+
"text/plain": [
|
| 277 |
+
"'def get_alpha(recvec, alpha_scaling=5):\\n \"\"\"\\n Calculate the alpha value for the Ewald summation, scaled by a specified factor.\\n Parameters:\\n recvec (np.ndarray): A 3x3 array representing the reciprocal lattice vectors.\\n alpha_scaling (float): A scaling factor applied to the alpha value. Default is 5.\\n Returns:\\n float: The calculated alpha value.\\n \"\"\"\\n alpha = alpha_scaling * np.max(np.linalg.norm(recvec, axis=1))\\n return alpha'"
|
| 278 |
+
]
|
| 279 |
+
},
|
| 280 |
+
"execution_count": 23,
|
| 281 |
+
"metadata": {},
|
| 282 |
+
"output_type": "execute_result"
|
| 283 |
+
}
|
| 284 |
+
],
|
| 285 |
+
"source": [
|
| 286 |
+
"df['ground_truth_code'].values[0]"
|
| 287 |
+
]
|
| 288 |
+
},
|
| 289 |
+
{
|
| 290 |
+
"cell_type": "code",
|
| 291 |
+
"execution_count": 21,
|
| 292 |
+
"id": "ffb0be7e",
|
| 293 |
+
"metadata": {},
|
| 294 |
+
"outputs": [
|
| 295 |
+
{
|
| 296 |
+
"data": {
|
| 297 |
+
"text/plain": [
|
| 298 |
+
"\"def get_alpha(recvec, alpha_scaling=5):\\n '''Calculate the alpha value for the Ewald summation, scaled by a specified factor.\\n Parameters:\\n recvec (np.ndarray): A 3x3 array representing the reciprocal lattice vectors.\\n alpha_scaling (float): A scaling factor applied to the alpha value. Default is 5.\\n Returns:\\n float: The calculated alpha value.\\n '''\""
|
| 299 |
+
]
|
| 300 |
+
},
|
| 301 |
+
"execution_count": 21,
|
| 302 |
+
"metadata": {},
|
| 303 |
+
"output_type": "execute_result"
|
| 304 |
+
}
|
| 305 |
+
],
|
| 306 |
+
"source": [
|
| 307 |
+
"df['function_header'].values[0]"
|
| 308 |
+
]
|
| 309 |
+
},
|
| 310 |
+
{
|
| 311 |
+
"cell_type": "code",
|
| 312 |
+
"execution_count": 24,
|
| 313 |
+
"id": "69acf613",
|
| 314 |
+
"metadata": {},
|
| 315 |
+
"outputs": [
|
| 316 |
+
{
|
| 317 |
+
"data": {
|
| 318 |
+
"text/plain": [
|
| 319 |
+
"'import numpy as np\\nfrom scipy.special import erfc'"
|
| 320 |
+
]
|
| 321 |
+
},
|
| 322 |
+
"execution_count": 24,
|
| 323 |
+
"metadata": {},
|
| 324 |
+
"output_type": "execute_result"
|
| 325 |
+
}
|
| 326 |
+
],
|
| 327 |
+
"source": [
|
| 328 |
+
"df['required_dependencies'].values[0]"
|
| 329 |
+
]
|
| 330 |
+
},
|
| 331 |
+
{
|
| 332 |
+
"cell_type": "code",
|
| 333 |
+
"execution_count": 25,
|
| 334 |
+
"id": "b5696e0e",
|
| 335 |
+
"metadata": {},
|
| 336 |
+
"outputs": [
|
| 337 |
+
{
|
| 338 |
+
"data": {
|
| 339 |
+
"text/plain": [
|
| 340 |
+
"Index(['step_number', 'step_description_prompt', 'step_background',\n",
|
| 341 |
+
" 'ground_truth_code', 'function_header', 'test_cases', 'return_line',\n",
|
| 342 |
+
" 'required_dependencies'],\n",
|
| 343 |
+
" dtype='object')"
|
| 344 |
+
]
|
| 345 |
+
},
|
| 346 |
+
"execution_count": 25,
|
| 347 |
+
"metadata": {},
|
| 348 |
+
"output_type": "execute_result"
|
| 349 |
+
}
|
| 350 |
+
],
|
| 351 |
+
"source": [
|
| 352 |
+
"df.columns"
|
| 353 |
+
]
|
| 354 |
+
},
|
| 355 |
+
{
|
| 356 |
+
"cell_type": "code",
|
| 357 |
+
"execution_count": 27,
|
| 358 |
+
"id": "0a3085a9",
|
| 359 |
+
"metadata": {},
|
| 360 |
+
"outputs": [
|
| 361 |
+
{
|
| 362 |
+
"data": {
|
| 363 |
+
"text/plain": [
|
| 364 |
+
"\"def get_alpha(recvec, alpha_scaling=5):\\n '''Calculate the alpha value for the Ewald summation, scaled by a specified factor.\\n Parameters:\\n recvec (np.ndarray): A 3x3 array representing the reciprocal lattice vectors.\\n alpha_scaling (float): A scaling factor applied to the alpha value. Default is 5.\\n Returns:\\n float: The calculated alpha value.\\n '''\""
|
| 365 |
+
]
|
| 366 |
+
},
|
| 367 |
+
"execution_count": 27,
|
| 368 |
+
"metadata": {},
|
| 369 |
+
"output_type": "execute_result"
|
| 370 |
+
}
|
| 371 |
+
],
|
| 372 |
+
"source": [
|
| 373 |
+
"df['function_header'].values[0]"
|
| 374 |
+
]
|
| 375 |
+
},
|
| 376 |
+
{
|
| 377 |
+
"cell_type": "code",
|
| 378 |
+
"execution_count": 28,
|
| 379 |
+
"id": "e6a76c86",
|
| 380 |
+
"metadata": {},
|
| 381 |
+
"outputs": [
|
| 382 |
+
{
|
| 383 |
+
"data": {
|
| 384 |
+
"text/plain": [
|
| 385 |
+
"\"ref1 = -1.74756\\nEX1 = {\\n 'latvec': np.array([\\n [0.0, 1.0, 1.0],\\n [1.0, 0.0, 1.0],\\n [1.0, 1.0, 0.0]\\n ]),\\n 'atom_charges': np.array([1]),\\n 'atom_coords': np.array([\\n [0.0, 0.0, 0.0]\\n ]),\\n 'configs': np.array([\\n [1.0, 1.0, 1.0]\\n ]),\\n}\\nassert np.allclose(get_alpha(np.linalg.inv(EX1['latvec']).T), target)\""
|
| 386 |
+
]
|
| 387 |
+
},
|
| 388 |
+
"execution_count": 28,
|
| 389 |
+
"metadata": {},
|
| 390 |
+
"output_type": "execute_result"
|
| 391 |
+
}
|
| 392 |
+
],
|
| 393 |
+
"source": [
|
| 394 |
+
"df['test_cases'].values[0]"
|
| 395 |
+
]
|
| 396 |
+
},
|
| 397 |
+
{
|
| 398 |
+
"cell_type": "code",
|
| 399 |
+
"execution_count": null,
|
| 400 |
+
"id": "99775141",
|
| 401 |
+
"metadata": {},
|
| 402 |
+
"outputs": [],
|
| 403 |
+
"source": []
|
| 404 |
+
}
|
| 405 |
+
],
|
| 406 |
+
"metadata": {
|
| 407 |
+
"kernelspec": {
|
| 408 |
+
"display_name": "Python 3 (ipykernel)",
|
| 409 |
+
"language": "python",
|
| 410 |
+
"name": "python3"
|
| 411 |
+
},
|
| 412 |
+
"language_info": {
|
| 413 |
+
"codemirror_mode": {
|
| 414 |
+
"name": "ipython",
|
| 415 |
+
"version": 3
|
| 416 |
+
},
|
| 417 |
+
"file_extension": ".py",
|
| 418 |
+
"mimetype": "text/x-python",
|
| 419 |
+
"name": "python",
|
| 420 |
+
"nbconvert_exporter": "python",
|
| 421 |
+
"pygments_lexer": "ipython3",
|
| 422 |
+
"version": "3.11.13"
|
| 423 |
+
}
|
| 424 |
+
},
|
| 425 |
+
"nbformat": 4,
|
| 426 |
+
"nbformat_minor": 5
|
| 427 |
+
}
|
evoagentx/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
__version__ = '0.1.0'
|
evoagentx/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (201 Bytes). View file
|
|
|
evoagentx/actions/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .action import Action, ActionInput, ActionOutput
|
| 2 |
+
from .code_verification import CodeVerification
|
| 3 |
+
from .code_extraction import CodeExtraction
|
| 4 |
+
|
| 5 |
+
__all__ = ["Action", "ActionInput", "ActionOutput", "CodeVerification", "CodeExtraction"]
|
evoagentx/actions/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (494 Bytes). View file
|
|
|
evoagentx/actions/__pycache__/action.cpython-311.pyc
ADDED
|
Binary file (15.3 kB). View file
|
|
|
evoagentx/actions/__pycache__/agent_generation.cpython-311.pyc
ADDED
|
Binary file (15.1 kB). View file
|
|
|
evoagentx/actions/__pycache__/code_extraction.cpython-311.pyc
ADDED
|
Binary file (14.6 kB). View file
|
|
|
evoagentx/actions/__pycache__/code_verification.cpython-311.pyc
ADDED
|
Binary file (6.33 kB). View file
|
|
|
evoagentx/actions/__pycache__/customize_action.cpython-311.pyc
ADDED
|
Binary file (31.5 kB). View file
|
|
|
evoagentx/actions/__pycache__/task_planning.cpython-311.pyc
ADDED
|
Binary file (5.56 kB). View file
|
|
|
evoagentx/actions/action.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pydantic import model_validator
|
| 3 |
+
from pydantic_core import PydanticUndefined
|
| 4 |
+
from typing import Optional, Type, Tuple, Union, List, Any
|
| 5 |
+
|
| 6 |
+
from ..core.module import BaseModule
|
| 7 |
+
from ..core.module_utils import get_type_name
|
| 8 |
+
from ..core.registry import MODULE_REGISTRY
|
| 9 |
+
# from ..core.base_config import Parameter
|
| 10 |
+
from ..core.parser import Parser
|
| 11 |
+
from ..core.message import Message
|
| 12 |
+
from ..models.base_model import BaseLLM, LLMOutputParser
|
| 13 |
+
from ..tools.tool import Toolkit
|
| 14 |
+
from ..prompts.context_extraction import CONTEXT_EXTRACTION
|
| 15 |
+
from ..prompts.template import PromptTemplate
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ActionInput(LLMOutputParser):
|
| 19 |
+
"""Input specification and parsing for actions.
|
| 20 |
+
|
| 21 |
+
This class defines the input requirements for actions and provides methods
|
| 22 |
+
to generate structured input specifications. It inherits from LLMOutputParser
|
| 23 |
+
to allow parsing of LLM outputs into structured inputs for actions.
|
| 24 |
+
|
| 25 |
+
Notes:
|
| 26 |
+
Parameters in ActionInput should be defined in Pydantic Field format.
|
| 27 |
+
For optional variables, use format:
|
| 28 |
+
var: Optional[int] = Field(default=None, description="xxx")
|
| 29 |
+
Remember to add `default=None` for optional parameters.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
@classmethod
|
| 33 |
+
def get_input_specification(cls, ignore_fields: List[str] = []) -> str:
|
| 34 |
+
"""Generate a JSON specification of the input requirements.
|
| 35 |
+
|
| 36 |
+
Examines the class fields and produces a structured specification of
|
| 37 |
+
the input parameters, including their types, descriptions, and whether
|
| 38 |
+
they are required.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
ignore_fields (List[str]): List of field names to exclude from the specification.
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
A JSON string containing the input specification, or an empty string
|
| 45 |
+
if no fields are defined or all are ignored.
|
| 46 |
+
"""
|
| 47 |
+
fields_info = {}
|
| 48 |
+
attrs = cls.get_attrs()
|
| 49 |
+
for field_name, field_info in cls.model_fields.items():
|
| 50 |
+
if field_name in ignore_fields:
|
| 51 |
+
continue
|
| 52 |
+
if field_name not in attrs:
|
| 53 |
+
continue
|
| 54 |
+
field_type = get_type_name(field_info.annotation)
|
| 55 |
+
field_desc = field_info.description if field_info.description is not None else None
|
| 56 |
+
# field_required = field_info.is_required()
|
| 57 |
+
field_default = str(field_info.default) if field_info.default is not PydanticUndefined else None
|
| 58 |
+
field_required = True if field_default is None else False
|
| 59 |
+
description = field_type + ", "
|
| 60 |
+
if field_desc is not None:
|
| 61 |
+
description += (field_desc.strip() + ", ")
|
| 62 |
+
description += ("required" if field_required else "optional")
|
| 63 |
+
if field_default is not None:
|
| 64 |
+
description += (", Default value: " + field_default)
|
| 65 |
+
fields_info[field_name] = description
|
| 66 |
+
|
| 67 |
+
if len(fields_info) == 0:
|
| 68 |
+
return ""
|
| 69 |
+
fields_info_str = json.dumps(fields_info, indent=4)
|
| 70 |
+
return fields_info_str
|
| 71 |
+
|
| 72 |
+
@classmethod
|
| 73 |
+
def get_required_input_names(cls) -> List[str]:
|
| 74 |
+
"""Get a list of all required input parameter names.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
List[str]: Names of all parameters that are required (don't have default values).
|
| 78 |
+
"""
|
| 79 |
+
required_fields = []
|
| 80 |
+
attrs = cls.get_attrs()
|
| 81 |
+
for field_name, field_info in cls.model_fields.items():
|
| 82 |
+
if field_name not in attrs:
|
| 83 |
+
continue
|
| 84 |
+
field_default = field_info.default
|
| 85 |
+
# A field is required if it doesn't have a default value
|
| 86 |
+
if field_default is PydanticUndefined:
|
| 87 |
+
required_fields.append(field_name)
|
| 88 |
+
return required_fields
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class ActionOutput(LLMOutputParser):
|
| 92 |
+
"""Output representation for actions.
|
| 93 |
+
|
| 94 |
+
This class handles the structured output of actions, providing methods
|
| 95 |
+
to convert the output to structured data. It inherits from LLMOutputParser
|
| 96 |
+
to support parsing of LLM outputs into structured action results.
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
def to_str(self) -> str:
|
| 100 |
+
"""Convert the output to a formatted JSON string.
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
A pretty-printed JSON string representation of the structured data.
|
| 104 |
+
"""
|
| 105 |
+
return json.dumps(self.get_structured_data(), indent=4)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class Action(BaseModule):
|
| 109 |
+
"""Base class for all actions in the EvoAgentX framework.
|
| 110 |
+
|
| 111 |
+
Actions represent discrete operations that can be performed by agents.
|
| 112 |
+
They define inputs, outputs, and execution behavior, and can optionally
|
| 113 |
+
use tools to accomplish their tasks.
|
| 114 |
+
|
| 115 |
+
Attributes:
|
| 116 |
+
name (str): Unique identifier for the action.
|
| 117 |
+
description (str): Human-readable description of what the action does.
|
| 118 |
+
prompt (Optional[str]): Optional prompt template for this action.
|
| 119 |
+
tools (Optional[List[Toolkit]]): Optional list of tools that can be used by this action.
|
| 120 |
+
inputs_format (Optional[Type[ActionInput]]): Optional class defining the expected input structure.
|
| 121 |
+
outputs_format (Optional[Type[Parser]]): Optional class defining the expected output structure.
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
name: str
|
| 125 |
+
description: str
|
| 126 |
+
prompt: Optional[str] = None
|
| 127 |
+
prompt_template: Optional[PromptTemplate] = None
|
| 128 |
+
tools: Optional[List[Toolkit]] = None # specify the possible tool for the action
|
| 129 |
+
inputs_format: Optional[Type[ActionInput]] = None # specify the input format of the action
|
| 130 |
+
outputs_format: Optional[Type[Parser]] = None # specify the possible structured output format
|
| 131 |
+
|
| 132 |
+
def init_module(self):
|
| 133 |
+
"""Initialize the action module.
|
| 134 |
+
|
| 135 |
+
This method is called after the action is instantiated.
|
| 136 |
+
Subclasses can override this to perform custom initialization.
|
| 137 |
+
"""
|
| 138 |
+
pass
|
| 139 |
+
|
| 140 |
+
def to_dict(self, exclude_none: bool = True, ignore: List[str] = [], **kwargs) -> dict:
|
| 141 |
+
"""
|
| 142 |
+
Convert the action to a dictionary for saving.
|
| 143 |
+
"""
|
| 144 |
+
data = super().to_dict(exclude_none=exclude_none, ignore=ignore, **kwargs)
|
| 145 |
+
if self.inputs_format:
|
| 146 |
+
data["inputs_format"] = self.inputs_format.__name__
|
| 147 |
+
if self.outputs_format:
|
| 148 |
+
data["outputs_format"] = self.outputs_format.__name__
|
| 149 |
+
# TODO: customize serialization for the tools
|
| 150 |
+
return data
|
| 151 |
+
|
| 152 |
+
@model_validator(mode="before")
|
| 153 |
+
@classmethod
|
| 154 |
+
def validate_data(cls, data: Any) -> Any:
|
| 155 |
+
if "inputs_format" in data and data["inputs_format"] and isinstance(data["inputs_format"], str):
|
| 156 |
+
# only used when loading from a file
|
| 157 |
+
data["inputs_format"] = MODULE_REGISTRY.get_module(data["inputs_format"])
|
| 158 |
+
if "outputs_format" in data and data["outputs_format"] and isinstance(data["outputs_format"], str):
|
| 159 |
+
# only used when loading from a file
|
| 160 |
+
data["outputs_format"] = MODULE_REGISTRY.get_module(data["outputs_format"])
|
| 161 |
+
# TODO: customize loading for the tools
|
| 162 |
+
return data
|
| 163 |
+
|
| 164 |
+
def execute(self, llm: Optional[BaseLLM] = None, inputs: Optional[dict] = None, sys_msg: Optional[str]=None, return_prompt: bool = False, **kwargs) -> Optional[Union[Parser, Tuple[Parser, str]]]:
|
| 165 |
+
"""Execute the action to produce a result.
|
| 166 |
+
|
| 167 |
+
This is the main entry point for executing an action. Subclasses must
|
| 168 |
+
implement this method to define the action's behavior.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
llm (Optional[BaseLLM]): The LLM used to execute the action.
|
| 172 |
+
inputs (Optional[dict]): Input data for the action execution. The input data should be a dictionary that matches the input format of the provided prompt.
|
| 173 |
+
For example, if the prompt contains a variable `{input_var}`, the `inputs` dictionary should have a key `input_var`, otherwise the variable will be set to empty string.
|
| 174 |
+
sys_msg (Optional[str]): Optional system message for the LLM.
|
| 175 |
+
return_prompt (bool): Whether to return the complete prompt passed to the LLM.
|
| 176 |
+
**kwargs (Any): Additional keyword arguments for the execution.
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
If `return_prompt` is False, the method returns a Parser object containing the structured result of the action.
|
| 180 |
+
If `return_prompt` is True, the method returns a tuple containing the Parser object and the complete prompt passed to the LLM.
|
| 181 |
+
"""
|
| 182 |
+
raise NotImplementedError(f"`execute` function of {type(self).__name__} is not implemented!")
|
| 183 |
+
|
| 184 |
+
async def async_execute(self, llm: Optional[BaseLLM] = None, inputs: Optional[dict] = None, sys_msg: Optional[str]=None, return_prompt: bool = False, **kwargs) -> Optional[Union[Parser, Tuple[Parser, str]]]:
|
| 185 |
+
"""
|
| 186 |
+
Asynchronous execution of the action.
|
| 187 |
+
|
| 188 |
+
This method is the asynchronous counterpart of the `execute` method.
|
| 189 |
+
It allows the action to be executed asynchronously using an LLM.
|
| 190 |
+
"""
|
| 191 |
+
raise NotImplementedError(f"`async_execute` function of {type(self).__name__} is not implemented!")
|
| 192 |
+
|
| 193 |
+
class ContextExtraction(Action):
|
| 194 |
+
"""Action for extracting structured inputs from context.
|
| 195 |
+
|
| 196 |
+
This action analyzes a conversation context to extract relevant information
|
| 197 |
+
that can be used as inputs for other actions. It uses the LLM to interpret
|
| 198 |
+
unstructured contextual information and format it according to the target
|
| 199 |
+
action's input requirements.
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
def __init__(self, **kwargs):
|
| 203 |
+
name = kwargs.pop("name") if "name" in kwargs else CONTEXT_EXTRACTION["name"]
|
| 204 |
+
description = kwargs.pop("description") if "description" in kwargs else CONTEXT_EXTRACTION["description"]
|
| 205 |
+
super().__init__(name=name, description=description, **kwargs)
|
| 206 |
+
|
| 207 |
+
def get_context_from_messages(self, messages: List[Message]) -> str:
|
| 208 |
+
str_context = "\n\n".join([str(msg) for msg in messages])
|
| 209 |
+
return str_context
|
| 210 |
+
|
| 211 |
+
def execute(self, llm: Optional[BaseLLM] = None, action: Action = None, context: List[Message] = None, **kwargs) -> Union[dict, None]:
|
| 212 |
+
"""Extract structured inputs for an action from conversation context.
|
| 213 |
+
|
| 214 |
+
This method uses the LLM to analyze the conversation context and extract
|
| 215 |
+
information that matches the input requirements of the target action.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
llm: The language model to use for extraction.
|
| 219 |
+
action: The target action whose input requirements (`inputs_format`) define what to extract.
|
| 220 |
+
context: List of messages providing the conversation context.
|
| 221 |
+
**kwargs: Additional keyword arguments.
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
A dictionary containing the extracted inputs for the target action,
|
| 225 |
+
or None if extraction is not possible (e.g., if the action doesn't
|
| 226 |
+
require inputs or if context is missing).
|
| 227 |
+
"""
|
| 228 |
+
if action is None or context is None:
|
| 229 |
+
return None
|
| 230 |
+
|
| 231 |
+
action_inputs_cls: Type[ActionInput] = action.inputs_format
|
| 232 |
+
if action_inputs_cls is None:
|
| 233 |
+
# the action does not require inputs
|
| 234 |
+
return None
|
| 235 |
+
|
| 236 |
+
action_inputs_desc = action_inputs_cls.get_input_specification()
|
| 237 |
+
str_context = self.get_context_from_messages(messages=context)
|
| 238 |
+
|
| 239 |
+
if not action_inputs_desc or not str_context:
|
| 240 |
+
return None
|
| 241 |
+
|
| 242 |
+
prompt = CONTEXT_EXTRACTION["prompt"].format(
|
| 243 |
+
context=str_context,
|
| 244 |
+
action_name=action.name,
|
| 245 |
+
action_description=action.description,
|
| 246 |
+
action_inputs=action_inputs_desc
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
action_inputs = llm.generate(
|
| 250 |
+
prompt=prompt,
|
| 251 |
+
system_message=CONTEXT_EXTRACTION["system_prompt"],
|
| 252 |
+
parser=action_inputs_cls
|
| 253 |
+
)
|
| 254 |
+
action_inputs_data = action_inputs.get_structured_data()
|
| 255 |
+
|
| 256 |
+
return action_inputs_data
|
evoagentx/actions/agent_generation.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from pydantic import Field, model_validator
|
| 3 |
+
from typing import Optional, List
|
| 4 |
+
|
| 5 |
+
from ..core.logging import logger
|
| 6 |
+
from ..core.module import BaseModule
|
| 7 |
+
from ..core.base_config import Parameter
|
| 8 |
+
from ..models.base_model import BaseLLM
|
| 9 |
+
from .action import Action, ActionInput, ActionOutput
|
| 10 |
+
from ..prompts.agent_generator import AGENT_GENERATION_ACTION
|
| 11 |
+
from ..prompts.tool_calling import AGENT_GENERATION_TOOLS_PROMPT
|
| 12 |
+
from ..utils.utils import normalize_text
|
| 13 |
+
|
| 14 |
+
class AgentGenerationInput(ActionInput):
|
| 15 |
+
"""
|
| 16 |
+
Input specification for the agent generation action.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
goal: str = Field(description="A detailed statement of the workflow's goal, explaining the objectives the entire workflow aims to achieve")
|
| 20 |
+
workflow: str = Field(description="An overview of the entire workflow, detailing all sub-tasks with their respective names, descriptions, inputs, and outputs")
|
| 21 |
+
task: str = Field(description="A detailed JSON representation of the sub-task requiring agent generation. It should include the task's name, description, inputs, and outputs.")
|
| 22 |
+
|
| 23 |
+
history: Optional[str] = Field(default=None, description="Optional field containing previously selected or generated agents.")
|
| 24 |
+
suggestion: Optional[str] = Field(default=None, description="Optional suggestions to refine the generated agents.")
|
| 25 |
+
existing_agents: Optional[str] = Field(default=None, description="Optional field containing the description of predefined agents, including each agent's name, role, and available actions.")
|
| 26 |
+
tools: Optional[str] = Field(default=None, description="Optional field containing the description of tools that agents can use, including each tool's name and functionality.")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class GeneratedAgent(BaseModule):
|
| 30 |
+
"""
|
| 31 |
+
Representation of a generated agent with validation capabilities.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
name: str
|
| 35 |
+
description: str
|
| 36 |
+
inputs: List[Parameter]
|
| 37 |
+
outputs: List[Parameter]
|
| 38 |
+
prompt: str
|
| 39 |
+
tool_names: Optional[List[str]] = None
|
| 40 |
+
|
| 41 |
+
@classmethod
|
| 42 |
+
def find_output_name(cls, text: str, outputs: List[str]):
|
| 43 |
+
def sim(t1: str, t2: str):
|
| 44 |
+
t1_words = normalize_text(t1).split()
|
| 45 |
+
t2_words = normalize_text(t2).split()
|
| 46 |
+
return len(set(t1_words)&set(t2_words))
|
| 47 |
+
|
| 48 |
+
similarities = [sim(text, output) for output in outputs]
|
| 49 |
+
max_sim = max(similarities)
|
| 50 |
+
return outputs[similarities.index(max_sim)]
|
| 51 |
+
|
| 52 |
+
@model_validator(mode="after")
|
| 53 |
+
@classmethod
|
| 54 |
+
def validate_prompt(cls, agent: 'GeneratedAgent'):
|
| 55 |
+
"""Validate and fix the agent's prompt template.
|
| 56 |
+
|
| 57 |
+
This validator ensures that:
|
| 58 |
+
1. All input parameters are properly referenced in the prompt
|
| 59 |
+
2. Input references use the correct format with braces
|
| 60 |
+
3. All output sections match the defined output parameters
|
| 61 |
+
|
| 62 |
+
If there are mismatches in the output sections, it attempts to
|
| 63 |
+
fix them by finding the most similar output name.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
agent: The GeneratedAgent instance to validate.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
The validated and potentially modified GeneratedAgent.
|
| 70 |
+
|
| 71 |
+
Raises:
|
| 72 |
+
ValueError: If inputs are missing from the prompt or output sections don't match the defined outputs.
|
| 73 |
+
"""
|
| 74 |
+
# check whether all the inputs are present in the prompt
|
| 75 |
+
input_names = [inp.name for inp in agent.inputs]
|
| 76 |
+
prompt_has_inputs = [name in agent.prompt for name in input_names]
|
| 77 |
+
if not all(prompt_has_inputs):
|
| 78 |
+
missing_input_names = [name for name, has_input in zip(input_names, prompt_has_inputs) if not has_input]
|
| 79 |
+
raise ValueError(f'The prompt miss inputs: {missing_input_names}')
|
| 80 |
+
|
| 81 |
+
# check the format of the prompt to make sure it is wrapped in brackets.
|
| 82 |
+
pattern = r"### Instructions(.*?)### Output Format"
|
| 83 |
+
prompt = agent.prompt
|
| 84 |
+
|
| 85 |
+
def replace_with_braces(match):
|
| 86 |
+
instructions = match.group(1)
|
| 87 |
+
for name in input_names:
|
| 88 |
+
instructions = re.sub(fr'<input>{{*\b{re.escape(name)}\b}}*</input>', fr'<input>{{{name}}}</input>', instructions)
|
| 89 |
+
return "### Instructions" + instructions + "### Output Format"
|
| 90 |
+
|
| 91 |
+
modified_prompt = re.sub(pattern, replace_with_braces, prompt, flags=re.DOTALL)
|
| 92 |
+
agent.prompt = modified_prompt
|
| 93 |
+
|
| 94 |
+
# check whether all the outputs are present in the prompt
|
| 95 |
+
prompt = agent.prompt
|
| 96 |
+
pattern = r"### Output Format(.*)"
|
| 97 |
+
outputs_names = [out.name for out in agent.outputs]
|
| 98 |
+
|
| 99 |
+
def fix_output_names(match):
|
| 100 |
+
output_format = match.group(1)
|
| 101 |
+
matches = re.findall(r"## ([^\n#]+)", output_format, flags=re.DOTALL)
|
| 102 |
+
generated_outputs = [m.strip() for m in matches if m.strip() != "Thought"]
|
| 103 |
+
# check the number of generated outputs and agent outputs
|
| 104 |
+
if len(generated_outputs) != len(outputs_names):
|
| 105 |
+
raise ValueError(f"The number of outputs in the prompt is different from that defined in the `outputs` field of the agent. The outputs in the prompt are: {generated_outputs}, while the outputs from the agent's `outputs` field are: {outputs_names}")
|
| 106 |
+
# check whether the generated output names are the same as agent outputs
|
| 107 |
+
for generated_output in generated_outputs:
|
| 108 |
+
if generated_output not in outputs_names:
|
| 109 |
+
most_similar_output_name = cls.find_output_name(text=generated_output, outputs=outputs_names)
|
| 110 |
+
output_format = output_format.replace(generated_output, most_similar_output_name)
|
| 111 |
+
logger.warning(f"Couldn't find output name in prompt ('{generated_output}') in agent's outputs. Replace it with the most similar agent output: '{most_similar_output_name}'")
|
| 112 |
+
return "### Output Format" + output_format
|
| 113 |
+
|
| 114 |
+
modified_prompt = re.sub(pattern, fix_output_names, prompt, flags=re.DOTALL)
|
| 115 |
+
agent.prompt = modified_prompt
|
| 116 |
+
|
| 117 |
+
return agent
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class AgentGenerationOutput(ActionOutput):
|
| 121 |
+
|
| 122 |
+
selected_agents: List[str] = Field(description="A list of selected agent's names")
|
| 123 |
+
generated_agents: List[GeneratedAgent] = Field(description="A list of generated agetns to address a sub-task")
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class AgentGeneration(Action):
|
| 127 |
+
"""
|
| 128 |
+
Action for generating agent specifications for workflow tasks.
|
| 129 |
+
|
| 130 |
+
This action analyzes task requirements and generates appropriate agent
|
| 131 |
+
specifications, including their prompts, inputs, and outputs. It can either
|
| 132 |
+
select from existing agents or create new ones tailored to the task.
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
def __init__(self, **kwargs):
|
| 136 |
+
name = kwargs.pop("name") if "name" in kwargs else AGENT_GENERATION_ACTION["name"]
|
| 137 |
+
description = kwargs.pop("description") if "description" in kwargs else AGENT_GENERATION_ACTION["description"]
|
| 138 |
+
prompt = kwargs.pop("prompt") if "prompt" in kwargs else AGENT_GENERATION_ACTION["prompt"]
|
| 139 |
+
# inputs_format = kwargs.pop("inputs_format") if "inputs_format" in kwargs else AgentGenerationInput
|
| 140 |
+
# outputs_format = kwargs.pop("outputs_format") if "outputs_format" in kwargs else AgentGenerationOutput
|
| 141 |
+
inputs_format = kwargs.pop("inputs_format", None) or AgentGenerationInput
|
| 142 |
+
outputs_format = kwargs.pop("outputs_format", None) or AgentGenerationOutput
|
| 143 |
+
tools = kwargs.pop("tools", None)
|
| 144 |
+
super().__init__(name=name, description=description, prompt=prompt, inputs_format=inputs_format, outputs_format=outputs_format, **kwargs)
|
| 145 |
+
self.tools = tools
|
| 146 |
+
|
| 147 |
+
def execute(self, llm: Optional[BaseLLM] = None, inputs: Optional[dict] = None, sys_msg: Optional[str]=None, return_prompt: bool = False, **kwargs) -> AgentGenerationOutput:
|
| 148 |
+
"""Execute the agent generation process.
|
| 149 |
+
|
| 150 |
+
This method uses the provided language model to generate agent specifications
|
| 151 |
+
based on the workflow context and task requirements.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
llm: The language model to use for generation.
|
| 155 |
+
inputs: Input data containing workflow and task information.
|
| 156 |
+
sys_msg: Optional system message for the language model.
|
| 157 |
+
return_prompt: Whether to return both the generated agents and the prompt used.
|
| 158 |
+
**kwargs: Additional keyword arguments.
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
If return_prompt is False (default): The generated agents output.
|
| 162 |
+
If return_prompt is True: A tuple of (generated agents, prompt used).
|
| 163 |
+
|
| 164 |
+
Raises:
|
| 165 |
+
ValueError: If the inputs are None or empty.
|
| 166 |
+
"""
|
| 167 |
+
if not inputs:
|
| 168 |
+
logger.error("AgentGeneration action received invalid `inputs`: None or empty.")
|
| 169 |
+
raise ValueError('The `inputs` to AgentGeneration action is None or empty.')
|
| 170 |
+
|
| 171 |
+
inputs_format: AgentGenerationInput = self.inputs_format
|
| 172 |
+
outputs_format: AgentGenerationOutput = self.outputs_format
|
| 173 |
+
|
| 174 |
+
prompt_params_names = inputs_format.get_attrs()
|
| 175 |
+
prompt_params_values = {param: inputs.get(param, "") for param in prompt_params_names}
|
| 176 |
+
if self.tools:
|
| 177 |
+
tool_description = [
|
| 178 |
+
{
|
| 179 |
+
tool.name: [
|
| 180 |
+
s["function"]["description"] for s in tool.get_tool_schemas()
|
| 181 |
+
],
|
| 182 |
+
}
|
| 183 |
+
for tool in self.tools
|
| 184 |
+
]
|
| 185 |
+
prompt_params_values["tools"] = AGENT_GENERATION_TOOLS_PROMPT.format(tools_description=tool_description)
|
| 186 |
+
prompt = self.prompt.format(**prompt_params_values)
|
| 187 |
+
agents = llm.generate(
|
| 188 |
+
prompt = prompt,
|
| 189 |
+
system_message = sys_msg,
|
| 190 |
+
parser=outputs_format,
|
| 191 |
+
parse_mode="json"
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
if return_prompt:
|
| 195 |
+
return agents, prompt
|
| 196 |
+
|
| 197 |
+
return agents
|
| 198 |
+
|
evoagentx/actions/code_extraction.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Optional, List, Dict
|
| 3 |
+
from pydantic import Field
|
| 4 |
+
|
| 5 |
+
from ..models.base_model import BaseLLM, LLMOutputParser
|
| 6 |
+
from .action import Action, ActionInput, ActionOutput
|
| 7 |
+
from ..prompts.code_extraction import CODE_EXTRACTION
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class CodeExtractionInput(ActionInput):
|
| 11 |
+
"""
|
| 12 |
+
Input parameters for the CodeExtraction action.
|
| 13 |
+
"""
|
| 14 |
+
code_string: str = Field(description="The string containing code blocks to extract")
|
| 15 |
+
target_directory: str = Field(description="The directory path where extracted code files will be saved")
|
| 16 |
+
project_name: Optional[str] = Field(default=None, description="Optional name for the project folder")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class CodeExtractionOutput(ActionOutput):
|
| 20 |
+
"""
|
| 21 |
+
Output of the CodeExtraction action.
|
| 22 |
+
"""
|
| 23 |
+
extracted_files: Dict[str, str] = Field(description="Map of filename to file path of saved files")
|
| 24 |
+
main_file: Optional[str] = Field(default=None, description="Path to the main file if identified")
|
| 25 |
+
error: Optional[str] = Field(default=None, description="Error message if any operation failed")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class CodeBlockInfo(LLMOutputParser):
|
| 29 |
+
"""
|
| 30 |
+
Information about an extracted code block.
|
| 31 |
+
"""
|
| 32 |
+
language: str = Field(description="Programming language of the code block")
|
| 33 |
+
filename: str = Field(description="Suggested filename for the code block")
|
| 34 |
+
content: str = Field(description="The actual code content")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class CodeBlockList(LLMOutputParser):
|
| 38 |
+
"""
|
| 39 |
+
List of code blocks extracted from text.
|
| 40 |
+
"""
|
| 41 |
+
code_blocks: List[CodeBlockInfo] = Field(description="List of code blocks")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class CodeExtraction(Action):
|
| 45 |
+
"""
|
| 46 |
+
An action that extracts and organizes code blocks from text.
|
| 47 |
+
|
| 48 |
+
This action uses an LLM to analyze text containing code blocks, extract them,
|
| 49 |
+
suggest appropriate filenames, and save them to a specified directory. It can
|
| 50 |
+
also identify which file is likely the main entry point based on heuristics.
|
| 51 |
+
|
| 52 |
+
Attributes:
|
| 53 |
+
name: The name of the action.
|
| 54 |
+
description: A description of what the action does.
|
| 55 |
+
prompt: The prompt template used by the action.
|
| 56 |
+
inputs_format: The expected format of inputs to this action.
|
| 57 |
+
outputs_format: The format of the action's output.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(self, **kwargs):
|
| 61 |
+
|
| 62 |
+
name = kwargs.pop("name") if "name" in kwargs else CODE_EXTRACTION["name"]
|
| 63 |
+
description = kwargs.pop("description") if "description" in kwargs else CODE_EXTRACTION["description"]
|
| 64 |
+
prompt = kwargs.pop("prompt") if "prompt" in kwargs else CODE_EXTRACTION["prompt"]
|
| 65 |
+
# inputs_format = kwargs.pop("inputs_format") if "inputs_format" in kwargs else CodeExtractionInput
|
| 66 |
+
# outputs_format = kwargs.pop("outputs_format") if "outputs_format" in kwargs else CodeExtractionOutput
|
| 67 |
+
inputs_format = kwargs.pop("inputs_format", None) or CodeExtractionInput
|
| 68 |
+
outputs_format = kwargs.pop("outputs_format", None) or CodeExtractionOutput
|
| 69 |
+
super().__init__(name=name, description=description, prompt=prompt, inputs_format=inputs_format, outputs_format=outputs_format, **kwargs)
|
| 70 |
+
|
| 71 |
+
def identify_main_file(self, saved_files: Dict[str, str]) -> Optional[str]:
|
| 72 |
+
"""Identify the main file from the saved files based on content and file type.
|
| 73 |
+
|
| 74 |
+
This method uses a combination of common filename conventions and content
|
| 75 |
+
analysis to determine which file is likely the main entry point of a project.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
saved_files: Dictionary mapping filenames to their full paths
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
Path to the main file if found, None otherwise
|
| 82 |
+
|
| 83 |
+
"""
|
| 84 |
+
# Priority lookup for common main files by language
|
| 85 |
+
main_file_priorities = [
|
| 86 |
+
# HTML files
|
| 87 |
+
"index.html",
|
| 88 |
+
# Python files
|
| 89 |
+
"main.py",
|
| 90 |
+
"app.py",
|
| 91 |
+
# JavaScript files
|
| 92 |
+
"index.js",
|
| 93 |
+
"main.js",
|
| 94 |
+
"app.js",
|
| 95 |
+
# Java files
|
| 96 |
+
"Main.java",
|
| 97 |
+
# C/C++ files
|
| 98 |
+
"main.cpp",
|
| 99 |
+
"main.c",
|
| 100 |
+
# Go files
|
| 101 |
+
"main.go",
|
| 102 |
+
# Other common entry points
|
| 103 |
+
"index.php",
|
| 104 |
+
"Program.cs"
|
| 105 |
+
]
|
| 106 |
+
|
| 107 |
+
# First check priority list
|
| 108 |
+
for main_file in main_file_priorities:
|
| 109 |
+
if main_file in saved_files:
|
| 110 |
+
return saved_files[main_file]
|
| 111 |
+
|
| 112 |
+
# If no priority file found, use heuristics based on file extensions
|
| 113 |
+
|
| 114 |
+
# If we have HTML files, use the first one
|
| 115 |
+
html_files = {k: v for k, v in saved_files.items() if k.endswith('.html')}
|
| 116 |
+
if html_files:
|
| 117 |
+
return next(iter(html_files.values()))
|
| 118 |
+
|
| 119 |
+
# Check for Python files with "__main__" section
|
| 120 |
+
py_files = {k: v for k, v in saved_files.items() if k.endswith('.py')}
|
| 121 |
+
if py_files:
|
| 122 |
+
for filename, path in py_files.items():
|
| 123 |
+
with open(path, 'r', encoding='utf-8') as f:
|
| 124 |
+
content = f.read()
|
| 125 |
+
if "if __name__ == '__main__'" in content or 'if __name__ == "__main__"' in content:
|
| 126 |
+
return path
|
| 127 |
+
# If no main found, return the first Python file
|
| 128 |
+
if py_files:
|
| 129 |
+
return next(iter(py_files.values()))
|
| 130 |
+
|
| 131 |
+
# If we have Java files, look for one with a main method
|
| 132 |
+
java_files = {k: v for k, v in saved_files.items() if k.endswith('.java')}
|
| 133 |
+
if java_files:
|
| 134 |
+
for filename, path in java_files.items():
|
| 135 |
+
with open(path, 'r', encoding='utf-8') as f:
|
| 136 |
+
content = f.read()
|
| 137 |
+
if "public static void main" in content:
|
| 138 |
+
return path
|
| 139 |
+
# If no main found, return the first Java file
|
| 140 |
+
if java_files:
|
| 141 |
+
return next(iter(java_files.values()))
|
| 142 |
+
|
| 143 |
+
# For JavaScript applications
|
| 144 |
+
js_files = {k: v for k, v in saved_files.items() if k.endswith('.js')}
|
| 145 |
+
if js_files:
|
| 146 |
+
return next(iter(js_files.values()))
|
| 147 |
+
|
| 148 |
+
# If all else fails, return the first file
|
| 149 |
+
if saved_files:
|
| 150 |
+
return next(iter(saved_files.values()))
|
| 151 |
+
|
| 152 |
+
# No files found
|
| 153 |
+
return None
|
| 154 |
+
|
| 155 |
+
def save_code_blocks(self, code_blocks: List[Dict], target_directory: str) -> Dict[str, str]:
|
| 156 |
+
"""Save code blocks to files in the target directory.
|
| 157 |
+
|
| 158 |
+
Creates the target directory if it doesn't exist and saves each code block
|
| 159 |
+
to a file with an appropriate name, handling filename conflicts.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
code_blocks: List of dictionaries containing code block information
|
| 163 |
+
target_directory: Directory path where files should be saved
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
Dictionary mapping filenames to their full paths
|
| 167 |
+
"""
|
| 168 |
+
os.makedirs(target_directory, exist_ok=True)
|
| 169 |
+
saved_files = {}
|
| 170 |
+
|
| 171 |
+
for block in code_blocks:
|
| 172 |
+
filename = block.get("filename", "unknown.txt")
|
| 173 |
+
content = block.get("content", "")
|
| 174 |
+
|
| 175 |
+
# Skip empty blocks
|
| 176 |
+
if not content.strip():
|
| 177 |
+
continue
|
| 178 |
+
|
| 179 |
+
# Handle filename conflicts
|
| 180 |
+
base_filename = filename
|
| 181 |
+
counter = 1
|
| 182 |
+
while filename in saved_files:
|
| 183 |
+
name_parts = base_filename.split('.')
|
| 184 |
+
if len(name_parts) > 1:
|
| 185 |
+
filename = f"{'.'.join(name_parts[:-1])}_{counter}.{name_parts[-1]}"
|
| 186 |
+
else:
|
| 187 |
+
filename = f"{base_filename}_{counter}"
|
| 188 |
+
counter += 1
|
| 189 |
+
|
| 190 |
+
# Save to file
|
| 191 |
+
file_path = os.path.join(target_directory, filename)
|
| 192 |
+
with open(file_path, 'w', encoding='utf-8') as f:
|
| 193 |
+
f.write(content)
|
| 194 |
+
|
| 195 |
+
# Add to map
|
| 196 |
+
saved_files[filename] = file_path
|
| 197 |
+
|
| 198 |
+
return saved_files
|
| 199 |
+
|
| 200 |
+
def execute(self, llm: Optional[BaseLLM] = None, inputs: Optional[dict] = None, sys_msg: Optional[str]=None, return_prompt: bool = False, **kwargs) -> CodeExtractionOutput:
|
| 201 |
+
"""Execute the CodeExtraction action.
|
| 202 |
+
|
| 203 |
+
Extracts code blocks from the provided text using the specified LLM,
|
| 204 |
+
saves them to the target directory, and identifies the main file.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
llm: The LLM to use for code extraction
|
| 208 |
+
inputs: Dictionary containing:
|
| 209 |
+
- code_string: The string with code blocks to extract
|
| 210 |
+
- target_directory: Where to save the files
|
| 211 |
+
- project_name: Optional project folder name
|
| 212 |
+
sys_msg: Optional system message override for the LLM
|
| 213 |
+
return_prompt: Whether to return the prompt along with the result
|
| 214 |
+
**kwargs (Any): Additional keyword arguments
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
CodeExtractionOutput with extracted file information
|
| 218 |
+
"""
|
| 219 |
+
if not llm:
|
| 220 |
+
error_msg = "CodeExtraction action requires an LLM."
|
| 221 |
+
return CodeExtractionOutput(extracted_files={}, error=error_msg)
|
| 222 |
+
|
| 223 |
+
if not inputs:
|
| 224 |
+
error_msg = "CodeExtraction action received invalid `inputs`: None or empty."
|
| 225 |
+
return CodeExtractionOutput(extracted_files={}, error=error_msg)
|
| 226 |
+
|
| 227 |
+
code_string = inputs.get("code_string", "")
|
| 228 |
+
target_directory = inputs.get("target_directory", "")
|
| 229 |
+
project_name = inputs.get("project_name", None)
|
| 230 |
+
|
| 231 |
+
if not code_string:
|
| 232 |
+
error_msg = "No code string provided."
|
| 233 |
+
return CodeExtractionOutput(extracted_files={}, error=error_msg)
|
| 234 |
+
|
| 235 |
+
if not target_directory:
|
| 236 |
+
error_msg = "No target directory provided."
|
| 237 |
+
return CodeExtractionOutput(extracted_files={}, error=error_msg)
|
| 238 |
+
|
| 239 |
+
# Create project folder if name is provided
|
| 240 |
+
if project_name:
|
| 241 |
+
project_dir = os.path.join(target_directory, project_name)
|
| 242 |
+
else:
|
| 243 |
+
project_dir = target_directory
|
| 244 |
+
|
| 245 |
+
try:
|
| 246 |
+
# Use LLM to extract code blocks and suggest filenames
|
| 247 |
+
prompt_params = {"code_string": code_string}
|
| 248 |
+
system_message = CODE_EXTRACTION["system_prompt"] if sys_msg is None else sys_msg
|
| 249 |
+
|
| 250 |
+
llm_response: CodeBlockList = llm.generate(
|
| 251 |
+
prompt=self.prompt.format(**prompt_params),
|
| 252 |
+
system_message=system_message,
|
| 253 |
+
parser=CodeBlockList,
|
| 254 |
+
parse_mode="json"
|
| 255 |
+
)
|
| 256 |
+
code_blocks = llm_response.get_structured_data().get("code_blocks", [])
|
| 257 |
+
|
| 258 |
+
# Save code blocks to files
|
| 259 |
+
saved_files = self.save_code_blocks(code_blocks, project_dir)
|
| 260 |
+
|
| 261 |
+
# Identify main file
|
| 262 |
+
main_file = self.identify_main_file(saved_files)
|
| 263 |
+
|
| 264 |
+
result = CodeExtractionOutput(
|
| 265 |
+
extracted_files=saved_files,
|
| 266 |
+
main_file=main_file
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
if return_prompt:
|
| 270 |
+
return result, self.prompt.format(**prompt_params)
|
| 271 |
+
|
| 272 |
+
return result
|
| 273 |
+
|
| 274 |
+
except Exception as e:
|
| 275 |
+
error_msg = f"Error extracting code: {str(e)}"
|
| 276 |
+
return CodeExtractionOutput(extracted_files={}, error=error_msg)
|
evoagentx/actions/code_verification.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import Field
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
from ..core.logging import logger
|
| 5 |
+
from ..core.module_utils import extract_code_blocks
|
| 6 |
+
from ..models.base_model import BaseLLM
|
| 7 |
+
from .action import Action, ActionInput, ActionOutput
|
| 8 |
+
from ..prompts.code_verification import CODE_VERIFICATION_ACTION
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class CodeVerificationInput(ActionInput):
|
| 12 |
+
|
| 13 |
+
code: str = Field(description="The code string to be verified for correctness and completeness.")
|
| 14 |
+
requirements: Optional[str] = Field(default=None, description="Optional field containing requirements or specifications for the code.")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class CodeVerificationOutput(ActionOutput):
|
| 18 |
+
|
| 19 |
+
analysis_summary: Optional[str] = Field(default=None, description="Brief summary of your findings, highlighting key issues or confirming overall quality.")
|
| 20 |
+
issues_identified: Optional[str] = Field(default=None, description="Categorized list of issues found, with explanation of impact and severity.")
|
| 21 |
+
thought_process: Optional[str] = Field(default=None, description="Detailed explanation of your verification reasoning and methodology applied.")
|
| 22 |
+
modification_strategy: Optional[str] = Field(default=None, description="Describe the changes you made (or will make) to address the issues. Include any assumptions, design choices, or additional components you decided to add to make the code complete and robust.")
|
| 23 |
+
verified_code: str = Field(description="The complete, corrected code if issues are found, or the original code if no issues are found.")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class CodeVerification(Action):
|
| 27 |
+
|
| 28 |
+
def __init__(self, **kwargs):
|
| 29 |
+
|
| 30 |
+
name = kwargs.pop("name") if "name" in kwargs else CODE_VERIFICATION_ACTION["name"]
|
| 31 |
+
description = kwargs.pop("description") if "description" in kwargs else CODE_VERIFICATION_ACTION["description"]
|
| 32 |
+
prompt = kwargs.pop("prompt") if "prompt" in kwargs else CODE_VERIFICATION_ACTION["prompt"]
|
| 33 |
+
# inputs_format = kwargs.pop("inputs_format") if "inputs_format" in kwargs else CodeVerificationInput
|
| 34 |
+
# outputs_format = kwargs.pop("outputs_format") if "outputs_format" in kwargs else CodeVerificationOutput
|
| 35 |
+
inputs_format = kwargs.pop("inputs_format", None) or CodeVerificationInput
|
| 36 |
+
outputs_format = kwargs.pop("outputs_format", None) or CodeVerificationOutput
|
| 37 |
+
super().__init__(name=name, description=description, prompt=prompt, inputs_format=inputs_format, outputs_format=outputs_format, **kwargs)
|
| 38 |
+
|
| 39 |
+
def execute(self, llm: Optional[BaseLLM] = None, inputs: Optional[dict] = None, sys_msg: Optional[str]=None, return_prompt: bool = False, **kwargs) -> CodeVerificationOutput:
|
| 40 |
+
|
| 41 |
+
if not inputs:
|
| 42 |
+
logger.error("CodeVerification action received invalid `inputs`: None or empty.")
|
| 43 |
+
raise ValueError('The `inputs` to CodeVerification action is None or empty.')
|
| 44 |
+
|
| 45 |
+
prompt_params_names = ["code", "requirements"]
|
| 46 |
+
prompt_params_values = {param: inputs.get(param, "Not Provided") for param in prompt_params_names}
|
| 47 |
+
prompt = self.prompt.format(**prompt_params_values)
|
| 48 |
+
response = llm.generate(prompt = prompt, system_message=sys_msg)
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
verification_result = self.outputs_format.parse(response.content, parse_mode="title")
|
| 52 |
+
except Exception:
|
| 53 |
+
try:
|
| 54 |
+
code_blocks = extract_code_blocks(response.content, return_type=True)
|
| 55 |
+
code = "\n\n".join([f"```{code_type}\n{code}\n```" for code_type, code in code_blocks])
|
| 56 |
+
verification_result = self.outputs_format(verified_code=code)
|
| 57 |
+
except Exception:
|
| 58 |
+
raise ValueError(f"Failed to extract code blocks from the response: {response.content}")
|
| 59 |
+
|
| 60 |
+
if return_prompt:
|
| 61 |
+
return verification_result, prompt
|
| 62 |
+
|
| 63 |
+
return verification_result
|
evoagentx/actions/customize_action.py
ADDED
|
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import Field
|
| 2 |
+
from typing import Optional, Any, Callable, List, Union
|
| 3 |
+
import re
|
| 4 |
+
import json
|
| 5 |
+
import asyncio
|
| 6 |
+
import inspect
|
| 7 |
+
import concurrent.futures
|
| 8 |
+
|
| 9 |
+
from ..core.logging import logger
|
| 10 |
+
from ..models.base_model import BaseLLM
|
| 11 |
+
from .action import Action
|
| 12 |
+
from ..core.message import Message
|
| 13 |
+
from ..prompts.template import StringTemplate, ChatTemplate
|
| 14 |
+
from ..prompts.tool_calling import OUTPUT_EXTRACTION_PROMPT, TOOL_CALLING_TEMPLATE, TOOL_CALLING_HISTORY_PROMPT, TOOL_CALLING_RETRY_PROMPT
|
| 15 |
+
from ..tools.tool import Toolkit
|
| 16 |
+
from ..core.registry import MODULE_REGISTRY
|
| 17 |
+
from ..models.base_model import LLMOutputParser
|
| 18 |
+
from ..core.module_utils import parse_json_from_llm_output, parse_json_from_text
|
| 19 |
+
|
| 20 |
+
class CustomizeAction(Action):
|
| 21 |
+
|
| 22 |
+
parse_mode: Optional[str] = Field(default="title", description="the parse mode of the action, must be one of: ['title', 'str', 'json', 'xml', 'custom']")
|
| 23 |
+
parse_func: Optional[Callable] = Field(default=None, exclude=True, description="the function to parse the LLM output. It receives the LLM output and returns a dict.")
|
| 24 |
+
title_format: Optional[str] = Field(default="## {title}", exclude=True, description="the format of the title. It is used when the `parse_mode` is 'title'.")
|
| 25 |
+
custom_output_format: Optional[str] = Field(default=None, exclude=True, description="the format of the output. It is used when the `prompt_template` is provided.")
|
| 26 |
+
|
| 27 |
+
tools: Optional[List[Toolkit]] = Field(default=None, description="The tools that the action can use")
|
| 28 |
+
conversation: Optional[Message] = Field(default=None, description="Current conversation state")
|
| 29 |
+
|
| 30 |
+
max_tool_try: int = Field(default=2, description="Maximum number of tool calling attempts allowed")
|
| 31 |
+
|
| 32 |
+
def __init__(self, **kwargs):
|
| 33 |
+
|
| 34 |
+
name = kwargs.pop("name", "CustomizeAction")
|
| 35 |
+
description = kwargs.pop("description", "Customized action that can use tools to accomplish its task")
|
| 36 |
+
|
| 37 |
+
super().__init__(name=name, description=description, **kwargs)
|
| 38 |
+
|
| 39 |
+
# Validate that at least one of prompt or prompt_template is provided
|
| 40 |
+
if not self.prompt and not self.prompt_template:
|
| 41 |
+
raise ValueError("`prompt` or `prompt_template` is required when creating CustomizeAction action")
|
| 42 |
+
# Prioritize template and give warning if both are provided
|
| 43 |
+
if self.prompt and self.prompt_template:
|
| 44 |
+
logger.warning("Both `prompt` and `prompt_template` are provided for CustomizeAction action. Prioritizing `prompt_template` and ignoring `prompt`.")
|
| 45 |
+
if self.tools:
|
| 46 |
+
self.tools_caller = {}
|
| 47 |
+
self.add_tools(self.tools)
|
| 48 |
+
|
| 49 |
+
def prepare_action_prompt(
|
| 50 |
+
self,
|
| 51 |
+
inputs: Optional[dict] = None,
|
| 52 |
+
system_prompt: Optional[str] = None,
|
| 53 |
+
**kwargs
|
| 54 |
+
) -> Union[str, List[dict]]:
|
| 55 |
+
"""Prepare prompt for action execution.
|
| 56 |
+
|
| 57 |
+
This helper function transforms the input dictionary into a formatted prompt
|
| 58 |
+
for the language model, handling different prompting modes.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
inputs: Dictionary of input parameters
|
| 62 |
+
system_prompt: Optional system prompt to include
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
Union[str, List[dict]]: Formatted prompt ready for LLM (string or chat messages)
|
| 66 |
+
|
| 67 |
+
Raises:
|
| 68 |
+
TypeError: If an input value type is not supported
|
| 69 |
+
ValueError: If neither prompt nor prompt_template is available
|
| 70 |
+
"""
|
| 71 |
+
# Process inputs into prompt parameter values
|
| 72 |
+
if inputs is None:
|
| 73 |
+
inputs = {}
|
| 74 |
+
|
| 75 |
+
prompt_params_names = self.inputs_format.get_attrs()
|
| 76 |
+
prompt_params_values = {}
|
| 77 |
+
for param in prompt_params_names:
|
| 78 |
+
value = inputs.get(param, "")
|
| 79 |
+
if isinstance(value, str):
|
| 80 |
+
prompt_params_values[param] = value
|
| 81 |
+
elif isinstance(value, (dict, list)):
|
| 82 |
+
prompt_params_values[param] = json.dumps(value, indent=4)
|
| 83 |
+
else:
|
| 84 |
+
raise TypeError(f"The input type {type(value)} is invalid! Valid types: [str, dict, list].")
|
| 85 |
+
|
| 86 |
+
if self.prompt:
|
| 87 |
+
prompt = self.prompt.format(**prompt_params_values) if prompt_params_values else self.prompt
|
| 88 |
+
if self.tools:
|
| 89 |
+
tools_schemas = [j["function"] for i in [tool.get_tool_schemas() for tool in self.tools] for j in i]
|
| 90 |
+
prompt += "\n\n" + TOOL_CALLING_TEMPLATE.format(tools_description = tools_schemas)
|
| 91 |
+
return prompt
|
| 92 |
+
else:
|
| 93 |
+
# Use goal-based tool calling mode
|
| 94 |
+
if self.tools:
|
| 95 |
+
self.prompt_template.set_tools(self.tools)
|
| 96 |
+
return self.prompt_template.format(
|
| 97 |
+
system_prompt=system_prompt,
|
| 98 |
+
values=prompt_params_values,
|
| 99 |
+
inputs_format=self.inputs_format,
|
| 100 |
+
outputs_format=self.outputs_format,
|
| 101 |
+
parse_mode=self.parse_mode,
|
| 102 |
+
title_format=self.title_format,
|
| 103 |
+
custom_output_format=self.custom_output_format,
|
| 104 |
+
tools=self.tools
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
def prepare_extraction_prompt(self, llm_output_content: str) -> str:
|
| 108 |
+
"""Prepare extraction prompt for fallback extraction when parsing fails.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
self: The action instance
|
| 112 |
+
llm_output_content: Raw output content from LLM
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
str: Formatted extraction prompt
|
| 116 |
+
"""
|
| 117 |
+
attr_descriptions: dict = self.outputs_format.get_attr_descriptions()
|
| 118 |
+
output_description_list = []
|
| 119 |
+
for i, (name, desc) in enumerate(attr_descriptions.items()):
|
| 120 |
+
output_description_list.append(f"{i+1}. {name}\nDescription: {desc}")
|
| 121 |
+
output_description = "\n\n".join(output_description_list)
|
| 122 |
+
return OUTPUT_EXTRACTION_PROMPT.format(text=llm_output_content, output_description=output_description)
|
| 123 |
+
|
| 124 |
+
def _get_unique_class_name(self, candidate_name: str) -> str:
|
| 125 |
+
"""
|
| 126 |
+
Get a unique class name by checking if it already exists in the registry.
|
| 127 |
+
If it does, append "Vx" to make it unique.
|
| 128 |
+
"""
|
| 129 |
+
if not MODULE_REGISTRY.has_module(candidate_name):
|
| 130 |
+
return candidate_name
|
| 131 |
+
|
| 132 |
+
i = 1
|
| 133 |
+
while True:
|
| 134 |
+
unique_name = f"{candidate_name}V{i}"
|
| 135 |
+
if not MODULE_REGISTRY.has_module(unique_name):
|
| 136 |
+
break
|
| 137 |
+
i += 1
|
| 138 |
+
return unique_name
|
| 139 |
+
|
| 140 |
+
def add_tools(self, tools: Union[Toolkit, List[Toolkit]]):
|
| 141 |
+
if not tools:
|
| 142 |
+
return
|
| 143 |
+
if isinstance(tools,Toolkit):
|
| 144 |
+
tools = [tools]
|
| 145 |
+
if not all(isinstance(tool, Toolkit) for tool in tools):
|
| 146 |
+
raise TypeError("`tools` must be a Toolkit or list of Toolkit instances.")
|
| 147 |
+
if not self.tools:
|
| 148 |
+
self.tools_caller = {}
|
| 149 |
+
self.tools = []
|
| 150 |
+
# self.tools += tools
|
| 151 |
+
# tools_callers = [tool.get_tools() for tool in tools]
|
| 152 |
+
# tools_callers = [j for i in tools_callers for j in i]
|
| 153 |
+
# for tool_caller in tools_callers:
|
| 154 |
+
# self.tools_caller[tool_caller.name] = tool_caller
|
| 155 |
+
|
| 156 |
+
# avoid duplication & type checks
|
| 157 |
+
for toolkit in tools:
|
| 158 |
+
try:
|
| 159 |
+
tool_callers = toolkit.get_tools()
|
| 160 |
+
if not isinstance(tool_callers, list):
|
| 161 |
+
logger.warning(f"Expected list of tool functions from '{toolkit.name}.get_tools()', got {type(tool_callers)}.")
|
| 162 |
+
continue
|
| 163 |
+
|
| 164 |
+
# add tool callers to the tools_caller dictionary
|
| 165 |
+
valid_tools_count = 0
|
| 166 |
+
valid_tools_names, valid_tool_callers = [], []
|
| 167 |
+
for tool_caller in tool_callers:
|
| 168 |
+
tool_caller_name = getattr(tool_caller, "name", None)
|
| 169 |
+
if not tool_caller_name or not callable(tool_caller):
|
| 170 |
+
logger.warning(f"Invalid tool function in '{toolkit.name}': missing name or not callable.")
|
| 171 |
+
continue
|
| 172 |
+
if tool_caller_name in self.tools_caller:
|
| 173 |
+
logger.warning(f"Duplicate tool function '{tool_caller_name}' detected. Overwriting previous function.")
|
| 174 |
+
# self.tools_caller[tool_caller_name] = tool_caller
|
| 175 |
+
valid_tools_count += 1
|
| 176 |
+
valid_tools_names.append(tool_caller_name)
|
| 177 |
+
valid_tool_callers.append(tool_caller)
|
| 178 |
+
|
| 179 |
+
if valid_tools_count == 0:
|
| 180 |
+
logger.info(f"No valid tools found in toolkit '{toolkit.name}'. Skipping.")
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
+
if valid_tools_count > 0 and all(name in self.tools_caller for name in valid_tools_names):
|
| 184 |
+
logger.info(f"All tools from toolkit '{toolkit.name}' are already added. Skipping.")
|
| 185 |
+
continue
|
| 186 |
+
|
| 187 |
+
if valid_tools_count > 0:
|
| 188 |
+
self.tools_caller.update({name: caller for name, caller in zip(valid_tools_names, valid_tool_callers)})
|
| 189 |
+
|
| 190 |
+
# only add toolkit if at least one valid tool is added and toolkit is not already added
|
| 191 |
+
existing_toolkit_names = {tkt.name for tkt in self.tools}
|
| 192 |
+
if valid_tools_count > 0 and toolkit.name not in existing_toolkit_names:
|
| 193 |
+
self.tools.append(toolkit)
|
| 194 |
+
if valid_tools_count > 0:
|
| 195 |
+
logger.info(f"Added toolkit '{toolkit.name}' with {valid_tools_count} valid tools in {self.name}: {valid_tools_names}.")
|
| 196 |
+
|
| 197 |
+
except Exception as e:
|
| 198 |
+
logger.error(f"Failed to load tools from toolkit '{toolkit.name}': {e}")
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def _extract_tool_calls(self, llm_output: str, llm: Optional[BaseLLM] = None) -> List[dict]:
|
| 202 |
+
pattern = r"<ToolCalling>\s*(.*?)\s*</ToolCalling>"
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# Find all ToolCalling blocks in the output
|
| 206 |
+
matches = re.findall(pattern, llm_output, re.DOTALL)
|
| 207 |
+
|
| 208 |
+
if not matches:
|
| 209 |
+
return []
|
| 210 |
+
|
| 211 |
+
parsed_tool_calls = []
|
| 212 |
+
for match_content in matches:
|
| 213 |
+
try:
|
| 214 |
+
json_content = match_content.strip()
|
| 215 |
+
json_list = parse_json_from_text(json_content)
|
| 216 |
+
if not json_list:
|
| 217 |
+
logger.warning("No valid JSON found in ToolCalling block")
|
| 218 |
+
continue
|
| 219 |
+
# Only use the first JSON string from each block
|
| 220 |
+
parsed_tool_call = json.loads(json_list[0])
|
| 221 |
+
if isinstance(parsed_tool_call, dict):
|
| 222 |
+
parsed_tool_calls.append(parsed_tool_call)
|
| 223 |
+
elif isinstance(parsed_tool_call, list):
|
| 224 |
+
parsed_tool_calls.extend(parsed_tool_call)
|
| 225 |
+
else:
|
| 226 |
+
logger.warning(f"Invalid tool call format: {parsed_tool_call}")
|
| 227 |
+
continue
|
| 228 |
+
except (json.JSONDecodeError, IndexError) as e:
|
| 229 |
+
logger.warning(f"Failed to parse tool calls from LLM output: {e}")
|
| 230 |
+
if llm is not None:
|
| 231 |
+
retry_prompt = TOOL_CALLING_RETRY_PROMPT.format(text=match_content)
|
| 232 |
+
try:
|
| 233 |
+
fixed_output = llm.generate(prompt=retry_prompt).content.strip()
|
| 234 |
+
logger.info(f"Retrying tool call parse with fixed output:\n{fixed_output}")
|
| 235 |
+
|
| 236 |
+
fixed_list = parse_json_from_text(fixed_output)
|
| 237 |
+
if fixed_list:
|
| 238 |
+
parsed_tool_call = json.loads(fixed_list[0])
|
| 239 |
+
if isinstance(parsed_tool_call, dict):
|
| 240 |
+
parsed_tool_calls.append(parsed_tool_call)
|
| 241 |
+
elif isinstance(parsed_tool_call, list):
|
| 242 |
+
parsed_tool_calls.extend(parsed_tool_call)
|
| 243 |
+
except Exception as retry_err:
|
| 244 |
+
logger.error(f"Retry failed: {retry_err}")
|
| 245 |
+
continue
|
| 246 |
+
else:
|
| 247 |
+
continue
|
| 248 |
+
|
| 249 |
+
return parsed_tool_calls
|
| 250 |
+
|
| 251 |
+
def _extract_output(self, llm_output: Any, llm: BaseLLM = None, **kwargs):
|
| 252 |
+
|
| 253 |
+
# Get the raw output content
|
| 254 |
+
llm_output_content = getattr(llm_output, "content", str(llm_output))
|
| 255 |
+
|
| 256 |
+
# Check if there are any defined output fields
|
| 257 |
+
output_attrs = self.outputs_format.get_attrs()
|
| 258 |
+
|
| 259 |
+
# If no output fields are defined, create a simple content-only output
|
| 260 |
+
if not output_attrs:
|
| 261 |
+
# Create output with just the content field
|
| 262 |
+
output = self.outputs_format.parse(content=llm_output_content)
|
| 263 |
+
# print("Created simple content output for agent with no defined outputs:")
|
| 264 |
+
# print(output)
|
| 265 |
+
return output
|
| 266 |
+
|
| 267 |
+
# Use the action's parse_mode and parse_func for parsing
|
| 268 |
+
try:
|
| 269 |
+
# Use the outputs_format's parse method with the action's parse settings
|
| 270 |
+
parsed_output = self.outputs_format.parse(
|
| 271 |
+
content=llm_output_content,
|
| 272 |
+
parse_mode=self.parse_mode,
|
| 273 |
+
parse_func=getattr(self, 'parse_func', None),
|
| 274 |
+
title_format=getattr(self, 'title_format', "## {title}")
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# print("Successfully parsed output using action's parse settings:")
|
| 278 |
+
# print(parsed_output)
|
| 279 |
+
return parsed_output
|
| 280 |
+
|
| 281 |
+
except Exception as e:
|
| 282 |
+
logger.info(f"Failed to parse with action's parse settings: {e}")
|
| 283 |
+
logger.info("Falling back to using LLM to extract outputs...")
|
| 284 |
+
|
| 285 |
+
# Fall back to extraction prompt if direct parsing fails
|
| 286 |
+
extraction_prompt = self.prepare_extraction_prompt(llm_output_content)
|
| 287 |
+
|
| 288 |
+
llm_extracted_output: LLMOutputParser = llm.generate(prompt=extraction_prompt)
|
| 289 |
+
llm_extracted_data: dict = parse_json_from_llm_output(llm_extracted_output.content)
|
| 290 |
+
output = self.outputs_format.from_dict(llm_extracted_data)
|
| 291 |
+
|
| 292 |
+
# print("Extracted output using fallback:")
|
| 293 |
+
# print(output)
|
| 294 |
+
return output
|
| 295 |
+
|
| 296 |
+
async def _async_extract_output(self, llm_output: Any, llm: BaseLLM = None, **kwargs):
|
| 297 |
+
|
| 298 |
+
# Get the raw output content
|
| 299 |
+
llm_output_content = getattr(llm_output, "content", str(llm_output))
|
| 300 |
+
|
| 301 |
+
# Check if there are any defined output fields
|
| 302 |
+
output_attrs = self.outputs_format.get_attrs()
|
| 303 |
+
|
| 304 |
+
# If no output fields are defined, create a simple content-only output
|
| 305 |
+
if not output_attrs:
|
| 306 |
+
# Create output with just the content field
|
| 307 |
+
output = self.outputs_format.parse(content=llm_output_content)
|
| 308 |
+
# print("Created simple content output for agent with no defined outputs:")
|
| 309 |
+
# print(output)
|
| 310 |
+
return output
|
| 311 |
+
|
| 312 |
+
# Use the action's parse_mode and parse_func for parsing
|
| 313 |
+
try:
|
| 314 |
+
# Use the outputs_format's parse method with the action's parse settings
|
| 315 |
+
parsed_output = self.outputs_format.parse(
|
| 316 |
+
content=llm_output_content,
|
| 317 |
+
parse_mode=self.parse_mode,
|
| 318 |
+
parse_func=getattr(self, 'parse_func', None),
|
| 319 |
+
title_format=getattr(self, 'title_format', "## {title}")
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# print("Successfully parsed output using action's parse settings:")
|
| 323 |
+
# print(parsed_output)
|
| 324 |
+
return parsed_output
|
| 325 |
+
|
| 326 |
+
except Exception as e:
|
| 327 |
+
logger.info(f"Failed to parse with action's parse settings: {e}")
|
| 328 |
+
logger.info("Falling back to using LLM to extract outputs...")
|
| 329 |
+
|
| 330 |
+
# Fall back to extraction prompt if direct parsing fails
|
| 331 |
+
extraction_prompt = self.prepare_extraction_prompt(llm_output_content)
|
| 332 |
+
|
| 333 |
+
llm_extracted_output = await llm.async_generate(prompt=extraction_prompt)
|
| 334 |
+
llm_extracted_data: dict = parse_json_from_llm_output(llm_extracted_output.content)
|
| 335 |
+
output = self.outputs_format.from_dict(llm_extracted_data)
|
| 336 |
+
|
| 337 |
+
# print("Extracted output using fallback:")
|
| 338 |
+
# print(output)
|
| 339 |
+
return output
|
| 340 |
+
|
| 341 |
+
def _call_single_tool(self, function_param: dict) -> tuple:
|
| 342 |
+
try:
|
| 343 |
+
function_name = function_param.get("function_name")
|
| 344 |
+
function_args = function_param.get("function_args") or {}
|
| 345 |
+
|
| 346 |
+
if not function_name:
|
| 347 |
+
return None, "No function name provided"
|
| 348 |
+
|
| 349 |
+
callable_fn = self.tools_caller.get(function_name)
|
| 350 |
+
if not callable(callable_fn):
|
| 351 |
+
return None, f"Function '{function_name}' not found or not callable"
|
| 352 |
+
|
| 353 |
+
print("_____________________ Start Function Calling _____________________")
|
| 354 |
+
print(f"Executing function calling: {function_name} with parameters: {function_args}")
|
| 355 |
+
result = callable_fn(**function_args)
|
| 356 |
+
return result, None
|
| 357 |
+
|
| 358 |
+
except Exception as e:
|
| 359 |
+
logger.error(f"Error executing tool {function_name}: {e}")
|
| 360 |
+
return None, f"Error executing tool {function_name}: {str(e)}"
|
| 361 |
+
|
| 362 |
+
def _calling_tools(self, tool_call_args: List[dict]) -> dict:
|
| 363 |
+
## ___________ Call the tools in parallel___________
|
| 364 |
+
errors = []
|
| 365 |
+
results = []
|
| 366 |
+
|
| 367 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 368 |
+
future_to_tool = {executor.submit(self._call_single_tool, param): param for param in tool_call_args}
|
| 369 |
+
|
| 370 |
+
for future in concurrent.futures.as_completed(future_to_tool):
|
| 371 |
+
result, error = future.result()
|
| 372 |
+
if error:
|
| 373 |
+
errors.append(error)
|
| 374 |
+
if result is not None:
|
| 375 |
+
results.append(result)
|
| 376 |
+
|
| 377 |
+
return {"result": results, "error": errors}
|
| 378 |
+
|
| 379 |
+
async def _async_call_single_tool(self, function_param: dict) -> tuple:
|
| 380 |
+
try:
|
| 381 |
+
function_name = function_param.get("function_name")
|
| 382 |
+
function_args = function_param.get("function_args") or {}
|
| 383 |
+
|
| 384 |
+
if not function_name:
|
| 385 |
+
return None, "No function name provided"
|
| 386 |
+
|
| 387 |
+
callable_fn = self.tools_caller.get(function_name)
|
| 388 |
+
if not callable(callable_fn):
|
| 389 |
+
return None, f"Function '{function_name}' not found or not callable"
|
| 390 |
+
|
| 391 |
+
print("_____________________ Start Function Calling _____________________")
|
| 392 |
+
print(f"Executing function calling: {function_name} with parameters: {function_args}")
|
| 393 |
+
|
| 394 |
+
if inspect.iscoroutinefunction(callable_fn):
|
| 395 |
+
result = await callable_fn(**function_args)
|
| 396 |
+
else:
|
| 397 |
+
loop = asyncio.get_running_loop()
|
| 398 |
+
result = await loop.run_in_executor(None, lambda: callable_fn(**function_args))
|
| 399 |
+
|
| 400 |
+
return result, None
|
| 401 |
+
|
| 402 |
+
except Exception as e:
|
| 403 |
+
logger.error(f"Error executing tool {function_name}: {e}")
|
| 404 |
+
return None, f"Error executing tool {function_name}: {str(e)}"
|
| 405 |
+
|
| 406 |
+
async def _async_calling_tools(self, tool_call_args: List[dict]) -> dict:
|
| 407 |
+
## ___________ Call the tools concurrently ___________
|
| 408 |
+
tasks = [self._async_call_single_tool(param) for param in tool_call_args]
|
| 409 |
+
results_with_errors = await asyncio.gather(*tasks)
|
| 410 |
+
|
| 411 |
+
results = [res for res, err in results_with_errors if err is None and res is not None]
|
| 412 |
+
errors = [err for _, err in results_with_errors if err is not None]
|
| 413 |
+
|
| 414 |
+
return {"result": results, "error": errors}
|
| 415 |
+
|
| 416 |
+
def execute(self, llm: Optional[BaseLLM] = None, inputs: Optional[dict] = None, sys_msg: Optional[str]=None, return_prompt: bool = False, time_out = 0, **kwargs):
|
| 417 |
+
# Allow empty inputs if the action has no required input attributes
|
| 418 |
+
input_attributes: dict = self.inputs_format.get_attr_descriptions()
|
| 419 |
+
if not inputs and input_attributes:
|
| 420 |
+
logger.error("CustomizeAction action received invalid `inputs`: None or empty.")
|
| 421 |
+
raise ValueError('The `inputs` to CustomizeAction action is None or empty.')
|
| 422 |
+
# Set inputs to empty dict if None and no inputs are required
|
| 423 |
+
if inputs is None:
|
| 424 |
+
inputs = {}
|
| 425 |
+
final_llm_response = None
|
| 426 |
+
|
| 427 |
+
if self.prompt_template:
|
| 428 |
+
|
| 429 |
+
if isinstance(self.prompt_template, ChatTemplate):
|
| 430 |
+
# must determine whether prompt_template is ChatTemplate first since ChatTemplate is a subclass of StringTemplate
|
| 431 |
+
conversation = self.prepare_action_prompt(inputs=inputs, system_prompt=sys_msg)
|
| 432 |
+
elif isinstance(self.prompt_template, StringTemplate):
|
| 433 |
+
conversation = [{"role": "system", "content": self.prepare_action_prompt(inputs=inputs, system_prompt=sys_msg)}]
|
| 434 |
+
else:
|
| 435 |
+
raise ValueError(f"`prompt_template` must be a StringTemplate or ChatTemplate instance, but got {type(self.prompt_template)}")
|
| 436 |
+
else:
|
| 437 |
+
conversation = [{"role": "system", "content": sys_msg}, {"role": "user", "content": self.prepare_action_prompt(inputs=inputs, system_prompt=sys_msg)}]
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
## 1. get all the input parameters
|
| 441 |
+
prompt_params_values = {k: inputs.get(k, "") for k in input_attributes.keys()}
|
| 442 |
+
while True:
|
| 443 |
+
### Generate response from LLM
|
| 444 |
+
if time_out > self.max_tool_try:
|
| 445 |
+
# Get the appropriate prompt for return
|
| 446 |
+
current_prompt = self.prepare_action_prompt(inputs=prompt_params_values or {})
|
| 447 |
+
# Use the final LLM response if available, otherwise fall back to execution history
|
| 448 |
+
content_to_extract = final_llm_response if final_llm_response is not None else "{content}".format(content = conversation)
|
| 449 |
+
if return_prompt:
|
| 450 |
+
return self._extract_output(content_to_extract, llm = llm), current_prompt
|
| 451 |
+
return self._extract_output(content_to_extract, llm = llm)
|
| 452 |
+
time_out += 1
|
| 453 |
+
|
| 454 |
+
# Handle both string prompts and chat message lists
|
| 455 |
+
llm_response = llm.generate(messages=conversation)
|
| 456 |
+
conversation.append({"role": "assistant", "content": llm_response.content})
|
| 457 |
+
|
| 458 |
+
# Store the final LLM response
|
| 459 |
+
final_llm_response = llm_response
|
| 460 |
+
|
| 461 |
+
tool_call_args = self._extract_tool_calls(llm_response.content)
|
| 462 |
+
if not tool_call_args:
|
| 463 |
+
break
|
| 464 |
+
|
| 465 |
+
logger.info("Extracted tool call args:")
|
| 466 |
+
logger.info(json.dumps(tool_call_args, indent=4))
|
| 467 |
+
|
| 468 |
+
results = self._calling_tools(tool_call_args)
|
| 469 |
+
|
| 470 |
+
logger.info("Tool call results:")
|
| 471 |
+
logger.info(json.dumps(results, indent=4))
|
| 472 |
+
|
| 473 |
+
conversation.append({"role": "assistant", "content": TOOL_CALLING_HISTORY_PROMPT.format(
|
| 474 |
+
iteration_number=time_out,
|
| 475 |
+
tool_call_args=f"{tool_call_args}",
|
| 476 |
+
results=f"{results}"
|
| 477 |
+
)})
|
| 478 |
+
|
| 479 |
+
# Get the appropriate prompt for return
|
| 480 |
+
current_prompt = self.prepare_action_prompt(inputs=prompt_params_values or {})
|
| 481 |
+
# Use the final LLM response if available, otherwise fall back to execution history
|
| 482 |
+
content_to_extract = final_llm_response if final_llm_response is not None else "{content}".format(content = conversation)
|
| 483 |
+
if return_prompt:
|
| 484 |
+
return self._extract_output(content_to_extract, llm = llm), current_prompt
|
| 485 |
+
return self._extract_output(content_to_extract, llm = llm)
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
async def async_execute(self, llm: Optional[BaseLLM] = None, inputs: Optional[dict] = None, sys_msg: Optional[str]=None, return_prompt: bool = False, time_out = 0, **kwargs):
|
| 489 |
+
# Allow empty inputs if the action has no required input attributes
|
| 490 |
+
input_attributes: dict = self.inputs_format.get_attr_descriptions()
|
| 491 |
+
if not inputs and input_attributes:
|
| 492 |
+
logger.error("CustomizeAction action received invalid `inputs`: None or empty.")
|
| 493 |
+
raise ValueError('The `inputs` to CustomizeAction action is None or empty.')
|
| 494 |
+
# Set inputs to empty dict if None and no inputs are required
|
| 495 |
+
if inputs is None:
|
| 496 |
+
inputs = {}
|
| 497 |
+
final_llm_response = None
|
| 498 |
+
|
| 499 |
+
if self.prompt_template:
|
| 500 |
+
if isinstance(self.prompt_template, ChatTemplate):
|
| 501 |
+
# must determine whether prompt_template is ChatTemplate first since ChatTemplate is a subclass of StringTemplate
|
| 502 |
+
conversation = self.prepare_action_prompt(inputs=inputs, system_prompt=sys_msg)
|
| 503 |
+
elif isinstance(self.prompt_template, StringTemplate):
|
| 504 |
+
conversation = [{"role": "system", "content": self.prepare_action_prompt(inputs=inputs, system_prompt=sys_msg)}]
|
| 505 |
+
else:
|
| 506 |
+
raise ValueError(f"`prompt_template` must be a StringTemplate or ChatTemplate instance, but got {type(self.prompt_template)}")
|
| 507 |
+
else:
|
| 508 |
+
conversation = [{"role": "system", "content": sys_msg}, {"role": "user", "content": self.prepare_action_prompt(inputs=inputs, system_prompt=sys_msg)}]
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
## 1. get all the input parameters
|
| 512 |
+
prompt_params_values = {k: inputs.get(k, "") for k in input_attributes.keys()}
|
| 513 |
+
while True:
|
| 514 |
+
### Generate response from LLM
|
| 515 |
+
if time_out > self.max_tool_try:
|
| 516 |
+
# Get the appropriate prompt for return
|
| 517 |
+
current_prompt = self.prepare_action_prompt(inputs=prompt_params_values or {})
|
| 518 |
+
# Use the final LLM response if available, otherwise fall back to execution history
|
| 519 |
+
content_to_extract = final_llm_response if final_llm_response is not None else "{content}".format(content = conversation)
|
| 520 |
+
if return_prompt:
|
| 521 |
+
return await self._async_extract_output(content_to_extract, llm = llm), current_prompt
|
| 522 |
+
return await self._async_extract_output(content_to_extract, llm = llm)
|
| 523 |
+
time_out += 1
|
| 524 |
+
|
| 525 |
+
# Handle both string prompts and chat message lists
|
| 526 |
+
llm_response = await llm.async_generate(messages=conversation)
|
| 527 |
+
conversation.append({"role": "assistant", "content": llm_response.content})
|
| 528 |
+
|
| 529 |
+
# Store the final LLM response
|
| 530 |
+
final_llm_response = llm_response
|
| 531 |
+
|
| 532 |
+
tool_call_args = self._extract_tool_calls(llm_response.content)
|
| 533 |
+
if not tool_call_args:
|
| 534 |
+
break
|
| 535 |
+
|
| 536 |
+
logger.info("Extracted tool call args:")
|
| 537 |
+
logger.info(json.dumps(tool_call_args, indent=4))
|
| 538 |
+
|
| 539 |
+
results = self._calling_tools(tool_call_args)
|
| 540 |
+
|
| 541 |
+
logger.info("Tool call results:")
|
| 542 |
+
try:
|
| 543 |
+
logger.info(json.dumps(results, indent=4))
|
| 544 |
+
except Exception:
|
| 545 |
+
logger.info(str(results))
|
| 546 |
+
|
| 547 |
+
conversation.append({"role": "assistant", "content": TOOL_CALLING_HISTORY_PROMPT.format(
|
| 548 |
+
iteration_number=time_out,
|
| 549 |
+
tool_call_args=f"{tool_call_args}",
|
| 550 |
+
results=f"{results}"
|
| 551 |
+
)})
|
| 552 |
+
|
| 553 |
+
# Get the appropriate prompt for return
|
| 554 |
+
current_prompt = self.prepare_action_prompt(inputs=prompt_params_values or {})
|
| 555 |
+
# Use the final LLM response if available, otherwise fall back to execution history
|
| 556 |
+
content_to_extract = final_llm_response if final_llm_response is not None else "{content}".format(content = conversation)
|
| 557 |
+
if return_prompt:
|
| 558 |
+
return await self._async_extract_output(content_to_extract, llm = llm), current_prompt
|
| 559 |
+
return await self._async_extract_output(content_to_extract, llm = llm)
|
evoagentx/actions/task_planning.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import Field
|
| 2 |
+
from typing import Optional, List
|
| 3 |
+
|
| 4 |
+
from ..core.logging import logger
|
| 5 |
+
from ..models.base_model import BaseLLM
|
| 6 |
+
from .action import Action, ActionInput, ActionOutput
|
| 7 |
+
from ..prompts.task_planner import TASK_PLANNING_ACTION
|
| 8 |
+
from ..workflow.workflow_graph import WorkFlowNode
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TaskPlanningInput(ActionInput):
|
| 12 |
+
"""
|
| 13 |
+
Input specification for the task planning action.
|
| 14 |
+
"""
|
| 15 |
+
goal: str = Field(description="A clear and detailed description of the user's goal, specifying what needs to be achieved.")
|
| 16 |
+
history: Optional[str] = Field(default=None, description="Optional field containing previously generated task plan.")
|
| 17 |
+
suggestion: Optional[str] = Field(default=None, description="Optional suggestions or ideas to guide the planning process.")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TaskPlanningOutput(ActionOutput):
|
| 21 |
+
"""
|
| 22 |
+
Output structure for the task planning action.
|
| 23 |
+
"""
|
| 24 |
+
sub_tasks: List[WorkFlowNode] = Field(description="A list of sub-tasks that collectively achieve user's goal.")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class TaskPlanning(Action):
|
| 28 |
+
"""
|
| 29 |
+
Action for planning a series of tasks to achieve a goal.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, **kwargs):
|
| 33 |
+
|
| 34 |
+
name = kwargs.pop("name") if "name" in kwargs else TASK_PLANNING_ACTION["name"]
|
| 35 |
+
description = kwargs.pop("description") if "description" in kwargs else TASK_PLANNING_ACTION["description"]
|
| 36 |
+
prompt = kwargs.pop("prompt") if "prompt" in kwargs else TASK_PLANNING_ACTION["prompt"]
|
| 37 |
+
# inputs_format = kwargs.pop("inputs_format") if "inputs_format" in kwargs else TaskPlanningInput
|
| 38 |
+
# outputs_format = kwargs.pop("outputs_format") if "outputs_format" in kwargs else TaskPlanningOutput
|
| 39 |
+
inputs_format = kwargs.pop("inputs_format", None) or TaskPlanningInput
|
| 40 |
+
outputs_format = kwargs.pop("outputs_format", None) or TaskPlanningOutput
|
| 41 |
+
super().__init__(name=name, description=description, prompt=prompt, inputs_format=inputs_format, outputs_format=outputs_format, **kwargs)
|
| 42 |
+
|
| 43 |
+
def execute(self, llm: Optional[BaseLLM] = None, inputs: Optional[dict] = None, sys_msg: Optional[str]=None, return_prompt: bool = False, **kwargs) -> TaskPlanningOutput:
|
| 44 |
+
"""Execute the task planning process.
|
| 45 |
+
|
| 46 |
+
This method uses the provided language model to generate a structured
|
| 47 |
+
plan of sub-tasks based on the user's goal and any additional context.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
llm: The language model to use for planning.
|
| 51 |
+
inputs: Input data containing the goal and optional context.
|
| 52 |
+
sys_msg: Optional system message for the language model.
|
| 53 |
+
return_prompt: Whether to return both the task plan and the prompt used.
|
| 54 |
+
**kwargs: Additional keyword arguments.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
If return_prompt is False (default): The generated task plan.
|
| 58 |
+
If return_prompt is True: A tuple of (task plan, prompt used).
|
| 59 |
+
|
| 60 |
+
Raises:
|
| 61 |
+
ValueError: If the inputs are None or empty.
|
| 62 |
+
"""
|
| 63 |
+
if not inputs:
|
| 64 |
+
logger.error("TaskPlanning action received invalid `inputs`: None or empty.")
|
| 65 |
+
raise ValueError('The `inputs` to TaskPlanning action is None or empty.')
|
| 66 |
+
|
| 67 |
+
prompt_params_names = ["goal", "history", "suggestion"]
|
| 68 |
+
prompt_params_values = {param: inputs.get(param, "") for param in prompt_params_names}
|
| 69 |
+
prompt = self.prompt.format(**prompt_params_values)
|
| 70 |
+
task_plan = llm.generate(
|
| 71 |
+
prompt = prompt,
|
| 72 |
+
system_message = sys_msg,
|
| 73 |
+
parser=self.outputs_format,
|
| 74 |
+
parse_mode="json"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
if return_prompt:
|
| 78 |
+
return task_plan, prompt
|
| 79 |
+
|
| 80 |
+
return task_plan
|
evoagentx/agents/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .agent import Agent
|
| 2 |
+
from .customize_agent import CustomizeAgent
|
| 3 |
+
from .action_agent import ActionAgent
|
| 4 |
+
from .agent_manager import AgentManager
|
| 5 |
+
|
| 6 |
+
__all__ = ["Agent", "CustomizeAgent", "ActionAgent", "AgentManager"]
|
evoagentx/agents/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (485 Bytes). View file
|
|
|
evoagentx/agents/__pycache__/action_agent.cpython-311.pyc
ADDED
|
Binary file (25.6 kB). View file
|
|
|
evoagentx/agents/__pycache__/agent.cpython-311.pyc
ADDED
|
Binary file (25.3 kB). View file
|
|
|
evoagentx/agents/__pycache__/agent_generator.cpython-311.pyc
ADDED
|
Binary file (2.06 kB). View file
|
|
|
evoagentx/agents/__pycache__/agent_manager.cpython-311.pyc
ADDED
|
Binary file (26.7 kB). View file
|
|
|
evoagentx/agents/__pycache__/customize_agent.cpython-311.pyc
ADDED
|
Binary file (28.2 kB). View file
|
|
|
evoagentx/agents/__pycache__/task_planner.cpython-311.pyc
ADDED
|
Binary file (2.85 kB). View file
|
|
|
evoagentx/agents/__pycache__/workflow_reviewer.cpython-311.pyc
ADDED
|
Binary file (1.23 kB). View file
|
|
|
evoagentx/agents/action_agent.py
ADDED
|
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import json
|
| 3 |
+
from pydantic import create_model, Field
|
| 4 |
+
from typing import Optional, Callable, Type, List, Any
|
| 5 |
+
|
| 6 |
+
from .agent import Agent
|
| 7 |
+
from ..core.logging import logger
|
| 8 |
+
from ..core.registry import MODULE_REGISTRY, ACTION_FUNCTION_REGISTRY
|
| 9 |
+
from ..models.model_configs import LLMConfig
|
| 10 |
+
from ..actions.action import Action, ActionOutput, ActionInput
|
| 11 |
+
from ..utils.utils import generate_dynamic_class_name, make_parent_folder
|
| 12 |
+
from ..core.message import Message, MessageType
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ActionAgent(Agent):
|
| 16 |
+
"""
|
| 17 |
+
ActionAgent is a specialized agent that executes a provided function directly without LLM.
|
| 18 |
+
It creates an action that uses the provided function as the execution backbone.
|
| 19 |
+
|
| 20 |
+
Attributes:
|
| 21 |
+
name (str): The name of the agent.
|
| 22 |
+
description (str): A description of the agent's purpose and capabilities.
|
| 23 |
+
inputs (List[dict]): List of input specifications, where each dict contains:
|
| 24 |
+
- name (str): Name of the input parameter
|
| 25 |
+
- type (str): Type of the input
|
| 26 |
+
- description (str): Description of what the input represents
|
| 27 |
+
- required (bool, optional): Whether this input is required (default: True)
|
| 28 |
+
outputs (List[dict]): List of output specifications, where each dict contains:
|
| 29 |
+
- name (str): Name of the output field
|
| 30 |
+
- type (str): Type of the output
|
| 31 |
+
- description (str): Description of what the output represents
|
| 32 |
+
- required (bool, optional): Whether this output is required (default: True)
|
| 33 |
+
execute_func (Callable): The function to execute the agent.
|
| 34 |
+
async_execute_func (Callable, Optional): Async version of the function. If not provided,
|
| 35 |
+
an async wrapper will be automatically created around execute_func.
|
| 36 |
+
llm_config (LLMConfig, optional): Configuration for the language model (minimal usage).
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
name: str,
|
| 43 |
+
description: str,
|
| 44 |
+
inputs: List[dict],
|
| 45 |
+
outputs: List[dict],
|
| 46 |
+
execute_func: Callable,
|
| 47 |
+
async_execute_func: Optional[Callable] = None,
|
| 48 |
+
llm_config: Optional[LLMConfig] = None,
|
| 49 |
+
**kwargs
|
| 50 |
+
):
|
| 51 |
+
# Validate inputs
|
| 52 |
+
if not callable(execute_func):
|
| 53 |
+
raise ValueError("execute_func must be callable")
|
| 54 |
+
|
| 55 |
+
if async_execute_func is not None and not callable(async_execute_func):
|
| 56 |
+
raise ValueError("async_execute_func must be callable")
|
| 57 |
+
|
| 58 |
+
# Validate inputs and outputs
|
| 59 |
+
self._validate_inputs_outputs(inputs, outputs)
|
| 60 |
+
|
| 61 |
+
# Set is_human based on LLM availability
|
| 62 |
+
is_human = llm_config is None
|
| 63 |
+
|
| 64 |
+
# Initialize parent directly
|
| 65 |
+
super().__init__(
|
| 66 |
+
name=name,
|
| 67 |
+
description=description,
|
| 68 |
+
llm_config=llm_config,
|
| 69 |
+
is_human=is_human,
|
| 70 |
+
**kwargs
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# Store function references and metadata
|
| 74 |
+
self.execute_func = execute_func
|
| 75 |
+
self.async_execute_func = async_execute_func
|
| 76 |
+
self.inputs = inputs
|
| 77 |
+
self.outputs = outputs
|
| 78 |
+
|
| 79 |
+
# Create and add the function-based action
|
| 80 |
+
action = self._create_function_action_with_params(
|
| 81 |
+
name, execute_func, async_execute_func, inputs, outputs
|
| 82 |
+
)
|
| 83 |
+
self.add_action(action)
|
| 84 |
+
|
| 85 |
+
def init_llm(self):
|
| 86 |
+
pass
|
| 87 |
+
|
| 88 |
+
def _validate_inputs_outputs(self, inputs: List[dict], outputs: List[dict]):
|
| 89 |
+
"""Validate the structure of inputs and outputs."""
|
| 90 |
+
# Allow empty inputs for functions that don't require any inputs
|
| 91 |
+
if inputs is None:
|
| 92 |
+
inputs = []
|
| 93 |
+
|
| 94 |
+
if outputs is None:
|
| 95 |
+
outputs = []
|
| 96 |
+
|
| 97 |
+
# Validate inputs structure
|
| 98 |
+
for i, input_field in enumerate(inputs):
|
| 99 |
+
if not isinstance(input_field, dict):
|
| 100 |
+
raise ValueError(f"Input field {i} must be a dictionary, got {type(input_field)}")
|
| 101 |
+
|
| 102 |
+
required_keys = ["name", "type", "description"]
|
| 103 |
+
for key in required_keys:
|
| 104 |
+
if key not in input_field:
|
| 105 |
+
raise ValueError(f"Input field {i} missing required key '{key}'")
|
| 106 |
+
|
| 107 |
+
if not isinstance(input_field["name"], str):
|
| 108 |
+
raise ValueError(f"Input field {i} 'name' must be a string, got {type(input_field['name'])}")
|
| 109 |
+
|
| 110 |
+
if not isinstance(input_field["type"], str):
|
| 111 |
+
raise ValueError(f"Input field {i} 'type' must be a string, got {type(input_field['type'])}")
|
| 112 |
+
|
| 113 |
+
if not isinstance(input_field["description"], str):
|
| 114 |
+
raise ValueError(f"Input field {i} 'description' must be a string, got {type(input_field['description'])}")
|
| 115 |
+
|
| 116 |
+
# Check for duplicate input names
|
| 117 |
+
input_names = [field["name"] for field in inputs]
|
| 118 |
+
if len(input_names) != len(set(input_names)):
|
| 119 |
+
raise ValueError(f"Duplicate input names found: {[name for name in input_names if input_names.count(name) > 1]}")
|
| 120 |
+
|
| 121 |
+
# Validate outputs structure
|
| 122 |
+
for i, output_field in enumerate(outputs):
|
| 123 |
+
if not isinstance(output_field, dict):
|
| 124 |
+
raise ValueError(f"Output field {i} must be a dictionary, got {type(output_field)}")
|
| 125 |
+
|
| 126 |
+
required_keys = ["name", "type", "description"]
|
| 127 |
+
for key in required_keys:
|
| 128 |
+
if key not in output_field:
|
| 129 |
+
raise ValueError(f"Output field {i} missing required key '{key}'")
|
| 130 |
+
|
| 131 |
+
if not isinstance(output_field["name"], str):
|
| 132 |
+
raise ValueError(f"Output field {i} 'name' must be a string, got {type(output_field['name'])}")
|
| 133 |
+
|
| 134 |
+
if not isinstance(output_field["type"], str):
|
| 135 |
+
raise ValueError(f"Output field {i} 'type' must be a string, got {type(output_field['type'])}")
|
| 136 |
+
|
| 137 |
+
if not isinstance(output_field["description"], str):
|
| 138 |
+
raise ValueError(f"Output field {i} 'description' must be a string, got {type(output_field['description'])}")
|
| 139 |
+
|
| 140 |
+
# Check for duplicate output names
|
| 141 |
+
output_names = [field["name"] for field in outputs]
|
| 142 |
+
if len(output_names) != len(set(output_names)):
|
| 143 |
+
raise ValueError(f"Duplicate output names found: {[name for name in output_names if output_names.count(name) > 1]}")
|
| 144 |
+
|
| 145 |
+
def _create_function_action_input_type(self, name: str, inputs: List[dict]) -> Type[ActionInput]:
|
| 146 |
+
"""Create ActionInput type from input specifications."""
|
| 147 |
+
action_input_fields = {}
|
| 148 |
+
for field in inputs:
|
| 149 |
+
required = field.get("required", True)
|
| 150 |
+
if required:
|
| 151 |
+
action_input_fields[field["name"]] = (str, Field(description=field["description"]))
|
| 152 |
+
else:
|
| 153 |
+
action_input_fields[field["name"]] = (Optional[str], Field(default=None, description=field["description"]))
|
| 154 |
+
|
| 155 |
+
action_input_type = create_model(
|
| 156 |
+
self._get_unique_class_name(
|
| 157 |
+
generate_dynamic_class_name(f"{name} action_input")
|
| 158 |
+
),
|
| 159 |
+
**action_input_fields,
|
| 160 |
+
__base__=ActionInput
|
| 161 |
+
)
|
| 162 |
+
return action_input_type
|
| 163 |
+
|
| 164 |
+
def _create_function_action_output_type(self, name: str, outputs: List[dict]) -> Type[ActionOutput]:
|
| 165 |
+
"""Create ActionOutput type from output specifications."""
|
| 166 |
+
action_output_fields = {}
|
| 167 |
+
for field in outputs:
|
| 168 |
+
required = field.get("required", True)
|
| 169 |
+
if required:
|
| 170 |
+
action_output_fields[field["name"]] = (Any, Field(description=field["description"]))
|
| 171 |
+
else:
|
| 172 |
+
action_output_fields[field["name"]] = (Optional[Any], Field(default=None, description=field["description"]))
|
| 173 |
+
|
| 174 |
+
action_output_type = create_model(
|
| 175 |
+
self._get_unique_class_name(
|
| 176 |
+
generate_dynamic_class_name(f"{name} action_output")
|
| 177 |
+
),
|
| 178 |
+
**action_output_fields,
|
| 179 |
+
__base__=ActionOutput
|
| 180 |
+
)
|
| 181 |
+
return action_output_type
|
| 182 |
+
|
| 183 |
+
def _create_execute_method(self, execute_func: Callable):
|
| 184 |
+
"""Create the execute method for the action."""
|
| 185 |
+
def execute_method(action_self, llm=None, inputs=None, sys_msg=None, return_prompt=False, **kwargs):
|
| 186 |
+
# Validate inputs
|
| 187 |
+
if inputs is None:
|
| 188 |
+
inputs = {}
|
| 189 |
+
|
| 190 |
+
# Validate that all required inputs are provided
|
| 191 |
+
required_inputs = action_self.inputs_format.get_required_input_names()
|
| 192 |
+
missing_inputs = [input_name for input_name in required_inputs if input_name not in inputs]
|
| 193 |
+
if missing_inputs:
|
| 194 |
+
raise ValueError(f"Missing required inputs: {missing_inputs}")
|
| 195 |
+
|
| 196 |
+
# Validate input types (basic validation)
|
| 197 |
+
filtered_inputs = {}
|
| 198 |
+
for input_name, input_value in inputs.items():
|
| 199 |
+
if input_name in [field["name"] for field in self.inputs]:
|
| 200 |
+
filtered_inputs[input_name] = input_value
|
| 201 |
+
else:
|
| 202 |
+
logger.warning(f"Unexpected input '{input_name}' provided")
|
| 203 |
+
|
| 204 |
+
# Execute function
|
| 205 |
+
try:
|
| 206 |
+
result = execute_func(**filtered_inputs)
|
| 207 |
+
except Exception as e:
|
| 208 |
+
# Create error output - try to use error field if it exists, otherwise use first available field
|
| 209 |
+
try:
|
| 210 |
+
# Check if output format has an error field
|
| 211 |
+
output_fields = action_self.outputs_format.get_attrs()
|
| 212 |
+
if "error" in output_fields:
|
| 213 |
+
error_output = action_self.outputs_format(
|
| 214 |
+
error=f"Function execution failed: {str(e)}"
|
| 215 |
+
)
|
| 216 |
+
elif len(output_fields) > 0:
|
| 217 |
+
# Use the first field as error field
|
| 218 |
+
first_field = output_fields[0]
|
| 219 |
+
error_output = action_self.outputs_format(**{first_field: f"Error: {str(e)}"})
|
| 220 |
+
else:
|
| 221 |
+
# Fallback to creating a simple output with error message
|
| 222 |
+
error_output = action_self.outputs_format()
|
| 223 |
+
except Exception as create_error:
|
| 224 |
+
# If all else fails, create a minimal output
|
| 225 |
+
logger.error(f"Failed to create error output: {create_error}")
|
| 226 |
+
error_output = action_self.outputs_format()
|
| 227 |
+
return error_output, "Function execution"
|
| 228 |
+
|
| 229 |
+
# Create success output using the parse method
|
| 230 |
+
if isinstance(result, dict):
|
| 231 |
+
# For dict results, create output directly
|
| 232 |
+
output = action_self.outputs_format(**result)
|
| 233 |
+
else:
|
| 234 |
+
# For simple values, create output with the first field
|
| 235 |
+
output_fields = action_self.outputs_format.get_attrs()
|
| 236 |
+
if len(output_fields) > 0:
|
| 237 |
+
first_field = output_fields[0]
|
| 238 |
+
output = action_self.outputs_format(**{first_field: result})
|
| 239 |
+
else:
|
| 240 |
+
# Fallback to creating empty output
|
| 241 |
+
output = action_self.outputs_format()
|
| 242 |
+
|
| 243 |
+
return output, "Function execution"
|
| 244 |
+
|
| 245 |
+
return execute_method
|
| 246 |
+
|
| 247 |
+
def _create_async_execute_method(self, async_execute_func: Callable, execute_func: Callable):
|
| 248 |
+
"""Create the async execute method for the action."""
|
| 249 |
+
async def async_execute_method(action_self, llm=None, inputs=None, sys_msg=None, return_prompt=False, **kwargs):
|
| 250 |
+
# Validate inputs
|
| 251 |
+
if inputs is None:
|
| 252 |
+
inputs = {}
|
| 253 |
+
|
| 254 |
+
# Validate that all required inputs are provided
|
| 255 |
+
required_inputs = action_self.inputs_format.get_required_input_names()
|
| 256 |
+
missing_inputs = [input_name for input_name in required_inputs if input_name not in inputs]
|
| 257 |
+
if missing_inputs:
|
| 258 |
+
raise ValueError(f"Missing required inputs: {missing_inputs}")
|
| 259 |
+
|
| 260 |
+
# Validate input types (basic validation)
|
| 261 |
+
filtered_inputs = {}
|
| 262 |
+
for input_name, input_value in inputs.items():
|
| 263 |
+
if input_name in [field["name"] for field in self.inputs]:
|
| 264 |
+
filtered_inputs[input_name] = input_value
|
| 265 |
+
else:
|
| 266 |
+
logger.warning(f"Unexpected input '{input_name}' provided")
|
| 267 |
+
|
| 268 |
+
# Execute async function
|
| 269 |
+
try:
|
| 270 |
+
if async_execute_func is not None:
|
| 271 |
+
result = await async_execute_func(**filtered_inputs)
|
| 272 |
+
else:
|
| 273 |
+
# Use sync function in async context
|
| 274 |
+
loop = asyncio.get_event_loop()
|
| 275 |
+
result = await loop.run_in_executor(None, lambda: execute_func(**filtered_inputs))
|
| 276 |
+
except Exception as e:
|
| 277 |
+
# Create error output - try to use error field if it exists, otherwise use first available field
|
| 278 |
+
try:
|
| 279 |
+
# Check if output format has an error field
|
| 280 |
+
output_fields = action_self.outputs_format.get_attrs()
|
| 281 |
+
if "error" in output_fields:
|
| 282 |
+
error_output = action_self.outputs_format(
|
| 283 |
+
error=f"Async function execution failed: {str(e)}"
|
| 284 |
+
)
|
| 285 |
+
elif len(output_fields) > 0:
|
| 286 |
+
# Use the first field as error field
|
| 287 |
+
first_field = list(output_fields.keys())[0]
|
| 288 |
+
error_output = action_self.outputs_format(**{first_field: f"Error: {str(e)}"})
|
| 289 |
+
else:
|
| 290 |
+
# Fallback to creating a simple output with error message
|
| 291 |
+
error_output = action_self.outputs_format()
|
| 292 |
+
except Exception as create_error:
|
| 293 |
+
# If all else fails, create a minimal output
|
| 294 |
+
logger.error(f"Failed to create error output: {create_error}")
|
| 295 |
+
error_output = action_self.outputs_format()
|
| 296 |
+
return error_output, "Async function execution"
|
| 297 |
+
|
| 298 |
+
# Create success output using the parse method
|
| 299 |
+
if isinstance(result, dict):
|
| 300 |
+
# For dict results, create output directly
|
| 301 |
+
output = action_self.outputs_format(**result)
|
| 302 |
+
else:
|
| 303 |
+
# For simple values, create output with the first field
|
| 304 |
+
output_fields = action_self.outputs_format.get_attrs()
|
| 305 |
+
if len(output_fields) > 0:
|
| 306 |
+
first_field = output_fields[0]
|
| 307 |
+
output = action_self.outputs_format(**{first_field: result})
|
| 308 |
+
else:
|
| 309 |
+
# Fallback to creating empty output
|
| 310 |
+
output = action_self.outputs_format()
|
| 311 |
+
|
| 312 |
+
return output, "Async function execution"
|
| 313 |
+
|
| 314 |
+
return async_execute_method
|
| 315 |
+
|
| 316 |
+
def _create_function_action_with_params(self, name: str, execute_func: Callable, async_execute_func: Callable, inputs: List[dict], outputs: List[dict]) -> Action:
|
| 317 |
+
"""Create an action that executes the provided function with given parameters."""
|
| 318 |
+
|
| 319 |
+
# Create input/output types
|
| 320 |
+
action_input_type = self._create_function_action_input_type(name, inputs)
|
| 321 |
+
action_output_type = self._create_function_action_output_type(name, outputs)
|
| 322 |
+
|
| 323 |
+
# Create custom action class
|
| 324 |
+
action_cls_name = self._get_unique_class_name(
|
| 325 |
+
generate_dynamic_class_name(f"{name} function action")
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
# Create action class with function execution
|
| 329 |
+
function_action_cls = create_model(
|
| 330 |
+
action_cls_name,
|
| 331 |
+
__base__=Action
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
# Create action instance
|
| 335 |
+
function_action = function_action_cls(
|
| 336 |
+
name=action_cls_name,
|
| 337 |
+
description=f"Executes {execute_func.__name__} function",
|
| 338 |
+
inputs_format=action_input_type,
|
| 339 |
+
outputs_format=action_output_type
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
# Override execute methods - bind them properly to the action instance
|
| 343 |
+
execute_method = self._create_execute_method(execute_func)
|
| 344 |
+
async_execute_method = self._create_async_execute_method(async_execute_func, execute_func)
|
| 345 |
+
|
| 346 |
+
# Bind the methods to the action instance
|
| 347 |
+
function_action.execute = execute_method.__get__(function_action, type(function_action))
|
| 348 |
+
function_action.async_execute = async_execute_method.__get__(function_action, type(function_action))
|
| 349 |
+
|
| 350 |
+
return function_action
|
| 351 |
+
|
| 352 |
+
def _create_function_action(self, name: str, execute_func: Callable, async_execute_func: Callable, inputs: List[dict], outputs: List[dict]) -> Action:
|
| 353 |
+
"""Create an action that executes the provided function."""
|
| 354 |
+
return self._create_function_action_with_params(
|
| 355 |
+
name,
|
| 356 |
+
execute_func,
|
| 357 |
+
async_execute_func,
|
| 358 |
+
inputs,
|
| 359 |
+
outputs
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
def get_config(self) -> dict:
|
| 363 |
+
"""Get configuration for the ActionAgent."""
|
| 364 |
+
# Get base config from Agent
|
| 365 |
+
config = super().get_config()
|
| 366 |
+
|
| 367 |
+
# Add ActionAgent-specific information
|
| 368 |
+
config.update({
|
| 369 |
+
"class_name": "ActionAgent",
|
| 370 |
+
"execute_func_name": self.execute_func.__name__ if self.execute_func else None,
|
| 371 |
+
"async_execute_func_name": self.async_execute_func.__name__ if self.async_execute_func else None,
|
| 372 |
+
"inputs": self.inputs,
|
| 373 |
+
"outputs": self.outputs
|
| 374 |
+
})
|
| 375 |
+
return config
|
| 376 |
+
|
| 377 |
+
def save_module(self, path: str, ignore: List[str] = [], **kwargs) -> str:
|
| 378 |
+
"""Save the ActionAgent configuration to a JSON file.
|
| 379 |
+
|
| 380 |
+
Args:
|
| 381 |
+
path: File path where the configuration should be saved
|
| 382 |
+
ignore: List of keys to exclude from the saved configuration
|
| 383 |
+
**kwargs (Any): Additional parameters for the save operation
|
| 384 |
+
|
| 385 |
+
Returns:
|
| 386 |
+
The path where the configuration was saved
|
| 387 |
+
"""
|
| 388 |
+
config = self.get_config()
|
| 389 |
+
|
| 390 |
+
# Add ActionAgent-specific information
|
| 391 |
+
config.update({
|
| 392 |
+
"class_name": "ActionAgent",
|
| 393 |
+
"execute_func_name": self.execute_func.__name__ if self.execute_func else None,
|
| 394 |
+
"async_execute_func_name": self.async_execute_func.__name__ if self.async_execute_func else None,
|
| 395 |
+
"inputs": self.inputs,
|
| 396 |
+
"outputs": self.outputs
|
| 397 |
+
})
|
| 398 |
+
|
| 399 |
+
# Remove non-serializable items
|
| 400 |
+
for ignore_key in ignore:
|
| 401 |
+
config.pop(ignore_key, None)
|
| 402 |
+
|
| 403 |
+
# Save to JSON file
|
| 404 |
+
make_parent_folder(path)
|
| 405 |
+
with open(path, 'w', encoding='utf-8') as f:
|
| 406 |
+
json.dump(config, f, indent=4, ensure_ascii=False)
|
| 407 |
+
|
| 408 |
+
return path
|
| 409 |
+
|
| 410 |
+
@classmethod
|
| 411 |
+
def load_module(cls, path: str, llm_config: LLMConfig = None, **kwargs) -> "ActionAgent":
|
| 412 |
+
"""Load the ActionAgent from a JSON file.
|
| 413 |
+
|
| 414 |
+
Args:
|
| 415 |
+
path: The path of the file
|
| 416 |
+
llm_config: The LLMConfig instance (optional)
|
| 417 |
+
**kwargs: Additional keyword arguments
|
| 418 |
+
|
| 419 |
+
Returns:
|
| 420 |
+
ActionAgent: The loaded agent instance
|
| 421 |
+
|
| 422 |
+
Raises:
|
| 423 |
+
KeyError: If required functions are not found in the registry
|
| 424 |
+
"""
|
| 425 |
+
# Load configuration
|
| 426 |
+
with open(path, 'r', encoding='utf-8') as f:
|
| 427 |
+
config = json.load(f)
|
| 428 |
+
|
| 429 |
+
# Extract function names
|
| 430 |
+
execute_func_name = config.get("execute_func_name")
|
| 431 |
+
async_execute_func_name = config.get("async_execute_func_name")
|
| 432 |
+
|
| 433 |
+
# Retrieve functions from registry
|
| 434 |
+
execute_func = None
|
| 435 |
+
async_execute_func = None
|
| 436 |
+
|
| 437 |
+
if execute_func_name:
|
| 438 |
+
if not ACTION_FUNCTION_REGISTRY.has_function(execute_func_name):
|
| 439 |
+
raise KeyError(f"Function '{execute_func_name}' not found in registry. Please register it first.")
|
| 440 |
+
execute_func = ACTION_FUNCTION_REGISTRY.get_function(execute_func_name)
|
| 441 |
+
|
| 442 |
+
if async_execute_func_name:
|
| 443 |
+
if not ACTION_FUNCTION_REGISTRY.has_function(async_execute_func_name):
|
| 444 |
+
raise KeyError(f"Function '{async_execute_func_name}' not found in registry. Please register it first.")
|
| 445 |
+
async_execute_func = ACTION_FUNCTION_REGISTRY.get_function(async_execute_func_name)
|
| 446 |
+
|
| 447 |
+
# Create agent
|
| 448 |
+
agent = cls(
|
| 449 |
+
name=config["name"],
|
| 450 |
+
description=config["description"],
|
| 451 |
+
inputs=config["inputs"],
|
| 452 |
+
outputs=config["outputs"],
|
| 453 |
+
execute_func=execute_func,
|
| 454 |
+
async_execute_func=async_execute_func,
|
| 455 |
+
llm_config=llm_config,
|
| 456 |
+
**kwargs
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
return agent
|
| 460 |
+
|
| 461 |
+
def __call__(self, inputs: dict = None, return_msg_type: MessageType = MessageType.UNKNOWN, **kwargs) -> Message:
|
| 462 |
+
"""
|
| 463 |
+
Call the main function action.
|
| 464 |
+
|
| 465 |
+
Args:
|
| 466 |
+
inputs (dict): The inputs to the function action.
|
| 467 |
+
return_msg_type (MessageType): The type of message to return.
|
| 468 |
+
**kwargs (Any): Additional keyword arguments.
|
| 469 |
+
|
| 470 |
+
Returns:
|
| 471 |
+
Message: The output of the function action.
|
| 472 |
+
"""
|
| 473 |
+
inputs = inputs or {}
|
| 474 |
+
return super().__call__(action_name=self.main_action_name, action_input_data=inputs, return_msg_type=return_msg_type, **kwargs)
|
| 475 |
+
|
| 476 |
+
@property
|
| 477 |
+
def main_action_name(self) -> str:
|
| 478 |
+
"""
|
| 479 |
+
Get the name of the main function action for this agent.
|
| 480 |
+
|
| 481 |
+
Returns:
|
| 482 |
+
The name of the main function action
|
| 483 |
+
"""
|
| 484 |
+
for action in self.actions:
|
| 485 |
+
if action.name != self.cext_action_name:
|
| 486 |
+
return action.name
|
| 487 |
+
raise ValueError("Couldn't find the main action name!")
|
| 488 |
+
|
| 489 |
+
def _get_unique_class_name(self, candidate_name: str) -> str:
|
| 490 |
+
"""
|
| 491 |
+
Get a unique class name by checking if it already exists in the registry.
|
| 492 |
+
If it does, append "Vx" to make it unique.
|
| 493 |
+
"""
|
| 494 |
+
if not MODULE_REGISTRY.has_module(candidate_name):
|
| 495 |
+
return candidate_name
|
| 496 |
+
|
| 497 |
+
counter = 1
|
| 498 |
+
while True:
|
| 499 |
+
new_name = f"{candidate_name}V{counter}"
|
| 500 |
+
if not MODULE_REGISTRY.has_module(new_name):
|
| 501 |
+
return new_name
|
| 502 |
+
counter += 1
|
evoagentx/agents/agent.py
ADDED
|
@@ -0,0 +1,531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import inspect
|
| 3 |
+
from pydantic import Field
|
| 4 |
+
from typing import Type, Optional, Union, Tuple, List, Any, Coroutine
|
| 5 |
+
|
| 6 |
+
from ..core.module import BaseModule
|
| 7 |
+
from ..core.module_utils import generate_id
|
| 8 |
+
from ..core.message import Message, MessageType
|
| 9 |
+
from ..core.registry import MODEL_REGISTRY
|
| 10 |
+
from ..models.model_configs import LLMConfig
|
| 11 |
+
from ..models.base_model import BaseLLM
|
| 12 |
+
from ..memory.memory import ShortTermMemory
|
| 13 |
+
from ..memory.long_term_memory import LongTermMemory
|
| 14 |
+
from ..memory.memory_manager import MemoryManager
|
| 15 |
+
from ..storages.base import StorageHandler
|
| 16 |
+
from ..actions.action import Action
|
| 17 |
+
from ..actions.action import ContextExtraction
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class Agent(BaseModule):
|
| 21 |
+
"""
|
| 22 |
+
Base class for all agents.
|
| 23 |
+
|
| 24 |
+
Attributes:
|
| 25 |
+
name (str): Unique identifier for the agent
|
| 26 |
+
description (str): Human-readable description of the agent's purpose
|
| 27 |
+
llm_config (Optional[LLMConfig]): Configuration for the language model. If provided, a new LLM instance will be created.
|
| 28 |
+
Otherwise, the existing LLM instance specified in the `llm` field will be used.
|
| 29 |
+
llm (Optional[BaseLLM]): Language model instance. If provided, the existing LLM instance will be used.
|
| 30 |
+
agent_id (Optional[str]): Unique ID for the agent, auto-generated if not provided
|
| 31 |
+
system_prompt (Optional[str]): System prompt for the Agent.
|
| 32 |
+
actions (List[Action]): List of available actions
|
| 33 |
+
n (Optional[int]): Number of latest messages used to provide context for action execution. It uses all the messages in short term memory by default.
|
| 34 |
+
is_human (bool): Whether this agent represents a human user
|
| 35 |
+
version (int): Version number of the agent, default is 0.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
name: str # should be unique
|
| 39 |
+
description: str
|
| 40 |
+
llm_config: Optional[LLMConfig] = None
|
| 41 |
+
llm: Optional[BaseLLM] = None
|
| 42 |
+
agent_id: Optional[str] = Field(default_factory=generate_id)
|
| 43 |
+
system_prompt: Optional[str] = None
|
| 44 |
+
short_term_memory: Optional[ShortTermMemory] = Field(default_factory=ShortTermMemory) # store short term memory for a single workflow.
|
| 45 |
+
use_long_term_memory: Optional[bool] = False
|
| 46 |
+
storage_handler: Optional[StorageHandler] = None
|
| 47 |
+
long_term_memory: Optional[LongTermMemory] = None
|
| 48 |
+
long_term_memory_manager: Optional[MemoryManager] = None
|
| 49 |
+
actions: List[Action] = Field(default=None)
|
| 50 |
+
n: int = Field(default=None, description="number of latest messages used to provide context for action execution. It uses all the messages in short term memory by default.")
|
| 51 |
+
is_human: bool = Field(default=False)
|
| 52 |
+
version: int = 0
|
| 53 |
+
|
| 54 |
+
def init_module(self):
|
| 55 |
+
if not self.is_human:
|
| 56 |
+
self.init_llm()
|
| 57 |
+
if self.use_long_term_memory:
|
| 58 |
+
self.init_long_term_memory()
|
| 59 |
+
self.actions = [] if self.actions is None else self.actions
|
| 60 |
+
self._action_map = {action.name: action for action in self.actions} if self.actions else dict()
|
| 61 |
+
self._save_ignore_fields = ["llm", "llm_config"]
|
| 62 |
+
self.init_context_extractor()
|
| 63 |
+
|
| 64 |
+
# def __call__(self, *args, **kwargs) -> Message:
|
| 65 |
+
# """Make the agent callable and automatically choose between sync and async execution"""
|
| 66 |
+
# if asyncio.iscoroutinefunction(self.async_execute) and asyncio.get_event_loop().is_running():
|
| 67 |
+
# # If the operator is in an asynchronous environment and has an execute_async method, return a coroutine
|
| 68 |
+
# return self.async_execute(*args, **kwargs)
|
| 69 |
+
# # Otherwise, use the synchronous method
|
| 70 |
+
# return self.execute(*args, **kwargs)
|
| 71 |
+
|
| 72 |
+
def __call__(self, *args: Any, **kwargs: Any) -> Union[dict, Coroutine[Any, Any, dict]]:
|
| 73 |
+
"""Make the operator callable and automatically choose between sync and async execution."""
|
| 74 |
+
try:
|
| 75 |
+
# Safe way to check if we're inside an async environment
|
| 76 |
+
asyncio.get_running_loop()
|
| 77 |
+
return self.async_execute(*args, **kwargs)
|
| 78 |
+
except RuntimeError:
|
| 79 |
+
# No running loop — likely in sync context or worker thread
|
| 80 |
+
return self.execute(*args, **kwargs)
|
| 81 |
+
|
| 82 |
+
def _prepare_execution(
|
| 83 |
+
self,
|
| 84 |
+
action_name: str,
|
| 85 |
+
msgs: Optional[List[Message]] = None,
|
| 86 |
+
action_input_data: Optional[dict] = None,
|
| 87 |
+
**kwargs
|
| 88 |
+
) -> Tuple[Action, dict]:
|
| 89 |
+
"""Prepare for action execution by updating memory and getting inputs.
|
| 90 |
+
|
| 91 |
+
Helper method used by both execute and aexecute methods.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
action_name: The name of the action to execute
|
| 95 |
+
msgs: Optional list of messages providing context for the action
|
| 96 |
+
action_input_data: Optional pre-extracted input data for the action
|
| 97 |
+
**kwargs: Additional workflow parameters
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
Tuple containing the action object and input data
|
| 101 |
+
|
| 102 |
+
Raises:
|
| 103 |
+
AssertionError: If neither msgs nor action_input_data is provided
|
| 104 |
+
"""
|
| 105 |
+
assert msgs is not None or action_input_data is not None, "must provide either `msgs` or `action_input_data`"
|
| 106 |
+
action = self.get_action(action_name=action_name)
|
| 107 |
+
|
| 108 |
+
# update short-term memory
|
| 109 |
+
if msgs is not None:
|
| 110 |
+
# directly add messages to short-term memory
|
| 111 |
+
self.short_term_memory.add_messages(msgs)
|
| 112 |
+
if action_input_data is not None:
|
| 113 |
+
# create a message from action_input_data and add it to short-term memory
|
| 114 |
+
input_message = Message(
|
| 115 |
+
content = action_input_data,
|
| 116 |
+
next_actions = [action_name],
|
| 117 |
+
msg_type = MessageType.INPUT,
|
| 118 |
+
wf_goal = kwargs.get("wf_goal", None),
|
| 119 |
+
wf_task = kwargs.get("wf_task", None),
|
| 120 |
+
wf_task_desc = kwargs.get("wf_task_desc", None)
|
| 121 |
+
)
|
| 122 |
+
self.short_term_memory.add_message(input_message)
|
| 123 |
+
|
| 124 |
+
# obtain action input data from short term memory if not provided
|
| 125 |
+
action_input_data = action_input_data or self.get_action_inputs(action=action)
|
| 126 |
+
|
| 127 |
+
return action, action_input_data
|
| 128 |
+
|
| 129 |
+
def _create_output_message(
|
| 130 |
+
self,
|
| 131 |
+
action_output,
|
| 132 |
+
prompt: str,
|
| 133 |
+
action_name: str,
|
| 134 |
+
return_msg_type: Optional[MessageType] = MessageType.UNKNOWN,
|
| 135 |
+
**kwargs
|
| 136 |
+
) -> Message:
|
| 137 |
+
"""Create a message from execution results and update memory.
|
| 138 |
+
|
| 139 |
+
Helper method used by both execute and aexecute methods.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
action_output: The output from action execution
|
| 143 |
+
prompt: The prompt used for execution
|
| 144 |
+
action_name: The name of the executed action
|
| 145 |
+
return_msg_type: Message type for the return message
|
| 146 |
+
**kwargs: Additional workflow parameters
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
Message object containing execution results
|
| 150 |
+
"""
|
| 151 |
+
# formulate a message
|
| 152 |
+
message = Message(
|
| 153 |
+
content=action_output,
|
| 154 |
+
agent=self.name,
|
| 155 |
+
action=action_name,
|
| 156 |
+
prompt=prompt,
|
| 157 |
+
msg_type=return_msg_type,
|
| 158 |
+
wf_goal = kwargs.get("wf_goal", None),
|
| 159 |
+
wf_task = kwargs.get("wf_task", None),
|
| 160 |
+
wf_task_desc = kwargs.get("wf_task_desc", None)
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# update short-term memory
|
| 164 |
+
self.short_term_memory.add_message(message)
|
| 165 |
+
|
| 166 |
+
return message
|
| 167 |
+
|
| 168 |
+
async def async_execute(
|
| 169 |
+
self,
|
| 170 |
+
action_name: str,
|
| 171 |
+
msgs: Optional[List[Message]] = None,
|
| 172 |
+
action_input_data: Optional[dict] = None,
|
| 173 |
+
return_msg_type: Optional[MessageType] = MessageType.UNKNOWN,
|
| 174 |
+
return_action_input_data: Optional[bool] = False,
|
| 175 |
+
**kwargs
|
| 176 |
+
) -> Union[Message, Tuple[Message, dict]]:
|
| 177 |
+
"""Execute an action asynchronously with the given context and return results.
|
| 178 |
+
|
| 179 |
+
This is the async version of the execute method, allowing it to perform actions
|
| 180 |
+
based on the current conversation context.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
action_name: The name of the action to execute
|
| 184 |
+
msgs: Optional list of messages providing context for the action
|
| 185 |
+
action_input_data: Optional pre-extracted input data for the action
|
| 186 |
+
return_msg_type: Message type for the return message
|
| 187 |
+
**kwargs (Any): Additional parameters, may include workflow information
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
Message: A message containing the execution results
|
| 191 |
+
"""
|
| 192 |
+
action, action_input_data = self._prepare_execution(
|
| 193 |
+
action_name=action_name,
|
| 194 |
+
msgs=msgs,
|
| 195 |
+
action_input_data=action_input_data,
|
| 196 |
+
**kwargs
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
# execute action asynchronously
|
| 200 |
+
async_execute_source = inspect.getsource(action.async_execute)
|
| 201 |
+
if "NotImplementedError" in async_execute_source:
|
| 202 |
+
# if the async_execute method is not implemented, use the execute method instead
|
| 203 |
+
execution_results = action.execute(
|
| 204 |
+
llm=self.llm,
|
| 205 |
+
inputs=action_input_data,
|
| 206 |
+
sys_msg=self.system_prompt,
|
| 207 |
+
return_prompt=True,
|
| 208 |
+
**kwargs
|
| 209 |
+
)
|
| 210 |
+
else:
|
| 211 |
+
execution_results = await action.async_execute(
|
| 212 |
+
llm=self.llm,
|
| 213 |
+
inputs=action_input_data,
|
| 214 |
+
sys_msg=self.system_prompt,
|
| 215 |
+
return_prompt=True,
|
| 216 |
+
**kwargs
|
| 217 |
+
)
|
| 218 |
+
action_output, prompt = execution_results
|
| 219 |
+
|
| 220 |
+
message = self._create_output_message(
|
| 221 |
+
action_output=action_output,
|
| 222 |
+
prompt=prompt,
|
| 223 |
+
action_name=action_name,
|
| 224 |
+
return_msg_type=return_msg_type,
|
| 225 |
+
**kwargs
|
| 226 |
+
)
|
| 227 |
+
if return_action_input_data:
|
| 228 |
+
return message, action_input_data
|
| 229 |
+
return message
|
| 230 |
+
|
| 231 |
+
def execute(
|
| 232 |
+
self,
|
| 233 |
+
action_name: str,
|
| 234 |
+
msgs: Optional[List[Message]] = None,
|
| 235 |
+
action_input_data: Optional[dict] = None,
|
| 236 |
+
return_msg_type: Optional[MessageType] = MessageType.UNKNOWN,
|
| 237 |
+
return_action_input_data: Optional[bool] = False,
|
| 238 |
+
**kwargs
|
| 239 |
+
) -> Union[Message, Tuple[Message, dict]]:
|
| 240 |
+
"""Execute an action with the given context and return results.
|
| 241 |
+
|
| 242 |
+
This is the core method for agent functionality, allowing it to perform actions
|
| 243 |
+
based on the current conversation context.
|
| 244 |
+
|
| 245 |
+
Args:
|
| 246 |
+
action_name: The name of the action to execute
|
| 247 |
+
msgs: Optional list of messages providing context for the action
|
| 248 |
+
action_input_data: Optional pre-extracted input data for the action
|
| 249 |
+
return_msg_type: Message type for the return message
|
| 250 |
+
**kwargs (Any): Additional parameters, may include workflow information
|
| 251 |
+
|
| 252 |
+
Returns:
|
| 253 |
+
Message: A message containing the execution results
|
| 254 |
+
"""
|
| 255 |
+
action, action_input_data = self._prepare_execution(
|
| 256 |
+
action_name=action_name,
|
| 257 |
+
msgs=msgs,
|
| 258 |
+
action_input_data=action_input_data,
|
| 259 |
+
**kwargs
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
# execute action
|
| 263 |
+
execution_results = action.execute(
|
| 264 |
+
llm=self.llm,
|
| 265 |
+
inputs=action_input_data,
|
| 266 |
+
sys_msg=self.system_prompt,
|
| 267 |
+
return_prompt=True,
|
| 268 |
+
**kwargs
|
| 269 |
+
)
|
| 270 |
+
action_output, prompt = execution_results
|
| 271 |
+
|
| 272 |
+
message = self._create_output_message(
|
| 273 |
+
action_output=action_output,
|
| 274 |
+
prompt=prompt,
|
| 275 |
+
action_name=action_name,
|
| 276 |
+
return_msg_type=return_msg_type,
|
| 277 |
+
**kwargs
|
| 278 |
+
)
|
| 279 |
+
if return_action_input_data:
|
| 280 |
+
return message, action_input_data
|
| 281 |
+
return message
|
| 282 |
+
|
| 283 |
+
def init_llm(self):
|
| 284 |
+
"""
|
| 285 |
+
Initialize the language model for the agent.
|
| 286 |
+
"""
|
| 287 |
+
# Only initialize LLM if not human and LLM is provided
|
| 288 |
+
if not self.is_human and (not self.llm_config and not self.llm):
|
| 289 |
+
raise ValueError("must provide `llm_config` or `llm` when `is_human` is False")
|
| 290 |
+
if not self.is_human and (self.llm_config or self.llm):
|
| 291 |
+
if self.llm_config and not self.llm:
|
| 292 |
+
llm_cls = MODEL_REGISTRY.get_model(self.llm_config.llm_type)
|
| 293 |
+
self.llm = llm_cls(config=self.llm_config)
|
| 294 |
+
if self.llm:
|
| 295 |
+
self.llm_config = self.llm.config
|
| 296 |
+
# If is_human=True or no LLM provided, self.llm remains None
|
| 297 |
+
|
| 298 |
+
def init_long_term_memory(self):
|
| 299 |
+
"""
|
| 300 |
+
Initialize long-term memory components.
|
| 301 |
+
"""
|
| 302 |
+
assert self.storage_handler is not None, "must provide ``storage_handler`` when use_long_term_memory=True"
|
| 303 |
+
# TODO revise the initialisation of long_term_memory and long_term_memory_manager
|
| 304 |
+
if not self.long_term_memory:
|
| 305 |
+
self.long_term_memory = LongTermMemory()
|
| 306 |
+
if not self.long_term_memory_manager:
|
| 307 |
+
self.long_term_memory_manager = MemoryManager(
|
| 308 |
+
storage_handler=self.storage_handler,
|
| 309 |
+
memory=self.long_term_memory
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
def init_context_extractor(self):
|
| 313 |
+
"""
|
| 314 |
+
Initialize the context extraction action.
|
| 315 |
+
"""
|
| 316 |
+
cext_action = ContextExtraction()
|
| 317 |
+
self.cext_action_name = cext_action.name
|
| 318 |
+
self.add_action(cext_action)
|
| 319 |
+
|
| 320 |
+
def add_action(self, action: Type[Action]):
|
| 321 |
+
"""
|
| 322 |
+
Add a new action to the agent's available actions.
|
| 323 |
+
|
| 324 |
+
Args:
|
| 325 |
+
action: The action instance to add
|
| 326 |
+
"""
|
| 327 |
+
action_name = action.name
|
| 328 |
+
if action_name in self._action_map:
|
| 329 |
+
return
|
| 330 |
+
self.actions.append(action)
|
| 331 |
+
self._action_map[action_name] = action
|
| 332 |
+
|
| 333 |
+
def check_action_name(self, action_name: str):
|
| 334 |
+
"""
|
| 335 |
+
Check if an action name is valid for this agent.
|
| 336 |
+
|
| 337 |
+
Args:
|
| 338 |
+
action_name: Name of the action to check
|
| 339 |
+
"""
|
| 340 |
+
if action_name not in self._action_map:
|
| 341 |
+
raise KeyError(f"'{action_name}' is an invalid action for {self.name}! Available action names: {list(self._action_map.keys())}")
|
| 342 |
+
|
| 343 |
+
def get_action(self, action_name: str) -> Action:
|
| 344 |
+
"""
|
| 345 |
+
Retrieves the Action instance associated with the given name.
|
| 346 |
+
|
| 347 |
+
Args:
|
| 348 |
+
action_name: Name of the action to retrieve
|
| 349 |
+
|
| 350 |
+
Returns:
|
| 351 |
+
The Action instance with the specified name
|
| 352 |
+
"""
|
| 353 |
+
self.check_action_name(action_name=action_name)
|
| 354 |
+
return self._action_map[action_name]
|
| 355 |
+
|
| 356 |
+
def get_action_name(self, action_cls: Type[Action]) -> str:
|
| 357 |
+
"""
|
| 358 |
+
Searches through the agent's actions to find one matching the specified type.
|
| 359 |
+
|
| 360 |
+
Args:
|
| 361 |
+
action_cls: The Action class type to search for
|
| 362 |
+
|
| 363 |
+
Returns:
|
| 364 |
+
The name of the matching action
|
| 365 |
+
"""
|
| 366 |
+
for name, action in self._action_map.items():
|
| 367 |
+
if isinstance(action, action_cls):
|
| 368 |
+
return name
|
| 369 |
+
raise ValueError(f"Couldn't find an action that matches Type '{action_cls.__name__}'")
|
| 370 |
+
|
| 371 |
+
def get_action_inputs(self, action: Action) -> Union[dict, None]:
|
| 372 |
+
"""
|
| 373 |
+
Uses the context extraction action to determine appropriate inputs
|
| 374 |
+
for the specified action based on the conversation history.
|
| 375 |
+
|
| 376 |
+
Args:
|
| 377 |
+
action: The action for which to extract inputs
|
| 378 |
+
|
| 379 |
+
Returns:
|
| 380 |
+
Dictionary of extracted input data, or None if extraction fails
|
| 381 |
+
"""
|
| 382 |
+
# return the input data of an action.
|
| 383 |
+
context = self.short_term_memory.get(n=self.n)
|
| 384 |
+
cext_action = self.get_action(self.cext_action_name)
|
| 385 |
+
action_inputs = cext_action.execute(llm=self.llm, action=action, context=context)
|
| 386 |
+
return action_inputs
|
| 387 |
+
|
| 388 |
+
def get_all_actions(self) -> List[Action]:
|
| 389 |
+
"""Get all actions except the context extraction action.
|
| 390 |
+
|
| 391 |
+
Returns:
|
| 392 |
+
List of Action instances available for execution
|
| 393 |
+
"""
|
| 394 |
+
actions = [action for action in self.actions if action.name != self.cext_action_name]
|
| 395 |
+
return actions
|
| 396 |
+
|
| 397 |
+
def get_agent_profile(self, action_names: List[str] = None) -> str:
|
| 398 |
+
"""Generate a human-readable profile of the agent and its capabilities.
|
| 399 |
+
|
| 400 |
+
Args:
|
| 401 |
+
action_names: Optional list of action names to include in the profile.
|
| 402 |
+
If None, all actions are included.
|
| 403 |
+
|
| 404 |
+
Returns:
|
| 405 |
+
A formatted string containing the agent profile
|
| 406 |
+
"""
|
| 407 |
+
all_actions = self.get_all_actions()
|
| 408 |
+
if action_names is None:
|
| 409 |
+
# if `action_names` is None, return description of all actions
|
| 410 |
+
action_descriptions = "\n".join([f" - {action.name}: {action.description}" for action in all_actions])
|
| 411 |
+
else:
|
| 412 |
+
# otherwise, only return description of actions that matches `action_names`
|
| 413 |
+
action_descriptions = "\n".join([f" - {action.name}: {action.description}" for action in all_actions if action.name in action_names])
|
| 414 |
+
profile = f"Agent Name: {self.name}\nDescription: {self.description}\nAvailable Actions:\n{action_descriptions}"
|
| 415 |
+
return profile
|
| 416 |
+
|
| 417 |
+
def clear_short_term_memory(self):
|
| 418 |
+
"""
|
| 419 |
+
Remove all content from the agent's short-term memory.
|
| 420 |
+
"""
|
| 421 |
+
pass
|
| 422 |
+
|
| 423 |
+
def __eq__(self, other: "Agent"):
|
| 424 |
+
return self.agent_id == other.agent_id
|
| 425 |
+
|
| 426 |
+
def __hash__(self):
|
| 427 |
+
return self.agent_id
|
| 428 |
+
|
| 429 |
+
def get_prompts(self) -> dict:
|
| 430 |
+
"""
|
| 431 |
+
Get all the prompts of the agent.
|
| 432 |
+
|
| 433 |
+
Returns:
|
| 434 |
+
dict: A dictionary with keys in the format 'agent_name::action_name' and values
|
| 435 |
+
containing the system_prompt and action prompt.
|
| 436 |
+
"""
|
| 437 |
+
prompts = {}
|
| 438 |
+
for action in self.get_all_actions():
|
| 439 |
+
prompts[action.name] = {
|
| 440 |
+
"system_prompt": self.system_prompt,
|
| 441 |
+
"prompt": action.prompt
|
| 442 |
+
}
|
| 443 |
+
return prompts
|
| 444 |
+
|
| 445 |
+
def set_prompt(self, action_name: str, prompt: str, system_prompt: Optional[str] = None) -> bool:
|
| 446 |
+
"""
|
| 447 |
+
Set the prompt for a specific action of this agent.
|
| 448 |
+
|
| 449 |
+
Args:
|
| 450 |
+
action_name: Name of the action whose prompt should be updated
|
| 451 |
+
prompt: New prompt text to set for the action
|
| 452 |
+
system_prompt: Optional new system prompt to set for the agent
|
| 453 |
+
|
| 454 |
+
Returns:
|
| 455 |
+
bool: True if the prompt was successfully updated, False otherwise
|
| 456 |
+
|
| 457 |
+
Raises:
|
| 458 |
+
KeyError: If the action_name does not exist for this agent
|
| 459 |
+
"""
|
| 460 |
+
try:
|
| 461 |
+
action = self.get_action(action_name)
|
| 462 |
+
action.prompt = prompt
|
| 463 |
+
|
| 464 |
+
if system_prompt is not None:
|
| 465 |
+
self.system_prompt = system_prompt
|
| 466 |
+
|
| 467 |
+
return True
|
| 468 |
+
except KeyError:
|
| 469 |
+
raise KeyError(f"Action '{action_name}' not found in agent '{self.name}'")
|
| 470 |
+
|
| 471 |
+
def set_prompts(self, prompts: dict) -> bool:
|
| 472 |
+
"""
|
| 473 |
+
Set the prompts for all actions of this agent.
|
| 474 |
+
|
| 475 |
+
Args:
|
| 476 |
+
prompts: A dictionary with keys in the format 'action_name' and values
|
| 477 |
+
containing the system_prompt and action prompt.
|
| 478 |
+
|
| 479 |
+
Returns:
|
| 480 |
+
bool: True if the prompts were successfully updated, False otherwise
|
| 481 |
+
"""
|
| 482 |
+
for action_name, prompt_data in prompts.items():
|
| 483 |
+
# self.set_prompt(action_name, prompt_data["prompt"], prompt_data["system_prompt"])
|
| 484 |
+
if not isinstance(prompt_data, dict):
|
| 485 |
+
raise ValueError(f"Invalid prompt data for action '{action_name}'. Expected a dictionary with 'prompt' and 'system_prompt' (optional) keys.")
|
| 486 |
+
if "prompt" not in prompt_data:
|
| 487 |
+
raise ValueError(f"Missing 'prompt' key in prompt data for action '{action_name}'.")
|
| 488 |
+
self.set_prompt(action_name, prompt_data["prompt"], prompt_data.get("system_prompt", None))
|
| 489 |
+
return True
|
| 490 |
+
|
| 491 |
+
def save_module(self, path: str, ignore: List[str] = [], **kwargs)-> str:
|
| 492 |
+
"""Save the agent to persistent storage.
|
| 493 |
+
|
| 494 |
+
Args:
|
| 495 |
+
path: Path where the agent should be saved
|
| 496 |
+
ignore: List of field names to exclude from serialization
|
| 497 |
+
**kwargs (Any): Additional parameters for the save operation
|
| 498 |
+
|
| 499 |
+
Returns:
|
| 500 |
+
The path where the agent was saved
|
| 501 |
+
"""
|
| 502 |
+
ignore_fields = self._save_ignore_fields + ignore
|
| 503 |
+
super().save_module(path=path, ignore=ignore_fields, **kwargs)
|
| 504 |
+
|
| 505 |
+
@classmethod
|
| 506 |
+
def load_module(cls, path: str, llm_config: LLMConfig = None, **kwargs) -> "Agent":
|
| 507 |
+
"""
|
| 508 |
+
load the agent from local storage. Must provide `llm_config` when loading the agent from local storage.
|
| 509 |
+
|
| 510 |
+
Args:
|
| 511 |
+
path: The path of the file
|
| 512 |
+
llm_config: The LLMConfig instance
|
| 513 |
+
|
| 514 |
+
Returns:
|
| 515 |
+
Agent: The loaded agent instance
|
| 516 |
+
"""
|
| 517 |
+
agent = super().load_module(path=path, **kwargs)
|
| 518 |
+
if llm_config is not None:
|
| 519 |
+
agent["llm_config"] = llm_config.to_dict()
|
| 520 |
+
return agent
|
| 521 |
+
|
| 522 |
+
def get_config(self) -> dict:
|
| 523 |
+
"""
|
| 524 |
+
Get a dictionary containing all necessary configuration to recreate this agent.
|
| 525 |
+
|
| 526 |
+
Returns:
|
| 527 |
+
dict: A configuration dictionary that can be used to initialize a new Agent instance
|
| 528 |
+
with the same properties as this one.
|
| 529 |
+
"""
|
| 530 |
+
config = self.to_dict()
|
| 531 |
+
return config
|
evoagentx/agents/agent_generator.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .agent import Agent
|
| 2 |
+
from ..actions.agent_generation import AgentGeneration
|
| 3 |
+
from ..prompts.agent_generator import AGENT_GENERATOR
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class AgentGenerator(Agent):
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
An agent responsible for generating agents for a task.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
def __init__(self, **kwargs):
|
| 13 |
+
|
| 14 |
+
name = kwargs.pop("name") if "name" in kwargs else AGENT_GENERATOR["name"]
|
| 15 |
+
description = kwargs.pop("description") if "description" in kwargs else AGENT_GENERATOR["description"]
|
| 16 |
+
system_prompt = kwargs.pop("system_prompt") if "system_prompt" in kwargs else AGENT_GENERATOR["system_prompt"]
|
| 17 |
+
actions = kwargs.pop("actions") if "actions" in kwargs else [AgentGeneration(tools=kwargs.pop("tools", []))]
|
| 18 |
+
super().__init__(name=name, description=description, system_prompt=system_prompt, actions=actions, **kwargs)
|
| 19 |
+
|
| 20 |
+
@property
|
| 21 |
+
def agent_generation_action_name(self):
|
| 22 |
+
return self.get_action_name(action_cls=AgentGeneration)
|
| 23 |
+
|
evoagentx/agents/agent_manager.py
ADDED
|
@@ -0,0 +1,505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import threading
|
| 2 |
+
from enum import Enum
|
| 3 |
+
from typing import Union, Optional, Dict, List
|
| 4 |
+
from pydantic import Field
|
| 5 |
+
from copy import deepcopy
|
| 6 |
+
|
| 7 |
+
from .agent import Agent
|
| 8 |
+
# from .agent_generator import AgentGenerator
|
| 9 |
+
from .customize_agent import CustomizeAgent
|
| 10 |
+
from ..core.module import BaseModule
|
| 11 |
+
from ..core.decorators import atomic_method
|
| 12 |
+
from ..storages.base import StorageHandler
|
| 13 |
+
from ..models.model_configs import LLMConfig
|
| 14 |
+
from ..tools.tool import Toolkit, Tool
|
| 15 |
+
class AgentState(str, Enum):
|
| 16 |
+
AVAILABLE = "available"
|
| 17 |
+
RUNNING = "running"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class AgentManager(BaseModule):
|
| 21 |
+
"""
|
| 22 |
+
Responsible for creating and managing all Agent objects required for workflow operation.
|
| 23 |
+
|
| 24 |
+
Attributes:
|
| 25 |
+
storage_handler (StorageHandler): Used to load and save agents from/to storage.
|
| 26 |
+
agents (List[Agent]): A list to keep track of all managed Agent instances.
|
| 27 |
+
agent_states (Dict[str, AgentState]): A dictionary to track the state of each Agent by name.
|
| 28 |
+
"""
|
| 29 |
+
agents: List[Agent] = Field(default_factory=list)
|
| 30 |
+
agent_states: Dict[str, AgentState] = Field(default_factory=dict) # agent_name to AgentState mapping
|
| 31 |
+
storage_handler: Optional[StorageHandler] = None # used to load and save agent from storage.
|
| 32 |
+
# agent_generator: Optional[AgentGenerator] = None # used to generate agents for a specific subtask
|
| 33 |
+
tools: Optional[List[Union[Toolkit, Tool]]] = None
|
| 34 |
+
|
| 35 |
+
def init_module(self):
|
| 36 |
+
self._lock = threading.Lock()
|
| 37 |
+
self._state_conditions = {}
|
| 38 |
+
if self.agents:
|
| 39 |
+
for agent in self.agents:
|
| 40 |
+
self.agent_states[agent.name] = self.agent_states.get(agent.name, AgentState.AVAILABLE)
|
| 41 |
+
if agent.name not in self._state_conditions:
|
| 42 |
+
self._state_conditions[agent.name] = threading.Condition()
|
| 43 |
+
self.check_agents()
|
| 44 |
+
|
| 45 |
+
def check_agents(self):
|
| 46 |
+
"""Validate agent list integrity and state consistency.
|
| 47 |
+
|
| 48 |
+
Performs thorough validation of the agent manager's internal state:
|
| 49 |
+
1. Checks for duplicate agent names
|
| 50 |
+
2. Verifies that agent states exist for all agents
|
| 51 |
+
3. Ensures agent list and state dictionary sizes match
|
| 52 |
+
"""
|
| 53 |
+
# check that the names of self.agents should be unique
|
| 54 |
+
duplicate_agent_names = self.find_duplicate_agents(self.agents)
|
| 55 |
+
if duplicate_agent_names:
|
| 56 |
+
raise ValueError(f"The agents should be unique. Found duplicate agent names: {duplicate_agent_names}!")
|
| 57 |
+
# check agent states
|
| 58 |
+
if len(self.agents) != len(self.agent_states):
|
| 59 |
+
raise ValueError(f"The lengths of self.agents ({len(self.agents)}) and self.agent_states ({len(self.agent_states)}) are different!")
|
| 60 |
+
missing_agents = self.find_missing_agent_states()
|
| 61 |
+
if missing_agents:
|
| 62 |
+
raise ValueError(f"The following agents' states were not found: {missing_agents}")
|
| 63 |
+
|
| 64 |
+
def find_duplicate_agents(self, agents: List[Agent]) -> List[str]:
|
| 65 |
+
# return the names of duplicate agents based on agent.name
|
| 66 |
+
unique_agent_names = set()
|
| 67 |
+
duplicate_agent_names = set()
|
| 68 |
+
for agent in agents:
|
| 69 |
+
agent_name = agent.name
|
| 70 |
+
if agent_name in unique_agent_names:
|
| 71 |
+
duplicate_agent_names.add(agent_name)
|
| 72 |
+
unique_agent_names.add(agent_name)
|
| 73 |
+
return list(duplicate_agent_names)
|
| 74 |
+
|
| 75 |
+
def find_missing_agent_states(self):
|
| 76 |
+
missing_agents = [agent.name for agent in self.agents if agent.name not in self.agent_states]
|
| 77 |
+
return missing_agents
|
| 78 |
+
|
| 79 |
+
def list_agents(self) -> List[str]:
|
| 80 |
+
return [agent.name for agent in self.agents]
|
| 81 |
+
|
| 82 |
+
def has_agent(self, agent_name: str) -> bool:
|
| 83 |
+
"""Check if an agent with the given name exists in the manager.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
agent_name: The name of the agent to check
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
True if an agent with the given name exists, False otherwise
|
| 90 |
+
"""
|
| 91 |
+
all_agent_names = self.list_agents()
|
| 92 |
+
return agent_name in all_agent_names
|
| 93 |
+
|
| 94 |
+
@property
|
| 95 |
+
def size(self):
|
| 96 |
+
"""
|
| 97 |
+
Get the total number of agents managed by this manager.
|
| 98 |
+
"""
|
| 99 |
+
return len(self.agents)
|
| 100 |
+
|
| 101 |
+
def load_agent(self, agent_name: str, **kwargs) -> Agent:
|
| 102 |
+
"""Load an agent from local storage through storage_handler.
|
| 103 |
+
|
| 104 |
+
Retrieves agent data from storage and creates an Agent instance.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
agent_name: The name of the agent to load
|
| 108 |
+
**kwargs (Any): Additional parameters for agent creation
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
Agent instance with data loaded from storage
|
| 112 |
+
"""
|
| 113 |
+
if not self.storage_handler:
|
| 114 |
+
raise ValueError("must provide ``self.storage_handler`` to use ``load_agent``")
|
| 115 |
+
agent_data = self.storage_handler.load_agent(agent_name=agent_name)
|
| 116 |
+
agent: Agent = self.create_customize_agent(agent_data=agent_data)
|
| 117 |
+
return agent
|
| 118 |
+
|
| 119 |
+
def load_all_agents(self, **kwargs):
|
| 120 |
+
"""Load all agents from storage and add them to the manager.
|
| 121 |
+
|
| 122 |
+
Retrieves all available agents from storage and adds them to the
|
| 123 |
+
managed agents collection.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
**kwargs (Any): Additional parameters passed to storage handler
|
| 127 |
+
"""
|
| 128 |
+
pass
|
| 129 |
+
|
| 130 |
+
def update_tools(self, agent_data: dict) -> None:
|
| 131 |
+
"""
|
| 132 |
+
Update agent_data with tools based on tool_names.
|
| 133 |
+
|
| 134 |
+
Handles four scenarios:
|
| 135 |
+
1. Neither tool_names nor tools exist: return directly
|
| 136 |
+
2. Only tool_names exists: resolve tool_names to tools and set tools field
|
| 137 |
+
3. Only tools exists: return directly (no action needed)
|
| 138 |
+
4. Both exist: merge tool_names into existing tools (skip duplicates)
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
agent_data (dict): Agent configuration dictionary that may contain 'tool_names' and/or 'tools'
|
| 142 |
+
|
| 143 |
+
Raises:
|
| 144 |
+
ValueError: If tool_names exist but self.tools is None, or if requested tools are not found
|
| 145 |
+
"""
|
| 146 |
+
tool_names = agent_data.get("tool_names", None)
|
| 147 |
+
existing_tools = agent_data.get("tools", None)
|
| 148 |
+
|
| 149 |
+
# Case 1: Neither tool_names nor tools exist
|
| 150 |
+
if not tool_names and not existing_tools:
|
| 151 |
+
return
|
| 152 |
+
|
| 153 |
+
# Case 3: Only tools exist (no tool_names)
|
| 154 |
+
if not tool_names and existing_tools:
|
| 155 |
+
return
|
| 156 |
+
|
| 157 |
+
# For cases 2 and 4: tool_names exists, need to resolve
|
| 158 |
+
if self.tools is None:
|
| 159 |
+
raise ValueError(
|
| 160 |
+
f"Agent requires tools {tool_names}, but no tools are available in AgentManager. "
|
| 161 |
+
f"Please set self.tools before creating agents with tool_names."
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
# Create tool mapping from available tools
|
| 165 |
+
tool_mapping = {}
|
| 166 |
+
for tool in self.tools:
|
| 167 |
+
tool_mapping[tool.name] = tool
|
| 168 |
+
|
| 169 |
+
# Case 2: Only tool_names exists - initialize empty tools list
|
| 170 |
+
if tool_names and not existing_tools:
|
| 171 |
+
existing_tools = []
|
| 172 |
+
|
| 173 |
+
# Case 2 & 4: Process tool_names (either with empty or existing tools list)
|
| 174 |
+
if tool_names:
|
| 175 |
+
# Create a set of existing tool names for quick lookup
|
| 176 |
+
existing_tool_names = {tool.name for tool in existing_tools}
|
| 177 |
+
|
| 178 |
+
tools_to_add = []
|
| 179 |
+
missing_tools = []
|
| 180 |
+
|
| 181 |
+
for tool_name in tool_names:
|
| 182 |
+
# Skip if tool already exists in tools
|
| 183 |
+
if tool_name in existing_tool_names:
|
| 184 |
+
continue
|
| 185 |
+
|
| 186 |
+
# Try to resolve new tool
|
| 187 |
+
if tool_name in tool_mapping:
|
| 188 |
+
tools_to_add.append(tool_mapping[tool_name])
|
| 189 |
+
else:
|
| 190 |
+
missing_tools.append(tool_name)
|
| 191 |
+
|
| 192 |
+
if missing_tools:
|
| 193 |
+
available_tools = list(tool_mapping.keys())
|
| 194 |
+
raise ValueError(
|
| 195 |
+
f"The following tools are not available: {missing_tools}. "
|
| 196 |
+
f"Available tools: {available_tools}"
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
# Merge new tools with existing ones
|
| 200 |
+
if tools_to_add:
|
| 201 |
+
agent_data["tools"] = list(existing_tools) + tools_to_add
|
| 202 |
+
|
| 203 |
+
def create_customize_agent(self, agent_data: dict, llm_config: Optional[Union[LLMConfig, dict]]=None, **kwargs) -> CustomizeAgent:
|
| 204 |
+
"""
|
| 205 |
+
create a customized agent from the provided `agent_data`.
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
agent_data: The data used to create an Agent instance, must contain 'name', 'description' and 'prompt' keys.
|
| 209 |
+
llm_config (Optional[LLMConfig]): The LLM configuration to be used for the agent.
|
| 210 |
+
It will be used as the default LLM for agents without a `llm_config` key.
|
| 211 |
+
If not provided, the `agent_data` should contain a `llm_config` key.
|
| 212 |
+
If provided and `agent_data` contains a `llm_config` key, the `llm_config` in `agent_data` will be used.
|
| 213 |
+
**kwargs (Any): Additional parameters for agent creation
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
Agent: the instantiated agent instance.
|
| 217 |
+
"""
|
| 218 |
+
|
| 219 |
+
agent_data = deepcopy(agent_data)
|
| 220 |
+
agent_llm_config = agent_data.get("llm_config", llm_config)
|
| 221 |
+
if not agent_data.get("is_human", False) and not agent_llm_config:
|
| 222 |
+
raise ValueError("`agent_data` should contain a `llm_config` key or `llm_config` should be provided.")
|
| 223 |
+
|
| 224 |
+
if agent_llm_config:
|
| 225 |
+
if isinstance(agent_llm_config, dict):
|
| 226 |
+
agent_data["llm_config"] = agent_llm_config
|
| 227 |
+
elif isinstance(agent_llm_config, LLMConfig):
|
| 228 |
+
agent_data["llm_config"] = agent_llm_config.to_dict()
|
| 229 |
+
|
| 230 |
+
# tool_mapping = {}
|
| 231 |
+
# if self.tools is not None:
|
| 232 |
+
# for tool in self.tools:
|
| 233 |
+
# tool_mapping[tool.name] = tool
|
| 234 |
+
# if agent_data.get("tool_names", None):
|
| 235 |
+
# agent_data["tools"] = [tool_mapping[tool_name] for tool_name in agent_data["tool_names"]]
|
| 236 |
+
self.update_tools(agent_data=agent_data) # add `tools` field if needed
|
| 237 |
+
return CustomizeAgent.from_dict(data=agent_data)
|
| 238 |
+
|
| 239 |
+
def get_agent_name(self, agent: Union[str, dict, Agent]) -> str:
|
| 240 |
+
"""Extract agent name from different agent representations.
|
| 241 |
+
|
| 242 |
+
Handles different ways to specify an agent (string name, dictionary, or
|
| 243 |
+
Agent instance) and extracts the agent name.
|
| 244 |
+
|
| 245 |
+
Args:
|
| 246 |
+
agent: Agent specified as a string name, dictionary with 'name' key,
|
| 247 |
+
or Agent instance
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
The extracted agent name as a string
|
| 251 |
+
"""
|
| 252 |
+
if isinstance(agent, str):
|
| 253 |
+
agent_name = agent
|
| 254 |
+
elif isinstance(agent, dict):
|
| 255 |
+
agent_name = agent["name"]
|
| 256 |
+
elif isinstance(agent, Agent):
|
| 257 |
+
agent_name = agent.name
|
| 258 |
+
else:
|
| 259 |
+
raise ValueError(f"{type(agent)} is not a supported type for ``get_agent_name``. Supported types: [str, dict, Agent].")
|
| 260 |
+
return agent_name
|
| 261 |
+
|
| 262 |
+
def create_agent(self, agent: Union[str, dict, Agent], llm_config: Optional[LLMConfig]=None, **kwargs) -> Agent:
|
| 263 |
+
|
| 264 |
+
if isinstance(agent, str):
|
| 265 |
+
if self.storage_handler is None:
|
| 266 |
+
# if self.storage_handler is None, the agent (str) must exist in self.agents. Otherwise, a dictionary or an Agent instance should be provided.
|
| 267 |
+
if not self.has_agent(agent_name=agent):
|
| 268 |
+
raise ValueError(f"Agent ``{agent}`` does not exist! You should provide a dictionary or an Agent instance when ``self.storage_handler`` is not provided.")
|
| 269 |
+
return self.get_agent(agent_name=agent)
|
| 270 |
+
else:
|
| 271 |
+
# if self.storage_handler is not None, the agent (str) must exist in the storage and will be loaded from the storage.
|
| 272 |
+
agent_instance = self.load_agent(agent_name=agent)
|
| 273 |
+
elif isinstance(agent, dict):
|
| 274 |
+
if not agent.get("is_human", False) and (llm_config is None and "llm_config" not in agent):
|
| 275 |
+
raise ValueError("When providing an agent as a dictionary, you must either include 'llm_config' in the dictionary or provide it as a parameter.")
|
| 276 |
+
agent_instance = self.create_customize_agent(agent_data=agent, llm_config=llm_config, **kwargs)
|
| 277 |
+
elif isinstance(agent, Agent):
|
| 278 |
+
agent_instance = agent
|
| 279 |
+
else:
|
| 280 |
+
raise ValueError(f"{type(agent)} is not a supported input type of ``create_agent``. Supported types: [str, dict, Agent].")
|
| 281 |
+
return agent_instance
|
| 282 |
+
|
| 283 |
+
@atomic_method
|
| 284 |
+
def add_agent(self, agent: Union[str, dict, Agent], llm_config: Optional[LLMConfig]=None, **kwargs):
|
| 285 |
+
"""
|
| 286 |
+
add a single agent, ignore if the agent already exists (judged by the name of an agent).
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
agent: The agent to be added, specified as:
|
| 290 |
+
- String: Agent name to load from storage
|
| 291 |
+
- Dictionary: Agent specification to create a CustomizeAgent
|
| 292 |
+
- Agent: Existing Agent instance to add directly
|
| 293 |
+
llm_config (Optional[LLMConfig]): The LLM configuration to be used for the agent. Only used when the `agent` is a dictionary, used to create a CustomizeAgent.
|
| 294 |
+
**kwargs (Any): Additional parameters for agent creation
|
| 295 |
+
"""
|
| 296 |
+
# Check for 'tool' key and convert it to 'tools' if needed
|
| 297 |
+
# if isinstance(agent, dict) and "tool_names" in agent:
|
| 298 |
+
# tools_mapping = {}
|
| 299 |
+
# if self.tools is not None:
|
| 300 |
+
# for tool in self.tools:
|
| 301 |
+
# tools_mapping[tool.name] = tool
|
| 302 |
+
# agent["tools"] = [tools_mapping[tool_name] for tool_name in agent["tool_names"]]
|
| 303 |
+
# agent["tools"] = [tool if isinstance(tool, Toolkit) else Toolkit(name=tool.name, tools=[tool]) for tool in agent["tools"]]
|
| 304 |
+
|
| 305 |
+
agent_name = self.get_agent_name(agent=agent)
|
| 306 |
+
if self.has_agent(agent_name=agent_name):
|
| 307 |
+
return
|
| 308 |
+
agent_instance = self.create_agent(agent=agent, llm_config=llm_config, **kwargs)
|
| 309 |
+
self.agents.append(agent_instance)
|
| 310 |
+
self.agent_states[agent_instance.name] = AgentState.AVAILABLE
|
| 311 |
+
if agent_instance.name not in self._state_conditions:
|
| 312 |
+
self._state_conditions[agent_instance.name] = threading.Condition()
|
| 313 |
+
self.check_agents()
|
| 314 |
+
|
| 315 |
+
def add_agents(self, agents: List[Union[str, dict, Agent]], llm_config: Optional[LLMConfig]=None, **kwargs):
|
| 316 |
+
"""
|
| 317 |
+
add several agents by using self.add_agent().
|
| 318 |
+
"""
|
| 319 |
+
for agent in agents:
|
| 320 |
+
self.add_agent(agent=agent, llm_config=llm_config, **kwargs)
|
| 321 |
+
|
| 322 |
+
def add_agents_from_workflow(self, workflow_graph, llm_config: Optional[LLMConfig]=None, **kwargs):
|
| 323 |
+
"""
|
| 324 |
+
Initialize agents from the nodes of a given WorkFlowGraph and add these agents to self.agents.
|
| 325 |
+
|
| 326 |
+
Args:
|
| 327 |
+
workflow_graph (WorkFlowGraph): The workflow graph containing nodes with agents information.
|
| 328 |
+
llm_config (Optional[LLMConfig]): The LLM configuration to be used for the agents.
|
| 329 |
+
**kwargs (Any): Additional parameters passed to add_agent
|
| 330 |
+
"""
|
| 331 |
+
from ..workflow.workflow_graph import WorkFlowGraph
|
| 332 |
+
if not isinstance(workflow_graph, WorkFlowGraph):
|
| 333 |
+
raise TypeError("workflow_graph must be an instance of WorkFlowGraph")
|
| 334 |
+
for node in workflow_graph.nodes:
|
| 335 |
+
if node.agents:
|
| 336 |
+
for agent in node.agents:
|
| 337 |
+
self.add_agent(agent=agent, llm_config=llm_config, **kwargs)
|
| 338 |
+
|
| 339 |
+
def update_agents_from_workflow(self, workflow_graph, llm_config: Optional[LLMConfig]=None, **kwargs):
|
| 340 |
+
"""
|
| 341 |
+
Update agents from a given WorkFlowGraph.
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
workflow_graph (WorkFlowGraph): The workflow graph containing nodes with agents information.
|
| 345 |
+
llm_config (Optional[LLMConfig]): The LLM configuration to be used for the agents.
|
| 346 |
+
**kwargs: Additional parameters passed to update_agent
|
| 347 |
+
"""
|
| 348 |
+
from ..workflow.workflow_graph import WorkFlowGraph
|
| 349 |
+
if not isinstance(workflow_graph, WorkFlowGraph):
|
| 350 |
+
raise TypeError("workflow_graph must be an instance of WorkFlowGraph")
|
| 351 |
+
for node in workflow_graph.nodes:
|
| 352 |
+
if node.agents:
|
| 353 |
+
for agent in node.agents:
|
| 354 |
+
agent_name = self.get_agent_name(agent=agent)
|
| 355 |
+
if self.has_agent(agent_name=agent_name):
|
| 356 |
+
# use the llm_config of the existing agent
|
| 357 |
+
agent_llm_config = self.get_agent(agent_name).llm_config
|
| 358 |
+
self.update_agent(agent=agent, llm_config=agent_llm_config, **kwargs)
|
| 359 |
+
else:
|
| 360 |
+
self.add_agent(agent=agent, llm_config=llm_config, **kwargs)
|
| 361 |
+
|
| 362 |
+
def get_agent(self, agent_name: str, **kwargs) -> Agent:
|
| 363 |
+
"""Retrieve an agent by its name from managed agents.
|
| 364 |
+
|
| 365 |
+
Searches the list of managed agents for an agent with the specified name.
|
| 366 |
+
|
| 367 |
+
Args:
|
| 368 |
+
agent_name: The name of the agent to retrieve
|
| 369 |
+
**kwargs (Any): Additional parameters (unused)
|
| 370 |
+
|
| 371 |
+
Returns:
|
| 372 |
+
The Agent instance with the specified name
|
| 373 |
+
"""
|
| 374 |
+
for agent in self.agents:
|
| 375 |
+
if agent.name == agent_name:
|
| 376 |
+
return agent
|
| 377 |
+
raise ValueError(f"Agent ``{agent_name}`` does not exists!")
|
| 378 |
+
|
| 379 |
+
def update_agent(self, agent: Union[dict, Agent], llm_config: Optional[LLMConfig]=None, **kwargs):
|
| 380 |
+
"""
|
| 381 |
+
Update an agent in the manager.
|
| 382 |
+
|
| 383 |
+
Args:
|
| 384 |
+
agent: The agent to be updated, specified as:
|
| 385 |
+
- Dictionary: Agent specification to update a CustomizeAgent
|
| 386 |
+
- Agent: Existing Agent instance to update
|
| 387 |
+
llm_config (Optional[LLMConfig]): The LLM configuration to be used for the agent.
|
| 388 |
+
"""
|
| 389 |
+
agent_name = self.get_agent_name(agent=agent)
|
| 390 |
+
self.remove_agent(agent_name=agent_name)
|
| 391 |
+
self.add_agent(agent=agent, llm_config=llm_config, **kwargs)
|
| 392 |
+
|
| 393 |
+
@atomic_method
|
| 394 |
+
def remove_agent(self, agent_name: str, remove_from_storage: bool=False, **kwargs):
|
| 395 |
+
"""
|
| 396 |
+
Remove an agent from the manager and optionally from storage.
|
| 397 |
+
|
| 398 |
+
Args:
|
| 399 |
+
agent_name: The name of the agent to remove
|
| 400 |
+
remove_from_storage: If True, also remove the agent from storage
|
| 401 |
+
**kwargs (Any): Additional parameters passed to storage_handler.remove_agent
|
| 402 |
+
"""
|
| 403 |
+
self.agents = [agent for agent in self.agents if agent.name != agent_name]
|
| 404 |
+
self.agent_states.pop(agent_name, None)
|
| 405 |
+
self._state_conditions.pop(agent_name, None)
|
| 406 |
+
if remove_from_storage:
|
| 407 |
+
self.storage_handler.remove_agent(agent_name=agent_name, **kwargs)
|
| 408 |
+
self.check_agents()
|
| 409 |
+
|
| 410 |
+
def get_agent_state(self, agent_name: str) -> AgentState:
|
| 411 |
+
"""
|
| 412 |
+
Get the state of a specific agent by its name.
|
| 413 |
+
|
| 414 |
+
Args:
|
| 415 |
+
agent_name: The name of the agent.
|
| 416 |
+
|
| 417 |
+
Returns:
|
| 418 |
+
AgentState: The current state of the agent.
|
| 419 |
+
"""
|
| 420 |
+
return self.agent_states[agent_name]
|
| 421 |
+
|
| 422 |
+
@atomic_method
|
| 423 |
+
def set_agent_state(self, agent_name: str, new_state: AgentState) -> bool:
|
| 424 |
+
"""
|
| 425 |
+
Changes an agent's state and notifies any threads waiting on that agent's state.
|
| 426 |
+
Thread-safe operation for coordinating multi-threaded agent execution.
|
| 427 |
+
|
| 428 |
+
Args:
|
| 429 |
+
agent_name: The name of the agent
|
| 430 |
+
new_state: The new state to set
|
| 431 |
+
|
| 432 |
+
Returns:
|
| 433 |
+
True if the state was updated successfully, False otherwise
|
| 434 |
+
"""
|
| 435 |
+
|
| 436 |
+
# if agent_name in self.agent_states and isinstance(new_state, AgentState):
|
| 437 |
+
# # self.agent_states[agent_name] = new_state
|
| 438 |
+
# with self._state_conditions[agent_name]:
|
| 439 |
+
# self.agent_states[agent_name] = new_state
|
| 440 |
+
# self._state_conditions[agent_name].notify_all()
|
| 441 |
+
# self.check_agents()
|
| 442 |
+
# return True
|
| 443 |
+
# else:
|
| 444 |
+
# return False
|
| 445 |
+
if agent_name in self.agent_states and isinstance(new_state, AgentState):
|
| 446 |
+
if agent_name not in self._state_conditions:
|
| 447 |
+
self._state_conditions[agent_name] = threading.Condition()
|
| 448 |
+
with self._state_conditions[agent_name]:
|
| 449 |
+
self.agent_states[agent_name] = new_state
|
| 450 |
+
self._state_conditions[agent_name].notify_all()
|
| 451 |
+
return True
|
| 452 |
+
return False
|
| 453 |
+
|
| 454 |
+
def get_all_agent_states(self) -> Dict[str, AgentState]:
|
| 455 |
+
"""Get the states of all managed agents.
|
| 456 |
+
|
| 457 |
+
Returns:
|
| 458 |
+
Dict[str, AgentState]: A dictionary mapping agent names to their states.
|
| 459 |
+
"""
|
| 460 |
+
return self.agent_states
|
| 461 |
+
|
| 462 |
+
@atomic_method
|
| 463 |
+
def save_all_agents(self, **kwargs):
|
| 464 |
+
"""Save all managed agents to persistent storage.
|
| 465 |
+
|
| 466 |
+
Args:
|
| 467 |
+
**kwargs (Any): Additional parameters passed to the storage handler
|
| 468 |
+
"""
|
| 469 |
+
pass
|
| 470 |
+
|
| 471 |
+
@atomic_method
|
| 472 |
+
def clear_agents(self):
|
| 473 |
+
"""
|
| 474 |
+
Remove all agents from the manager.
|
| 475 |
+
"""
|
| 476 |
+
self.agents = []
|
| 477 |
+
self.agent_states = {}
|
| 478 |
+
self._state_conditions = {}
|
| 479 |
+
self.check_agents()
|
| 480 |
+
|
| 481 |
+
def wait_for_agent_available(self, agent_name: str, timeout: Optional[float] = None) -> bool:
|
| 482 |
+
"""Wait for an agent to be available.
|
| 483 |
+
|
| 484 |
+
Args:
|
| 485 |
+
agent_name: The name of the agent to wait for
|
| 486 |
+
timeout: Maximum time to wait in seconds, or None to wait indefinitely
|
| 487 |
+
|
| 488 |
+
Returns:
|
| 489 |
+
True if the agent became available, False if timed out
|
| 490 |
+
"""
|
| 491 |
+
if agent_name not in self._state_conditions:
|
| 492 |
+
self._state_conditions[agent_name] = threading.Condition()
|
| 493 |
+
condition = self._state_conditions[agent_name]
|
| 494 |
+
|
| 495 |
+
with condition:
|
| 496 |
+
return condition.wait_for(
|
| 497 |
+
lambda: self.agent_states.get(agent_name) == AgentState.AVAILABLE,
|
| 498 |
+
timeout=timeout
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
def copy(self) -> "AgentManager":
|
| 502 |
+
"""
|
| 503 |
+
Create a shallow copy of the AgentManager.
|
| 504 |
+
"""
|
| 505 |
+
return AgentManager(agents=self.agents, storage_handler=self.storage_handler)
|
evoagentx/agents/customize_agent.py
ADDED
|
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import inspect
|
| 3 |
+
from pydantic import create_model, Field
|
| 4 |
+
from typing import Optional, Callable, Type, List, Any, Union, Dict
|
| 5 |
+
|
| 6 |
+
from .agent import Agent
|
| 7 |
+
from ..core.logging import logger
|
| 8 |
+
from ..core.registry import MODULE_REGISTRY, PARSE_FUNCTION_REGISTRY
|
| 9 |
+
from ..core.message import Message, MessageType
|
| 10 |
+
from ..models.model_configs import LLMConfig
|
| 11 |
+
from ..models.base_model import PARSER_VALID_MODE
|
| 12 |
+
from ..prompts.utils import DEFAULT_SYSTEM_PROMPT
|
| 13 |
+
from ..prompts.template import PromptTemplate
|
| 14 |
+
from ..actions.action import Action, ActionOutput
|
| 15 |
+
from ..utils.utils import generate_dynamic_class_name, make_parent_folder
|
| 16 |
+
from ..actions.customize_action import CustomizeAction
|
| 17 |
+
from ..actions.action import ActionInput
|
| 18 |
+
from ..tools.tool import Toolkit, Tool
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class CustomizeAgent(Agent):
|
| 22 |
+
|
| 23 |
+
"""
|
| 24 |
+
CustomizeAgent provides a flexible framework for creating specialized LLM-powered agents without
|
| 25 |
+
writing custom code. It enables the creation of agents with well-defined inputs and outputs,
|
| 26 |
+
custom prompt templates, and configurable parsing strategies.
|
| 27 |
+
|
| 28 |
+
Attributes:
|
| 29 |
+
name (str): The name of the agent.
|
| 30 |
+
description (str): A description of the agent's purpose and capabilities.
|
| 31 |
+
prompt_template (PromptTemplate, optional): The prompt template that will be used for the agent's primary action.
|
| 32 |
+
prompt (str, optional): The prompt template that will be used for the agent's primary action.
|
| 33 |
+
Should contain placeholders in the format `{input_name}` for each input parameter.
|
| 34 |
+
llm_config (LLMConfig, optional): Configuration for the language model.
|
| 35 |
+
inputs (List[dict], optional): List of input specifications, where each dict (e.g., `{"name": str, "type": str, "description": str, ["required": bool]}`) contains:
|
| 36 |
+
- name (str): Name of the input parameter
|
| 37 |
+
- type (str): Type of the input
|
| 38 |
+
- description (str): Description of what the input represents
|
| 39 |
+
- required (bool, optional): Whether this input is required (default: True)
|
| 40 |
+
outputs (List[dict], optional): List of output specifications, where each dict (e.g., `{"name": str, "type": str, "description": str, ["required": bool]}`) contains:
|
| 41 |
+
- name (str): Name of the output field
|
| 42 |
+
- type (str): Type of the output
|
| 43 |
+
- description (str): Description of what the output represents
|
| 44 |
+
- required (bool, optional): Whether this output is required (default: True)
|
| 45 |
+
system_prompt (str, optional): The system prompt for the LLM. Defaults to DEFAULT_SYSTEM_PROMPT.
|
| 46 |
+
output_parser (Type[ActionOutput], optional): A custom class for parsing the LLM's output.
|
| 47 |
+
Must be a subclass of ActionOutput.
|
| 48 |
+
parse_mode (str, optional): Mode for parsing LLM output. Options are:
|
| 49 |
+
- "title": Parse outputs using section titles (default)
|
| 50 |
+
- "str": Parse as plain text
|
| 51 |
+
- "json": Parse as JSON
|
| 52 |
+
- "xml": Parse as XML
|
| 53 |
+
- "custom": Use a custom parsing function
|
| 54 |
+
parse_func (Callable, optional): Custom function for parsing LLM output when parse_mode is "custom".
|
| 55 |
+
Must accept a "content" parameter and return a dictionary.
|
| 56 |
+
title_format (str, optional): Format string for title parsing mode with {title} placeholder.
|
| 57 |
+
Default is "## {title}".
|
| 58 |
+
tools (list[Toolkit], optional): List of tools to be used by the agent.
|
| 59 |
+
max_tool_calls (int, optional): Maximum number of tool calls. Defaults to 5.
|
| 60 |
+
custom_output_format (str, optional): Specify the output format. Only used when `prompt_template` is used.
|
| 61 |
+
If not provided, the output format will be constructed from the `outputs` specification and `parse_mode`.
|
| 62 |
+
"""
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
name: str,
|
| 66 |
+
description: str,
|
| 67 |
+
prompt: Optional[str] = None,
|
| 68 |
+
prompt_template: Optional[PromptTemplate] = None,
|
| 69 |
+
llm_config: Optional[LLMConfig] = None,
|
| 70 |
+
inputs: Optional[List[dict]] = None,
|
| 71 |
+
outputs: Optional[List[dict]] = None,
|
| 72 |
+
system_prompt: Optional[str] = None,
|
| 73 |
+
output_parser: Optional[Type[ActionOutput]] = None,
|
| 74 |
+
parse_mode: Optional[str] = "title",
|
| 75 |
+
parse_func: Optional[Callable] = None,
|
| 76 |
+
title_format: Optional[str] = None,
|
| 77 |
+
tools: Optional[List[Union[Toolkit, Tool]]] = None,
|
| 78 |
+
max_tool_calls: Optional[int] = 5,
|
| 79 |
+
custom_output_format: Optional[str] = None,
|
| 80 |
+
**kwargs
|
| 81 |
+
):
|
| 82 |
+
system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
|
| 83 |
+
inputs = inputs or []
|
| 84 |
+
outputs = outputs or []
|
| 85 |
+
if tools is not None:
|
| 86 |
+
raw_tool_map = {tool.name: tool for tool in tools}
|
| 87 |
+
tools = [tool if isinstance(tool, Toolkit) else Toolkit(name=tool.name, tools=[tool]) for tool in tools]
|
| 88 |
+
else:
|
| 89 |
+
raw_tool_map = None
|
| 90 |
+
|
| 91 |
+
if prompt is not None and prompt_template is not None:
|
| 92 |
+
logger.warning("Both `prompt` and `prompt_template` are provided in `CustomizeAgent`. `prompt_template` will be used.")
|
| 93 |
+
prompt = None
|
| 94 |
+
|
| 95 |
+
if isinstance(parse_func, str):
|
| 96 |
+
if not PARSE_FUNCTION_REGISTRY.has_function(parse_func):
|
| 97 |
+
raise ValueError(f"parse function `{parse_func}` is not registered! To instantiate a CustomizeAgent from a file, you should use decorator `@register_parse_function` to register the parse function.")
|
| 98 |
+
parse_func = PARSE_FUNCTION_REGISTRY.get_function(parse_func)
|
| 99 |
+
|
| 100 |
+
if isinstance(output_parser, str):
|
| 101 |
+
output_parser = MODULE_REGISTRY.get_module(output_parser)
|
| 102 |
+
|
| 103 |
+
# set default title format
|
| 104 |
+
if parse_mode == "title" and title_format is None:
|
| 105 |
+
title_format = "## {title}"
|
| 106 |
+
|
| 107 |
+
# validate the data
|
| 108 |
+
self.validate_data(
|
| 109 |
+
prompt = prompt,
|
| 110 |
+
prompt_template = prompt_template,
|
| 111 |
+
inputs = inputs,
|
| 112 |
+
outputs = outputs,
|
| 113 |
+
output_parser = output_parser,
|
| 114 |
+
parse_mode = parse_mode,
|
| 115 |
+
parse_func = parse_func,
|
| 116 |
+
title_format = title_format
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
customize_action = self.create_customize_action(
|
| 120 |
+
name=name,
|
| 121 |
+
desc=description,
|
| 122 |
+
prompt=prompt,
|
| 123 |
+
prompt_template=prompt_template,
|
| 124 |
+
inputs=inputs,
|
| 125 |
+
outputs=outputs,
|
| 126 |
+
parse_mode=parse_mode,
|
| 127 |
+
parse_func=parse_func,
|
| 128 |
+
output_parser=output_parser,
|
| 129 |
+
title_format=title_format,
|
| 130 |
+
custom_output_format=custom_output_format ,
|
| 131 |
+
tools=tools,
|
| 132 |
+
max_tool_calls=max_tool_calls
|
| 133 |
+
)
|
| 134 |
+
super().__init__(
|
| 135 |
+
name=name,
|
| 136 |
+
description=description,
|
| 137 |
+
llm_config=llm_config,
|
| 138 |
+
system_prompt=system_prompt,
|
| 139 |
+
actions=[customize_action],
|
| 140 |
+
**kwargs
|
| 141 |
+
)
|
| 142 |
+
self._store_inputs_outputs_info(inputs, outputs, raw_tool_map)
|
| 143 |
+
self.output_parser = output_parser
|
| 144 |
+
self.parse_mode = parse_mode
|
| 145 |
+
self.parse_func = parse_func
|
| 146 |
+
self.title_format = title_format
|
| 147 |
+
self.tools = tools
|
| 148 |
+
self.max_tool_calls = max_tool_calls
|
| 149 |
+
self.custom_output_format = custom_output_format
|
| 150 |
+
|
| 151 |
+
def _add_tools(self, tools: List[Toolkit]):
|
| 152 |
+
self.get_action(self.customize_action_name).add_tools(tools)
|
| 153 |
+
|
| 154 |
+
@property
|
| 155 |
+
def customize_action_name(self) -> str:
|
| 156 |
+
"""
|
| 157 |
+
Get the name of the primary custom action for this agent.
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
The name of the primary custom action
|
| 161 |
+
"""
|
| 162 |
+
for action in self.actions:
|
| 163 |
+
if action.name != self.cext_action_name:
|
| 164 |
+
return action.name
|
| 165 |
+
raise ValueError("Couldn't find the customize action name!")
|
| 166 |
+
|
| 167 |
+
@property
|
| 168 |
+
def action(self) -> Action:
|
| 169 |
+
"""
|
| 170 |
+
Get the primary custom action for this agent.
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
The primary custom action
|
| 174 |
+
"""
|
| 175 |
+
return self.get_action(self.customize_action_name)
|
| 176 |
+
|
| 177 |
+
@property
|
| 178 |
+
def prompt(self) -> str:
|
| 179 |
+
"""
|
| 180 |
+
Get the prompt for the primary custom action.
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
The prompt for the primary custom action
|
| 184 |
+
"""
|
| 185 |
+
return self.action.prompt
|
| 186 |
+
|
| 187 |
+
@property
|
| 188 |
+
def prompt_template(self) -> PromptTemplate:
|
| 189 |
+
"""
|
| 190 |
+
Get the prompt template for the primary custom action.
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
The prompt template for the primary custom action
|
| 194 |
+
"""
|
| 195 |
+
return self.action.prompt_template
|
| 196 |
+
|
| 197 |
+
def validate_data(self, prompt: str, prompt_template: PromptTemplate, inputs: List[dict], outputs: List[dict], output_parser: Type[ActionOutput], parse_mode: str, parse_func: Callable, title_format: str):
|
| 198 |
+
|
| 199 |
+
# check if the prompt is provided
|
| 200 |
+
if prompt is None and prompt_template is None:
|
| 201 |
+
raise ValueError("`prompt` or `prompt_template` is required when creating a CustomizeAgent.")
|
| 202 |
+
|
| 203 |
+
# check if all the inputs are in the prompt (only used when prompt_template is not provided)
|
| 204 |
+
if prompt_template is None and inputs:
|
| 205 |
+
all_input_names = [input_item["name"] for input_item in inputs]
|
| 206 |
+
inputs_names_not_in_prompt = [name for name in all_input_names if f'{{{name}}}' not in prompt]
|
| 207 |
+
if inputs_names_not_in_prompt:
|
| 208 |
+
raise KeyError(f"The following inputs are not found in the prompt: {inputs_names_not_in_prompt}.")
|
| 209 |
+
|
| 210 |
+
# check if the output_parser is valid
|
| 211 |
+
if output_parser is not None:
|
| 212 |
+
self._check_output_parser(outputs, output_parser)
|
| 213 |
+
|
| 214 |
+
# check the parse_mode, parse_func, and title_format
|
| 215 |
+
if parse_mode not in PARSER_VALID_MODE:
|
| 216 |
+
raise ValueError(f"'{parse_mode}' is an invalid value for `parse_mode`. Available choices: {PARSER_VALID_MODE}.")
|
| 217 |
+
|
| 218 |
+
if parse_mode == "custom":
|
| 219 |
+
if parse_func is None:
|
| 220 |
+
raise ValueError("`parse_func` (a callable function with an input argument `content`) must be provided when `parse_mode` is 'custom'.")
|
| 221 |
+
|
| 222 |
+
if parse_func is not None:
|
| 223 |
+
if not callable(parse_func):
|
| 224 |
+
raise ValueError("`parse_func` must be a callable function with an input argument `content`.")
|
| 225 |
+
signature = inspect.signature(parse_func)
|
| 226 |
+
if "content" not in signature.parameters:
|
| 227 |
+
raise ValueError("`parse_func` must have an input argument `content`.")
|
| 228 |
+
if not PARSE_FUNCTION_REGISTRY.has_function(parse_func.__name__):
|
| 229 |
+
logger.warning(
|
| 230 |
+
f"parse function `{parse_func.__name__}` is not registered. This can cause issues when loading the agent from a file. "
|
| 231 |
+
f"It is recommended to register the parse function using `register_parse_function`:\n"
|
| 232 |
+
f"from evoagentx.core.registry import register_parse_function\n"
|
| 233 |
+
f"@register_parse_function\n"
|
| 234 |
+
f"def {parse_func.__name__}(content: str) -> dict:\n"
|
| 235 |
+
r" return {'output_name': output_value}"
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
if title_format is not None:
|
| 239 |
+
if parse_mode != "title":
|
| 240 |
+
logger.warning(f"`title_format` will not be used because `parse_mode` is '{parse_mode}', not 'title'. Set `parse_mode='title'` to use title formatting.")
|
| 241 |
+
if r'{title}' not in title_format:
|
| 242 |
+
raise ValueError(r"`title_format` must contain the placeholder `{title}`.")
|
| 243 |
+
|
| 244 |
+
def create_customize_action(
|
| 245 |
+
self,
|
| 246 |
+
name: str,
|
| 247 |
+
desc: str,
|
| 248 |
+
prompt: str,
|
| 249 |
+
prompt_template: PromptTemplate,
|
| 250 |
+
inputs: List[dict],
|
| 251 |
+
outputs: List[dict],
|
| 252 |
+
parse_mode: str,
|
| 253 |
+
parse_func: Optional[Callable] = None,
|
| 254 |
+
output_parser: Optional[ActionOutput] = None,
|
| 255 |
+
title_format: Optional[str] = "## {title}",
|
| 256 |
+
custom_output_format: Optional[str] = None,
|
| 257 |
+
tools: Optional[List[Toolkit]] = None,
|
| 258 |
+
max_tool_calls: Optional[int] = 5
|
| 259 |
+
) -> Action:
|
| 260 |
+
"""Create a custom action based on the provided specifications.
|
| 261 |
+
|
| 262 |
+
This method dynamically generates an Action class and instance with:
|
| 263 |
+
- Input parameters defined by the inputs specification
|
| 264 |
+
- Output format defined by the outputs specification
|
| 265 |
+
- Custom execution logic using the customize_action_execute function
|
| 266 |
+
- If tools is provided, returns a CustomizeAction action instead
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
name: Base name for the action
|
| 270 |
+
desc: Description of the action
|
| 271 |
+
prompt: Prompt template for the action
|
| 272 |
+
prompt_template: Prompt template for the action
|
| 273 |
+
inputs: List of input field specifications
|
| 274 |
+
outputs: List of output field specifications
|
| 275 |
+
parse_mode: Mode to use for parsing LLM output
|
| 276 |
+
parse_func: Optional custom parsing function
|
| 277 |
+
output_parser: Optional custom output parser class
|
| 278 |
+
tools: Optional list of tools
|
| 279 |
+
|
| 280 |
+
Returns:
|
| 281 |
+
A newly created Action instance
|
| 282 |
+
"""
|
| 283 |
+
assert prompt is not None or prompt_template is not None, "must provide `prompt` or `prompt_template` when creating CustomizeAgent"
|
| 284 |
+
|
| 285 |
+
# create the action input type
|
| 286 |
+
action_input_fields = {}
|
| 287 |
+
for field in inputs:
|
| 288 |
+
required = field.get("required", True)
|
| 289 |
+
if required:
|
| 290 |
+
action_input_fields[field["name"]] = (str, Field(description=field["description"]))
|
| 291 |
+
else:
|
| 292 |
+
action_input_fields[field["name"]] = (Optional[str], Field(default=None, description=field["description"]))
|
| 293 |
+
|
| 294 |
+
action_input_type = create_model(
|
| 295 |
+
self._get_unique_class_name(
|
| 296 |
+
generate_dynamic_class_name(name+" action_input")
|
| 297 |
+
),
|
| 298 |
+
**action_input_fields,
|
| 299 |
+
__base__=ActionInput
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
# create the action output type
|
| 303 |
+
if output_parser is None:
|
| 304 |
+
action_output_fields = {}
|
| 305 |
+
for field in outputs:
|
| 306 |
+
required = field.get("required", True)
|
| 307 |
+
if required:
|
| 308 |
+
action_output_fields[field["name"]] = (Any, Field(description=field["description"]))
|
| 309 |
+
else:
|
| 310 |
+
action_output_fields[field["name"]] = (Optional[Any], Field(default=None, description=field["description"]))
|
| 311 |
+
action_output_type = create_model(
|
| 312 |
+
self._get_unique_class_name(
|
| 313 |
+
generate_dynamic_class_name(name+" action_output")
|
| 314 |
+
),
|
| 315 |
+
**action_output_fields,
|
| 316 |
+
__base__=ActionOutput,
|
| 317 |
+
# get_content_data=customize_get_content_data,
|
| 318 |
+
# to_str=customize_to_str
|
| 319 |
+
)
|
| 320 |
+
else:
|
| 321 |
+
# self._check_output_parser(outputs, output_parser)
|
| 322 |
+
action_output_type = output_parser
|
| 323 |
+
|
| 324 |
+
action_cls_name = self._get_unique_class_name(
|
| 325 |
+
generate_dynamic_class_name(name+" action")
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
# Create CustomizeAction-based action with parsing properties only
|
| 329 |
+
customize_action_cls = create_model(
|
| 330 |
+
action_cls_name,
|
| 331 |
+
__base__=CustomizeAction
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
customize_action = customize_action_cls(
|
| 335 |
+
name=action_cls_name,
|
| 336 |
+
description=desc,
|
| 337 |
+
prompt=prompt,
|
| 338 |
+
prompt_template=prompt_template,
|
| 339 |
+
inputs_format=action_input_type,
|
| 340 |
+
outputs_format=action_output_type,
|
| 341 |
+
parse_mode=parse_mode,
|
| 342 |
+
parse_func=parse_func,
|
| 343 |
+
title_format=title_format,
|
| 344 |
+
custom_output_format=custom_output_format,
|
| 345 |
+
max_tool_try=max_tool_calls,
|
| 346 |
+
tools=tools
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
return customize_action
|
| 350 |
+
|
| 351 |
+
def _check_output_parser(self, outputs: List[dict], output_parser: Type[ActionOutput]):
|
| 352 |
+
|
| 353 |
+
if output_parser is not None:
|
| 354 |
+
if not isinstance(output_parser, type):
|
| 355 |
+
raise TypeError(f"output_parser must be a class, but got {type(output_parser).__name__}")
|
| 356 |
+
if not issubclass(output_parser, ActionOutput):
|
| 357 |
+
raise ValueError(f"`output_parser` must be a class and a subclass of `ActionOutput`, but got `{output_parser.__name__}`.")
|
| 358 |
+
|
| 359 |
+
# check if the output parser is compatible with the outputs
|
| 360 |
+
output_parser_fields = output_parser.get_attrs()
|
| 361 |
+
all_output_names = [output_item["name"] for output_item in outputs]
|
| 362 |
+
for field in output_parser_fields:
|
| 363 |
+
if field not in all_output_names:
|
| 364 |
+
raise ValueError(
|
| 365 |
+
f"The output parser `{output_parser.__name__}` is not compatible with the `outputs`.\n"
|
| 366 |
+
f"The output parser fields: {output_parser_fields}.\n"
|
| 367 |
+
f"The outputs: {all_output_names}.\n"
|
| 368 |
+
f"All the fields in the output parser must be present in the outputs."
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
def _store_inputs_outputs_info(self, inputs: List[dict], outputs: List[dict], tool_map: Dict[str, Union[Toolkit, Tool]]):
|
| 372 |
+
|
| 373 |
+
self._action_input_types, self._action_input_required = {}, {}
|
| 374 |
+
for field in inputs:
|
| 375 |
+
required = field.get("required", True)
|
| 376 |
+
self._action_input_types[field["name"]] = field["type"]
|
| 377 |
+
self._action_input_required[field["name"]] = required
|
| 378 |
+
self._action_output_types, self._action_output_required = {}, {}
|
| 379 |
+
for field in outputs:
|
| 380 |
+
required = field.get("required", True)
|
| 381 |
+
self._action_output_types[field["name"]] = field["type"]
|
| 382 |
+
self._action_output_required[field["name"]] = required
|
| 383 |
+
self._raw_tool_map = tool_map
|
| 384 |
+
|
| 385 |
+
def __call__(self, inputs: dict = None, return_msg_type: MessageType = MessageType.UNKNOWN, **kwargs) -> Message:
|
| 386 |
+
"""
|
| 387 |
+
Call the customize action.
|
| 388 |
+
|
| 389 |
+
Args:
|
| 390 |
+
inputs (dict): The inputs to the customize action.
|
| 391 |
+
**kwargs (Any): Additional keyword arguments.
|
| 392 |
+
|
| 393 |
+
Returns:
|
| 394 |
+
ActionOutput: The output of the customize action.
|
| 395 |
+
"""
|
| 396 |
+
# return self.execute(action_name=self.customize_action_name, action_input_data=inputs, **kwargs)
|
| 397 |
+
inputs = inputs or {}
|
| 398 |
+
return super().__call__(action_name=self.customize_action_name, action_input_data=inputs, return_msg_type=return_msg_type, **kwargs)
|
| 399 |
+
|
| 400 |
+
def get_customize_agent_info(self) -> dict:
|
| 401 |
+
"""
|
| 402 |
+
Get the information of the customize agent.
|
| 403 |
+
"""
|
| 404 |
+
customize_action = self.get_action(self.customize_action_name)
|
| 405 |
+
action_input_params = customize_action.inputs_format.get_attrs()
|
| 406 |
+
action_output_params = customize_action.outputs_format.get_attrs()
|
| 407 |
+
|
| 408 |
+
config = {
|
| 409 |
+
"class_name": "CustomizeAgent",
|
| 410 |
+
"name": self.name,
|
| 411 |
+
"description": self.description,
|
| 412 |
+
"prompt": customize_action.prompt,
|
| 413 |
+
"prompt_template": customize_action.prompt_template.to_dict() if customize_action.prompt_template is not None else None,
|
| 414 |
+
# "llm_config": self.llm_config.to_dict(exclude_none=True),
|
| 415 |
+
"inputs": [
|
| 416 |
+
{
|
| 417 |
+
"name": field,
|
| 418 |
+
"type": self._action_input_types[field],
|
| 419 |
+
"description": field_info.description,
|
| 420 |
+
"required": self._action_input_required[field]
|
| 421 |
+
}
|
| 422 |
+
for field, field_info in customize_action.inputs_format.model_fields.items() if field in action_input_params
|
| 423 |
+
],
|
| 424 |
+
"outputs": [
|
| 425 |
+
{
|
| 426 |
+
"name": field,
|
| 427 |
+
"type": self._action_output_types[field],
|
| 428 |
+
"description": field_info.description,
|
| 429 |
+
"required": self._action_output_required[field]
|
| 430 |
+
}
|
| 431 |
+
for field, field_info in customize_action.outputs_format.model_fields.items() if field in action_output_params
|
| 432 |
+
],
|
| 433 |
+
"system_prompt": self.system_prompt,
|
| 434 |
+
"output_parser": self.output_parser.__name__ if self.output_parser is not None else None,
|
| 435 |
+
"parse_mode": self.parse_mode,
|
| 436 |
+
"parse_func": self.parse_func.__name__ if self.parse_func is not None else None,
|
| 437 |
+
"title_format": self.title_format,
|
| 438 |
+
"tool_names": [tool.name for tool in customize_action.tools] if customize_action.tools else [],
|
| 439 |
+
"max_tool_calls": self.max_tool_calls,
|
| 440 |
+
"custom_output_format": self.custom_output_format
|
| 441 |
+
}
|
| 442 |
+
return config
|
| 443 |
+
|
| 444 |
+
@classmethod
|
| 445 |
+
def load_module(cls, path: str, llm_config: LLMConfig = None, tools: List[Union[Toolkit, Tool]] = None, **kwargs) -> "CustomizeAgent":
|
| 446 |
+
"""
|
| 447 |
+
load the agent from local storage. Must provide `llm_config` when loading the agent from local storage.
|
| 448 |
+
If tools is provided, tool_names must also be provided.
|
| 449 |
+
|
| 450 |
+
Args:
|
| 451 |
+
path: The path of the file
|
| 452 |
+
llm_config: The LLMConfig instance
|
| 453 |
+
tool_names: List of tool names to be used by the agent. If provided,
|
| 454 |
+
tool_dict: Dictionary mapping tool names to Tool instances. Required when tool_names is provided.
|
| 455 |
+
|
| 456 |
+
Returns:
|
| 457 |
+
CustomizeAgent: The loaded agent instance
|
| 458 |
+
"""
|
| 459 |
+
match_dict = {}
|
| 460 |
+
agent = super().load_module(path=path, llm_config=llm_config, **kwargs)
|
| 461 |
+
if tools:
|
| 462 |
+
match_dict = {tool.name:tool for tool in tools}
|
| 463 |
+
if agent.get("tool_names", None):
|
| 464 |
+
assert tools is not None, "must provide `tools: List[Union[Toolkit, Tool]]` when using `load_module` or `from_file` to load the agent from local storage and `tool_names` is not None or empty"
|
| 465 |
+
added_tools = [match_dict[tool_name] for tool_name in agent["tool_names"]]
|
| 466 |
+
agent["tools"] = [tool if isinstance(tool, Toolkit) else Toolkit(name=tool.name, tools=[tool]) for tool in added_tools]
|
| 467 |
+
return agent
|
| 468 |
+
|
| 469 |
+
def save_module(self, path: str, ignore: List[str] = [], **kwargs)-> str:
|
| 470 |
+
"""Save the customize agent's configuration to a JSON file.
|
| 471 |
+
|
| 472 |
+
Args:
|
| 473 |
+
path: File path where the configuration should be saved
|
| 474 |
+
ignore: List of keys to exclude from the saved configuration
|
| 475 |
+
**kwargs (Any): Additional parameters for the save operation
|
| 476 |
+
|
| 477 |
+
Returns:
|
| 478 |
+
The path where the configuration was saved
|
| 479 |
+
"""
|
| 480 |
+
config = self.get_customize_agent_info()
|
| 481 |
+
|
| 482 |
+
for ignore_key in ignore:
|
| 483 |
+
config.pop(ignore_key, None)
|
| 484 |
+
|
| 485 |
+
# Save to JSON file
|
| 486 |
+
make_parent_folder(path)
|
| 487 |
+
with open(path, 'w', encoding='utf-8') as f:
|
| 488 |
+
json.dump(config, f, indent=4, ensure_ascii=False)
|
| 489 |
+
|
| 490 |
+
return path
|
| 491 |
+
|
| 492 |
+
def _get_unique_class_name(self, candidate_name: str) -> str:
|
| 493 |
+
"""
|
| 494 |
+
Get a unique class name by checking if it already exists in the registry.
|
| 495 |
+
If it does, append "Vx" to make it unique.
|
| 496 |
+
"""
|
| 497 |
+
if not MODULE_REGISTRY.has_module(candidate_name):
|
| 498 |
+
return candidate_name
|
| 499 |
+
|
| 500 |
+
i = 1
|
| 501 |
+
while True:
|
| 502 |
+
unique_name = f"{candidate_name}V{i}"
|
| 503 |
+
if not MODULE_REGISTRY.has_module(unique_name):
|
| 504 |
+
break
|
| 505 |
+
i += 1
|
| 506 |
+
return unique_name
|
| 507 |
+
|
| 508 |
+
def get_config(self) -> dict:
|
| 509 |
+
"""
|
| 510 |
+
Get a dictionary containing all necessary configuration to recreate this agent.
|
| 511 |
+
|
| 512 |
+
Returns:
|
| 513 |
+
dict: A configuration dictionary that can be used to initialize a new Agent instance
|
| 514 |
+
with the same properties as this one.
|
| 515 |
+
"""
|
| 516 |
+
config = self.get_customize_agent_info()
|
| 517 |
+
config["llm_config"] = self.llm_config.to_dict()
|
| 518 |
+
tool_names = config.pop("tool_names", None)
|
| 519 |
+
if tool_names:
|
| 520 |
+
config["tools"] = [self._raw_tool_map[name] for name in tool_names]
|
| 521 |
+
return config
|
| 522 |
+
|
evoagentx/agents/long_term_memory_agent.py
ADDED
|
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import asyncio
|
| 3 |
+
from uuid import uuid4
|
| 4 |
+
from pydantic import Field
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from typing import Optional, List, Tuple, Dict, Union
|
| 7 |
+
|
| 8 |
+
from evoagentx.agents import Agent
|
| 9 |
+
from evoagentx.core.parser import Parser
|
| 10 |
+
from evoagentx.models import BaseLLM
|
| 11 |
+
from evoagentx.core.logging import logger
|
| 12 |
+
from evoagentx.models import OpenAILLMConfig
|
| 13 |
+
from evoagentx.storages.base import StorageHandler
|
| 14 |
+
from evoagentx.core.message import Message, MessageType
|
| 15 |
+
from evoagentx.memory.memory_manager import MemoryManager
|
| 16 |
+
from evoagentx.memory.long_term_memory import LongTermMemory
|
| 17 |
+
from evoagentx.actions.action import Action, ActionInput, ActionOutput
|
| 18 |
+
from evoagentx.rag.rag_config import RAGConfig
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class MemoryActionInput(ActionInput):
|
| 22 |
+
user_prompt: str = Field(description="The user's input prompt")
|
| 23 |
+
conversation_id: Optional[str] = Field(default=None, description="ID for tracking conversation")
|
| 24 |
+
top_k: Optional[int] = Field(default=5, description="Number of memory results to retrieve")
|
| 25 |
+
metadata_filters: Optional[Dict] = Field(default=None, description="Filters for memory retrieval")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class MemoryActionOutput(ActionOutput):
|
| 29 |
+
response: str = Field(description="The agent's response based on memory and prompt")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class MemoryAction(Action):
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
name: str = "MemoryAction",
|
| 36 |
+
description: str = "Action that processes user input with long-term memory context",
|
| 37 |
+
prompt: str = "Based on the following context and user prompt, provide a relevant response:\n\nContext: {context}\n\nUser Prompt: {user_prompt}\n\n",
|
| 38 |
+
inputs_format: ActionInput = None,
|
| 39 |
+
outputs_format: ActionOutput = None,
|
| 40 |
+
**kwargs
|
| 41 |
+
):
|
| 42 |
+
inputs_format = inputs_format or MemoryActionInput
|
| 43 |
+
outputs_format = outputs_format or MemoryActionOutput
|
| 44 |
+
super().__init__(
|
| 45 |
+
name=name,
|
| 46 |
+
description=description,
|
| 47 |
+
prompt=prompt,
|
| 48 |
+
inputs_format=inputs_format,
|
| 49 |
+
outputs_format=outputs_format,
|
| 50 |
+
**kwargs
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
def execute(self, llm: BaseLLM | None = None,
|
| 54 |
+
inputs: Dict | None = None,
|
| 55 |
+
sys_msg: str | None = None,
|
| 56 |
+
return_prompt: bool = False,
|
| 57 |
+
memory_manager: Optional[MemoryManager] = None,
|
| 58 |
+
**kwargs
|
| 59 |
+
) -> Parser | Tuple[Parser | str] | None:
|
| 60 |
+
return asyncio.run(self.async_execute(llm, inputs, sys_msg, return_prompt, memory_manager, **kwargs))
|
| 61 |
+
|
| 62 |
+
async def async_execute(
|
| 63 |
+
self,
|
| 64 |
+
llm: Optional["BaseLLM"] = None,
|
| 65 |
+
inputs: Optional[Dict] = None,
|
| 66 |
+
sys_msg: Optional[str] = None,
|
| 67 |
+
return_prompt: bool = False,
|
| 68 |
+
memory_manager: Optional[MemoryManager] = None,
|
| 69 |
+
**kwargs
|
| 70 |
+
) -> Union[MemoryActionOutput, tuple]:
|
| 71 |
+
if not memory_manager:
|
| 72 |
+
logger.error("MemoryManager is required for MemoryAction execution")
|
| 73 |
+
raise ValueError("MemoryManager is required for MemoryAction")
|
| 74 |
+
|
| 75 |
+
action_input = self.inputs_format(**inputs)
|
| 76 |
+
user_prompt = action_input.user_prompt
|
| 77 |
+
conversation_id = action_input.conversation_id
|
| 78 |
+
if not conversation_id:
|
| 79 |
+
conversation_id = str(uuid4())
|
| 80 |
+
logger.warning("No conversation_id provided; generated a new UUID4 for this session")
|
| 81 |
+
top_k = action_input.top_k
|
| 82 |
+
metadata_filters = action_input.metadata_filters
|
| 83 |
+
|
| 84 |
+
message = await memory_manager.create_conversation_message(
|
| 85 |
+
user_prompt=user_prompt,
|
| 86 |
+
conversation_id=conversation_id,
|
| 87 |
+
top_k=top_k,
|
| 88 |
+
metadata_filters=metadata_filters
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
action_input_attrs = self.inputs_format.get_attrs()
|
| 92 |
+
action_input_data = {attr: getattr(action_input, attr, "undefined") for attr in action_input_attrs}
|
| 93 |
+
action_input_data["context"] = message.content
|
| 94 |
+
prompt = self.prompt.format(**action_input_data)
|
| 95 |
+
logger.info(f"The New Created Message by LongTermMemory:\n\n{prompt}")
|
| 96 |
+
|
| 97 |
+
output = await llm.async_generate(
|
| 98 |
+
prompt=prompt,
|
| 99 |
+
system_message=sys_msg,
|
| 100 |
+
parser=self.outputs_format,
|
| 101 |
+
parse_mode='str'
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
response_message = Message(
|
| 105 |
+
content=output.content,
|
| 106 |
+
msg_type=MessageType.RESPONSE,
|
| 107 |
+
timestamp=datetime.now().isoformat(),
|
| 108 |
+
conversation_id=conversation_id,
|
| 109 |
+
memory_ids=message.memory_ids
|
| 110 |
+
)
|
| 111 |
+
memory_ids = await memory_manager.handle_memory(
|
| 112 |
+
action="add",
|
| 113 |
+
data=response_message,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# Prepare the final output
|
| 117 |
+
final_output = self.outputs_format(
|
| 118 |
+
response=output.content,
|
| 119 |
+
memory_ids=memory_ids
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
if return_prompt:
|
| 123 |
+
return final_output, prompt
|
| 124 |
+
return final_output
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class MemoryAgent(Agent):
|
| 128 |
+
memory_manager: Optional[MemoryManager] = Field(default=None, description="Manager for long-term memory operations")
|
| 129 |
+
inputs: List[Dict] = Field(default_factory=list, description="Input specifications for the memory action")
|
| 130 |
+
outputs: List[Dict] = Field(default_factory=list, description="Output specifications for the memory action")
|
| 131 |
+
|
| 132 |
+
def __init__(
|
| 133 |
+
self,
|
| 134 |
+
name: str = "MemoryAgent",
|
| 135 |
+
description: str = "An agent that uses long-term memory to provide context-aware responses",
|
| 136 |
+
inputs: Optional[List[Dict]] = None,
|
| 137 |
+
outputs: Optional[List[Dict]] = None,
|
| 138 |
+
llm_config: Optional[OpenAILLMConfig] = None,
|
| 139 |
+
storage_handler: Optional[StorageHandler] = None,
|
| 140 |
+
rag_config: Optional[RAGConfig] = None,
|
| 141 |
+
conversation_id: Optional[str] = None,
|
| 142 |
+
system_prompt: Optional[str] = None,
|
| 143 |
+
prompt: str = "Based on the following context and user prompt, provide a relevant response:\n\nContext: {context}\n\nUser Prompt: {user_prompt}",
|
| 144 |
+
**kwargs
|
| 145 |
+
):
|
| 146 |
+
# Define inputs and outputs inspired by CustomizeAgent
|
| 147 |
+
inputs = inputs or []
|
| 148 |
+
outputs = outputs or []
|
| 149 |
+
|
| 150 |
+
# Initialize base Agent with provided parameters
|
| 151 |
+
super().__init__(
|
| 152 |
+
name=name,
|
| 153 |
+
description=description,
|
| 154 |
+
llm_config=llm_config,
|
| 155 |
+
system_prompt=system_prompt,
|
| 156 |
+
storage_handler=storage_handler,
|
| 157 |
+
inputs=inputs,
|
| 158 |
+
outputs=outputs,
|
| 159 |
+
**kwargs
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
self.long_term_memory = LongTermMemory(
|
| 163 |
+
storage_handler=storage_handler,
|
| 164 |
+
rag_config=rag_config,
|
| 165 |
+
default_corpus_id=conversation_id
|
| 166 |
+
)
|
| 167 |
+
self.memory_manager = MemoryManager(
|
| 168 |
+
memory=self.long_term_memory,
|
| 169 |
+
llm=llm_config.get_llm() if llm_config else None,
|
| 170 |
+
use_llm_management=True
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# Initialize inputs and outputs
|
| 174 |
+
self.inputs = inputs
|
| 175 |
+
self.outputs = outputs
|
| 176 |
+
|
| 177 |
+
# Initialize actions list and add MemoryAction
|
| 178 |
+
self.actions = []
|
| 179 |
+
self._action_map = {}
|
| 180 |
+
memory_action = MemoryAction(
|
| 181 |
+
name="MemoryAction",
|
| 182 |
+
description="Action that processes user input with long-term memory context",
|
| 183 |
+
prompt=prompt,
|
| 184 |
+
inputs_format=MemoryActionInput,
|
| 185 |
+
outputs_format=MemoryActionOutput
|
| 186 |
+
)
|
| 187 |
+
self.add_action(memory_action)
|
| 188 |
+
|
| 189 |
+
def _create_output_message(
|
| 190 |
+
self,
|
| 191 |
+
action_output,
|
| 192 |
+
action_name: str,
|
| 193 |
+
action_input_data: Optional[Dict],
|
| 194 |
+
prompt: str,
|
| 195 |
+
return_msg_type: MessageType = MessageType.RESPONSE,
|
| 196 |
+
**kwargs
|
| 197 |
+
) -> Message:
|
| 198 |
+
msg = super()._create_output_message(
|
| 199 |
+
action_output=action_output,
|
| 200 |
+
action_name=action_name,
|
| 201 |
+
action_input_data=action_input_data,
|
| 202 |
+
prompt=prompt,
|
| 203 |
+
return_msg_type=return_msg_type,
|
| 204 |
+
**kwargs
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
if action_input_data and "user_prompt" in action_input_data:
|
| 208 |
+
user_msg = Message(
|
| 209 |
+
content=action_input_data["user_prompt"],
|
| 210 |
+
msg_type=MessageType.REQUEST,
|
| 211 |
+
conversation_id=msg.conversation_id
|
| 212 |
+
)
|
| 213 |
+
asyncio.create_task(self.memory_manager.handle_memory(action="add", data=user_msg))
|
| 214 |
+
|
| 215 |
+
response_msg = Message(
|
| 216 |
+
content=action_output.response if hasattr(action_output, "response") else str(action_output),
|
| 217 |
+
msg_type=MessageType.RESPONSE,
|
| 218 |
+
conversation_id=msg.conversation_id
|
| 219 |
+
)
|
| 220 |
+
asyncio.create_task(self.memory_manager.handle_memory(action="add", data=response_msg))
|
| 221 |
+
|
| 222 |
+
return msg
|
| 223 |
+
|
| 224 |
+
async def async_execute(
|
| 225 |
+
self,
|
| 226 |
+
action_name: str,
|
| 227 |
+
msgs: Optional[List[Message]] = None,
|
| 228 |
+
action_input_data: Optional[Dict] = None,
|
| 229 |
+
return_msg_type: Optional[MessageType] = MessageType.RESPONSE,
|
| 230 |
+
return_action_input_data: Optional[bool] = False,
|
| 231 |
+
**kwargs
|
| 232 |
+
) -> Union[Message, Tuple[Message, Dict]]:
|
| 233 |
+
"""
|
| 234 |
+
Execute an action asynchronously with memory management.
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
action_name: Name of the action to execute
|
| 238 |
+
msgs: Optional list of messages providing context
|
| 239 |
+
action_input_data: Optional input data for the action
|
| 240 |
+
return_msg_type: Message type for the return message
|
| 241 |
+
return_action_input_data: Whether to return the action input data
|
| 242 |
+
**kwargs: Additional parameters
|
| 243 |
+
|
| 244 |
+
Returns:
|
| 245 |
+
Message or tuple: The execution result, optionally with input data
|
| 246 |
+
"""
|
| 247 |
+
action, action_input_data = self._prepare_execution(
|
| 248 |
+
action_name=action_name,
|
| 249 |
+
msgs=msgs,
|
| 250 |
+
action_input_data=action_input_data,
|
| 251 |
+
**kwargs
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
# Execute action with memory_manager
|
| 255 |
+
execution_results = await action.async_execute(
|
| 256 |
+
llm=self.llm,
|
| 257 |
+
inputs=action_input_data,
|
| 258 |
+
sys_msg=self.system_prompt,
|
| 259 |
+
return_prompt=True,
|
| 260 |
+
memory_manager=self.memory_manager,
|
| 261 |
+
**kwargs
|
| 262 |
+
)
|
| 263 |
+
action_output, prompt = execution_results
|
| 264 |
+
|
| 265 |
+
message = self._create_output_message(
|
| 266 |
+
action_output=action_output,
|
| 267 |
+
prompt=prompt,
|
| 268 |
+
action_name=action_name,
|
| 269 |
+
return_msg_type=return_msg_type,
|
| 270 |
+
action_input_data=action_input_data,
|
| 271 |
+
**kwargs
|
| 272 |
+
)
|
| 273 |
+
if return_action_input_data:
|
| 274 |
+
return message, action_input_data
|
| 275 |
+
return message
|
| 276 |
+
|
| 277 |
+
def execute(
|
| 278 |
+
self,
|
| 279 |
+
action_name: str,
|
| 280 |
+
msgs: Optional[List[Message]] = None,
|
| 281 |
+
action_input_data: Optional[Dict] = None,
|
| 282 |
+
return_msg_type: Optional[MessageType] = MessageType.RESPONSE,
|
| 283 |
+
return_action_input_data: Optional[bool] = False,
|
| 284 |
+
**kwargs
|
| 285 |
+
) -> Union[Message, Tuple[Message, Dict]]:
|
| 286 |
+
"""
|
| 287 |
+
Execute an action synchronously with memory management.
|
| 288 |
+
|
| 289 |
+
Args:
|
| 290 |
+
action_name: Name of the action to execute
|
| 291 |
+
msgs: Optional list of messages providing context
|
| 292 |
+
action_input_data: Optional input data for the action
|
| 293 |
+
return_msg_type: Message type for the return message
|
| 294 |
+
return_action_input_data: Whether to return the action input data
|
| 295 |
+
**kwargs: Additional parameters
|
| 296 |
+
|
| 297 |
+
Returns:
|
| 298 |
+
Message or tuple: The execution result, optionally with input data
|
| 299 |
+
"""
|
| 300 |
+
action, action_input_data = self._prepare_execution(
|
| 301 |
+
action_name=action_name,
|
| 302 |
+
msgs=msgs,
|
| 303 |
+
action_input_data=action_input_data,
|
| 304 |
+
**kwargs
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
# Execute action with memory_manager
|
| 308 |
+
execution_results = action.execute(
|
| 309 |
+
llm=self.llm,
|
| 310 |
+
inputs=action_input_data,
|
| 311 |
+
sys_msg=self.system_prompt,
|
| 312 |
+
return_prompt=True,
|
| 313 |
+
memory_manager=self.memory_manager,
|
| 314 |
+
**kwargs
|
| 315 |
+
)
|
| 316 |
+
action_output, prompt = execution_results
|
| 317 |
+
|
| 318 |
+
message = self._create_output_message(
|
| 319 |
+
action_output=action_output,
|
| 320 |
+
prompt=prompt,
|
| 321 |
+
action_name=action_name,
|
| 322 |
+
return_msg_type=return_msg_type,
|
| 323 |
+
action_input_data=action_input_data,
|
| 324 |
+
**kwargs
|
| 325 |
+
)
|
| 326 |
+
if return_action_input_data:
|
| 327 |
+
return message, action_input_data
|
| 328 |
+
return message
|
| 329 |
+
|
| 330 |
+
def chat(
|
| 331 |
+
self,
|
| 332 |
+
user_prompt: str,
|
| 333 |
+
*,
|
| 334 |
+
conversation_id: Optional[str] = None,
|
| 335 |
+
top_k: Optional[int] = None,
|
| 336 |
+
metadata_filters: Optional[dict] = None,
|
| 337 |
+
return_message: bool = True,
|
| 338 |
+
**kwargs
|
| 339 |
+
):
|
| 340 |
+
action_input_data = {
|
| 341 |
+
"user_prompt": user_prompt,
|
| 342 |
+
"conversation_id": conversation_id or self._default_conversation_id(),
|
| 343 |
+
"top_k": top_k if top_k is not None else 3,
|
| 344 |
+
"metadata_filters": metadata_filters or {},
|
| 345 |
+
}
|
| 346 |
+
msg = self.execute(
|
| 347 |
+
action_name="MemoryAction",
|
| 348 |
+
action_input_data=action_input_data,
|
| 349 |
+
return_msg_type=MessageType.RESPONSE,
|
| 350 |
+
**kwargs
|
| 351 |
+
)
|
| 352 |
+
return msg if return_message else (getattr(msg, "content", None) or str(msg))
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
async def async_chat(
|
| 356 |
+
self,
|
| 357 |
+
user_prompt: str,
|
| 358 |
+
*,
|
| 359 |
+
conversation_id: Optional[str] = None,
|
| 360 |
+
top_k: Optional[int] = None,
|
| 361 |
+
metadata_filters: Optional[dict] = None,
|
| 362 |
+
return_message: bool = True,
|
| 363 |
+
**kwargs
|
| 364 |
+
):
|
| 365 |
+
action_input_data = {
|
| 366 |
+
"user_prompt": user_prompt,
|
| 367 |
+
"conversation_id": conversation_id or self._default_conversation_id(),
|
| 368 |
+
"top_k": top_k if top_k is not None else 3,
|
| 369 |
+
"metadata_filters": metadata_filters or {},
|
| 370 |
+
}
|
| 371 |
+
msg = await self.async_execute(
|
| 372 |
+
action_name="MemoryAction",
|
| 373 |
+
action_input_data=action_input_data,
|
| 374 |
+
return_msg_type=MessageType.RESPONSE,
|
| 375 |
+
**kwargs
|
| 376 |
+
)
|
| 377 |
+
return msg if return_message else (getattr(msg, "content", None) or str(msg))
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def _default_conversation_id(self) -> str:
|
| 381 |
+
"""
|
| 382 |
+
Session scope: By default, a new uuid4() is returned (new session).
|
| 383 |
+
User/global scope: Reuse LongTermMemory.default_corpus_id (stable namespace).
|
| 384 |
+
Note: The final ID is still uniformly managed by MemoryAgent._prepare_execution() (which will override based on the scope).
|
| 385 |
+
"""
|
| 386 |
+
scope = getattr(self, "conversation_scope", "session")
|
| 387 |
+
if scope == "session":
|
| 388 |
+
return str(uuid4())
|
| 389 |
+
return getattr(getattr(self, "long_term_memory", None), "default_corpus_id", None) or "global_corpus"
|
| 390 |
+
|
| 391 |
+
async def interactive_chat(
|
| 392 |
+
self,
|
| 393 |
+
conversation_id: Optional[str] = None,
|
| 394 |
+
top_k: int = 3,
|
| 395 |
+
metadata_filters: Optional[dict] = None
|
| 396 |
+
):
|
| 397 |
+
"""
|
| 398 |
+
In interactive chat, each round of input will:
|
| 399 |
+
1. Retrieve from memory
|
| 400 |
+
2. Generate a response based on historical context
|
| 401 |
+
3. Write the input/output to long-term memory and refresh the index
|
| 402 |
+
"""
|
| 403 |
+
conversation_id = conversation_id or self._default_conversation_id()
|
| 404 |
+
metadata_filters = metadata_filters or {}
|
| 405 |
+
|
| 406 |
+
print("💬 MemoryAgent has been started (type 'exit' to quit)\n")
|
| 407 |
+
|
| 408 |
+
while True:
|
| 409 |
+
user_prompt = input("You: ").strip()
|
| 410 |
+
if user_prompt.lower() in ["exit", "quit"]:
|
| 411 |
+
print("🔚 Conversation ended")
|
| 412 |
+
break
|
| 413 |
+
|
| 414 |
+
# Retrieve historical context
|
| 415 |
+
retrieved_memories = await self.memory_manager.handle_memory(
|
| 416 |
+
action="search",
|
| 417 |
+
user_prompt=user_prompt,
|
| 418 |
+
top_k=top_k,
|
| 419 |
+
metadata_filters=metadata_filters
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
context_texts = []
|
| 423 |
+
for msg, _ in retrieved_memories:
|
| 424 |
+
if hasattr(msg, "content") and msg.content:
|
| 425 |
+
context_texts.append(msg.content)
|
| 426 |
+
context_str = "\n".join(context_texts)
|
| 427 |
+
|
| 428 |
+
# if context_str:
|
| 429 |
+
# print(f"📖 Retrieved context from memory:\n{context_str}\n")
|
| 430 |
+
|
| 431 |
+
# Concatenate the historical context into the user input and invoke async_chat
|
| 432 |
+
full_prompt = f"Context:\n{context_str}\n\nUser: {user_prompt}" if context_str else user_prompt
|
| 433 |
+
msg = await self.async_chat(
|
| 434 |
+
user_prompt=full_prompt,
|
| 435 |
+
conversation_id=conversation_id,
|
| 436 |
+
top_k=top_k,
|
| 437 |
+
metadata_filters=metadata_filters
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
print(f"Agent: {msg.content}\n")
|
| 441 |
+
|
| 442 |
+
# Refresh the index to ensure it can be retrieved in the next round
|
| 443 |
+
if hasattr(self.memory_manager, "handle_memory_flush"):
|
| 444 |
+
await self.memory_manager.handle_memory_flush()
|
| 445 |
+
else:
|
| 446 |
+
await asyncio.sleep(0.1)
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
def save_module(self, path: str, ignore: List[str] = ["llm", "llm_config", "memory_manager"], **kwargs) -> str:
|
| 451 |
+
"""
|
| 452 |
+
Save the agent's configuration to a JSON file, excluding memory_manager by default.
|
| 453 |
+
|
| 454 |
+
Args:
|
| 455 |
+
path: File path to save the configuration
|
| 456 |
+
ignore: List of keys to exclude from the saved configuration
|
| 457 |
+
**kwargs: Additional parameters for saving
|
| 458 |
+
|
| 459 |
+
Returns:
|
| 460 |
+
str: The path where the configuration was saved
|
| 461 |
+
"""
|
| 462 |
+
return super().save_module(path=path, ignore=ignore, **kwargs)
|
| 463 |
+
|
| 464 |
+
@classmethod
|
| 465 |
+
def from_file(cls, path: str, llm_config: OpenAILLMConfig, storage_handler: Optional[StorageHandler] = None, rag_config: Optional[RAGConfig] = None, **kwargs) -> "MemoryAgent":
|
| 466 |
+
"""
|
| 467 |
+
Load a MemoryAgent from a JSON configuration file.
|
| 468 |
+
|
| 469 |
+
Args:
|
| 470 |
+
path: Path to the JSON configuration file
|
| 471 |
+
llm_config: LLM configuration
|
| 472 |
+
storage_handler: Optional storage handler
|
| 473 |
+
rag_config: Optional RAG configuration
|
| 474 |
+
**kwargs: Additional parameters
|
| 475 |
+
|
| 476 |
+
Returns:
|
| 477 |
+
MemoryAgent: The loaded agent instance
|
| 478 |
+
"""
|
| 479 |
+
with open(path, 'r', encoding='utf-8') as f:
|
| 480 |
+
config = json.load(f)
|
| 481 |
+
return cls(
|
| 482 |
+
name=config.get("name", "MemoryAgent"),
|
| 483 |
+
description=config.get("description", "An agent that uses long-term memory"),
|
| 484 |
+
llm_config=llm_config,
|
| 485 |
+
storage_handler=storage_handler,
|
| 486 |
+
rag_config=rag_config,
|
| 487 |
+
system_prompt=config.get("system_prompt"),
|
| 488 |
+
prompt=config.get("prompt"),
|
| 489 |
+
use_long_term_memory=config.get("use_long_term_memory", True),
|
| 490 |
+
**kwargs
|
| 491 |
+
)
|
evoagentx/agents/task_planner.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .agent import Agent
|
| 2 |
+
from ..actions.task_planning import TaskPlanning
|
| 3 |
+
from ..prompts.task_planner import TASK_PLANNER
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class TaskPlanner(Agent):
|
| 7 |
+
"""An agent responsible for planning and decomposing high-level tasks into smaller sub-tasks.
|
| 8 |
+
|
| 9 |
+
The TaskPlanner agent analyzes complex goals and breaks them down into a structured
|
| 10 |
+
sequence of smaller, more manageable tasks. It serves as a critical component in the
|
| 11 |
+
workflow by creating execution plans that other specialized agents can follow.
|
| 12 |
+
|
| 13 |
+
Attributes:
|
| 14 |
+
name (str): Name of the task planner agent, defaults to the value in TASK_PLANNER
|
| 15 |
+
description (str): Description of the agent's purpose and capabilities, defaults to the value in TASK_PLANNER
|
| 16 |
+
system_prompt (str): System prompt guiding the agent's behavior, defaults to the value in TASK_PLANNER
|
| 17 |
+
actions (List[Action]): List of actions the agent can perform, defaults to [TaskPlanning()]
|
| 18 |
+
"""
|
| 19 |
+
def __init__(self, **kwargs):
|
| 20 |
+
|
| 21 |
+
name = kwargs.pop("name") if "name" in kwargs else TASK_PLANNER["name"]
|
| 22 |
+
description = kwargs.pop("description") if "description" in kwargs else TASK_PLANNER["description"]
|
| 23 |
+
system_prompt = kwargs.pop("system_prompt") if "system_prompt" in kwargs else TASK_PLANNER["system_prompt"]
|
| 24 |
+
actions = kwargs.pop("actions") if "actions" in kwargs else [TaskPlanning()]
|
| 25 |
+
super().__init__(name=name, description=description, system_prompt=system_prompt, actions=actions, **kwargs)
|
| 26 |
+
|
| 27 |
+
@property
|
| 28 |
+
def task_planning_action_name(self):
|
| 29 |
+
"""Get the name of the TaskPlanning action associated with this agent.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
The name of the TaskPlanning action in this agent's action registry
|
| 33 |
+
"""
|
| 34 |
+
return self.get_action_name(action_cls=TaskPlanning)
|
| 35 |
+
|
evoagentx/agents/workflow_reviewer.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, List
|
| 2 |
+
from .agent import Agent
|
| 3 |
+
from ..core.message import Message # MessageType
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class WorkFlowReviewer(Agent):
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
Placeholder for the Agent that is responsible for reviewing workflow plans and agents.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
def execute(self, action_name: str, msgs: Optional[List[Message]] = None, action_input_data: Optional[dict] = None, **kwargs) -> Message:
|
| 13 |
+
|
| 14 |
+
raise NotImplementedError("WorkflowReviewer is not implemented yet.")
|
evoagentx/app/__init__.py
ADDED
|
File without changes
|
evoagentx/app/api.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API routes for EvoAgentX application.
|
| 3 |
+
"""
|
| 4 |
+
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks
|
| 5 |
+
from fastapi.security import OAuth2PasswordRequestForm
|
| 6 |
+
from typing import List, Dict, Any # , Optional
|
| 7 |
+
from fastapi import Response
|
| 8 |
+
|
| 9 |
+
from datetime import timedelta
|
| 10 |
+
|
| 11 |
+
from evoagentx.app.config import settings
|
| 12 |
+
from evoagentx.app.schemas import (
|
| 13 |
+
AgentCreate, AgentUpdate, AgentResponse,
|
| 14 |
+
WorkflowCreate, WorkflowUpdate, WorkflowResponse,
|
| 15 |
+
ExecutionCreate, ExecutionResponse,
|
| 16 |
+
PaginationParams, SearchParams,
|
| 17 |
+
Token, UserCreate, UserResponse, # UserLogin,
|
| 18 |
+
)
|
| 19 |
+
from evoagentx.app.services import AgentService, WorkflowService, WorkflowExecutionService
|
| 20 |
+
from evoagentx.app.security import (
|
| 21 |
+
create_access_token,
|
| 22 |
+
authenticate_user,
|
| 23 |
+
create_user,
|
| 24 |
+
get_current_active_user,
|
| 25 |
+
get_current_admin_user
|
| 26 |
+
)
|
| 27 |
+
from evoagentx.app.db import Database, ExecutionStatus
|
| 28 |
+
|
| 29 |
+
# Create routers for different route groups
|
| 30 |
+
auth_router = APIRouter(prefix=settings.API_PREFIX)
|
| 31 |
+
agents_router = APIRouter(prefix=settings.API_PREFIX)
|
| 32 |
+
workflows_router = APIRouter(prefix=settings.API_PREFIX)
|
| 33 |
+
executions_router = APIRouter(prefix=settings.API_PREFIX)
|
| 34 |
+
system_router = APIRouter(prefix=settings.API_PREFIX)
|
| 35 |
+
|
| 36 |
+
# Authentication Routes
|
| 37 |
+
@auth_router.post("/auth/register", response_model=UserResponse, tags=["Authentication"])
|
| 38 |
+
async def register_user(user: UserCreate):
|
| 39 |
+
"""Register a new user."""
|
| 40 |
+
return await create_user(user)
|
| 41 |
+
|
| 42 |
+
@auth_router.post("/auth/login", response_model=Token, tags=["Authentication"])
|
| 43 |
+
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
| 44 |
+
"""Login and return access token."""
|
| 45 |
+
user = await authenticate_user(form_data.username, form_data.password)
|
| 46 |
+
if not user:
|
| 47 |
+
raise HTTPException(
|
| 48 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 49 |
+
detail="Incorrect username or password",
|
| 50 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
| 54 |
+
access_token = create_access_token(
|
| 55 |
+
subject=user['email'],
|
| 56 |
+
expires_delta=access_token_expires
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
return {
|
| 60 |
+
"access_token": access_token,
|
| 61 |
+
"token_type": "bearer"
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
# Agent Routes
|
| 65 |
+
@agents_router.post("/agents", response_model=AgentResponse, tags=["Agents"])
|
| 66 |
+
async def create_agent(
|
| 67 |
+
agent: AgentCreate,
|
| 68 |
+
current_user: Dict[str, Any] = Depends(get_current_active_user)
|
| 69 |
+
):
|
| 70 |
+
"""Create a new agent."""
|
| 71 |
+
try:
|
| 72 |
+
created_agent = await AgentService.create_agent(
|
| 73 |
+
agent,
|
| 74 |
+
user_id=str(current_user['_id'])
|
| 75 |
+
)
|
| 76 |
+
# Convert the ObjectId to string before creating the response model
|
| 77 |
+
created_agent["_id"] = str(created_agent["_id"])
|
| 78 |
+
return AgentResponse(**created_agent)
|
| 79 |
+
except ValueError as e:
|
| 80 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 81 |
+
|
| 82 |
+
@agents_router.get("/agents/{agent_id}", response_model=AgentResponse, tags=["Agents"])
|
| 83 |
+
async def get_agent(
|
| 84 |
+
agent_id: str,
|
| 85 |
+
current_user: Dict[str, Any] = Depends(get_current_active_user)
|
| 86 |
+
):
|
| 87 |
+
"""Retrieve a specific agent by ID."""
|
| 88 |
+
agent = await AgentService.get_agent(agent_id)
|
| 89 |
+
if not agent:
|
| 90 |
+
raise HTTPException(status_code=404, detail="Agent not found")
|
| 91 |
+
agent["_id"] = str(agent["_id"])
|
| 92 |
+
return AgentResponse(**agent)
|
| 93 |
+
|
| 94 |
+
@agents_router.put("/agents/{agent_id}", response_model=AgentResponse, tags=["Agents"])
|
| 95 |
+
async def update_agent(
|
| 96 |
+
agent_id: str,
|
| 97 |
+
agent_update: AgentUpdate,
|
| 98 |
+
current_user: Dict[str, Any] = Depends(get_current_active_user)
|
| 99 |
+
):
|
| 100 |
+
"""Update an existing agent."""
|
| 101 |
+
try:
|
| 102 |
+
updated_agent = await AgentService.update_agent(agent_id, agent_update)
|
| 103 |
+
if not updated_agent:
|
| 104 |
+
raise HTTPException(status_code=404, detail="Agent not found")
|
| 105 |
+
updated_agent["_id"] = str(updated_agent["_id"])
|
| 106 |
+
return AgentResponse(**updated_agent)
|
| 107 |
+
except ValueError as e:
|
| 108 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 109 |
+
|
| 110 |
+
@agents_router.get("/agents", response_model=List[AgentResponse], tags=["Agents"])
|
| 111 |
+
async def list_agents(
|
| 112 |
+
pagination: PaginationParams = Depends(),
|
| 113 |
+
search: SearchParams = Depends(),
|
| 114 |
+
current_user: Dict[str, Any] = Depends(get_current_active_user)
|
| 115 |
+
):
|
| 116 |
+
"""List agents with optional pagination and search."""
|
| 117 |
+
agents, total = await AgentService.list_agents(pagination, search)
|
| 118 |
+
# Convert _id to string for each agent in the list
|
| 119 |
+
for agent in agents:
|
| 120 |
+
agent["_id"] = str(agent["_id"])
|
| 121 |
+
return [AgentResponse(**agent) for agent in agents]
|
| 122 |
+
|
| 123 |
+
@agents_router.delete("/agents/{agent_id}", status_code=204, tags=["Agents"])
|
| 124 |
+
async def delete_agent(
|
| 125 |
+
agent_id: str,
|
| 126 |
+
current_user: Dict[str, Any] = Depends(get_current_admin_user)
|
| 127 |
+
):
|
| 128 |
+
"""Delete an agent (admin-only)."""
|
| 129 |
+
try:
|
| 130 |
+
success = await AgentService.delete_agent(agent_id)
|
| 131 |
+
if not success:
|
| 132 |
+
raise HTTPException(status_code=404, detail="Agent not found")
|
| 133 |
+
return # With 204, no content is returned
|
| 134 |
+
except ValueError as e:
|
| 135 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# Workflow Routes
|
| 140 |
+
@workflows_router.post("/workflows", response_model=WorkflowResponse,status_code=201, tags=["Workflows"])
|
| 141 |
+
async def create_workflow(
|
| 142 |
+
workflow: WorkflowCreate,
|
| 143 |
+
current_user: Dict[str, Any] = Depends(get_current_active_user)
|
| 144 |
+
):
|
| 145 |
+
"""Create a new workflow."""
|
| 146 |
+
try:
|
| 147 |
+
created_workflow = await WorkflowService.create_workflow(
|
| 148 |
+
workflow,
|
| 149 |
+
user_id=str(current_user['_id'])
|
| 150 |
+
)
|
| 151 |
+
# Convert the ObjectId to string for consistency
|
| 152 |
+
created_workflow["_id"] = str(created_workflow["_id"])
|
| 153 |
+
return WorkflowResponse(**created_workflow)
|
| 154 |
+
except ValueError as e:
|
| 155 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
@workflows_router.get("/workflows/{workflow_id}", response_model=WorkflowResponse, tags=["Workflows"])
|
| 160 |
+
async def get_workflow(
|
| 161 |
+
workflow_id: str,
|
| 162 |
+
current_user: Dict[str, Any] = Depends(get_current_active_user)
|
| 163 |
+
):
|
| 164 |
+
"""Retrieve a specific workflow by ID."""
|
| 165 |
+
workflow = await WorkflowService.get_workflow(workflow_id)
|
| 166 |
+
if not workflow:
|
| 167 |
+
raise HTTPException(status_code=404, detail="Workflow not found")
|
| 168 |
+
# Convert ObjectId to string
|
| 169 |
+
workflow["_id"] = str(workflow["_id"])
|
| 170 |
+
return WorkflowResponse(**workflow)
|
| 171 |
+
|
| 172 |
+
@workflows_router.put("/workflows/{workflow_id}", response_model=WorkflowResponse, tags=["Workflows"])
|
| 173 |
+
async def update_workflow(
|
| 174 |
+
workflow_id: str,
|
| 175 |
+
workflow_update: WorkflowUpdate,
|
| 176 |
+
current_user: Dict[str, Any] = Depends(get_current_active_user)
|
| 177 |
+
):
|
| 178 |
+
"""Update an existing workflow."""
|
| 179 |
+
try:
|
| 180 |
+
updated_workflow = await WorkflowService.update_workflow(workflow_id, workflow_update)
|
| 181 |
+
if not updated_workflow:
|
| 182 |
+
raise HTTPException(status_code=404, detail="Workflow not found")
|
| 183 |
+
|
| 184 |
+
updated_workflow["_id"] = str(updated_workflow["_id"])
|
| 185 |
+
return WorkflowResponse(**updated_workflow)
|
| 186 |
+
except ValueError as e:
|
| 187 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 188 |
+
|
| 189 |
+
@workflows_router.delete("/workflows/{workflow_id}", status_code=204, tags=["Workflows"])
|
| 190 |
+
async def delete_workflow(
|
| 191 |
+
workflow_id: str,
|
| 192 |
+
current_user: Dict[str, Any] = Depends(get_current_admin_user)
|
| 193 |
+
):
|
| 194 |
+
"""Delete a workflow (admin-only)."""
|
| 195 |
+
try:
|
| 196 |
+
success = await WorkflowService.delete_workflow(workflow_id)
|
| 197 |
+
if not success:
|
| 198 |
+
raise HTTPException(status_code=404, detail="Workflow not found")
|
| 199 |
+
return Response(status_code=204)
|
| 200 |
+
except ValueError as e:
|
| 201 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
@workflows_router.get("/workflows", response_model=List[WorkflowResponse], tags=["Workflows"])
|
| 205 |
+
async def list_workflows(
|
| 206 |
+
pagination: PaginationParams = Depends(),
|
| 207 |
+
search: SearchParams = Depends(),
|
| 208 |
+
current_user: Dict[str, Any] = Depends(get_current_active_user)
|
| 209 |
+
):
|
| 210 |
+
"""List workflows with optional pagination and search."""
|
| 211 |
+
workflows, total = await WorkflowService.list_workflows(pagination, search)
|
| 212 |
+
|
| 213 |
+
# Convert ObjectId to string for each workflow
|
| 214 |
+
converted_workflows = [
|
| 215 |
+
{**workflow, "_id": str(workflow["_id"])}
|
| 216 |
+
for workflow in workflows
|
| 217 |
+
]
|
| 218 |
+
|
| 219 |
+
return [WorkflowResponse(**workflow) for workflow in converted_workflows]
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
# Workflow Execution Routes
|
| 223 |
+
@executions_router.post("/executions", response_model=ExecutionResponse, status_code=202)
|
| 224 |
+
async def create_execution(
|
| 225 |
+
execution: ExecutionCreate,
|
| 226 |
+
background_tasks: BackgroundTasks,
|
| 227 |
+
current_user: Dict[str, Any] = Depends(get_current_active_user)
|
| 228 |
+
):
|
| 229 |
+
"""Create and start a workflow execution."""
|
| 230 |
+
try:
|
| 231 |
+
execution_result = await WorkflowExecutionService.create_execution(
|
| 232 |
+
execution_data=execution,
|
| 233 |
+
user_id=str(current_user['_id'])
|
| 234 |
+
)
|
| 235 |
+
# Convert _id to string for consistency
|
| 236 |
+
execution_result["_id"] = str(execution_result["_id"])
|
| 237 |
+
return ExecutionResponse(**execution_result)
|
| 238 |
+
except ValueError as e:
|
| 239 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
@executions_router.get("/executions/{execution_id}", response_model=ExecutionResponse)
|
| 243 |
+
async def get_execution(
|
| 244 |
+
execution_id: str,
|
| 245 |
+
current_user: Dict[str, Any] = Depends(get_current_active_user)
|
| 246 |
+
):
|
| 247 |
+
"""Retrieve a specific workflow execution by ID."""
|
| 248 |
+
try:
|
| 249 |
+
execution = await WorkflowExecutionService.get_execution(execution_id)
|
| 250 |
+
if not execution:
|
| 251 |
+
raise HTTPException(status_code=404, detail="Execution not found")
|
| 252 |
+
execution["_id"] = str(execution["_id"])
|
| 253 |
+
return ExecutionResponse(**execution)
|
| 254 |
+
except ValueError as e:
|
| 255 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
@executions_router.post("/executions/{execution_id}/stop", response_model=ExecutionResponse)
|
| 259 |
+
async def stop_execution(
|
| 260 |
+
execution_id: str,
|
| 261 |
+
current_user: Dict[str, Any] = Depends(get_current_active_user)
|
| 262 |
+
):
|
| 263 |
+
"""Stop (cancel) a workflow execution."""
|
| 264 |
+
try:
|
| 265 |
+
updated_execution = await WorkflowExecutionService.update_execution_status(
|
| 266 |
+
execution_id=execution_id,
|
| 267 |
+
status=ExecutionStatus.CANCELLED
|
| 268 |
+
)
|
| 269 |
+
if not updated_execution:
|
| 270 |
+
raise HTTPException(status_code=404, detail="Execution not found")
|
| 271 |
+
# Convert ObjectId to string for consistency
|
| 272 |
+
updated_execution["_id"] = str(updated_execution["_id"])
|
| 273 |
+
return ExecutionResponse(**updated_execution)
|
| 274 |
+
except ValueError as e:
|
| 275 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
@executions_router.get("/executions", response_model=List[ExecutionResponse])
|
| 279 |
+
async def list_executions(
|
| 280 |
+
pagination: PaginationParams = Depends(),
|
| 281 |
+
search: SearchParams = Depends(),
|
| 282 |
+
current_user: Dict[str, Any] = Depends(get_current_active_user)
|
| 283 |
+
):
|
| 284 |
+
"""List workflow executions with optional pagination and search."""
|
| 285 |
+
executions, total = await WorkflowExecutionService.list_executions(
|
| 286 |
+
params=pagination,
|
| 287 |
+
search=search
|
| 288 |
+
)
|
| 289 |
+
# Convert _id to string for each execution
|
| 290 |
+
for exec_item in executions:
|
| 291 |
+
exec_item["_id"] = str(exec_item["_id"])
|
| 292 |
+
return [ExecutionResponse(**exec_item) for exec_item in executions]
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
@executions_router.get("/executions/{execution_id}/logs", response_model=List[Dict[str, Any]])
|
| 296 |
+
async def get_execution_logs(
|
| 297 |
+
execution_id: str,
|
| 298 |
+
pagination: PaginationParams = Depends(),
|
| 299 |
+
current_user: Dict[str, Any] = Depends(get_current_active_user)
|
| 300 |
+
):
|
| 301 |
+
"""Retrieve logs for a specific execution."""
|
| 302 |
+
logs, total = await WorkflowExecutionService.get_execution_logs(execution_id, params=pagination)
|
| 303 |
+
# Convert _id in each log entry to string
|
| 304 |
+
for log in logs:
|
| 305 |
+
log["_id"] = str(log["_id"])
|
| 306 |
+
return logs
|
| 307 |
+
|
| 308 |
+
# Health Check Route
|
| 309 |
+
@system_router.get("/health", tags=["System"])
|
| 310 |
+
async def health_check():
|
| 311 |
+
"""Simple health check endpoint."""
|
| 312 |
+
try:
|
| 313 |
+
# You can add more comprehensive health checks here
|
| 314 |
+
await Database.db.command('ping')
|
| 315 |
+
return {
|
| 316 |
+
"status": "healthy",
|
| 317 |
+
"version": "1.0.0"
|
| 318 |
+
}
|
| 319 |
+
except Exception as e:
|
| 320 |
+
raise HTTPException(status_code=500, detail=f"Database connection error: {str(e)}")
|
| 321 |
+
|
| 322 |
+
# Export the routers
|
| 323 |
+
__all__ = [
|
| 324 |
+
'auth_router',
|
| 325 |
+
'agents_router',
|
| 326 |
+
'workflows_router',
|
| 327 |
+
'executions_router',
|
| 328 |
+
'system_router'
|
| 329 |
+
]
|
evoagentx/app/app.env
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# .env file
|
| 2 |
+
APP_NAME=EvoAgentX
|
| 3 |
+
DEBUG=True
|
| 4 |
+
API_PREFIX=/api/v1
|
| 5 |
+
HOST=0.0.0.0
|
| 6 |
+
PORT=8000
|
| 7 |
+
|
| 8 |
+
# MongoDB settings
|
| 9 |
+
MONGODB_URL=mongodb+srv://eax:eax@cluster0.1lkbi0y.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0
|
| 10 |
+
MONGODB_DB_NAME=evoagentx
|
| 11 |
+
|
| 12 |
+
# JWT Authentication
|
| 13 |
+
SECRET_KEY=your-secret-key
|
| 14 |
+
ACCESS_TOKEN_EXPIRE_MINUTES=30
|
| 15 |
+
ALGORITHM=HS256
|
| 16 |
+
|
| 17 |
+
# Logging
|
| 18 |
+
LOG_LEVEL=INFO
|
| 19 |
+
|
| 20 |
+
# CORS settings
|
| 21 |
+
ALLOWED_HOSTS: List[str]
|
| 22 |
+
CORS_ORIGINS: List[str]
|
evoagentx/app/config.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration settings for the EvoAgentX application.
|
| 3 |
+
"""
|
| 4 |
+
# import os
|
| 5 |
+
from pydantic import BaseModel, Field, validator
|
| 6 |
+
from pydantic_settings import BaseSettings
|
| 7 |
+
from typing import Optional, Dict, Any, List
|
| 8 |
+
|
| 9 |
+
class Settings(BaseSettings):
|
| 10 |
+
# Application settings
|
| 11 |
+
APP_NAME: str
|
| 12 |
+
DEBUG: bool
|
| 13 |
+
API_PREFIX: str
|
| 14 |
+
HOST: str
|
| 15 |
+
PORT: int
|
| 16 |
+
|
| 17 |
+
# MongoDB settings
|
| 18 |
+
MONGODB_URL: str
|
| 19 |
+
MONGODB_DB_NAME: str
|
| 20 |
+
|
| 21 |
+
# JWT Authentication
|
| 22 |
+
SECRET_KEY: str
|
| 23 |
+
ACCESS_TOKEN_EXPIRE_MINUTES: int
|
| 24 |
+
ALGORITHM: str
|
| 25 |
+
|
| 26 |
+
# Logging configuration
|
| 27 |
+
LOG_LEVEL: str
|
| 28 |
+
|
| 29 |
+
# Add CORS settings
|
| 30 |
+
CORS_ORIGINS: List[str] = ["http://localhost:3000", "http://localhost:8000"]
|
| 31 |
+
CORS_ALLOW_CREDENTIALS: bool = True
|
| 32 |
+
|
| 33 |
+
class Config:
|
| 34 |
+
env_file = ".env"
|
| 35 |
+
case_sensitive = True
|
| 36 |
+
env_delimiter = ","
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# Global settings instance
|
| 41 |
+
settings = Settings()
|
| 42 |
+
|
| 43 |
+
# Agent and Workflow configuration
|
| 44 |
+
class AgentConfig(BaseModel):
|
| 45 |
+
"""Base configuration for an LLM agent."""
|
| 46 |
+
model_name: str
|
| 47 |
+
temperature: float = 0.7
|
| 48 |
+
max_tokens: int = 2048
|
| 49 |
+
api_key_env_var: Optional[str] = None
|
| 50 |
+
system_prompt: Optional[str] = None
|
| 51 |
+
extra_params: Dict[str, Any] = Field(default_factory=dict)
|
| 52 |
+
|
| 53 |
+
@validator('temperature')
|
| 54 |
+
def validate_temperature(cls, v):
|
| 55 |
+
if v < 0 or v > 1:
|
| 56 |
+
raise ValueError('Temperature must be between 0 and 1')
|
| 57 |
+
return v
|
| 58 |
+
|
| 59 |
+
class WorkflowStepConfig(BaseModel):
|
| 60 |
+
"""Configuration for a single step in a workflow."""
|
| 61 |
+
step_id: str
|
| 62 |
+
agent_id: str
|
| 63 |
+
action: str
|
| 64 |
+
input_mapping: Dict[str, str] = Field(default_factory=dict)
|
| 65 |
+
output_mapping: Dict[str, str] = Field(default_factory=dict)
|
| 66 |
+
timeout_seconds: int = 300
|
| 67 |
+
retry_count: int = 3
|
| 68 |
+
|
| 69 |
+
class WorkflowConfig(BaseModel):
|
| 70 |
+
"""Configuration for a workflow composed of agent steps."""
|
| 71 |
+
name: str
|
| 72 |
+
description: Optional[str] = None
|
| 73 |
+
steps: List[WorkflowStepConfig]
|
| 74 |
+
parallel_execution: bool = False
|
| 75 |
+
timeout_seconds: int = 3600 # Default to 1 hour total timeout
|
| 76 |
+
|
| 77 |
+
class ExecutionConfig(BaseModel):
|
| 78 |
+
"""Configuration for a workflow execution."""
|
| 79 |
+
workflow_id: str
|
| 80 |
+
input_params: Dict[str, Any] = Field(default_factory=dict)
|
| 81 |
+
user_id: Optional[str] = None
|
| 82 |
+
priority: int = 1 # Higher number means higher priority
|
| 83 |
+
callback_url: Optional[str] = None
|
evoagentx/app/db.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Database connection and models for EvoAgentX.
|
| 3 |
+
"""
|
| 4 |
+
# import asyncio
|
| 5 |
+
import logging
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from enum import Enum
|
| 8 |
+
from typing import Optional, List, Dict, Any # , Union
|
| 9 |
+
from motor.motor_asyncio import AsyncIOMotorClient
|
| 10 |
+
from pymongo import ASCENDING, TEXT
|
| 11 |
+
from pydantic_core import core_schema
|
| 12 |
+
from bson import ObjectId
|
| 13 |
+
from pydantic import GetCoreSchemaHandler
|
| 14 |
+
from pydantic import Field, BaseModel
|
| 15 |
+
from evoagentx.app.config import settings
|
| 16 |
+
|
| 17 |
+
# Setup logger
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
# Custom PyObjectId for MongoDB ObjectId compatibility with Pydantic
|
| 21 |
+
class PyObjectId(ObjectId):
|
| 22 |
+
@classmethod
|
| 23 |
+
def __get_pydantic_core_schema__(cls, source_type, handler: GetCoreSchemaHandler):
|
| 24 |
+
return core_schema.no_info_after_validator_function(cls.validate, core_schema.str_schema())
|
| 25 |
+
|
| 26 |
+
@classmethod
|
| 27 |
+
def validate(cls, v):
|
| 28 |
+
if not ObjectId.is_valid(v):
|
| 29 |
+
raise ValueError("Invalid ObjectId")
|
| 30 |
+
return ObjectId(v)
|
| 31 |
+
|
| 32 |
+
# Base model with ObjectId handling
|
| 33 |
+
class MongoBaseModel(BaseModel):
|
| 34 |
+
id: Optional[PyObjectId] = Field(alias="_id", default=None)
|
| 35 |
+
|
| 36 |
+
model_config = {
|
| 37 |
+
"protected_namespaces": (),
|
| 38 |
+
"populate_by_name": True, # Replace `allow_population_by_field_name`
|
| 39 |
+
"arbitrary_types_allowed": True, # Keep custom types like ObjectId
|
| 40 |
+
"json_encoders": {
|
| 41 |
+
ObjectId: str # Ensure ObjectId is serialized as a string
|
| 42 |
+
}
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
# Status Enums
|
| 46 |
+
class AgentStatus(str, Enum):
|
| 47 |
+
CREATED = "created"
|
| 48 |
+
ACTIVE = "active"
|
| 49 |
+
INACTIVE = "inactive"
|
| 50 |
+
ERROR = "error"
|
| 51 |
+
|
| 52 |
+
class WorkflowStatus(str, Enum):
|
| 53 |
+
CREATED = "created"
|
| 54 |
+
RUNNING = "running"
|
| 55 |
+
COMPLETED = "completed"
|
| 56 |
+
FAILED = "failed"
|
| 57 |
+
CANCELLED = "cancelled"
|
| 58 |
+
|
| 59 |
+
class ExecutionStatus(str, Enum):
|
| 60 |
+
PENDING = "pending"
|
| 61 |
+
RUNNING = "running"
|
| 62 |
+
COMPLETED = "completed"
|
| 63 |
+
FAILED = "failed"
|
| 64 |
+
TIMEOUT = "timeout"
|
| 65 |
+
CANCELLED = "cancelled"
|
| 66 |
+
|
| 67 |
+
# Database Models
|
| 68 |
+
class Agent(MongoBaseModel):
|
| 69 |
+
id: str = Field(..., alias="_id")
|
| 70 |
+
name: str
|
| 71 |
+
description: Optional[str] = None
|
| 72 |
+
config: Dict[str, Any]
|
| 73 |
+
state: Dict[str, Any] = Field(default_factory=dict)
|
| 74 |
+
runtime_params: Dict[str, Any] = Field(default_factory=dict)
|
| 75 |
+
status: AgentStatus = AgentStatus.CREATED
|
| 76 |
+
created_at: datetime = Field(default_factory=datetime.utcnow)
|
| 77 |
+
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
| 78 |
+
created_by: Optional[str] = None
|
| 79 |
+
tags: List[str] = Field(default_factory=list)
|
| 80 |
+
|
| 81 |
+
class Workflow(MongoBaseModel):
|
| 82 |
+
id: str = Field(..., alias="_id")
|
| 83 |
+
name: str
|
| 84 |
+
description: Optional[str] = None
|
| 85 |
+
definition: Dict[str, Any]
|
| 86 |
+
agent_ids: List[str] = Field(default_factory=list)
|
| 87 |
+
status: WorkflowStatus = WorkflowStatus.CREATED
|
| 88 |
+
created_at: datetime = Field(default_factory=datetime.utcnow)
|
| 89 |
+
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
| 90 |
+
created_by: Optional[str] = None
|
| 91 |
+
tags: List[str] = Field(default_factory=list)
|
| 92 |
+
version: int = 1
|
| 93 |
+
|
| 94 |
+
class ExecutionLog(MongoBaseModel):
|
| 95 |
+
workflow_id: str
|
| 96 |
+
execution_id: str
|
| 97 |
+
step_id: Optional[str] = None
|
| 98 |
+
agent_id: Optional[str] = None
|
| 99 |
+
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
| 100 |
+
level: str = "INFO"
|
| 101 |
+
message: str
|
| 102 |
+
details: Dict[str, Any] = Field(default_factory=dict)
|
| 103 |
+
|
| 104 |
+
class WorkflowExecution(MongoBaseModel):
|
| 105 |
+
workflow_id: str
|
| 106 |
+
status: ExecutionStatus = ExecutionStatus.PENDING
|
| 107 |
+
start_time: Optional[datetime] = None
|
| 108 |
+
end_time: Optional[datetime] = None
|
| 109 |
+
input_params: Dict[str, Any] = Field(default_factory=dict)
|
| 110 |
+
results: Dict[str, Any] = Field(default_factory=dict)
|
| 111 |
+
created_by: Optional[str] = None
|
| 112 |
+
step_results: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
|
| 113 |
+
current_step: Optional[str] = None
|
| 114 |
+
error_message: Optional[str] = None
|
| 115 |
+
created_at: datetime = Field(default_factory=datetime.utcnow)
|
| 116 |
+
|
| 117 |
+
# Database client
|
| 118 |
+
class Database:
|
| 119 |
+
client: AsyncIOMotorClient = None
|
| 120 |
+
db = None
|
| 121 |
+
|
| 122 |
+
# Collections
|
| 123 |
+
agents = None
|
| 124 |
+
workflows = None
|
| 125 |
+
executions = None
|
| 126 |
+
logs = None
|
| 127 |
+
|
| 128 |
+
@classmethod
|
| 129 |
+
async def connect(cls):
|
| 130 |
+
"""Connect to MongoDB"""
|
| 131 |
+
logger.info(f"Connecting to MongoDB at {settings.MONGODB_URL}...")
|
| 132 |
+
cls.client = AsyncIOMotorClient(settings.MONGODB_URL)
|
| 133 |
+
cls.db = cls.client[settings.MONGODB_DB_NAME]
|
| 134 |
+
|
| 135 |
+
# Set up collections
|
| 136 |
+
cls.agents = cls.db.agents
|
| 137 |
+
cls.workflows = cls.db.workflows
|
| 138 |
+
cls.executions = cls.db.workflow_executions
|
| 139 |
+
cls.logs = cls.db.execution_logs
|
| 140 |
+
|
| 141 |
+
# Create indexes
|
| 142 |
+
await cls._create_indexes()
|
| 143 |
+
|
| 144 |
+
logger.info("Connected to MongoDB successfully")
|
| 145 |
+
|
| 146 |
+
@classmethod
|
| 147 |
+
async def disconnect(cls):
|
| 148 |
+
"""Disconnect from MongoDB"""
|
| 149 |
+
if cls.client:
|
| 150 |
+
cls.client.close()
|
| 151 |
+
logger.info("Disconnected from MongoDB")
|
| 152 |
+
|
| 153 |
+
@classmethod
|
| 154 |
+
async def _create_indexes(cls):
|
| 155 |
+
"""Create indexes for collections"""
|
| 156 |
+
# Agent indexes
|
| 157 |
+
await cls.agents.create_index([("name", ASCENDING)], unique=True)
|
| 158 |
+
await cls.agents.create_index([("name", TEXT), ("description", TEXT)])
|
| 159 |
+
await cls.agents.create_index([("created_at", ASCENDING)])
|
| 160 |
+
await cls.agents.create_index([("tags", ASCENDING)])
|
| 161 |
+
|
| 162 |
+
# Workflow indexes
|
| 163 |
+
await cls.workflows.create_index([("name", ASCENDING)])
|
| 164 |
+
await cls.workflows.create_index([("name", TEXT), ("description", TEXT)])
|
| 165 |
+
await cls.workflows.create_index([("created_at", ASCENDING)])
|
| 166 |
+
await cls.workflows.create_index([("agent_ids", ASCENDING)])
|
| 167 |
+
await cls.workflows.create_index([("tags", ASCENDING)])
|
| 168 |
+
|
| 169 |
+
# Execution indexes
|
| 170 |
+
await cls.executions.create_index([("workflow_id", ASCENDING)])
|
| 171 |
+
await cls.executions.create_index([("created_at", ASCENDING)])
|
| 172 |
+
await cls.executions.create_index([("status", ASCENDING)])
|
| 173 |
+
|
| 174 |
+
# Log indexes
|
| 175 |
+
await cls.logs.create_index([("execution_id", ASCENDING)])
|
| 176 |
+
await cls.logs.create_index([("timestamp", ASCENDING)])
|
| 177 |
+
await cls.logs.create_index([("workflow_id", ASCENDING), ("execution_id", ASCENDING)])
|
evoagentx/app/main.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Main application entry point for EvoAgentX.
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
# import asyncio
|
| 6 |
+
from contextlib import asynccontextmanager
|
| 7 |
+
|
| 8 |
+
import uvicorn
|
| 9 |
+
from fastapi import FastAPI, Request
|
| 10 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 11 |
+
from fastapi.responses import JSONResponse
|
| 12 |
+
from fastapi.exceptions import RequestValidationError, HTTPException
|
| 13 |
+
|
| 14 |
+
from evoagentx.app.config import settings
|
| 15 |
+
from evoagentx.app.db import Database
|
| 16 |
+
from evoagentx.app.security import init_users_collection
|
| 17 |
+
from evoagentx.app.api import (
|
| 18 |
+
auth_router,
|
| 19 |
+
agents_router,
|
| 20 |
+
workflows_router,
|
| 21 |
+
executions_router,
|
| 22 |
+
system_router
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# Configure logging
|
| 26 |
+
logging.basicConfig(
|
| 27 |
+
level=getattr(logging, settings.LOG_LEVEL.upper()),
|
| 28 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 29 |
+
)
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
# Lifespan context manager for startup and shutdown events
|
| 33 |
+
@asynccontextmanager
|
| 34 |
+
async def lifespan(app: FastAPI):
|
| 35 |
+
"""
|
| 36 |
+
Async context manager to handle application startup and shutdown events.
|
| 37 |
+
"""
|
| 38 |
+
# Startup tasks
|
| 39 |
+
try:
|
| 40 |
+
# Connect to database
|
| 41 |
+
await Database.connect()
|
| 42 |
+
|
| 43 |
+
# Initialize users collection and create admin user if not exists
|
| 44 |
+
await init_users_collection()
|
| 45 |
+
|
| 46 |
+
logger.info("Application startup completed successfully")
|
| 47 |
+
yield
|
| 48 |
+
except Exception as e:
|
| 49 |
+
logger.error(f"Error during application startup: {e}")
|
| 50 |
+
raise
|
| 51 |
+
finally:
|
| 52 |
+
# Shutdown tasks
|
| 53 |
+
try:
|
| 54 |
+
await Database.disconnect()
|
| 55 |
+
logger.info("Application shutdown completed successfully")
|
| 56 |
+
except Exception as e:
|
| 57 |
+
logger.error(f"Error during application shutdown: {e}")
|
| 58 |
+
|
| 59 |
+
# Create FastAPI application
|
| 60 |
+
app = FastAPI(
|
| 61 |
+
title="EvoAgentX API",
|
| 62 |
+
description="API for EvoAgentX platform",
|
| 63 |
+
version="1.0.0",
|
| 64 |
+
lifespan=lifespan,
|
| 65 |
+
docs_url="/docs",
|
| 66 |
+
redoc_url="/redoc"
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# Configure CORS
|
| 70 |
+
app.add_middleware(
|
| 71 |
+
CORSMiddleware,
|
| 72 |
+
allow_origins=settings.CORS_ORIGINS,
|
| 73 |
+
allow_credentials=True,
|
| 74 |
+
allow_methods=["*"],
|
| 75 |
+
allow_headers=["*"],
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Include routers
|
| 79 |
+
app.include_router(auth_router)
|
| 80 |
+
app.include_router(agents_router)
|
| 81 |
+
app.include_router(workflows_router)
|
| 82 |
+
app.include_router(executions_router)
|
| 83 |
+
app.include_router(system_router)
|
| 84 |
+
|
| 85 |
+
# Global exception handlers
|
| 86 |
+
@app.exception_handler(RequestValidationError)
|
| 87 |
+
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
| 88 |
+
"""
|
| 89 |
+
Custom validation error handler to provide more detailed error responses.
|
| 90 |
+
"""
|
| 91 |
+
return JSONResponse(
|
| 92 |
+
status_code=422,
|
| 93 |
+
content={
|
| 94 |
+
"status": "error",
|
| 95 |
+
"message": "Validation error",
|
| 96 |
+
"errors": exc.errors()
|
| 97 |
+
}
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
@app.exception_handler(HTTPException)
|
| 101 |
+
async def http_exception_handler(request: Request, exc: HTTPException):
|
| 102 |
+
"""
|
| 103 |
+
Custom HTTP exception handler to provide consistent error responses.
|
| 104 |
+
"""
|
| 105 |
+
return JSONResponse(
|
| 106 |
+
status_code=exc.status_code,
|
| 107 |
+
content={
|
| 108 |
+
"status": "error",
|
| 109 |
+
"message": exc.detail
|
| 110 |
+
}
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# Root endpoint for health check
|
| 114 |
+
@app.get("/")
|
| 115 |
+
async def root():
|
| 116 |
+
"""
|
| 117 |
+
Root endpoint for application health check.
|
| 118 |
+
"""
|
| 119 |
+
return {
|
| 120 |
+
"app_name": settings.APP_NAME,
|
| 121 |
+
"status": "healthy",
|
| 122 |
+
"version": "0.1.0"
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
# Workflow logging and monitoring endpoint
|
| 126 |
+
@app.get("/metrics")
|
| 127 |
+
async def get_metrics():
|
| 128 |
+
"""
|
| 129 |
+
Endpoint to retrieve system metrics and stats.
|
| 130 |
+
"""
|
| 131 |
+
# Collect metrics from different services
|
| 132 |
+
try:
|
| 133 |
+
# Collect agent metrics
|
| 134 |
+
total_agents = await Database.agents.count_documents({})
|
| 135 |
+
active_agents = await Database.agents.count_documents({"status": "active"})
|
| 136 |
+
|
| 137 |
+
# Collect workflow metrics
|
| 138 |
+
total_workflows = await Database.workflows.count_documents({})
|
| 139 |
+
running_workflows = await Database.workflows.count_documents({"status": "running"})
|
| 140 |
+
|
| 141 |
+
# Collect execution metrics
|
| 142 |
+
total_executions = await Database.executions.count_documents({})
|
| 143 |
+
failed_executions = await Database.executions.count_documents({"status": "failed"})
|
| 144 |
+
|
| 145 |
+
return {
|
| 146 |
+
"agents": {
|
| 147 |
+
"total": total_agents,
|
| 148 |
+
"active": active_agents
|
| 149 |
+
},
|
| 150 |
+
"workflows": {
|
| 151 |
+
"total": total_workflows,
|
| 152 |
+
"running": running_workflows
|
| 153 |
+
},
|
| 154 |
+
"executions": {
|
| 155 |
+
"total": total_executions,
|
| 156 |
+
"failed": failed_executions
|
| 157 |
+
}
|
| 158 |
+
}
|
| 159 |
+
except Exception as e:
|
| 160 |
+
logger.error(f"Error retrieving metrics: {e}")
|
| 161 |
+
return {
|
| 162 |
+
"status": "error",
|
| 163 |
+
"message": "Unable to retrieve metrics"
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
# Run the application if this script is executed directly
|
| 167 |
+
if __name__ == "__main__":
|
| 168 |
+
# Configuration for running the server
|
| 169 |
+
uvicorn_config = {
|
| 170 |
+
"host": settings.HOST,
|
| 171 |
+
"port": settings.PORT,
|
| 172 |
+
"reload": settings.DEBUG,
|
| 173 |
+
"log_level": settings.LOG_LEVEL.lower()
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
# Start the server
|
| 177 |
+
uvicorn.run("evoagentx.app.main:app", **uvicorn_config)
|
evoagentx/app/requirements.txt
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FastAPI and ASGI server
|
| 2 |
+
fastapi==0.115.10
|
| 3 |
+
uvicorn==0.22.0
|
| 4 |
+
pydantic==2.7.0
|
| 5 |
+
pydantic-settings==2.8.1
|
| 6 |
+
|
| 7 |
+
# MongoDB ODM
|
| 8 |
+
motor==3.3.1
|
| 9 |
+
pymongo==4.6.0
|
| 10 |
+
sqlalchemy-2.0.38
|
| 11 |
+
|
| 12 |
+
python-jose==3.3.0
|
| 13 |
+
passlib==1.7.4
|
| 14 |
+
python-multipart==0.0.6
|
| 15 |
+
bcrypt==4.0.1
|
| 16 |
+
celery==5.3.4
|
| 17 |
+
redis==5.0.0
|
| 18 |
+
pytest==7.4.2
|
| 19 |
+
pytest-asyncio==0.21.0
|
| 20 |
+
httpx==0.24.1
|
| 21 |
+
asgi-lifespan==1.0.1
|
| 22 |
+
python-dotenv==1.0.0
|
| 23 |
+
loguru==0.7.3
|
evoagentx/app/schemas.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pydantic models for request/response validation in the EvoAgentX API.
|
| 3 |
+
"""
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from typing import Optional, List, Dict, Any # , Union
|
| 6 |
+
from pydantic import BaseModel, Field # , validator
|
| 7 |
+
from bson import ObjectId
|
| 8 |
+
from evoagentx.app.db import AgentStatus, WorkflowStatus, ExecutionStatus
|
| 9 |
+
|
| 10 |
+
# Helper for ObjectId validation
|
| 11 |
+
class PyObjectId(ObjectId):
|
| 12 |
+
@classmethod
|
| 13 |
+
def __get_validators__(cls):
|
| 14 |
+
yield cls.validate
|
| 15 |
+
|
| 16 |
+
@classmethod
|
| 17 |
+
def validate(cls, v):
|
| 18 |
+
if not ObjectId.is_valid(v):
|
| 19 |
+
raise ValueError("Invalid ObjectId")
|
| 20 |
+
return ObjectId(v)
|
| 21 |
+
|
| 22 |
+
@classmethod
|
| 23 |
+
def __modify_schema__(cls, field_schema):
|
| 24 |
+
field_schema.update(type="string")
|
| 25 |
+
|
| 26 |
+
# Base Schema Models
|
| 27 |
+
class BaseSchema(BaseModel):
|
| 28 |
+
class Config:
|
| 29 |
+
allow_population_by_field_name = True
|
| 30 |
+
arbitrary_types_allowed = True
|
| 31 |
+
json_encoders = {
|
| 32 |
+
ObjectId: str,
|
| 33 |
+
datetime: lambda dt: dt.isoformat()
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
# Agent Schemas
|
| 37 |
+
class AgentCreate(BaseSchema):
|
| 38 |
+
name: str
|
| 39 |
+
description: Optional[str] = None
|
| 40 |
+
config: Dict[str, Any]
|
| 41 |
+
runtime_params: Dict[str, Any] = Field(default_factory=dict)
|
| 42 |
+
tags: List[str] = Field(default_factory=list)
|
| 43 |
+
|
| 44 |
+
class AgentUpdate(BaseSchema):
|
| 45 |
+
name: Optional[str] = None
|
| 46 |
+
description: Optional[str] = None
|
| 47 |
+
config: Optional[Dict[str, Any]] = None
|
| 48 |
+
runtime_params: Optional[Dict[str, Any]] = None
|
| 49 |
+
status: Optional[AgentStatus] = None
|
| 50 |
+
tags: Optional[List[str]] = None
|
| 51 |
+
|
| 52 |
+
class AgentResponse(BaseSchema):
|
| 53 |
+
id: str = Field(..., alias="_id")
|
| 54 |
+
name: str
|
| 55 |
+
description: Optional[str] = None
|
| 56 |
+
config: Dict[str, Any]
|
| 57 |
+
status: AgentStatus
|
| 58 |
+
runtime_params: Dict[str, Any]
|
| 59 |
+
created_at: datetime
|
| 60 |
+
updated_at: datetime
|
| 61 |
+
created_by: Optional[str] = None
|
| 62 |
+
tags: List[str]
|
| 63 |
+
|
| 64 |
+
# Workflow Schemas
|
| 65 |
+
class WorkflowStepDefinition(BaseSchema):
|
| 66 |
+
step_id: str
|
| 67 |
+
agent_id: str
|
| 68 |
+
action: str
|
| 69 |
+
input_mapping: Dict[str, str] = Field(default_factory=dict)
|
| 70 |
+
output_mapping: Dict[str, str] = Field(default_factory=dict)
|
| 71 |
+
timeout_seconds: int = 300
|
| 72 |
+
retry_count: int = 3
|
| 73 |
+
depends_on: List[str] = Field(default_factory=list)
|
| 74 |
+
|
| 75 |
+
class WorkflowCreate(BaseSchema):
|
| 76 |
+
name: str
|
| 77 |
+
description: Optional[str] = None
|
| 78 |
+
definition: Dict[str, Any]
|
| 79 |
+
tags: List[str] = Field(default_factory=list)
|
| 80 |
+
|
| 81 |
+
class WorkflowUpdate(BaseSchema):
|
| 82 |
+
name: Optional[str] = None
|
| 83 |
+
description: Optional[str] = None
|
| 84 |
+
definition: Optional[Dict[str, Any]] = None
|
| 85 |
+
status: Optional[WorkflowStatus] = None
|
| 86 |
+
tags: Optional[List[str]] = None
|
| 87 |
+
|
| 88 |
+
class WorkflowResponse(BaseSchema):
|
| 89 |
+
id: str = Field(..., alias="_id")
|
| 90 |
+
name: str
|
| 91 |
+
description: Optional[str] = None
|
| 92 |
+
definition: Dict[str, Any]
|
| 93 |
+
agent_ids: List[str]
|
| 94 |
+
status: WorkflowStatus
|
| 95 |
+
created_at: datetime
|
| 96 |
+
updated_at: datetime
|
| 97 |
+
created_by: Optional[str] = None
|
| 98 |
+
tags: List[str]
|
| 99 |
+
version: int
|
| 100 |
+
|
| 101 |
+
# Execution Schemas
|
| 102 |
+
class ExecutionCreate(BaseSchema):
|
| 103 |
+
workflow_id: str
|
| 104 |
+
input_params: Dict[str, Any] = Field(default_factory=dict)
|
| 105 |
+
callback_url: Optional[str] = None
|
| 106 |
+
|
| 107 |
+
class ExecutionResponse(BaseSchema):
|
| 108 |
+
id: str = Field(..., alias="_id")
|
| 109 |
+
workflow_id: str
|
| 110 |
+
status: ExecutionStatus
|
| 111 |
+
start_time: Optional[datetime] = None
|
| 112 |
+
end_time: Optional[datetime] = None
|
| 113 |
+
input_params: Dict[str, Any]
|
| 114 |
+
results: Dict[str, Any]
|
| 115 |
+
created_by: Optional[str] = None
|
| 116 |
+
step_results: Dict[str, Dict[str, Any]]
|
| 117 |
+
current_step: Optional[str] = None
|
| 118 |
+
error_message: Optional[str] = None
|
| 119 |
+
created_at: datetime
|
| 120 |
+
|
| 121 |
+
class ExecutionLogResponse(BaseSchema):
|
| 122 |
+
id: str = Field(..., alias="_id")
|
| 123 |
+
workflow_id: str
|
| 124 |
+
execution_id: str
|
| 125 |
+
step_id: Optional[str] = None
|
| 126 |
+
agent_id: Optional[str] = None
|
| 127 |
+
timestamp: datetime
|
| 128 |
+
level: str
|
| 129 |
+
message: str
|
| 130 |
+
details: Dict[str, Any]
|
| 131 |
+
|
| 132 |
+
# User auth schemas
|
| 133 |
+
class Token(BaseSchema):
|
| 134 |
+
access_token: str
|
| 135 |
+
token_type: str
|
| 136 |
+
|
| 137 |
+
class TokenPayload(BaseSchema):
|
| 138 |
+
sub: Optional[str] = None
|
| 139 |
+
exp: Optional[int] = None
|
| 140 |
+
|
| 141 |
+
class UserCreate(BaseSchema):
|
| 142 |
+
email: str
|
| 143 |
+
password: str
|
| 144 |
+
full_name: Optional[str] = None
|
| 145 |
+
|
| 146 |
+
class UserLogin(BaseSchema):
|
| 147 |
+
email: str
|
| 148 |
+
password: str
|
| 149 |
+
|
| 150 |
+
class UserResponse(BaseSchema):
|
| 151 |
+
id: str = Field(..., alias="_id")
|
| 152 |
+
email: str
|
| 153 |
+
full_name: Optional[str] = None
|
| 154 |
+
is_active: bool
|
| 155 |
+
is_admin: bool
|
| 156 |
+
created_at: datetime
|
| 157 |
+
|
| 158 |
+
# Query parameters
|
| 159 |
+
class PaginationParams(BaseSchema):
|
| 160 |
+
skip: int = 0
|
| 161 |
+
limit: int = 100
|
| 162 |
+
|
| 163 |
+
class SearchParams(BaseSchema):
|
| 164 |
+
query: Optional[str] = None
|
| 165 |
+
tags: Optional[List[str]] = None
|
| 166 |
+
status: Optional[str] = None
|
| 167 |
+
start_date: Optional[datetime] = None
|
| 168 |
+
end_date: Optional[datetime] = None
|
evoagentx/app/security.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Security components for authentication and authorization.
|
| 3 |
+
"""
|
| 4 |
+
import jwt
|
| 5 |
+
from datetime import datetime, timedelta
|
| 6 |
+
from typing import Optional, Dict, Any # , List
|
| 7 |
+
from passlib.context import CryptContext
|
| 8 |
+
from fastapi import Depends, HTTPException, status
|
| 9 |
+
from fastapi.security import OAuth2PasswordBearer
|
| 10 |
+
from pydantic import BaseModel, ValidationError
|
| 11 |
+
from pymongo.errors import DuplicateKeyError
|
| 12 |
+
from bson import ObjectId
|
| 13 |
+
|
| 14 |
+
from evoagentx.app.config import settings
|
| 15 |
+
from evoagentx.app.db import Database
|
| 16 |
+
from evoagentx.app.schemas import TokenPayload, UserCreate, UserResponse
|
| 17 |
+
|
| 18 |
+
# Password hashing
|
| 19 |
+
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
| 20 |
+
|
| 21 |
+
# OAuth2 scheme for token authentication
|
| 22 |
+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_PREFIX}/auth/login")
|
| 23 |
+
|
| 24 |
+
# User model for database
|
| 25 |
+
class UserInDB(BaseModel):
|
| 26 |
+
_id: Optional[ObjectId] = None
|
| 27 |
+
email: str
|
| 28 |
+
hashed_password: str
|
| 29 |
+
full_name: Optional[str] = None
|
| 30 |
+
is_active: bool = True
|
| 31 |
+
is_admin: bool = False
|
| 32 |
+
created_at: datetime = datetime.utcnow()
|
| 33 |
+
|
| 34 |
+
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
| 35 |
+
"""Verify a password against a hash."""
|
| 36 |
+
return pwd_context.verify(plain_password, hashed_password)
|
| 37 |
+
|
| 38 |
+
def get_password_hash(password: str) -> str:
|
| 39 |
+
"""Hash a password for storing."""
|
| 40 |
+
return pwd_context.hash(password)
|
| 41 |
+
|
| 42 |
+
async def get_user_by_email(email: str) -> Optional[Dict[str, Any]]:
|
| 43 |
+
"""Get a user by email."""
|
| 44 |
+
return await Database.db.users.find_one({"email": email})
|
| 45 |
+
|
| 46 |
+
async def authenticate_user(email: str, password: str) -> Optional[Dict[str, Any]]:
|
| 47 |
+
"""Authenticate a user by email and password."""
|
| 48 |
+
user = await get_user_by_email(email)
|
| 49 |
+
if not user:
|
| 50 |
+
return None
|
| 51 |
+
if not verify_password(password, user["hashed_password"]):
|
| 52 |
+
return None
|
| 53 |
+
if not user.get("is_active", True):
|
| 54 |
+
return None
|
| 55 |
+
return user
|
| 56 |
+
|
| 57 |
+
async def create_user(user_create: UserCreate) -> UserResponse:
|
| 58 |
+
"""Create a new user."""
|
| 59 |
+
# Check if user already exists
|
| 60 |
+
existing_user = await get_user_by_email(user_create.email)
|
| 61 |
+
if existing_user:
|
| 62 |
+
raise HTTPException(
|
| 63 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 64 |
+
detail="Email already registered"
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# Create new user
|
| 68 |
+
user_dict = user_create.dict()
|
| 69 |
+
hashed_password = get_password_hash(user_dict.pop("password"))
|
| 70 |
+
|
| 71 |
+
new_user = {
|
| 72 |
+
"email": user_dict["email"],
|
| 73 |
+
"hashed_password": hashed_password,
|
| 74 |
+
"full_name": user_dict.get("full_name"),
|
| 75 |
+
"is_active": True,
|
| 76 |
+
"is_admin": False,
|
| 77 |
+
"created_at": datetime.utcnow()
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
try:
|
| 81 |
+
insert_result = await Database.db.users.insert_one(new_user)
|
| 82 |
+
new_user["_id"] = insert_result.inserted_id
|
| 83 |
+
return UserResponse(**new_user)
|
| 84 |
+
except DuplicateKeyError:
|
| 85 |
+
raise HTTPException(
|
| 86 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 87 |
+
detail="Email already registered"
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
def create_access_token(subject: str, expires_delta: Optional[timedelta] = None) -> str:
|
| 91 |
+
"""Create a new JWT access token."""
|
| 92 |
+
if expires_delta:
|
| 93 |
+
expire = datetime.utcnow() + expires_delta
|
| 94 |
+
else:
|
| 95 |
+
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
| 96 |
+
|
| 97 |
+
to_encode = {"exp": expire, "sub": subject}
|
| 98 |
+
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
| 99 |
+
return encoded_jwt
|
| 100 |
+
|
| 101 |
+
async def get_current_user(token: str = Depends(oauth2_scheme)) -> Dict[str, Any]:
|
| 102 |
+
"""Get the current user from a JWT token."""
|
| 103 |
+
try:
|
| 104 |
+
payload = jwt.decode(
|
| 105 |
+
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
| 106 |
+
)
|
| 107 |
+
token_data = TokenPayload(**payload)
|
| 108 |
+
|
| 109 |
+
if datetime.fromtimestamp(token_data.exp) < datetime.utcnow():
|
| 110 |
+
raise HTTPException(
|
| 111 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 112 |
+
detail="Token expired",
|
| 113 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 114 |
+
)
|
| 115 |
+
except (jwt.PyJWTError, ValidationError):
|
| 116 |
+
raise HTTPException(
|
| 117 |
+
status_code=status.HTTP_403_FORBIDDEN,
|
| 118 |
+
detail="Could not validate credentials",
|
| 119 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
user = await get_user_by_email(token_data.sub)
|
| 123 |
+
if user is None:
|
| 124 |
+
raise HTTPException(
|
| 125 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 126 |
+
detail="User not found"
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
return user
|
| 130 |
+
|
| 131 |
+
async def get_current_active_user(current_user: Dict[str, Any] = Depends(get_current_user)) -> Dict[str, Any]:
|
| 132 |
+
"""Get the current active user."""
|
| 133 |
+
if not current_user.get("is_active", True):
|
| 134 |
+
raise HTTPException(status_code=400, detail="Inactive user")
|
| 135 |
+
return current_user
|
| 136 |
+
|
| 137 |
+
async def get_current_admin_user(current_user: Dict[str, Any] = Depends(get_current_user)) -> Dict[str, Any]:
|
| 138 |
+
"""Get the current admin user."""
|
| 139 |
+
if not current_user.get("is_admin", False):
|
| 140 |
+
raise HTTPException(
|
| 141 |
+
status_code=status.HTTP_403_FORBIDDEN,
|
| 142 |
+
detail="Not enough permissions"
|
| 143 |
+
)
|
| 144 |
+
return current_user
|
| 145 |
+
|
| 146 |
+
# Initialize the users collection
|
| 147 |
+
async def init_users_collection():
|
| 148 |
+
"""Initialize the users collection with indexes."""
|
| 149 |
+
await Database.db.users.create_index("email", unique=True)
|
| 150 |
+
|
| 151 |
+
# Create admin user if it doesn't exist
|
| 152 |
+
admin_email = "admin@clayx.ai"
|
| 153 |
+
admin = await get_user_by_email(admin_email)
|
| 154 |
+
if not admin:
|
| 155 |
+
admin_user = UserCreate(
|
| 156 |
+
email=admin_email,
|
| 157 |
+
password="adminpassword", # Change this in production!
|
| 158 |
+
full_name="Admin User"
|
| 159 |
+
)
|
| 160 |
+
user_dict = admin_user.dict()
|
| 161 |
+
hashed_password = get_password_hash(user_dict["password"])
|
| 162 |
+
|
| 163 |
+
new_admin = {
|
| 164 |
+
"email": admin_email,
|
| 165 |
+
"hashed_password": hashed_password,
|
| 166 |
+
"full_name": "Admin User",
|
| 167 |
+
"is_active": True,
|
| 168 |
+
"is_admin": True,
|
| 169 |
+
"created_at": datetime.utcnow()
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
await Database.db.users.insert_one(new_admin)
|
evoagentx/app/services.py
ADDED
|
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Business logic for agents, workflows, and executions.
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
# import asyncio
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 8 |
+
from bson import ObjectId
|
| 9 |
+
|
| 10 |
+
from evoagentx.app.db import (
|
| 11 |
+
Database, # Agent, Workflow, WorkflowExecution, ExecutionLog,
|
| 12 |
+
AgentStatus, WorkflowStatus, ExecutionStatus
|
| 13 |
+
)
|
| 14 |
+
from evoagentx.app.schemas import (
|
| 15 |
+
AgentCreate, AgentUpdate, WorkflowCreate, WorkflowUpdate,
|
| 16 |
+
ExecutionCreate, PaginationParams, SearchParams
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
# Agent Service
|
| 22 |
+
class AgentService:
|
| 23 |
+
@staticmethod
|
| 24 |
+
async def create_agent(agent_data: AgentCreate, user_id: Optional[str] = None) -> Dict[str, Any]:
|
| 25 |
+
"""Create a new agent."""
|
| 26 |
+
agent_dict = agent_data.dict()
|
| 27 |
+
agent_dict["created_by"] = user_id
|
| 28 |
+
agent_dict["created_at"] = datetime.utcnow()
|
| 29 |
+
agent_dict["updated_at"] = agent_dict["created_at"]
|
| 30 |
+
agent_dict["status"] = AgentStatus.CREATED
|
| 31 |
+
|
| 32 |
+
# Validate agent exists with the same name
|
| 33 |
+
existing_agent = await Database.agents.find_one({"name": agent_dict["name"]})
|
| 34 |
+
if existing_agent:
|
| 35 |
+
raise ValueError(f"Agent with name '{agent_dict['name']}' already exists")
|
| 36 |
+
|
| 37 |
+
result = await Database.agents.insert_one(agent_dict)
|
| 38 |
+
agent_dict["_id"] = result.inserted_id
|
| 39 |
+
|
| 40 |
+
logger.info(f"Created agent {agent_dict['name']} with ID {result.inserted_id}")
|
| 41 |
+
|
| 42 |
+
return agent_dict
|
| 43 |
+
|
| 44 |
+
@staticmethod
|
| 45 |
+
async def get_agent(agent_id: str) -> Optional[Dict[str, Any]]:
|
| 46 |
+
"""Get an agent by ID."""
|
| 47 |
+
if not ObjectId.is_valid(agent_id):
|
| 48 |
+
raise ValueError(f"Invalid agent ID: {agent_id}")
|
| 49 |
+
|
| 50 |
+
agent = await Database.agents.find_one({"_id": ObjectId(agent_id)})
|
| 51 |
+
return agent
|
| 52 |
+
|
| 53 |
+
@staticmethod
|
| 54 |
+
async def get_agent_by_name(name: str) -> Optional[Dict[str, Any]]:
|
| 55 |
+
"""Get an agent by name."""
|
| 56 |
+
return await Database.agents.find_one({"name": name})
|
| 57 |
+
|
| 58 |
+
@staticmethod
|
| 59 |
+
async def update_agent(agent_id: str, agent_data: AgentUpdate) -> Optional[Dict[str, Any]]:
|
| 60 |
+
"""Update an agent."""
|
| 61 |
+
if not ObjectId.is_valid(agent_id):
|
| 62 |
+
raise ValueError(f"Invalid agent ID: {agent_id}")
|
| 63 |
+
|
| 64 |
+
agent = await Database.agents.find_one({"_id": ObjectId(agent_id)})
|
| 65 |
+
if not agent:
|
| 66 |
+
return None
|
| 67 |
+
|
| 68 |
+
update_data = agent_data.dict(exclude_unset=True)
|
| 69 |
+
update_data["updated_at"] = datetime.utcnow()
|
| 70 |
+
|
| 71 |
+
if "name" in update_data:
|
| 72 |
+
# Check if the new name already exists
|
| 73 |
+
existing = await Database.agents.find_one({
|
| 74 |
+
"name": update_data["name"],
|
| 75 |
+
"_id": {"$ne": ObjectId(agent_id)}
|
| 76 |
+
})
|
| 77 |
+
if existing:
|
| 78 |
+
raise ValueError(f"Agent with name '{update_data['name']}' already exists")
|
| 79 |
+
|
| 80 |
+
await Database.agents.update_one(
|
| 81 |
+
{"_id": ObjectId(agent_id)},
|
| 82 |
+
{"$set": update_data}
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
updated_agent = await Database.agents.find_one({"_id": ObjectId(agent_id)})
|
| 86 |
+
logger.info(f"Updated agent {agent_id}")
|
| 87 |
+
|
| 88 |
+
return updated_agent
|
| 89 |
+
|
| 90 |
+
@staticmethod
|
| 91 |
+
async def delete_agent(agent_id: str) -> bool:
|
| 92 |
+
"""Delete an agent."""
|
| 93 |
+
if not ObjectId.is_valid(agent_id):
|
| 94 |
+
raise ValueError(f"Invalid agent ID: {agent_id}")
|
| 95 |
+
|
| 96 |
+
# Check if agent is used in any workflows
|
| 97 |
+
workflow_count = await Database.workflows.count_documents({"agent_ids": agent_id})
|
| 98 |
+
if workflow_count > 0:
|
| 99 |
+
raise ValueError(f"Cannot delete agent {agent_id} as it is used in {workflow_count} workflows")
|
| 100 |
+
|
| 101 |
+
result = await Database.agents.delete_one({"_id": ObjectId(agent_id)})
|
| 102 |
+
if result.deleted_count:
|
| 103 |
+
logger.info(f"Deleted agent {agent_id}")
|
| 104 |
+
return True
|
| 105 |
+
return False
|
| 106 |
+
|
| 107 |
+
@staticmethod
|
| 108 |
+
async def list_agents(
|
| 109 |
+
params: PaginationParams,
|
| 110 |
+
search: Optional[SearchParams] = None
|
| 111 |
+
) -> Tuple[List[Dict[str, Any]], int]:
|
| 112 |
+
"""List agents with pagination and search."""
|
| 113 |
+
query = {}
|
| 114 |
+
|
| 115 |
+
if search:
|
| 116 |
+
if search.query:
|
| 117 |
+
query["$text"] = {"$search": search.query}
|
| 118 |
+
|
| 119 |
+
if search.tags:
|
| 120 |
+
query["tags"] = {"$all": search.tags}
|
| 121 |
+
|
| 122 |
+
if search.status:
|
| 123 |
+
query["status"] = search.status
|
| 124 |
+
|
| 125 |
+
if search.start_date and search.end_date:
|
| 126 |
+
query["created_at"] = {
|
| 127 |
+
"$gte": search.start_date,
|
| 128 |
+
"$lte": search.end_date
|
| 129 |
+
}
|
| 130 |
+
elif search.start_date:
|
| 131 |
+
query["created_at"] = {"$gte": search.start_date}
|
| 132 |
+
elif search.end_date:
|
| 133 |
+
query["created_at"] = {"$lte": search.end_date}
|
| 134 |
+
|
| 135 |
+
total = await Database.agents.count_documents(query)
|
| 136 |
+
|
| 137 |
+
cursor = Database.agents.find(query)\
|
| 138 |
+
.sort("created_at", -1)\
|
| 139 |
+
.skip(params.skip)\
|
| 140 |
+
.limit(params.limit)
|
| 141 |
+
|
| 142 |
+
agents = await cursor.to_list(length=params.limit)
|
| 143 |
+
return agents, total
|
| 144 |
+
|
| 145 |
+
# Workflow Service
|
| 146 |
+
class WorkflowService:
|
| 147 |
+
@staticmethod
|
| 148 |
+
async def create_workflow(workflow_data: WorkflowCreate, user_id: Optional[str] = None) -> Dict[str, Any]:
|
| 149 |
+
"""Create a new workflow."""
|
| 150 |
+
workflow_dict = workflow_data.dict()
|
| 151 |
+
workflow_dict["created_by"] = user_id
|
| 152 |
+
workflow_dict["created_at"] = datetime.utcnow()
|
| 153 |
+
workflow_dict["updated_at"] = workflow_dict["created_at"]
|
| 154 |
+
workflow_dict["status"] = WorkflowStatus.CREATED
|
| 155 |
+
workflow_dict["version"] = 1
|
| 156 |
+
|
| 157 |
+
# Extract agent IDs from the workflow definition
|
| 158 |
+
agent_ids = set()
|
| 159 |
+
|
| 160 |
+
# Extract agent IDs from steps
|
| 161 |
+
steps = workflow_dict["definition"].get("steps", [])
|
| 162 |
+
for step in steps:
|
| 163 |
+
if "agent_id" in step:
|
| 164 |
+
agent_id = step["agent_id"]
|
| 165 |
+
# Validate agent exists
|
| 166 |
+
agent = await AgentService.get_agent(agent_id)
|
| 167 |
+
if not agent:
|
| 168 |
+
raise ValueError(f"Agent with ID {agent_id} does not exist")
|
| 169 |
+
agent_ids.add(agent_id)
|
| 170 |
+
|
| 171 |
+
workflow_dict["agent_ids"] = list(agent_ids)
|
| 172 |
+
|
| 173 |
+
# Check for existing workflow with the same name
|
| 174 |
+
existing = await Database.workflows.find_one({"name": workflow_dict["name"]})
|
| 175 |
+
if existing:
|
| 176 |
+
raise ValueError(f"Workflow with name '{workflow_dict['name']}' already exists")
|
| 177 |
+
|
| 178 |
+
result = await Database.workflows.insert_one(workflow_dict)
|
| 179 |
+
workflow_dict["_id"] = result.inserted_id
|
| 180 |
+
|
| 181 |
+
logger.info(f"Created workflow {workflow_dict['name']} with ID {result.inserted_id}")
|
| 182 |
+
|
| 183 |
+
return workflow_dict
|
| 184 |
+
|
| 185 |
+
@staticmethod
|
| 186 |
+
async def get_workflow(workflow_id: str) -> Optional[Dict[str, Any]]:
|
| 187 |
+
"""Get a workflow by ID."""
|
| 188 |
+
if not ObjectId.is_valid(workflow_id):
|
| 189 |
+
raise ValueError(f"Invalid workflow ID: {workflow_id}")
|
| 190 |
+
workflow = await Database.workflows.find_one({"_id": ObjectId(workflow_id)})
|
| 191 |
+
return workflow
|
| 192 |
+
|
| 193 |
+
@staticmethod
|
| 194 |
+
async def get_workflow_by_name(name: str) -> Optional[Dict[str, Any]]:
|
| 195 |
+
"""Get a workflow by name."""
|
| 196 |
+
return await Database.workflows.find_one({"name": name})
|
| 197 |
+
|
| 198 |
+
@staticmethod
|
| 199 |
+
async def update_workflow(workflow_id: str, workflow_data: WorkflowUpdate) -> Optional[Dict[str, Any]]:
|
| 200 |
+
"""Update a workflow."""
|
| 201 |
+
if not ObjectId.is_valid(workflow_id):
|
| 202 |
+
raise ValueError(f"Invalid workflow ID: {workflow_id}")
|
| 203 |
+
|
| 204 |
+
workflow = await Database.workflows.find_one({"_id": ObjectId(workflow_id)})
|
| 205 |
+
if not workflow:
|
| 206 |
+
return None
|
| 207 |
+
|
| 208 |
+
update_data = workflow_data.dict(exclude_unset=True)
|
| 209 |
+
update_data["updated_at"] = datetime.utcnow()
|
| 210 |
+
|
| 211 |
+
# Update version if definition changes
|
| 212 |
+
if "definition" in update_data:
|
| 213 |
+
update_data["version"] = workflow.get("version", 1) + 1
|
| 214 |
+
|
| 215 |
+
# Extract agent IDs from the updated workflow definition
|
| 216 |
+
agent_ids = set()
|
| 217 |
+
steps = update_data["definition"].get("steps", [])
|
| 218 |
+
for step in steps:
|
| 219 |
+
if "agent_id" in step:
|
| 220 |
+
agent_id = step["agent_id"]
|
| 221 |
+
# Validate agent exists
|
| 222 |
+
agent = await AgentService.get_agent(agent_id)
|
| 223 |
+
if not agent:
|
| 224 |
+
raise ValueError(f"Agent with ID {agent_id} does not exist")
|
| 225 |
+
agent_ids.add(agent_id)
|
| 226 |
+
|
| 227 |
+
update_data["agent_ids"] = list(agent_ids)
|
| 228 |
+
|
| 229 |
+
# Check for name conflict if name is being updated
|
| 230 |
+
if "name" in update_data:
|
| 231 |
+
existing = await Database.workflows.find_one({
|
| 232 |
+
"name": update_data["name"],
|
| 233 |
+
"_id": {"$ne": ObjectId(workflow_id)}
|
| 234 |
+
})
|
| 235 |
+
if existing:
|
| 236 |
+
raise ValueError(f"Workflow with name '{update_data['name']}' already exists")
|
| 237 |
+
|
| 238 |
+
await Database.workflows.update_one(
|
| 239 |
+
{"_id": ObjectId(workflow_id)},
|
| 240 |
+
{"$set": update_data}
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
updated_workflow = await Database.workflows.find_one({"_id": ObjectId(workflow_id)})
|
| 244 |
+
logger.info(f"Updated workflow {workflow_id}")
|
| 245 |
+
|
| 246 |
+
return updated_workflow
|
| 247 |
+
|
| 248 |
+
@staticmethod
|
| 249 |
+
async def delete_workflow(workflow_id: str) -> bool:
|
| 250 |
+
"""Delete a workflow."""
|
| 251 |
+
if not ObjectId.is_valid(workflow_id):
|
| 252 |
+
raise ValueError(f"Invalid workflow ID: {workflow_id}")
|
| 253 |
+
|
| 254 |
+
# Check if workflow has any ongoing or recent executions
|
| 255 |
+
recent_executions = await Database.executions.count_documents({
|
| 256 |
+
"workflow_id": workflow_id,
|
| 257 |
+
"status": {"$in": [
|
| 258 |
+
ExecutionStatus.PENDING,
|
| 259 |
+
ExecutionStatus.RUNNING
|
| 260 |
+
]}
|
| 261 |
+
})
|
| 262 |
+
|
| 263 |
+
if recent_executions > 0:
|
| 264 |
+
raise ValueError(f"Cannot delete workflow {workflow_id} with {recent_executions} active executions")
|
| 265 |
+
|
| 266 |
+
result = await Database.workflows.delete_one({"_id": ObjectId(workflow_id)})
|
| 267 |
+
if result.deleted_count:
|
| 268 |
+
# Delete associated execution logs
|
| 269 |
+
await Database.logs.delete_many({"workflow_id": workflow_id})
|
| 270 |
+
await Database.executions.delete_many({"workflow_id": workflow_id})
|
| 271 |
+
|
| 272 |
+
logger.info(f"Deleted workflow {workflow_id}")
|
| 273 |
+
return True
|
| 274 |
+
return False
|
| 275 |
+
|
| 276 |
+
@staticmethod
|
| 277 |
+
async def list_workflows(
|
| 278 |
+
params: PaginationParams,
|
| 279 |
+
search: Optional[SearchParams] = None
|
| 280 |
+
) -> Tuple[List[Dict[str, Any]], int]:
|
| 281 |
+
"""List workflows with pagination and search."""
|
| 282 |
+
query = {}
|
| 283 |
+
|
| 284 |
+
if search:
|
| 285 |
+
if search.query:
|
| 286 |
+
query["$text"] = {"$search": search.query}
|
| 287 |
+
|
| 288 |
+
if search.tags:
|
| 289 |
+
query["tags"] = {"$all": search.tags}
|
| 290 |
+
|
| 291 |
+
if search.status:
|
| 292 |
+
query["status"] = search.status
|
| 293 |
+
|
| 294 |
+
if search.start_date and search.end_date:
|
| 295 |
+
query["created_at"] = {
|
| 296 |
+
"$gte": search.start_date,
|
| 297 |
+
"$lte": search.end_date
|
| 298 |
+
}
|
| 299 |
+
elif search.start_date:
|
| 300 |
+
query["created_at"] = {"$gte": search.start_date}
|
| 301 |
+
elif search.end_date:
|
| 302 |
+
query["created_at"] = {"$lte": search.end_date}
|
| 303 |
+
|
| 304 |
+
total = await Database.workflows.count_documents(query)
|
| 305 |
+
|
| 306 |
+
cursor = Database.workflows.find(query)\
|
| 307 |
+
.sort("created_at", -1)\
|
| 308 |
+
.skip(params.skip)\
|
| 309 |
+
.limit(params.limit)
|
| 310 |
+
|
| 311 |
+
workflows = await cursor.to_list(length=params.limit)
|
| 312 |
+
return workflows, total
|
| 313 |
+
|
| 314 |
+
# Workflow Execution Service
|
| 315 |
+
class WorkflowExecutionService:
|
| 316 |
+
@staticmethod
|
| 317 |
+
async def create_execution(execution_data: ExecutionCreate, user_id: Optional[str] = None) -> Dict[str, Any]:
|
| 318 |
+
"""Create a new workflow execution."""
|
| 319 |
+
# Validate workflow exists
|
| 320 |
+
workflow = await WorkflowService.get_workflow(execution_data.workflow_id)
|
| 321 |
+
if not workflow:
|
| 322 |
+
raise ValueError(f"Workflow {execution_data.workflow_id} not found")
|
| 323 |
+
|
| 324 |
+
# Prepare execution document
|
| 325 |
+
execution_dict = {
|
| 326 |
+
"workflow_id": execution_data.workflow_id,
|
| 327 |
+
"status": ExecutionStatus.PENDING,
|
| 328 |
+
"start_time": datetime.utcnow(),
|
| 329 |
+
"input_params": execution_data.input_params,
|
| 330 |
+
"created_by": user_id,
|
| 331 |
+
"created_at": datetime.utcnow(),
|
| 332 |
+
"step_results": {},
|
| 333 |
+
"current_step": None,
|
| 334 |
+
"results": {},
|
| 335 |
+
"error_message": None
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
# Insert execution record
|
| 339 |
+
result = await Database.executions.insert_one(execution_dict)
|
| 340 |
+
execution_dict["_id"] = result.inserted_id
|
| 341 |
+
|
| 342 |
+
logger.info(f"Created workflow execution {result.inserted_id}")
|
| 343 |
+
|
| 344 |
+
# Optional: Queue execution for async processing
|
| 345 |
+
# This would typically use a task queue like Celery
|
| 346 |
+
# await execute_workflow_async.delay(execution_dict)
|
| 347 |
+
|
| 348 |
+
return execution_dict
|
| 349 |
+
|
| 350 |
+
@staticmethod
|
| 351 |
+
async def get_execution(execution_id: str) -> Optional[Dict[str, Any]]:
|
| 352 |
+
"""Get a workflow execution by ID."""
|
| 353 |
+
if not ObjectId.is_valid(execution_id):
|
| 354 |
+
raise ValueError(f"Invalid execution ID: {execution_id}")
|
| 355 |
+
|
| 356 |
+
execution = await Database.executions.find_one({"_id": ObjectId(execution_id)})
|
| 357 |
+
return execution
|
| 358 |
+
|
| 359 |
+
@staticmethod
|
| 360 |
+
async def update_execution_status(execution_id: str, status: ExecutionStatus, error_message: Optional[str] = None) -> Optional[Dict[str, Any]]:
|
| 361 |
+
"""Update execution status."""
|
| 362 |
+
if not ObjectId.is_valid(execution_id):
|
| 363 |
+
raise ValueError(f"Invalid execution ID: {execution_id}")
|
| 364 |
+
|
| 365 |
+
update_data = {
|
| 366 |
+
"status": status,
|
| 367 |
+
"updated_at": datetime.utcnow()
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
if status in [ExecutionStatus.COMPLETED, ExecutionStatus.FAILED, ExecutionStatus.CANCELLED]:
|
| 371 |
+
update_data["end_time"] = datetime.utcnow()
|
| 372 |
+
|
| 373 |
+
if error_message:
|
| 374 |
+
update_data["error_message"] = error_message
|
| 375 |
+
|
| 376 |
+
result = await Database.executions.find_one_and_update(
|
| 377 |
+
{"_id": ObjectId(execution_id)},
|
| 378 |
+
{"$set": update_data},
|
| 379 |
+
return_document=True
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
return result
|
| 383 |
+
|
| 384 |
+
@staticmethod
|
| 385 |
+
async def list_executions(
|
| 386 |
+
workflow_id: Optional[str] = None,
|
| 387 |
+
params: PaginationParams = PaginationParams(),
|
| 388 |
+
search: Optional[SearchParams] = None
|
| 389 |
+
) -> Tuple[List[Dict[str, Any]], int]:
|
| 390 |
+
"""List workflow executions with pagination and search."""
|
| 391 |
+
query = {}
|
| 392 |
+
|
| 393 |
+
if workflow_id:
|
| 394 |
+
query["workflow_id"] = workflow_id
|
| 395 |
+
|
| 396 |
+
if search:
|
| 397 |
+
if search.status:
|
| 398 |
+
query["status"] = search.status
|
| 399 |
+
|
| 400 |
+
if search.start_date and search.end_date:
|
| 401 |
+
query["created_at"] = {
|
| 402 |
+
"$gte": search.start_date,
|
| 403 |
+
"$lte": search.end_date
|
| 404 |
+
}
|
| 405 |
+
elif search.start_date:
|
| 406 |
+
query["created_at"] = {"$gte": search.start_date}
|
| 407 |
+
elif search.end_date:
|
| 408 |
+
query["created_at"] = {"$lte": search.end_date}
|
| 409 |
+
|
| 410 |
+
total = await Database.executions.count_documents(query)
|
| 411 |
+
|
| 412 |
+
cursor = Database.executions.find(query)\
|
| 413 |
+
.sort("created_at", -1)\
|
| 414 |
+
.skip(params.skip)\
|
| 415 |
+
.limit(params.limit)
|
| 416 |
+
|
| 417 |
+
executions = await cursor.to_list(length=params.limit)
|
| 418 |
+
return executions, total
|
| 419 |
+
|
| 420 |
+
@staticmethod
|
| 421 |
+
async def log_execution_event(
|
| 422 |
+
workflow_id: str,
|
| 423 |
+
execution_id: str,
|
| 424 |
+
message: str,
|
| 425 |
+
step_id: Optional[str] = None,
|
| 426 |
+
agent_id: Optional[str] = None,
|
| 427 |
+
level: str = "INFO",
|
| 428 |
+
details: Optional[Dict[str, Any]] = None
|
| 429 |
+
) -> Dict[str, Any]:
|
| 430 |
+
"""Log an event in a workflow execution."""
|
| 431 |
+
log_entry = {
|
| 432 |
+
"workflow_id": workflow_id,
|
| 433 |
+
"execution_id": execution_id,
|
| 434 |
+
"step_id": step_id,
|
| 435 |
+
"agent_id": agent_id,
|
| 436 |
+
"timestamp": datetime.utcnow(),
|
| 437 |
+
"level": level,
|
| 438 |
+
"message": message,
|
| 439 |
+
"details": details or {}
|
| 440 |
+
}
|
| 441 |
+
|
| 442 |
+
result = await Database.logs.insert_one(log_entry)
|
| 443 |
+
log_entry["_id"] = result.inserted_id
|
| 444 |
+
|
| 445 |
+
return log_entry
|
| 446 |
+
|
| 447 |
+
@staticmethod
|
| 448 |
+
async def get_execution_logs(
|
| 449 |
+
execution_id: str,
|
| 450 |
+
params: PaginationParams = PaginationParams()
|
| 451 |
+
) -> Tuple[List[Dict[str, Any]], int]:
|
| 452 |
+
"""Retrieve logs for a specific execution."""
|
| 453 |
+
query = {"execution_id": execution_id}
|
| 454 |
+
|
| 455 |
+
total = await Database.logs.count_documents(query)
|
| 456 |
+
|
| 457 |
+
cursor = Database.logs.find(query)\
|
| 458 |
+
.sort("timestamp", 1)\
|
| 459 |
+
.skip(params.skip)\
|
| 460 |
+
.limit(params.limit)
|
| 461 |
+
|
| 462 |
+
logs = await cursor.to_list(length=params.limit)
|
| 463 |
+
return logs, total
|
evoagentx/benchmark/.ipynb_checkpoints/Untitled-checkpoint.ipynb
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [],
|
| 3 |
+
"metadata": {},
|
| 4 |
+
"nbformat": 4,
|
| 5 |
+
"nbformat_minor": 5
|
| 6 |
+
}
|
evoagentx/benchmark/.ipynb_checkpoints/test_load_json-checkpoint.ipynb
ADDED
|
@@ -0,0 +1,570 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "385cefbd",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"ename": "ImportError",
|
| 11 |
+
"evalue": "attempted relative import with no known parent package",
|
| 12 |
+
"output_type": "error",
|
| 13 |
+
"traceback": [
|
| 14 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 15 |
+
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
|
| 16 |
+
"Cell \u001b[0;32mIn[1], line 7\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mshutil\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Union, Any, Callable, List, Dict, Tuple\n\u001b[0;32m----> 7\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mscicode\u001b[39;00m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m \u001b[38;5;66;03m# Many SciCode tests use numpy\u001b[39;00m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbenchmark\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m CodingBenchmark\n",
|
| 17 |
+
"File \u001b[0;32m/gpfs/radev/pi/ying_rex/tl688/selfevolve/EvoAgentX/evoagentx/benchmark/scicode.py:10\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mscicode\u001b[39;00m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m \u001b[38;5;66;03m# Many SciCode tests use numpy\u001b[39;00m\n\u001b[0;32m---> 10\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbenchmark\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m CodingBenchmark\n\u001b[1;32m 11\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcore\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mlogging\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m logger\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m download_file\n",
|
| 18 |
+
"\u001b[0;31mImportError\u001b[0m: attempted relative import with no known parent package"
|
| 19 |
+
]
|
| 20 |
+
}
|
| 21 |
+
],
|
| 22 |
+
"source": [
|
| 23 |
+
"import os\n",
|
| 24 |
+
"import re\n",
|
| 25 |
+
"import gzip\n",
|
| 26 |
+
"import shutil\n",
|
| 27 |
+
"from typing import Union, Any, Callable, List, Dict, Tuple\n",
|
| 28 |
+
"\n",
|
| 29 |
+
"import scicode\n",
|
| 30 |
+
"\n",
|
| 31 |
+
"import numpy as np # Many SciCode tests use numpy\n",
|
| 32 |
+
"from .benchmark import CodingBenchmark\n",
|
| 33 |
+
"from ..core.logging import logger\n",
|
| 34 |
+
"from ..utils.utils import download_file\n",
|
| 35 |
+
"from ..core.module_utils import load_json\n",
|
| 36 |
+
"from ..utils.aflow_utils.data_utils import AFLOW_DATASET_FILES_MAP, download_aflow_benchmark_data\n",
|
| 37 |
+
"\n",
|
| 38 |
+
"\n",
|
| 39 |
+
"# ----------------------------\n",
|
| 40 |
+
"# Raw SciCode (community) data\n",
|
| 41 |
+
"# ----------------------------\n",
|
| 42 |
+
"\n",
|
| 43 |
+
"SCICODE_DEFAULT_URL = \"https://raw.githubusercontent.com/scicode-bench/scicode/main/data/scicode.jsonl.gz\" # If you mirror elsewhere, update here.\n",
|
| 44 |
+
"\n",
|
| 45 |
+
"\n",
|
| 46 |
+
"def download_raw_scicode_data(save_folder: str, url: str = SCICODE_DEFAULT_URL) -> str:\n",
|
| 47 |
+
" \"\"\"\n",
|
| 48 |
+
" Download and unzip the raw SciCode jsonl(.gz) to `save_folder`.\n",
|
| 49 |
+
"\n",
|
| 50 |
+
" Returns:\n",
|
| 51 |
+
" str: Path to the unzipped jsonl file.\n",
|
| 52 |
+
" \"\"\"\n",
|
| 53 |
+
" os.makedirs(save_folder, exist_ok=True)\n",
|
| 54 |
+
" gz_path = os.path.join(save_folder, \"scicode.jsonl.gz\")\n",
|
| 55 |
+
" jsonl_path = os.path.join(save_folder, \"scicode.jsonl\")\n",
|
| 56 |
+
"\n",
|
| 57 |
+
" logger.info(f\"Downloading SciCode data from {url} ...\")\n",
|
| 58 |
+
" download_file(url=url, save_file=gz_path)\n",
|
| 59 |
+
"\n",
|
| 60 |
+
" logger.info(\"Unzipping SciCode data ...\")\n",
|
| 61 |
+
" with gzip.open(gz_path, \"rb\") as f_in, open(jsonl_path, \"wb\") as f_out:\n",
|
| 62 |
+
" shutil.copyfileobj(f_in, f_out)\n",
|
| 63 |
+
" if os.path.exists(gz_path):\n",
|
| 64 |
+
" os.remove(gz_path)\n",
|
| 65 |
+
"\n",
|
| 66 |
+
" return jsonl_path\n",
|
| 67 |
+
"\n",
|
| 68 |
+
"\n",
|
| 69 |
+
"# ----------------------------\n",
|
| 70 |
+
"# Schema helpers\n",
|
| 71 |
+
"# ----------------------------\n",
|
| 72 |
+
"\n",
|
| 73 |
+
"def _extract_entry_point_from_header(header: str) -> str:\n",
|
| 74 |
+
" \"\"\"\n",
|
| 75 |
+
" Given a SciCode 'function_header' string like:\n",
|
| 76 |
+
" \"def get_alpha(recvec, alpha_scaling=5):\\n '''...'''\"\n",
|
| 77 |
+
" return \"get_alpha\".\n",
|
| 78 |
+
" \"\"\"\n",
|
| 79 |
+
" m = re.search(r\"def\\s+([A-Za-z_][A-Za-z0-9_]*)\\s*\\(\", header)\n",
|
| 80 |
+
" if not m:\n",
|
| 81 |
+
" raise ValueError(\"Could not parse entry point from function_header\")\n",
|
| 82 |
+
" return m.group(1)\n",
|
| 83 |
+
"\n",
|
| 84 |
+
"\n",
|
| 85 |
+
"def _coerce_scicode_row_to_examples(row: Dict[str, Any]) -> List[Dict[str, Any]]:\n",
|
| 86 |
+
" \"\"\"\n",
|
| 87 |
+
" SciCode rows may contain a single task or multiple step tasks.\n",
|
| 88 |
+
" We normalize them to a list of examples with a unified structure:\n",
|
| 89 |
+
" {\n",
|
| 90 |
+
" \"task_id\": \"SciCode/<name>#<sub_id>\",\n",
|
| 91 |
+
" \"prompt\": <function_header + optional docstring block>,\n",
|
| 92 |
+
" \"entry_point\": <func_name>,\n",
|
| 93 |
+
" \"canonical_solution\": <ground_truth_code>,\n",
|
| 94 |
+
" \"tests\": List[str], # list of python test snippets\n",
|
| 95 |
+
" \"imports\": str # optional import prelude (e.g., 'import numpy as np')\n",
|
| 96 |
+
" }\n",
|
| 97 |
+
" \"\"\"\n",
|
| 98 |
+
" examples: List[Dict[str, Any]] = []\n",
|
| 99 |
+
"\n",
|
| 100 |
+
" name = str(row[0]) if 0 in row or isinstance(row, list) else str(row.get(\"name\", \"unknown\"))\n",
|
| 101 |
+
" # Different dumps can be list-based or dict-based; support both:\n",
|
| 102 |
+
" if isinstance(row, list):\n",
|
| 103 |
+
" # Heuristic index layout (based on the example provided by the user):\n",
|
| 104 |
+
" # [name, <maybe_int>, description, <maybe empty>, docstring, imports, steps(list[dict]) or code, tests(list[str]) or None]\n",
|
| 105 |
+
" # We will try to find keys by semantic type\n",
|
| 106 |
+
" description = None\n",
|
| 107 |
+
" doc_or_header = None\n",
|
| 108 |
+
" imports_block = None\n",
|
| 109 |
+
" steps_or_code = None\n",
|
| 110 |
+
" tests = None\n",
|
| 111 |
+
"\n",
|
| 112 |
+
" # Try assigning by scanning\n",
|
| 113 |
+
" for item in row:\n",
|
| 114 |
+
" if isinstance(item, str) and item.strip().startswith('\"\"\"'):\n",
|
| 115 |
+
" # docstring/prompt block for the top-level task\n",
|
| 116 |
+
" doc_or_header = item\n",
|
| 117 |
+
" elif isinstance(item, str) and (item.startswith(\"import \") or \"from \" in item):\n",
|
| 118 |
+
" imports_block = item\n",
|
| 119 |
+
" elif isinstance(item, list):\n",
|
| 120 |
+
" # Could be steps OR tests\n",
|
| 121 |
+
" if item and isinstance(item[0], dict) and \"function_header\" in item[0]:\n",
|
| 122 |
+
" steps_or_code = item\n",
|
| 123 |
+
" elif item and isinstance(item[0], str) and item[0].strip().startswith((\"ref\", \"assert\", \"from \")):\n",
|
| 124 |
+
" tests = item\n",
|
| 125 |
+
" elif isinstance(item, dict):\n",
|
| 126 |
+
" # Some SciCode variants may directly be dicts per step; treat as steps\n",
|
| 127 |
+
" steps_or_code = [item]\n",
|
| 128 |
+
"\n",
|
| 129 |
+
" # If we have step dictionaries, produce one example per step\n",
|
| 130 |
+
" if isinstance(steps_or_code, list) and steps_or_code and isinstance(steps_or_code[0], dict):\n",
|
| 131 |
+
" for idx, step in enumerate(steps_or_code):\n",
|
| 132 |
+
" header = step.get(\"function_header\") or step.get(\"header\") or \"\"\n",
|
| 133 |
+
" code = step.get(\"ground_truth_code\") or step.get(\"solution\") or \"\"\n",
|
| 134 |
+
" step_tests = step.get(\"test_cases\") or []\n",
|
| 135 |
+
" entry_point = _extract_entry_point_from_header(header)\n",
|
| 136 |
+
" prompt = header # keep header as the model prompt (header + docstring already embedded)\n",
|
| 137 |
+
" examples.append(\n",
|
| 138 |
+
" {\n",
|
| 139 |
+
" \"task_id\": f\"SciCode/{name}#step{idx+1}\",\n",
|
| 140 |
+
" \"prompt\": prompt,\n",
|
| 141 |
+
" \"entry_point\": entry_point,\n",
|
| 142 |
+
" \"canonical_solution\": code,\n",
|
| 143 |
+
" \"tests\": step_tests,\n",
|
| 144 |
+
" \"imports\": imports_block or \"\",\n",
|
| 145 |
+
" }\n",
|
| 146 |
+
" )\n",
|
| 147 |
+
" else:\n",
|
| 148 |
+
" # Single task variant: expect a combined \"function_header\" + \"ground_truth_code\" + \"test_cases\" in the row\n",
|
| 149 |
+
" # Try to detect them from the large code string block if present.\n",
|
| 150 |
+
" # Fall back to no-op if missing.\n",
|
| 151 |
+
" # NOTE: The user’s example shows a consolidated block near the end; we’ll try to parse it.\n",
|
| 152 |
+
" code_blob = None\n",
|
| 153 |
+
" for item in row:\n",
|
| 154 |
+
" if isinstance(item, str) and \"def \" in item and \"return\" in item:\n",
|
| 155 |
+
" code_blob = item\n",
|
| 156 |
+
" break\n",
|
| 157 |
+
" # Try to split the big blob into multiple functions; evaluate the last one as the main if we cannot find header separately.\n",
|
| 158 |
+
" if code_blob:\n",
|
| 159 |
+
" # Heuristic: the last \"def ...\" in the blob is the target entry point\n",
|
| 160 |
+
" headers = list(re.finditer(r\"(?ms)^(def\\s+[A-Za-z_][A-Za-z0-9_]*\\s*\\(.*?\\):\\s*\\n)\", code_blob))\n",
|
| 161 |
+
" if headers:\n",
|
| 162 |
+
" last_header = headers[-1].group(1)\n",
|
| 163 |
+
" entry_point = _extract_entry_point_from_header(last_header)\n",
|
| 164 |
+
" else:\n",
|
| 165 |
+
" entry_point = \"solution\"\n",
|
| 166 |
+
"\n",
|
| 167 |
+
" # We will treat entire blob as canonical_solution and create a minimal prompt from the docstring if present\n",
|
| 168 |
+
" prompt = doc_or_header or f\"def {entry_point}(*args, **kwargs):\\n '''Fill in the function body.'''\\n ...\"\n",
|
| 169 |
+
" examples.append(\n",
|
| 170 |
+
" {\n",
|
| 171 |
+
" \"task_id\": f\"SciCode/{name}\",\n",
|
| 172 |
+
" \"prompt\": prompt,\n",
|
| 173 |
+
" \"entry_point\": entry_point,\n",
|
| 174 |
+
" \"canonical_solution\": code_blob,\n",
|
| 175 |
+
" \"tests\": tests or [],\n",
|
| 176 |
+
" \"imports\": imports_block or \"\",\n",
|
| 177 |
+
" }\n",
|
| 178 |
+
" )\n",
|
| 179 |
+
"\n",
|
| 180 |
+
" else:\n",
|
| 181 |
+
" # Dict-style row (fallback): expect keys by name\n",
|
| 182 |
+
" steps = row.get(\"steps\", [])\n",
|
| 183 |
+
" imports_block = row.get(\"imports\", \"\")\n",
|
| 184 |
+
" task_name = row.get(\"name\", \"unknown\")\n",
|
| 185 |
+
"\n",
|
| 186 |
+
" if steps:\n",
|
| 187 |
+
" for idx, step in enumerate(steps):\n",
|
| 188 |
+
" header = step.get(\"function_header\", \"\")\n",
|
| 189 |
+
" code = step.get(\"ground_truth_code\", \"\")\n",
|
| 190 |
+
" step_tests = step.get(\"test_cases\", [])\n",
|
| 191 |
+
" entry_point = _extract_entry_point_from_header(header)\n",
|
| 192 |
+
" examples.append(\n",
|
| 193 |
+
" {\n",
|
| 194 |
+
" \"task_id\": f\"SciCode/{task_name}#step{idx+1}\",\n",
|
| 195 |
+
" \"prompt\": header,\n",
|
| 196 |
+
" \"entry_point\": entry_point,\n",
|
| 197 |
+
" \"canonical_solution\": code,\n",
|
| 198 |
+
" \"tests\": step_tests,\n",
|
| 199 |
+
" \"imports\": imports_block or \"\",\n",
|
| 200 |
+
" }\n",
|
| 201 |
+
" )\n",
|
| 202 |
+
" else:\n",
|
| 203 |
+
" header = row.get(\"function_header\", \"\")\n",
|
| 204 |
+
" code = row.get(\"ground_truth_code\", \"\")\n",
|
| 205 |
+
" tests = row.get(\"test_cases\", [])\n",
|
| 206 |
+
" entry_point = _extract_entry_point_from_header(header) if header else \"solution\"\n",
|
| 207 |
+
" prompt = header or f\"def {entry_point}(*args, **kwargs):\\n pass\"\n",
|
| 208 |
+
" examples.append(\n",
|
| 209 |
+
" {\n",
|
| 210 |
+
" \"task_id\": f\"SciCode/{task_name}\",\n",
|
| 211 |
+
" \"prompt\": prompt,\n",
|
| 212 |
+
" \"entry_point\": entry_point,\n",
|
| 213 |
+
" \"canonical_solution\": code,\n",
|
| 214 |
+
" \"tests\": tests,\n",
|
| 215 |
+
" \"imports\": imports_block or \"\",\n",
|
| 216 |
+
" }\n",
|
| 217 |
+
" )\n",
|
| 218 |
+
"\n",
|
| 219 |
+
" return examples\n",
|
| 220 |
+
"\n",
|
| 221 |
+
"\n",
|
| 222 |
+
"def load_scicode_data(jsonl_path: str) -> List[Dict[str, Any]]:\n",
|
| 223 |
+
" \"\"\"\n",
|
| 224 |
+
" Load SciCode jsonl and expand into normalized examples.\n",
|
| 225 |
+
" \"\"\"\n",
|
| 226 |
+
" raw = load_json(jsonl_path, type=\"jsonl\")\n",
|
| 227 |
+
" all_examples: List[Dict[str, Any]] = []\n",
|
| 228 |
+
" for row in raw:\n",
|
| 229 |
+
" try:\n",
|
| 230 |
+
" all_examples.extend(_coerce_scicode_row_to_examples(row))\n",
|
| 231 |
+
" except Exception as e:\n",
|
| 232 |
+
" logger.warning(f\"[SciCode] Skipping a malformed row due to: {e}\")\n",
|
| 233 |
+
" return all_examples\n",
|
| 234 |
+
"\n",
|
| 235 |
+
"\n",
|
| 236 |
+
"# ----------------------------\n",
|
| 237 |
+
"# Benchmark classes\n",
|
| 238 |
+
"# ----------------------------\n",
|
| 239 |
+
"\n",
|
| 240 |
+
"class SciCode(CodingBenchmark):\n",
|
| 241 |
+
" \"\"\"\n",
|
| 242 |
+
" Benchmark class for evaluating code generation on SciCode.\n",
|
| 243 |
+
"\n",
|
| 244 |
+
" SciCode problems provide:\n",
|
| 245 |
+
" - function_header (prompt stub)\n",
|
| 246 |
+
" - ground_truth_code (reference implementation)\n",
|
| 247 |
+
" - test_cases (list[str] of python asserts)\n",
|
| 248 |
+
"\n",
|
| 249 |
+
" We normalize each item and evaluate by executing the candidate implementation\n",
|
| 250 |
+
" against the provided test cases. Since many SciCode tests reference a variable\n",
|
| 251 |
+
" named `target`, we heuristically pre-compute `target` from the reference\n",
|
| 252 |
+
" implementation when necessary, or set it to True for boolean-allclose tests.\n",
|
| 253 |
+
" \"\"\"\n",
|
| 254 |
+
"\n",
|
| 255 |
+
" def __init__(self, path: str = None, mode: str = \"all\", timeout: int = 60, k: Union[int, list] = 1, **kwargs):\n",
|
| 256 |
+
" path = os.path.expanduser(path or \"~/.evoagentx/data/scicode\")\n",
|
| 257 |
+
" self.k = k\n",
|
| 258 |
+
" super().__init__(name=type(self).__name__, path=path, mode=mode, timeout=timeout, **kwargs)\n",
|
| 259 |
+
"\n",
|
| 260 |
+
" # ---------- Data loading ----------\n",
|
| 261 |
+
"\n",
|
| 262 |
+
" def _load_data(self):\n",
|
| 263 |
+
" data_path = os.path.join(self.path, \"scicode.jsonl\")\n",
|
| 264 |
+
" if not os.path.exists(data_path):\n",
|
| 265 |
+
" data_path = download_raw_scicode_data(self.path)\n",
|
| 266 |
+
"\n",
|
| 267 |
+
" # For SciCode, we place everything into \"test\" split by default.\n",
|
| 268 |
+
"\n",
|
| 269 |
+
" if self.mode in (\"dev\", \"all\"):\n",
|
| 270 |
+
" self._dev_data = load_scicode_data(\"/home/tl688/pitl688/selfevolve/SciCode/eval/data/subproblems_dev.jsonl\")\n",
|
| 271 |
+
" if self.mode in (\"test\", \"all\"):\n",
|
| 272 |
+
" self._test_data = load_scicode_data(\"/home/tl688/pitl688/selfevolve/SciCode/eval/data/subproblems_test.jsonl\")\n",
|
| 273 |
+
"\n",
|
| 274 |
+
" def _get_label(self, example: Any):\n",
|
| 275 |
+
" \"\"\"\n",
|
| 276 |
+
" For SciCode we treat the label as the full test suite plus metadata.\n",
|
| 277 |
+
" \"\"\"\n",
|
| 278 |
+
" return {\n",
|
| 279 |
+
" \"task_id\": example[\"task_id\"],\n",
|
| 280 |
+
" \"entry_point\": example[\"entry_point\"],\n",
|
| 281 |
+
" \"tests\": example.get(\"tests\", []),\n",
|
| 282 |
+
" \"canonical_solution\": example.get(\"canonical_solution\", \"\"),\n",
|
| 283 |
+
" \"imports\": example.get(\"imports\", \"\"),\n",
|
| 284 |
+
" }\n",
|
| 285 |
+
"\n",
|
| 286 |
+
" def _get_id(self, example: Any):\n",
|
| 287 |
+
" return example[\"task_id\"]\n",
|
| 288 |
+
"\n",
|
| 289 |
+
" # ---------- Evaluation ----------\n",
|
| 290 |
+
"\n",
|
| 291 |
+
" @staticmethod\n",
|
| 292 |
+
" def _build_reference_namespace(imports: str, canonical_solution: str) -> Dict[str, Any]:\n",
|
| 293 |
+
" \"\"\"\n",
|
| 294 |
+
" Build an execution namespace that defines the reference function.\n",
|
| 295 |
+
" \"\"\"\n",
|
| 296 |
+
" ns: Dict[str, Any] = {\"np\": np, \"scicode\":scicode}\n",
|
| 297 |
+
" if imports:\n",
|
| 298 |
+
" exec(imports, ns, ns) # e.g., \"import numpy as np\\nfrom scipy.special import erfc\"\n",
|
| 299 |
+
" if canonical_solution:\n",
|
| 300 |
+
" exec(canonical_solution, ns, ns)\n",
|
| 301 |
+
" return ns\n",
|
| 302 |
+
"\n",
|
| 303 |
+
" @staticmethod\n",
|
| 304 |
+
" def _extract_candidate_exprs_from_test(test_src: str) -> List[str]:\n",
|
| 305 |
+
" \"\"\"\n",
|
| 306 |
+
" Heuristically extract expressions that are compared against `target` inside np.allclose(..., target)\n",
|
| 307 |
+
" or equality checks like \"== target\" / \", target)\" etc. Returns a list of python expressions (as strings)\n",
|
| 308 |
+
" that we should evaluate with the *reference* implementation to generate `target`.\n",
|
| 309 |
+
"\n",
|
| 310 |
+
" This is a pragmatic parser covering the most common SciCode patterns.\n",
|
| 311 |
+
" \"\"\"\n",
|
| 312 |
+
" exprs: List[str] = []\n",
|
| 313 |
+
"\n",
|
| 314 |
+
" # Pattern A: np.allclose( <expr>, target )\n",
|
| 315 |
+
" for m in re.finditer(r\"np\\.allclose\\s*\\(\\s*(?P<expr>.+?)\\s*,\\s*target\\s*\\)\", test_src, flags=re.DOTALL):\n",
|
| 316 |
+
" exprs.append(m.group(\"expr\"))\n",
|
| 317 |
+
"\n",
|
| 318 |
+
" # Pattern B: assert <expr> == target\n",
|
| 319 |
+
" for m in re.finditer(r\"assert\\s+(?P<expr>.+?)\\s*==\\s*target\", test_src):\n",
|
| 320 |
+
" exprs.append(m.group(\"expr\"))\n",
|
| 321 |
+
"\n",
|
| 322 |
+
" # Pattern C: assert <expr>, target (when the first arg should be True)\n",
|
| 323 |
+
" # In this case, target is expected to be True; no need to compute it.\n",
|
| 324 |
+
" # We'll handle by leaving exprs empty and later default target=True.\n",
|
| 325 |
+
"\n",
|
| 326 |
+
" # Pattern D: Using slices like target[0], target[1] — we try to recover by\n",
|
| 327 |
+
" # extracting both left-hand expressions in the same line in order:\n",
|
| 328 |
+
" # np.allclose(func(...)[0], target[0]) and np.allclose(func(...)[1], target[1])\n",
|
| 329 |
+
" # Already captured by Pattern A; expr may include \"[0]\" or \"[1]\".\n",
|
| 330 |
+
" return exprs\n",
|
| 331 |
+
"\n",
|
| 332 |
+
" @staticmethod\n",
|
| 333 |
+
" def _compute_target_list(exprs: List[str], ref_ns: Dict[str, Any]) -> Any:\n",
|
| 334 |
+
" \"\"\"\n",
|
| 335 |
+
" Given a list of expressions (strings), evaluate them in the reference namespace.\n",
|
| 336 |
+
" If multiple expressions are found, we pack them into a tuple in the same order.\n",
|
| 337 |
+
" If no expression found, return True (to support tests of the form `assert <bool>, target`).\n",
|
| 338 |
+
" \"\"\"\n",
|
| 339 |
+
" if not exprs:\n",
|
| 340 |
+
" return True\n",
|
| 341 |
+
" values = []\n",
|
| 342 |
+
" for ex in exprs:\n",
|
| 343 |
+
" # Safety: limit builtins\n",
|
| 344 |
+
" local_ns: Dict[str, Any] = {}\n",
|
| 345 |
+
" val = eval(ex, ref_ns, local_ns)\n",
|
| 346 |
+
" values.append(val)\n",
|
| 347 |
+
" if len(values) == 1:\n",
|
| 348 |
+
" return values[0]\n",
|
| 349 |
+
" return tuple(values)\n",
|
| 350 |
+
"\n",
|
| 351 |
+
" def _make_harness(self, task_id: str, entry_point: str, imports: str, canonical_solution: str, tests: List[str], candidate_src: str) -> str:\n",
|
| 352 |
+
" \"\"\"\n",
|
| 353 |
+
" Construct an executable harness that:\n",
|
| 354 |
+
" 1) Defines imports\n",
|
| 355 |
+
" 2) Defines candidate implementation (prompt + candidate completion)\n",
|
| 356 |
+
" 3) Pre-computes `target` using the reference implementation for each test (heuristics)\n",
|
| 357 |
+
" 4) Executes the original test snippet with `target` bound.\n",
|
| 358 |
+
" We run each test independently within the same process, stopping on first failure.\n",
|
| 359 |
+
" \"\"\"\n",
|
| 360 |
+
" # We'll build a block that iterates tests in Python.\n",
|
| 361 |
+
" # We cannot dynamically pass `target` into a raw `assert` snippet without executing it;\n",
|
| 362 |
+
" # so for each test, we will:\n",
|
| 363 |
+
" # a) compute target in a separate namespace using reference function,\n",
|
| 364 |
+
" # b) then execute the original test with the candidate function and that target.\n",
|
| 365 |
+
" # This is orchestrated by the benchmark runtime (not inside the user env).\n",
|
| 366 |
+
"\n",
|
| 367 |
+
" # NOTE: actual orchestration happens in `evaluate()` by repeated calls to `check_solution`;\n",
|
| 368 |
+
" # here we only prepare the body (candidate code). The unit tests are executed by the\n",
|
| 369 |
+
" # framework’s sand-boxed executor using `test` passed in.\n",
|
| 370 |
+
"\n",
|
| 371 |
+
" # We keep the candidate_src as-is. The imports are prepended at runtime via the test body.\n",
|
| 372 |
+
" return candidate_src\n",
|
| 373 |
+
"\n",
|
| 374 |
+
" def handle_special_cases(self, task_id: str, solution: str, test: str) -> Tuple[str, str]:\n",
|
| 375 |
+
" \"\"\"\n",
|
| 376 |
+
" Hook: adjust solution/test for edge cases in SciCode, if needed.\n",
|
| 377 |
+
" Currently, we leave as-is and fallback to the base handler.\n",
|
| 378 |
+
" \"\"\"\n",
|
| 379 |
+
" return super().handle_special_cases(task_id=task_id, solution=solution, test=test)\n",
|
| 380 |
+
"\n",
|
| 381 |
+
" def evaluate(self, prediction: Any, label: Any) -> dict:\n",
|
| 382 |
+
" \"\"\"\n",
|
| 383 |
+
" Evaluate candidate solution(s) against SciCode test cases.\n",
|
| 384 |
+
"\n",
|
| 385 |
+
" Strategy:\n",
|
| 386 |
+
" - For each candidate solution:\n",
|
| 387 |
+
" - For each test snippet:\n",
|
| 388 |
+
" 1) Build reference namespace; compute `target` (heuristics).\n",
|
| 389 |
+
" 2) Build candidate code by concatenating example['prompt'] + candidate solution.\n",
|
| 390 |
+
" 3) Execute the test with `target` and candidate in the sandbox via `check_solution`.\n",
|
| 391 |
+
"\n",
|
| 392 |
+
" - Aggregate per-test pass/fail into a single boolean for the example.\n",
|
| 393 |
+
" - Compute pass@k across candidates.\n",
|
| 394 |
+
" \"\"\"\n",
|
| 395 |
+
" prediction, label = self._check_evaluation_inputs(prediction, label)\n",
|
| 396 |
+
"\n",
|
| 397 |
+
" results = []\n",
|
| 398 |
+
" for solution in prediction:\n",
|
| 399 |
+
" # Each `label` item corresponds to the SAME example in our usage (benchmark runs per example),\n",
|
| 400 |
+
" # but we preserve the structure consistent with the base class.\n",
|
| 401 |
+
" solution_states = []\n",
|
| 402 |
+
" for label_data in label:\n",
|
| 403 |
+
" task_id = label_data[\"task_id\"]\n",
|
| 404 |
+
" entry_point = label_data[\"entry_point\"]\n",
|
| 405 |
+
" tests = label_data.get(\"tests\", [])\n",
|
| 406 |
+
" imports = label_data.get(\"imports\", \"\")\n",
|
| 407 |
+
" canonical_solution = label_data.get(\"canonical_solution\", \"\")\n",
|
| 408 |
+
"\n",
|
| 409 |
+
" # Build reference env for computing `target`\n",
|
| 410 |
+
" ref_ns = self._build_reference_namespace(imports=imports, canonical_solution=canonical_solution)\n",
|
| 411 |
+
"\n",
|
| 412 |
+
" # Build candidate code (prompt + solution)\n",
|
| 413 |
+
" prompt = self.get_example_by_id(task_id)[\"prompt\"]\n",
|
| 414 |
+
" candidate_code = prompt + \"\\n\" + solution\n",
|
| 415 |
+
"\n",
|
| 416 |
+
" # Run each test individually; any failure => whole example fails\n",
|
| 417 |
+
" all_ok = True\n",
|
| 418 |
+
" for raw_test in tests if tests else [\"# no tests provided\\nassert True, True\"]:\n",
|
| 419 |
+
" # Heuristically precompute `target`\n",
|
| 420 |
+
" exprs = self._extract_candidate_exprs_from_test(raw_test)\n",
|
| 421 |
+
" try:\n",
|
| 422 |
+
" target_value = self._compute_target_list(exprs, ref_ns)\n",
|
| 423 |
+
" except Exception as e:\n",
|
| 424 |
+
" # If we cannot compute target from the reference, fall back to True\n",
|
| 425 |
+
" logger.warning(f\"[SciCode] Fallback target=True for {task_id} due to: {e}\")\n",
|
| 426 |
+
" target_value = True\n",
|
| 427 |
+
"\n",
|
| 428 |
+
" # Compose a runnable unit-test block:\n",
|
| 429 |
+
" # We inject `imports`, bind `target`, then execute the original test code.\n",
|
| 430 |
+
" unit_test = (\n",
|
| 431 |
+
" (imports or \"\")\n",
|
| 432 |
+
" + \"\\n\"\n",
|
| 433 |
+
" + \"target = __TARGET_VALUE__\\n\"\n",
|
| 434 |
+
" + raw_test\n",
|
| 435 |
+
" )\n",
|
| 436 |
+
"\n",
|
| 437 |
+
" # Because `check_solution` runs code in separate exec, we stringify the target safely.\n",
|
| 438 |
+
" # We'll register a placeholder and pass the real object via the executor's globals.\n",
|
| 439 |
+
" # Our base framework doesn't support direct object injection; so we serialize small types.\n",
|
| 440 |
+
" # For numpy arrays/tuples we rely on repr + eval. If that fails, we degrade to boolean.\n",
|
| 441 |
+
" try:\n",
|
| 442 |
+
" # Light-weight serializer for numpy arrays / tuples / lists / scalars\n",
|
| 443 |
+
" def _pyrepr(obj):\n",
|
| 444 |
+
" if isinstance(obj, np.ndarray):\n",
|
| 445 |
+
" return f\"np.array({repr(obj.tolist())})\"\n",
|
| 446 |
+
" return repr(obj)\n",
|
| 447 |
+
"\n",
|
| 448 |
+
" unit_test = unit_test.replace(\n",
|
| 449 |
+
" \"__TARGET_VALUE__\", _pyrepr(target_value)\n",
|
| 450 |
+
" )\n",
|
| 451 |
+
" except Exception:\n",
|
| 452 |
+
" unit_test = unit_test.replace(\"__TARGET_VALUE__\", \"True\")\n",
|
| 453 |
+
"\n",
|
| 454 |
+
" # Optional special-case patching hook\n",
|
| 455 |
+
" candidate_code_patched, unit_test_patched = self.handle_special_cases(\n",
|
| 456 |
+
" task_id=task_id, solution=candidate_code, test=unit_test\n",
|
| 457 |
+
" )\n",
|
| 458 |
+
"\n",
|
| 459 |
+
" # Execute\n",
|
| 460 |
+
" state, message = self.check_solution(\n",
|
| 461 |
+
" task_id=task_id,\n",
|
| 462 |
+
" solution=candidate_code_patched,\n",
|
| 463 |
+
" test=unit_test_patched,\n",
|
| 464 |
+
" entry_point=entry_point,\n",
|
| 465 |
+
" )\n",
|
| 466 |
+
" if state != self.SUCCESS:\n",
|
| 467 |
+
" all_ok = False\n",
|
| 468 |
+
" break\n",
|
| 469 |
+
"\n",
|
| 470 |
+
" solution_states.append(self.SUCCESS if all_ok else self.FAILURE)\n",
|
| 471 |
+
" results.append(len(solution_states) == len(label) and all(s == self.SUCCESS for s in solution_states))\n",
|
| 472 |
+
"\n",
|
| 473 |
+
" k_list = [self.k] if isinstance(self.k, int) else self.k\n",
|
| 474 |
+
" pass_at_k = self.compute_pass_at_k(results, k_list)\n",
|
| 475 |
+
" return pass_at_k\n",
|
| 476 |
+
"\n",
|
| 477 |
+
"\n",
|
| 478 |
+
"class AFlowSciCode(SciCode):\n",
|
| 479 |
+
" \"\"\"\n",
|
| 480 |
+
" AFlow-specific implementation of SciCode benchmark.\n",
|
| 481 |
+
" Uses AFLOW_DATASET_FILES_MAP['scicode'] for split files (if provided by your distribution).\n",
|
| 482 |
+
" \"\"\"\n",
|
| 483 |
+
"\n",
|
| 484 |
+
" def __init__(self, path: str = None, mode: str = \"all\", timeout: int = 60, k: Union[int, list] = 1, **kwargs):\n",
|
| 485 |
+
" path = os.path.expanduser(path or \"~/.evoagentx/data/aflow/scicode\")\n",
|
| 486 |
+
" super().__init__(path=path, mode=mode, timeout=timeout, k=k, **kwargs)\n",
|
| 487 |
+
"\n",
|
| 488 |
+
" def _load_data_from_file(self, file_name: str):\n",
|
| 489 |
+
" if file_name is None:\n",
|
| 490 |
+
" return None\n",
|
| 491 |
+
" file_path = os.path.join(self.path, file_name)\n",
|
| 492 |
+
" if not os.path.exists(file_path):\n",
|
| 493 |
+
" logger.info(\"Downloading AFlow SciCode split files ...\")\n",
|
| 494 |
+
" download_aflow_benchmark_data(dataset=\"scicode\", save_folder=self.path)\n",
|
| 495 |
+
" return load_json(path=file_path, type=\"jsonl\")\n",
|
| 496 |
+
"\n",
|
| 497 |
+
" def _load_data(self):\n",
|
| 498 |
+
" # Prefer AFLOW split files when available; otherwise fall back to raw download.\n",
|
| 499 |
+
" if \"scicode\" not in AFLOW_DATASET_FILES_MAP:\n",
|
| 500 |
+
" logger.warning(\"AFLOW_DATASET_FILES_MAP has no entry for 'scicode'; falling back to raw SciCode jsonl.\")\n",
|
| 501 |
+
" return super()._load_data()\n",
|
| 502 |
+
"\n",
|
| 503 |
+
" splits = AFLOW_DATASET_FILES_MAP[\"scicode\"]\n",
|
| 504 |
+
" data_all: Dict[str, List[Dict[str, Any]]] = {}\n",
|
| 505 |
+
"\n",
|
| 506 |
+
" for split in (\"train\", \"dev\", \"test\"):\n",
|
| 507 |
+
" fname = splits.get(split)\n",
|
| 508 |
+
" if fname:\n",
|
| 509 |
+
" logger.info(f\"Loading {split} data from {fname}\")\n",
|
| 510 |
+
" raw_split = self._load_data_from_file(file_name=fname)\n",
|
| 511 |
+
" # Normalize rows to examples\n",
|
| 512 |
+
" examples: List[Dict[str, Any]] = []\n",
|
| 513 |
+
" for row in raw_split or []:\n",
|
| 514 |
+
" try:\n",
|
| 515 |
+
" examples.extend(_coerce_scicode_row_to_examples(row))\n",
|
| 516 |
+
" except Exception as e:\n",
|
| 517 |
+
" logger.warning(f\"[AFlowSciCode] Skipping a malformed row in {split} due to: {e}\")\n",
|
| 518 |
+
" data_all[split] = examples\n",
|
| 519 |
+
" else:\n",
|
| 520 |
+
" data_all[split] = None\n",
|
| 521 |
+
"\n",
|
| 522 |
+
" if self.mode in (\"train\", \"all\"):\n",
|
| 523 |
+
" self._train_data = data_all.get(\"train\")\n",
|
| 524 |
+
" if self.mode in (\"dev\", \"all\"):\n",
|
| 525 |
+
" self._dev_data = data_all.get(\"dev\")\n",
|
| 526 |
+
" if self.mode in (\"test\", \"all\"):\n",
|
| 527 |
+
" self._test_data = data_all.get(\"test\")\n",
|
| 528 |
+
"\n",
|
| 529 |
+
" async def async_evaluate(self, graph: Callable, example: Any) -> float:\n",
|
| 530 |
+
" \"\"\"\n",
|
| 531 |
+
" Generate a solution asynchronously and return pass@1 for the example.\n",
|
| 532 |
+
" \"\"\"\n",
|
| 533 |
+
" prompt, entry_point = example[\"prompt\"], example[\"entry_point\"]\n",
|
| 534 |
+
" solution = await graph(prompt, entry_point)\n",
|
| 535 |
+
" label = self._get_label(example)\n",
|
| 536 |
+
" metrics = await super().async_evaluate(prediction=solution, label=label)\n",
|
| 537 |
+
" return metrics.get(\"pass@1\", 0.0)\n"
|
| 538 |
+
]
|
| 539 |
+
},
|
| 540 |
+
{
|
| 541 |
+
"cell_type": "code",
|
| 542 |
+
"execution_count": null,
|
| 543 |
+
"id": "a2bca001",
|
| 544 |
+
"metadata": {},
|
| 545 |
+
"outputs": [],
|
| 546 |
+
"source": []
|
| 547 |
+
}
|
| 548 |
+
],
|
| 549 |
+
"metadata": {
|
| 550 |
+
"kernelspec": {
|
| 551 |
+
"display_name": "Python 3 (ipykernel)",
|
| 552 |
+
"language": "python",
|
| 553 |
+
"name": "python3"
|
| 554 |
+
},
|
| 555 |
+
"language_info": {
|
| 556 |
+
"codemirror_mode": {
|
| 557 |
+
"name": "ipython",
|
| 558 |
+
"version": 3
|
| 559 |
+
},
|
| 560 |
+
"file_extension": ".py",
|
| 561 |
+
"mimetype": "text/x-python",
|
| 562 |
+
"name": "python",
|
| 563 |
+
"nbconvert_exporter": "python",
|
| 564 |
+
"pygments_lexer": "ipython3",
|
| 565 |
+
"version": "3.11.0"
|
| 566 |
+
}
|
| 567 |
+
},
|
| 568 |
+
"nbformat": 4,
|
| 569 |
+
"nbformat_minor": 5
|
| 570 |
+
}
|
evoagentx/benchmark/README.md
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Benchmark
|
| 2 |
+
|
| 3 |
+
## Benchmark Overview
|
| 4 |
+
|
| 5 |
+
This repository provides a set of benchmarks to facilitate the evaluation of different agent-based systems. Below is a summary of the benchmarks currently included, along with basic dataset statistics:
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
| Task | Dataset Name | # Train | # Dev | # Test |
|
| 9 |
+
| ------------------------- | --------------- | --------- | ------- | ------ |
|
| 10 |
+
| QA | NQ | 79,168 | 8,757 | 3,610 |
|
| 11 |
+
| Multi-Hop QA | HotPotQA | 90,447 | 7,405 | / |
|
| 12 |
+
| Math | GSM8K | 7,473 | / | 1,319 |
|
| 13 |
+
| Math | MATH | 7,500 | / | 5,000 |
|
| 14 |
+
| Code Generation | HumanEval | / | / | 164 |
|
| 15 |
+
| Code Generation | MBPP | / | / | 427 |
|
| 16 |
+
| Code Generation | LiveCodeBench(v1~v5) | / | / | 400~880 |
|
| 17 |
+
| Code Execution | LiveCodeBench | / | / | 479 |
|
| 18 |
+
| Test Output Prediction | LiveCodeBench | / | / | 442 |
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
Below, we introduce the preprocessing steps and evaluation metrics for each benchmark.
|
| 22 |
+
|
| 23 |
+
- [Question Answering](#question-answering)
|
| 24 |
+
- [NQ](#nq)
|
| 25 |
+
- [HotPotQA](#hotpotqa)
|
| 26 |
+
- [Math](#math)
|
| 27 |
+
- [GSM8K](#gsm8k)
|
| 28 |
+
- [MATH](#math)
|
| 29 |
+
- [Code Generation](#code-generation)
|
| 30 |
+
- [HumanEval](#humaneval)
|
| 31 |
+
- [MBPP](#mbpp)
|
| 32 |
+
- [LiveCodeBench](#livecodebench)
|
| 33 |
+
|
| 34 |
+
## Preprocessing and Evaluation Metrics
|
| 35 |
+
|
| 36 |
+
### Question Answering
|
| 37 |
+
|
| 38 |
+
For the QA datasets, we use Exact Match (EM), F1, and Accuracy (ACC) as evaluation metrics by default. EM requires the predicted answer to be exactly the same as the ground truth answer. ACC requires that the predicted answer contains the ground-truth answer, which is useful when the LLM is used to generate the answer.
|
| 39 |
+
|
| 40 |
+
#### NQ
|
| 41 |
+
[Natural Questions (NQ)](https://github.com/google-research-datasets/natural-questions) contains questions from the Google search engine and the answers, annotated by human annotators, are paragraphs or entities in the Wikipedia page of the top 5 search results. We use the dataset splits provided by the [DPR](https://github.com/facebookresearch/DPR) repository, which contains 79,168 training, 8,757 development, and 3,610 test examples.
|
| 42 |
+
|
| 43 |
+
You can load the dataset using the following code:
|
| 44 |
+
```python
|
| 45 |
+
from evoagentx.benchmark import NQ
|
| 46 |
+
nq_dataset = NQ() # optional: path="/path/to/save_data"
|
| 47 |
+
test_data = nq_dataset.get_test_data()
|
| 48 |
+
```
|
| 49 |
+
Each example in the dataset is in the following format:
|
| 50 |
+
```json
|
| 51 |
+
{
|
| 52 |
+
"id": "test-1",
|
| 53 |
+
"question": "the question",
|
| 54 |
+
"answers": ["possible answers"]
|
| 55 |
+
}
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
#### HotPotQA
|
| 60 |
+
[HotPotQA](https://hotpotqa.github.io/) is a multi-hop QA dataset that requires multi-step reasoning to answer the question. We use the distractor setting of the dataset. Each example contains a question, an answer, some context that contians both supporting and distractor information, and supporting facts. We only include the training and development sets, as the test set is not publicly available.
|
| 61 |
+
|
| 62 |
+
You can load the dataset using the following code:
|
| 63 |
+
```python
|
| 64 |
+
from evoagentx.benchmark import HotPotQA
|
| 65 |
+
hotpotqa_dataset = HotPotQA() # optional: path="/path/to/save_data"
|
| 66 |
+
test_data = hotpotqa_dataset.get_test_data()
|
| 67 |
+
```
|
| 68 |
+
Each example in the dataset is in the following format, where the second element (int) of a supporting_fact is the index of the sentence in the context that supports the answer.
|
| 69 |
+
```json
|
| 70 |
+
{
|
| 71 |
+
"_id": "the id of the example",
|
| 72 |
+
"question": "the question",
|
| 73 |
+
"answer": "the answer",
|
| 74 |
+
"context": [["context_title", ["context_sentence", "another_sentence"]]],
|
| 75 |
+
"supporting_facts": [["supporting_title", 0]]
|
| 76 |
+
}
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
### Math
|
| 81 |
+
|
| 82 |
+
For match datasets, we use the solve rate as the evaluation metric. The solve rate is the ratio of the number of examples that are solved correctly to the total number of examples.
|
| 83 |
+
|
| 84 |
+
#### GSM8K
|
| 85 |
+
[GSM8K](https://github.com/openai/grade-school-math) consists of high quality grade school math problems created by human problem writers. These problems require multi-step mathematical reasoning to solve. We use the dataset splits provided by the original repository, which contains 7.5K training problems and 1K test problems.
|
| 86 |
+
|
| 87 |
+
You can load the dataset using the following code:
|
| 88 |
+
```python
|
| 89 |
+
from evoagentx.benchmark import GSM8K
|
| 90 |
+
gsm8k_dataset = GSM8K() # optional: path="/path/to/save_data"
|
| 91 |
+
test_data = gsm8k_dataset.get_test_data()
|
| 92 |
+
```
|
| 93 |
+
Each example in the dataset is in the following format:
|
| 94 |
+
```json
|
| 95 |
+
{
|
| 96 |
+
"id": "test-1",
|
| 97 |
+
"question": "the question",
|
| 98 |
+
"answer": "the answer"
|
| 99 |
+
}
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
#### MATH
|
| 103 |
+
The [Mathematics Aptitude Test of Heuristics (MATH)](https://github.com/hendrycks/math) dataset consists of problems from mathematics competitions, including the AMC 10, AMC 12, AIME, etc. Each problem in MATH has a step-by-step solution. We use the dataset splits provided by the original repository, which contains 7.5K training problems and 5K test problems.
|
| 104 |
+
|
| 105 |
+
You can load the dataset using the following code:
|
| 106 |
+
```python
|
| 107 |
+
from evoagentx.benchmark import MATH
|
| 108 |
+
math_dataset = MATH() # optional: path="/path/to/save_data"
|
| 109 |
+
test_data = math_dataset.get_test_data()
|
| 110 |
+
```
|
| 111 |
+
Each example in the dataset is in the following format. For the `level` field, valid values are: "Level 1", "Level 2", "Level 3", "Level 4", "Level 5", and "Level ?". The `type` field can be one of: "Geometry", "Algebra", "Intermediate Algebra", "Counting & Probability", "Precalculus", "Number Theory", or "Prealgebra".
|
| 112 |
+
|
| 113 |
+
```json
|
| 114 |
+
{
|
| 115 |
+
"id": "test-1",
|
| 116 |
+
"problem": "the problem",
|
| 117 |
+
"solution": "the solution",
|
| 118 |
+
"level": "Level 1",
|
| 119 |
+
"type": "Algebra"
|
| 120 |
+
}
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
### Code Generation
|
| 124 |
+
For the code generation benchmarks, we use pass@k as the evaluation metric, where k is the number of solutions for each problem. By default, k is set to 1.
|
| 125 |
+
|
| 126 |
+
#### HumanEval
|
| 127 |
+
[HumanEval](https://github.com/openai/human-eval) is a dataset of 164 coding problems from the HumanEval benchmark. Each problem contains a function signature, a canonical solution, and a set of unit tests.
|
| 128 |
+
|
| 129 |
+
You can load the dataset using the following code:
|
| 130 |
+
```python
|
| 131 |
+
from evoagentx.benchmark import HumanEval
|
| 132 |
+
humaneval_dataset = HumanEval() # optional: path="/path/to/save_data"
|
| 133 |
+
test_data = humaneval_dataset.get_test_data()
|
| 134 |
+
```
|
| 135 |
+
Each example in the dataset is in the following format:
|
| 136 |
+
```json
|
| 137 |
+
{
|
| 138 |
+
"task_id": "HumanEval/0",
|
| 139 |
+
"prompt": "the prompt of the problem",
|
| 140 |
+
"entry_point": "the name of the function to be tested",
|
| 141 |
+
"canonical_solution": "the canonical solution of the problem",
|
| 142 |
+
"test": "the unit tests of the problem"
|
| 143 |
+
}
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
#### MBPP
|
| 147 |
+
[Mostly Basic Python Problems (MBPP)](https://github.com/google-research/google-research/tree/master/mbpp) consists of hundreds of entry-level Python programming problems. Each problem consists of a task description, code solution and 3 automated test cases. We use the [sanitized subset](https://github.com/google-research/google-research/blob/master/mbpp/sanitized-mbpp.json) of the MBPP dataset, which consists of 427 problems with data that are hand-verfied by the authors. To facilitate the evaluation, we convert the MBPP dataset into the HumanEval format.
|
| 148 |
+
|
| 149 |
+
You can load the dataset using the following code:
|
| 150 |
+
```python
|
| 151 |
+
from evoagentx.benchmark import MBPP
|
| 152 |
+
mbpp_dataset = MBPP() # optional: path="/path/to/save_data"
|
| 153 |
+
test_data = mbpp_dataset.get_test_data()
|
| 154 |
+
```
|
| 155 |
+
Each example in the dataset is in the following format, where we keep the original MBPP `task_id`.
|
| 156 |
+
```json
|
| 157 |
+
{
|
| 158 |
+
"task_id": 2,
|
| 159 |
+
"prompt": "the prompt of the problem",
|
| 160 |
+
"entry_point": "the name of the function to be tested",
|
| 161 |
+
"canonical_solution": "the canonical solution of the problem",
|
| 162 |
+
"test": "the unit tests of the problem"
|
| 163 |
+
}
|
| 164 |
+
```
|
| 165 |
+
You can also access the original MBPP attributes such as "code", "test_list" in the example by using `example["code"]`.
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
#### LiveCodeBench
|
| 169 |
+
[LiveCodeBench](https://livecodebench.github.io/) is a contamination-free evaluation benchmark of LLMs for code that continuously collects new problems over time. Particularly, LiveCodeBench also focuses on broader code-related capabilities, such as code execution, and test output prediction, beyond mere code generation. Currently, LiveCodeBench hosts over three hundred high-quality coding problems published between May 2023 and February 2024.
|
| 170 |
+
|
| 171 |
+
You can load the dataset using the following code, where `scenario` can be one of [`code_generation`, `test_output_prediction`, `code_execution`] indicating different tasks. `version` denotes different versions of the code generation datasets, which is only available for `code_generation` scenario, and can be one of `["release_v1", "release_v2", "release_v3", "release_v4", "release_v5", "release_latest"]`. Please refer to the [LiveCodeBench](https://livecodebench.github.io/) repository for more details.
|
| 172 |
+
|
| 173 |
+
```python
|
| 174 |
+
from evoagentx.benchmark import LiveCodeBench
|
| 175 |
+
livecodebench_dataset = LiveCodeBench(scenario="code_generation", version="release_v1") # optional: path="/path/to/save_data"
|
| 176 |
+
test_data = livecodebench_dataset.get_test_data()
|
| 177 |
+
```
|
| 178 |
+
|
evoagentx/benchmark/Untitled.ipynb
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [],
|
| 3 |
+
"metadata": {},
|
| 4 |
+
"nbformat": 4,
|
| 5 |
+
"nbformat_minor": 5
|
| 6 |
+
}
|
evoagentx/benchmark/WorfBench.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import random
|
| 4 |
+
from typing import Any, Dict, Callable, List
|
| 5 |
+
from .benchmark import Benchmark
|
| 6 |
+
from .measures import exact_match_score, f1_score, acc_score
|
| 7 |
+
from ..core.logging import logger
|
| 8 |
+
from ..core.module_utils import load_json
|
| 9 |
+
from datasets import load_dataset
|
| 10 |
+
|
| 11 |
+
# WorfBench dataset file mapping
|
| 12 |
+
WORFBENCH_FILES_MAP = {
|
| 13 |
+
"train": "worfbench_train.json",
|
| 14 |
+
"test": "worfbench_test.json"
|
| 15 |
+
}
|
| 16 |
+
VALID_WORFBENCH_FILES = list(WORFBENCH_FILES_MAP.values())
|
| 17 |
+
|
| 18 |
+
def evaluate_workflow_sequence(prediction: List[Any], ground_truth: List[Any]) -> float:
|
| 19 |
+
"""Evaluate F1 score for sequence workflow."""
|
| 20 |
+
from .measures import f1_score
|
| 21 |
+
return f1_score(prediction=prediction, ground_truth=ground_truth)
|
| 22 |
+
|
| 23 |
+
def evaluate_workflow_graph(prediction: Dict[str, Any], ground_truth: Dict[str, Any]) -> float:
|
| 24 |
+
"""Evaluate F1 score for graph workflow."""
|
| 25 |
+
pred_nodes = set(prediction.get("nodes", []))
|
| 26 |
+
true_nodes = set(ground_truth.get("nodes", []))
|
| 27 |
+
pred_edges = set(tuple(edge) for edge in prediction.get("edges", []))
|
| 28 |
+
true_edges = set(tuple(edge) for edge in ground_truth.get("edges", []))
|
| 29 |
+
|
| 30 |
+
node_precision = len(pred_nodes & true_nodes) / len(pred_nodes) if pred_nodes else 0
|
| 31 |
+
node_recall = len(pred_nodes & true_nodes) / len(true_nodes) if true_nodes else 0
|
| 32 |
+
edge_precision = len(pred_edges & true_edges) / len(pred_edges) if pred_edges else 0
|
| 33 |
+
edge_recall = len(pred_edges & true_edges) / len(true_edges) if true_edges else 0
|
| 34 |
+
|
| 35 |
+
node_f1 = 2 * (node_precision * node_recall) / (node_precision + node_recall) if (node_precision + node_recall) > 0 else 0
|
| 36 |
+
edge_f1 = 2 * (edge_precision * edge_recall) / (edge_precision + edge_recall) if (edge_precision + edge_recall) > 0 else 0
|
| 37 |
+
|
| 38 |
+
return (node_f1 + edge_f1) / 2
|
| 39 |
+
|
| 40 |
+
def download_worfbench_data(dataset: str, save_folder: str) -> None:
|
| 41 |
+
"""
|
| 42 |
+
Download WorfBench dataset from Hugging Face.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
dataset (str): Dataset name ("worfbench").
|
| 46 |
+
save_folder (str): Directory to save data.
|
| 47 |
+
"""
|
| 48 |
+
datasets_map = {
|
| 49 |
+
"train": {"repo_id": "zjunlp/WorFBench_train", "filename": "worfbench_train.json", "split": "train"},
|
| 50 |
+
"test": {"repo_id": "zjunlp/WorFBench_test", "filename": "worfbench_test.json", "split": "test"}
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
os.makedirs(save_folder, exist_ok=True)
|
| 54 |
+
for split, info in datasets_map.items():
|
| 55 |
+
repo_id = info["repo_id"]
|
| 56 |
+
filename = info["filename"]
|
| 57 |
+
dataset_split = info["split"]
|
| 58 |
+
save_path = os.path.join(save_folder, filename)
|
| 59 |
+
|
| 60 |
+
if not os.path.exists(save_path):
|
| 61 |
+
logger.info(f"Downloading {split} split of {dataset} from {repo_id}...")
|
| 62 |
+
try:
|
| 63 |
+
# Load dataset
|
| 64 |
+
ds = load_dataset(repo_id, split=dataset_split)
|
| 65 |
+
# Convert dataset to list and save as JSON
|
| 66 |
+
data = [item for item in ds]
|
| 67 |
+
with open(save_path, 'w', encoding='utf-8') as f:
|
| 68 |
+
json.dump(data, f, ensure_ascii=False, indent=2)
|
| 69 |
+
logger.info(f"Successfully downloaded and saved {filename} to {save_path}")
|
| 70 |
+
except Exception as e:
|
| 71 |
+
logger.error(f"Failed to download or save {filename}: {e}")
|
| 72 |
+
raise
|
| 73 |
+
else:
|
| 74 |
+
logger.info(f"File {save_path} already exists, skipping download.")
|
| 75 |
+
|
| 76 |
+
class WorfBench(Benchmark):
|
| 77 |
+
"""
|
| 78 |
+
WorfBench evaluation class for assessing LLM agents on complex workflow generation tasks.
|
| 79 |
+
Assumed data structure:
|
| 80 |
+
{
|
| 81 |
+
"id": str,
|
| 82 |
+
"task": str,
|
| 83 |
+
"context": list of dicts (e.g., [{"title": str, "content": list of str}]),
|
| 84 |
+
"expected_output": str or dict (sequence or graph),
|
| 85 |
+
"type": str,
|
| 86 |
+
"level": str
|
| 87 |
+
}
|
| 88 |
+
"""
|
| 89 |
+
def __init__(self, path: str = None, mode: str = "test", **kwargs):
|
| 90 |
+
path = os.path.expanduser(path or "~/.worfbench/data")
|
| 91 |
+
super().__init__(name=type(self).__name__, path=path, mode=mode, **kwargs)
|
| 92 |
+
|
| 93 |
+
def _load_data_from_file(self, file_name: str) -> Dict:
|
| 94 |
+
if file_name is None:
|
| 95 |
+
return None
|
| 96 |
+
file_path = os.path.join(self.path, file_name)
|
| 97 |
+
if not os.path.exists(file_path):
|
| 98 |
+
download_worfbench_data(dataset="worfbench", save_folder=self.path)
|
| 99 |
+
if not os.path.exists(file_path):
|
| 100 |
+
logger.error(f"File {file_path} still does not exist after download attempt!")
|
| 101 |
+
return None
|
| 102 |
+
logger.info(f"Loading WorfBench data from {file_path} ...")
|
| 103 |
+
data = load_json(path=file_path, type="json")
|
| 104 |
+
if data is None:
|
| 105 |
+
logger.error(f"Failed to load data from {file_path}")
|
| 106 |
+
return None
|
| 107 |
+
return data
|
| 108 |
+
|
| 109 |
+
def _load_data(self) -> None:
|
| 110 |
+
if self.mode in ["train", "dev"]:
|
| 111 |
+
self._train_data = self._load_data_from_file(file_name=WORFBENCH_FILES_MAP["train"])
|
| 112 |
+
if self.mode == "dev":
|
| 113 |
+
if self._train_data:
|
| 114 |
+
random.seed(42)
|
| 115 |
+
keys = list(self._train_data.keys())
|
| 116 |
+
n_dev = len(self._train_data[keys[0]]) // 10 or 1
|
| 117 |
+
indices = list(range(len(self._train_data[keys[0]])))
|
| 118 |
+
random.shuffle(indices)
|
| 119 |
+
self._train_data = {k: [v[i] for i in indices[:n_dev]] for k, v in self._train_data.items()}
|
| 120 |
+
if self.mode == "test":
|
| 121 |
+
self._test_data = self._load_data_from_file(file_name=WORFBENCH_FILES_MAP["test"])
|
| 122 |
+
|
| 123 |
+
def _get_label(self, example: Dict) -> Any:
|
| 124 |
+
return example.get("expected_output", "")
|
| 125 |
+
|
| 126 |
+
def _get_id(self, example: Dict) -> Any:
|
| 127 |
+
return example.get("id", "")
|
| 128 |
+
|
| 129 |
+
def evaluate(self, prediction: Any, label: Any) -> Dict:
|
| 130 |
+
if isinstance(prediction, list) and isinstance(label, list):
|
| 131 |
+
f1 = evaluate_workflow_sequence(prediction, label)
|
| 132 |
+
elif isinstance(prediction, dict) and isinstance(label, dict):
|
| 133 |
+
f1 = evaluate_workflow_graph(prediction, label)
|
| 134 |
+
else:
|
| 135 |
+
f1 = f1_score(prediction=str(prediction), ground_truth=str(label))
|
| 136 |
+
em = exact_match_score(prediction=prediction, ground_truth=label)
|
| 137 |
+
acc = acc_score(prediction=prediction, ground_truths=[label])
|
| 138 |
+
return {"em": em, "f1": f1, "acc": acc}
|
| 139 |
+
|
| 140 |
+
async def async_evaluate(self, graph: Callable, example: Dict) -> float:
|
| 141 |
+
task = example.get("task", "")
|
| 142 |
+
context = "\n".join(
|
| 143 |
+
f"{ctx.get('title', '')}: {' '.join(ctx.get('content', []))}"
|
| 144 |
+
for ctx in example.get("context", [])
|
| 145 |
+
if isinstance(ctx, dict)
|
| 146 |
+
)
|
| 147 |
+
inputs = f"Task: {task}\nContext: {context}\nGenerate workflow:\nAnswer:"
|
| 148 |
+
try:
|
| 149 |
+
generated_workflow = await graph(inputs)
|
| 150 |
+
except Exception as e:
|
| 151 |
+
logger.error(f"Error generating workflow: {e}")
|
| 152 |
+
generated_workflow = ""
|
| 153 |
+
label = self._get_label(example)
|
| 154 |
+
metrics = self.evaluate(prediction=generated_workflow, label=label)
|
| 155 |
+
return metrics["f1"]
|