Spaces:
Running
Running
Commit
·
0913c52
0
Parent(s):
clean HF Space commit (no binary history)
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +35 -0
- .gitignore +215 -0
- .pre-commit-config.yaml +22 -0
- .python-version +1 -0
- .vscode/settings.json +23 -0
- CLAUDE.md +257 -0
- Dockerfile +20 -0
- LICENSE +201 -0
- README.md +49 -0
- bench/mlebench_workflow.py +197 -0
- bench/register_models/gemini.py +180 -0
- bench/register_models/gpt.py +215 -0
- case-studies/case1/task.md +0 -0
- case-studies/case2/task.md +0 -0
- pyproject.toml +94 -0
- reasoning_bank/README.md +5 -0
- reasoning_bank/mem_induction.py +365 -0
- reasoning_bank/mem_manage.py +207 -0
- requirements.txt +3 -0
- scievo/__init__.py +0 -0
- scievo/agents/__init__.py +0 -0
- scievo/agents/critic_agent/__init__.py +2 -0
- scievo/agents/critic_agent/build.py +39 -0
- scievo/agents/critic_agent/execute.py +289 -0
- scievo/agents/critic_agent/state.py +24 -0
- scievo/agents/data_agent/__init__.py +2 -0
- scievo/agents/data_agent/build.py +178 -0
- scievo/agents/data_agent/execute.py +487 -0
- scievo/agents/data_agent/paper_subagent/__init__.py +10 -0
- scievo/agents/data_agent/paper_subagent/build.py +47 -0
- scievo/agents/data_agent/paper_subagent/execute.py +436 -0
- scievo/agents/data_agent/paper_subagent/state.py +27 -0
- scievo/agents/data_agent/plan.py +176 -0
- scievo/agents/data_agent/state.py +33 -0
- scievo/agents/dummy_agent.py +33 -0
- scievo/agents/experiment_agent/__init__.py +15 -0
- scievo/agents/experiment_agent/build.py +67 -0
- scievo/agents/experiment_agent/coding_subagent_v2/__init__.py +11 -0
- scievo/agents/experiment_agent/coding_subagent_v2/build.py +29 -0
- scievo/agents/experiment_agent/coding_subagent_v2/execute.py +161 -0
- scievo/agents/experiment_agent/coding_subagent_v2/state.py +160 -0
- scievo/agents/experiment_agent/coding_subagent_v3_claude/__init__.py +11 -0
- scievo/agents/experiment_agent/coding_subagent_v3_claude/build.py +29 -0
- scievo/agents/experiment_agent/coding_subagent_v3_claude/execute.py +189 -0
- scievo/agents/experiment_agent/coding_subagent_v3_claude/state.py +31 -0
- scievo/agents/experiment_agent/exec_subagent/__init__.py +12 -0
- scievo/agents/experiment_agent/exec_subagent/build.py +96 -0
- scievo/agents/experiment_agent/exec_subagent/execute.py +502 -0
- scievo/agents/experiment_agent/exec_subagent/state.py +57 -0
- scievo/agents/experiment_agent/execute.py +513 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[codz]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py.cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# UV
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
#uv.lock
|
| 102 |
+
|
| 103 |
+
# poetry
|
| 104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 106 |
+
# commonly ignored for libraries.
|
| 107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 108 |
+
#poetry.lock
|
| 109 |
+
#poetry.toml
|
| 110 |
+
|
| 111 |
+
# pdm
|
| 112 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 113 |
+
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
| 114 |
+
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
| 115 |
+
#pdm.lock
|
| 116 |
+
#pdm.toml
|
| 117 |
+
.pdm-python
|
| 118 |
+
.pdm-build/
|
| 119 |
+
|
| 120 |
+
# pixi
|
| 121 |
+
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
| 122 |
+
#pixi.lock
|
| 123 |
+
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
| 124 |
+
# in the .venv directory. It is recommended not to include this directory in version control.
|
| 125 |
+
.pixi
|
| 126 |
+
|
| 127 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 128 |
+
__pypackages__/
|
| 129 |
+
|
| 130 |
+
# Celery stuff
|
| 131 |
+
celerybeat-schedule
|
| 132 |
+
celerybeat.pid
|
| 133 |
+
|
| 134 |
+
# SageMath parsed files
|
| 135 |
+
*.sage.py
|
| 136 |
+
|
| 137 |
+
# Environments
|
| 138 |
+
.env
|
| 139 |
+
.envrc
|
| 140 |
+
.venv
|
| 141 |
+
env/
|
| 142 |
+
venv/
|
| 143 |
+
ENV/
|
| 144 |
+
env.bak/
|
| 145 |
+
venv.bak/
|
| 146 |
+
|
| 147 |
+
# Spyder project settings
|
| 148 |
+
.spyderproject
|
| 149 |
+
.spyproject
|
| 150 |
+
|
| 151 |
+
# Rope project settings
|
| 152 |
+
.ropeproject
|
| 153 |
+
|
| 154 |
+
# mkdocs documentation
|
| 155 |
+
/site
|
| 156 |
+
|
| 157 |
+
# mypy
|
| 158 |
+
.mypy_cache/
|
| 159 |
+
.dmypy.json
|
| 160 |
+
dmypy.json
|
| 161 |
+
|
| 162 |
+
# Pyre type checker
|
| 163 |
+
.pyre/
|
| 164 |
+
|
| 165 |
+
# pytype static type analyzer
|
| 166 |
+
.pytype/
|
| 167 |
+
|
| 168 |
+
# Cython debug symbols
|
| 169 |
+
cython_debug/
|
| 170 |
+
|
| 171 |
+
# PyCharm
|
| 172 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 173 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 174 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 175 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 176 |
+
#.idea/
|
| 177 |
+
|
| 178 |
+
# Abstra
|
| 179 |
+
# Abstra is an AI-powered process automation framework.
|
| 180 |
+
# Ignore directories containing user credentials, local state, and settings.
|
| 181 |
+
# Learn more at https://abstra.io/docs
|
| 182 |
+
.abstra/
|
| 183 |
+
|
| 184 |
+
# Visual Studio Code
|
| 185 |
+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
| 186 |
+
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
| 187 |
+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
| 188 |
+
# you could uncomment the following to ignore the entire vscode folder
|
| 189 |
+
# .vscode/
|
| 190 |
+
|
| 191 |
+
# Ruff stuff:
|
| 192 |
+
.ruff_cache/
|
| 193 |
+
|
| 194 |
+
# PyPI configuration file
|
| 195 |
+
.pypirc
|
| 196 |
+
|
| 197 |
+
# Cursor
|
| 198 |
+
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
|
| 199 |
+
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
|
| 200 |
+
# refer to https://docs.cursor.com/context/ignore-files
|
| 201 |
+
.cursorignore
|
| 202 |
+
.cursorindexingignore
|
| 203 |
+
|
| 204 |
+
# Marimo
|
| 205 |
+
marimo/_static/
|
| 206 |
+
marimo/_lsp/
|
| 207 |
+
__marimo__/
|
| 208 |
+
|
| 209 |
+
# temporary files
|
| 210 |
+
tmp_*
|
| 211 |
+
rsync_tmp_*
|
| 212 |
+
.aider*
|
| 213 |
+
data_analysis.md
|
| 214 |
+
software-agent-sdk
|
| 215 |
+
env
|
.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
repos:
|
| 2 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 3 |
+
rev: v6.0.0
|
| 4 |
+
hooks:
|
| 5 |
+
- id: trailing-whitespace
|
| 6 |
+
- id: end-of-file-fixer
|
| 7 |
+
- id: name-tests-test
|
| 8 |
+
- id: requirements-txt-fixer
|
| 9 |
+
- repo: https://github.com/pycqa/isort
|
| 10 |
+
rev: 5.13.2
|
| 11 |
+
hooks:
|
| 12 |
+
- id: isort
|
| 13 |
+
args: ["--profile", "black", "--line-length=100", "--python-version=310"]
|
| 14 |
+
- repo: https://github.com/psf/black
|
| 15 |
+
rev: 25.1.0
|
| 16 |
+
hooks:
|
| 17 |
+
- id: black
|
| 18 |
+
args: ["--line-length=100", "--target-version=py310"]
|
| 19 |
+
- repo: https://github.com/kynan/nbstripout
|
| 20 |
+
rev: 0.8.2
|
| 21 |
+
hooks:
|
| 22 |
+
- id: nbstripout
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.13
|
.vscode/settings.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"sync-rsync.local": "/home/link/github/SciEvo",
|
| 3 |
+
"sync-rsync.remote": "klin07@astral-8xA30-jump:~/rsync/SciEvo",
|
| 4 |
+
// "sync-rsync.remote": "ubuntu@192.9.141.47:~/rsync/mle-bench/agents/scievo/SciEvo",
|
| 5 |
+
"sync-rsync.onSave": false,
|
| 6 |
+
"sync-rsync.onSaveIndividual": true,
|
| 7 |
+
"sync-rsync.delete": false,
|
| 8 |
+
"sync-rsync.exclude": [
|
| 9 |
+
".vscode/",
|
| 10 |
+
"**/__pycache__/",
|
| 11 |
+
"**/*.pyc",
|
| 12 |
+
".git/",
|
| 13 |
+
".ipynb_checkpoints/",
|
| 14 |
+
"**/.venv/",
|
| 15 |
+
"**/.DS_Store",
|
| 16 |
+
"**/.pytest_cache/",
|
| 17 |
+
// tmp files
|
| 18 |
+
"**/tmp_*",
|
| 19 |
+
// project specific
|
| 20 |
+
"sth_large/"
|
| 21 |
+
],
|
| 22 |
+
"sync-rsync.useWSL": true
|
| 23 |
+
}
|
CLAUDE.md
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CLAUDE.md
|
| 2 |
+
|
| 3 |
+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
| 4 |
+
|
| 5 |
+
## Project Overview
|
| 6 |
+
|
| 7 |
+
SciEvo is a multi-agent framework for automated scientific experimentation. It orchestrates data analysis and experimental code generation through specialized agents that can search papers, generate code, execute experiments, and maintain long-term memory of insights.
|
| 8 |
+
|
| 9 |
+
## Setup and Environment
|
| 10 |
+
|
| 11 |
+
### Initial Setup
|
| 12 |
+
|
| 13 |
+
```bash
|
| 14 |
+
# Install dependencies (choose based on your platform)
|
| 15 |
+
# For macOS
|
| 16 |
+
uv sync --extra mac
|
| 17 |
+
|
| 18 |
+
# For CPU-only
|
| 19 |
+
uv sync --extra cpu
|
| 20 |
+
|
| 21 |
+
# For CUDA 12.8
|
| 22 |
+
uv sync --extra cu128
|
| 23 |
+
|
| 24 |
+
# Install pre-commit hooks
|
| 25 |
+
pip install pre-commit
|
| 26 |
+
pre-commit install
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
### Environment Configuration
|
| 30 |
+
|
| 31 |
+
Copy `.env.template` to `.env` and configure:
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
cp .env.template .env
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
Required environment variables:
|
| 38 |
+
- `OPENAI_API_KEY` - OpenAI API access
|
| 39 |
+
- `GEMINI_API_KEY` - Google Gemini API access
|
| 40 |
+
- `BRAIN_DIR` - Session storage location (default: `./tmp_brain`)
|
| 41 |
+
|
| 42 |
+
Optional configurations (see `.env.template` for full list):
|
| 43 |
+
- `REASONING_BANK_ENABLED` - Enable long-term memory consolidation
|
| 44 |
+
- `HISTORY_AUTO_COMPRESSION` - Auto-compress conversation history
|
| 45 |
+
- `CRITIC_ENABLED` - Enable agent output critique
|
| 46 |
+
- `CODING_AGENT_VERSION` - v2 or v3
|
| 47 |
+
- `AIDER_*` - Aider code editor configuration
|
| 48 |
+
- `OPENHANDS_MODEL` - Model for OpenHands integration
|
| 49 |
+
|
| 50 |
+
### Code Formatting
|
| 51 |
+
|
| 52 |
+
This project uses:
|
| 53 |
+
- **black** (line length: 100, target: py310)
|
| 54 |
+
- **isort** (profile: black, line length: 100)
|
| 55 |
+
- **nbstripout** for cleaning notebooks
|
| 56 |
+
|
| 57 |
+
Run formatting manually:
|
| 58 |
+
```bash
|
| 59 |
+
pre-commit run --all-files
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
## Running Workflows
|
| 63 |
+
|
| 64 |
+
### Full Workflow (Data Analysis + Experiment)
|
| 65 |
+
|
| 66 |
+
```bash
|
| 67 |
+
python -m scievo.run_workflow full <data_path> <workspace_path> "<user_query>" [repo_source]
|
| 68 |
+
|
| 69 |
+
# Example
|
| 70 |
+
python -m scievo.run_workflow full data.csv ./workspace "Train SVR model for regression"
|
| 71 |
+
|
| 72 |
+
# With options
|
| 73 |
+
python -m scievo.run_workflow full data.csv ./workspace "Train model" \
|
| 74 |
+
--data-recursion-limit 100 \
|
| 75 |
+
--experiment-recursion-limit 100 \
|
| 76 |
+
--session-name my_experiment
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
### Data Analysis Only
|
| 80 |
+
|
| 81 |
+
```bash
|
| 82 |
+
python -m scievo.run_workflow data <data_path> <workspace_path> [--recursion-limit N] [--session-name NAME]
|
| 83 |
+
|
| 84 |
+
# Example
|
| 85 |
+
python -m scievo.run_workflow data data.csv ./workspace --session-name my_analysis
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
### Experiment Only (Requires Existing Analysis)
|
| 89 |
+
|
| 90 |
+
```bash
|
| 91 |
+
python -m scievo.run_workflow experiment <workspace_path> "<user_query>" [data_analysis_path] [--recursion-limit N]
|
| 92 |
+
|
| 93 |
+
# Example (uses data_analysis.md from workspace)
|
| 94 |
+
python -m scievo.run_workflow experiment ./workspace "Train SVR model"
|
| 95 |
+
|
| 96 |
+
# With custom analysis file
|
| 97 |
+
python -m scievo.run_workflow experiment ./workspace "Train model" ./my_analysis.md
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
## Architecture Overview
|
| 101 |
+
|
| 102 |
+
### Core Components
|
| 103 |
+
|
| 104 |
+
**`scievo/core/`** - Infrastructure and shared utilities
|
| 105 |
+
- `types.py` - Core message types, state management (ToolsetState, HistoryState, RBankState, ExecState)
|
| 106 |
+
- `brain.py` - Singleton session manager coordinating shared application state
|
| 107 |
+
- `llms.py` - Model registry with completion/response API wrappers (supports rate limiting, embeddings)
|
| 108 |
+
- `exec/` - Command execution sessions (SessionManager, PTYSession)
|
| 109 |
+
- `code_env.py` - Workspace context manager (LocalEnv)
|
| 110 |
+
- `utils.py` - TOON/JSON parsing, markdown extraction
|
| 111 |
+
- `constant.py` - Configuration flags and defaults
|
| 112 |
+
|
| 113 |
+
**`scievo/tools/`** - 20+ tool integrations
|
| 114 |
+
- Core: `fs_tool`, `shell_tool`, `exec_tool`
|
| 115 |
+
- Search: `arxiv_tool`, `dataset_search_tool`, `metric_search_tool`, `web_tool`
|
| 116 |
+
- Code: `coder_tool`, `cursor_tool`, `claude_code_tool`, `claude_agent_sdk_tool`, `openhands_tool`
|
| 117 |
+
- Other: `github_tool`, `ideation_tool`, `history_tool`, `state_tool`, `todo_tool`, `env_tool`
|
| 118 |
+
- Registry: `Tool` base class with JSON schemas, `ToolRegistry` singleton
|
| 119 |
+
|
| 120 |
+
**`scievo/agents/`** - Agent implementations using LangGraph
|
| 121 |
+
- `data_agent/` - Analyzes data, generates `data_analysis.md`, searches papers/datasets
|
| 122 |
+
- Flow: START → planner → gateway (router) → llm_chat/tool_calling/mem_extraction → replanner → finalize → END
|
| 123 |
+
- Sub-agents: `paper_subagent/` for academic search
|
| 124 |
+
- `experiment_agent/` - Generates and executes experimental code
|
| 125 |
+
- Flow: START → init → coding → exec → summary → analysis → revision_judge → END
|
| 126 |
+
- Sub-agents: CodingSubagent, ExecSubagent, SummarySubagent
|
| 127 |
+
- `ideation_agent/` - Research idea generation
|
| 128 |
+
- `critic_agent/` - Output quality review
|
| 129 |
+
|
| 130 |
+
**`scievo/workflows/`** - Workflow orchestration
|
| 131 |
+
- `full_workflow.py` - Chains DataAgent → ExperimentAgent
|
| 132 |
+
- `data_workflow.py` - Standalone DataAgent execution
|
| 133 |
+
- `experiment_workflow.py` - Standalone ExperimentAgent execution
|
| 134 |
+
- `run_workflow.py` - CLI entry point with three subcommands (backward compatibility layer)
|
| 135 |
+
|
| 136 |
+
**`scievo/prompts/`** - Prompt management
|
| 137 |
+
- `prompt_data.py` - Dataclass-based organization (DataPrompts, ExperimentPrompts, etc.)
|
| 138 |
+
- YAML files with Jinja2 templating for dynamic content
|
| 139 |
+
|
| 140 |
+
**`scievo/rbank/`** - ReasoningBank (Long-term Memory)
|
| 141 |
+
- `memo.py` - Persistent memory with embeddings for similarity search
|
| 142 |
+
- `subgraph/` - Memory consolidation subgraph
|
| 143 |
+
- Three memory tiers: short-term (session), long-term (cross-project), project-specific
|
| 144 |
+
|
| 145 |
+
### Key Architectural Patterns
|
| 146 |
+
|
| 147 |
+
1. **Singleton Pattern** - Brain, ModelRegistry, SessionManager, ToolRegistry ensure single instances
|
| 148 |
+
2. **State Graph Pattern** (LangGraph) - Agents as stateful graphs with nodes (steps) and edges (transitions)
|
| 149 |
+
3. **Sub-agent Composition** - Complex agents orchestrate specialized sub-agents
|
| 150 |
+
4. **History Compression** - Automatic message summarization to manage token usage
|
| 151 |
+
5. **Tool Registry** - Self-registering tools with JSON schemas for LLM consumption
|
| 152 |
+
6. **Memory Consolidation** - Periodic extraction of insights into long-term, project, and short-term memory
|
| 153 |
+
|
| 154 |
+
### Data Flow
|
| 155 |
+
|
| 156 |
+
```
|
| 157 |
+
run_workflow.py CLI
|
| 158 |
+
↓
|
| 159 |
+
FullWorkflow
|
| 160 |
+
├─→ DataWorkflow
|
| 161 |
+
│ ├─→ DataAgent (planner → execution loop → finalize)
|
| 162 |
+
│ │ └─→ PaperSubagent (searches papers/datasets)
|
| 163 |
+
│ └─→ Output: data_analysis.md
|
| 164 |
+
│
|
| 165 |
+
└─→ ExperimentWorkflow
|
| 166 |
+
├─→ ExperimentAgent (init → coding → exec → summary → revision loop)
|
| 167 |
+
│ ├─→ CodingSubagent
|
| 168 |
+
│ ├─→ ExecSubagent
|
| 169 |
+
│ └─→ SummarySubagent
|
| 170 |
+
└─→ Output: metrics, final_summary
|
| 171 |
+
|
| 172 |
+
All agents use: Brain, ModelRegistry, ToolRegistry, Prompts, ReasoningBank
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
## Development Guidelines
|
| 176 |
+
|
| 177 |
+
### Agent State Management
|
| 178 |
+
|
| 179 |
+
Agents use LangGraph state objects that extend core state types:
|
| 180 |
+
- `HistoryState` - Message history with compression support
|
| 181 |
+
- `ToolsetState` - Available tools
|
| 182 |
+
- `RBankState` - Memory directories
|
| 183 |
+
- `ExecState` - Execution sessions
|
| 184 |
+
|
| 185 |
+
State is passed through node functions and updated via returns.
|
| 186 |
+
|
| 187 |
+
### Adding New Tools
|
| 188 |
+
|
| 189 |
+
1. Create tool in `scievo/tools/` directory
|
| 190 |
+
2. Inherit from `Tool` base class
|
| 191 |
+
3. Define `json_schema` property
|
| 192 |
+
4. Implement tool logic
|
| 193 |
+
5. Tool auto-registers on import via `ToolRegistry`
|
| 194 |
+
|
| 195 |
+
### Working with Memory
|
| 196 |
+
|
| 197 |
+
- Enable via `REASONING_BANK_ENABLED=true` in `.env`
|
| 198 |
+
- Extraction frequency controlled by `MEM_EXTRACTION_ROUND_FREQ`
|
| 199 |
+
- Three directories: short-term, long-term (MEM_LONG_TERM_DIR), project (MEM_PROJECT_DIR)
|
| 200 |
+
- Memories stored as markdown with embeddings for retrieval
|
| 201 |
+
|
| 202 |
+
### History Management
|
| 203 |
+
|
| 204 |
+
- Auto-compression enabled via `HISTORY_AUTO_COMPRESSION=true`
|
| 205 |
+
- Triggers at `HISTORY_AUTO_COMPRESSION_TOKEN_THRESHOLD` (default: 64000)
|
| 206 |
+
- Keeps `HISTORY_AUTO_COMPRESSION_KEEP_RATIO` (default: 0.33) of messages
|
| 207 |
+
- Compression patches stored in `HistoryState.history_patches`
|
| 208 |
+
|
| 209 |
+
## File Locations
|
| 210 |
+
|
| 211 |
+
- Workflow implementations: `scievo/workflows/`
|
| 212 |
+
- Agent logic: `scievo/agents/{agent_name}/`
|
| 213 |
+
- Tool definitions: `scievo/tools/`
|
| 214 |
+
- Prompts: `scievo/prompts/` (YAML files) + `prompt_data.py` (dataclasses)
|
| 215 |
+
- Core infrastructure: `scievo/core/`
|
| 216 |
+
- Memory: Configured via `BRAIN_DIR`, `MEM_LONG_TERM_DIR`, `MEM_PROJECT_DIR`
|
| 217 |
+
- Generated outputs: Within workspace directory specified in CLI
|
| 218 |
+
|
| 219 |
+
## Testing and Debugging
|
| 220 |
+
|
| 221 |
+
### Jupyter Notebooks
|
| 222 |
+
|
| 223 |
+
Development notebooks are prefixed with `tmp_*`:
|
| 224 |
+
- `tmp_workflow_w_ideation.ipynb` - Full workflow with ideation
|
| 225 |
+
- `tmp_ideation_test.ipynb` - Ideation agent testing
|
| 226 |
+
- `tmp_paper_agent_test.ipynb` - Paper search testing
|
| 227 |
+
- Other `tmp_*.ipynb` files for component testing
|
| 228 |
+
|
| 229 |
+
### Logging
|
| 230 |
+
|
| 231 |
+
Control verbosity via `.env`:
|
| 232 |
+
```bash
|
| 233 |
+
LOGURU_LEVEL=DEBUG # or INFO
|
| 234 |
+
LOG_MEM_SUBGRAPH=true # Memory consolidation logs
|
| 235 |
+
LOG_SYSTEM_PROMPT=false # Show system prompts
|
| 236 |
+
```
|
| 237 |
+
|
| 238 |
+
### Running Partial Workflows
|
| 239 |
+
|
| 240 |
+
Use mode-specific commands for testing individual components:
|
| 241 |
+
```bash
|
| 242 |
+
# Test only data analysis
|
| 243 |
+
python -m scievo.run_workflow data test_data/sample.csv ./debug_workspace
|
| 244 |
+
|
| 245 |
+
# Test experiment with existing analysis
|
| 246 |
+
python -m scievo.run_workflow experiment ./debug_workspace "Test query"
|
| 247 |
+
```
|
| 248 |
+
|
| 249 |
+
## Important Notes
|
| 250 |
+
|
| 251 |
+
- **Python Version**: Requires Python >=3.13 (see `pyproject.toml`)
|
| 252 |
+
- **Package Manager**: Uses `uv` for dependency management
|
| 253 |
+
- **PyTorch**: Platform-specific installation via custom indices (see `pyproject.toml` [tool.uv.sources])
|
| 254 |
+
- **Optional Dependencies**: OpenHands (`openhands-sdk`, `openhands-tools`) - enable via `SCIEVO_ENABLE_OPENHANDS`
|
| 255 |
+
- **Pre-commit Hooks**: Always run before committing to maintain code style
|
| 256 |
+
- **Temporary Files**: `tmp_*` directories and notebooks are for development, not production
|
| 257 |
+
- **Brain Directory**: Session state persists in `BRAIN_DIR` - can accumulate over time
|
Dockerfile
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.13.5-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
RUN apt-get update && apt-get install -y \
|
| 6 |
+
build-essential \
|
| 7 |
+
curl \
|
| 8 |
+
git \
|
| 9 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 10 |
+
|
| 11 |
+
COPY requirements.txt ./
|
| 12 |
+
COPY src/ ./src/
|
| 13 |
+
|
| 14 |
+
RUN pip3 install -r requirements.txt
|
| 15 |
+
|
| 16 |
+
EXPOSE 8501
|
| 17 |
+
|
| 18 |
+
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
| 19 |
+
|
| 20 |
+
ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SciEvo
|
| 2 |
+
|
| 3 |
+
```shell
|
| 4 |
+
# for cpu
|
| 5 |
+
uv sync --extra cpu
|
| 6 |
+
|
| 7 |
+
# for mac
|
| 8 |
+
uv sync --extra mac
|
| 9 |
+
|
| 10 |
+
# for gpu
|
| 11 |
+
uv sync --extra cu128
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
Optional: install Claude Code (for `claude_code` toolset):
|
| 15 |
+
|
| 16 |
+
- Ensure the `claude` CLI is installed and authenticated on your machine.
|
| 17 |
+
- If your `claude` command needs extra flags, set `CLAUDE_CODE_CMD`, e.g.:
|
| 18 |
+
|
| 19 |
+
```shell
|
| 20 |
+
export CLAUDE_CODE_CMD="claude"
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
Optional: install Claude Agent SDK (for `claude_agent_sdk` toolset):
|
| 24 |
+
|
| 25 |
+
- Docs: `https://platform.claude.com/docs/en/agent-sdk/overview`
|
| 26 |
+
- Install:
|
| 27 |
+
|
| 28 |
+
```shell
|
| 29 |
+
pip install claude-agent-sdk
|
| 30 |
+
export ANTHROPIC_API_KEY="..."
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
## Development Guide
|
| 34 |
+
|
| 35 |
+
First, install `pre-commit`:
|
| 36 |
+
```shell
|
| 37 |
+
pip install pre-commit
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
Install `pre-commit` to format code:
|
| 41 |
+
```shell
|
| 42 |
+
pre-commit install
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
Then, copy `.env.template` to `.env` and fill in the necessary values.
|
| 46 |
+
```
|
| 47 |
+
OPENAI_API_KEY=<your_openai_api_key>
|
| 48 |
+
GEMINI_API_KEY=<your_gemini_api_key>
|
| 49 |
+
```
|
bench/mlebench_workflow.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MLE-Bench Workflow
|
| 3 |
+
|
| 4 |
+
Simple wrapper for running SciEvo FullWorkflow on MLE-Bench competition tasks.
|
| 5 |
+
|
| 6 |
+
MLE-Bench provides:
|
| 7 |
+
- instructions.md: Specific task instructions (used as user_query)
|
| 8 |
+
- description.md: Overall task background description
|
| 9 |
+
|
| 10 |
+
This wrapper register models, reads these files, builds user_query, and invokes FullWorkflow.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import sys
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
# Add parent directory to path to find scievo and bench modules
|
| 17 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 18 |
+
|
| 19 |
+
from loguru import logger
|
| 20 |
+
|
| 21 |
+
from bench.register_models.gemini import (
|
| 22 |
+
register_gemini_low_medium_models,
|
| 23 |
+
register_gemini_medium_high_models,
|
| 24 |
+
)
|
| 25 |
+
from bench.register_models.gpt import (
|
| 26 |
+
register_gpt_low_medium_models,
|
| 27 |
+
register_gpt_medium_high_models,
|
| 28 |
+
)
|
| 29 |
+
from scievo.workflows.full_workflow import run_full_workflow
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def build_mlebench_user_query(
|
| 33 |
+
instructions_path: Path,
|
| 34 |
+
description_path: Path,
|
| 35 |
+
) -> tuple[str, str]:
|
| 36 |
+
"""
|
| 37 |
+
Build user query and data description from MLE-Bench task files.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
instructions_path: Path to instructions.md
|
| 41 |
+
description_path: Path to description.md
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
Tuple of (user_query, data_desc)
|
| 45 |
+
- user_query: Task instructions for the experiment
|
| 46 |
+
- data_desc: Task description for data analysis context
|
| 47 |
+
"""
|
| 48 |
+
# Load instructions
|
| 49 |
+
if not instructions_path.exists():
|
| 50 |
+
raise FileNotFoundError(f"Instructions file not found: {instructions_path}")
|
| 51 |
+
instructions = instructions_path.read_text(encoding="utf-8")
|
| 52 |
+
|
| 53 |
+
# Load description
|
| 54 |
+
if not description_path.exists():
|
| 55 |
+
raise FileNotFoundError(f"Description file not found: {description_path}")
|
| 56 |
+
description = description_path.read_text(encoding="utf-8")
|
| 57 |
+
|
| 58 |
+
# Use instructions as user_query, description as data_desc
|
| 59 |
+
user_query = instructions
|
| 60 |
+
data_desc = description
|
| 61 |
+
|
| 62 |
+
return user_query, data_desc
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
if __name__ == "__main__":
|
| 66 |
+
import argparse
|
| 67 |
+
|
| 68 |
+
parser = argparse.ArgumentParser(
|
| 69 |
+
description="MLE-Bench Workflow - Run SciEvo on MLE-Bench competition tasks",
|
| 70 |
+
prog="python -m bench.mlebench_workflow",
|
| 71 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 72 |
+
epilog="""
|
| 73 |
+
Examples:
|
| 74 |
+
# Basic usage
|
| 75 |
+
python -m bench.mlebench_workflow \\
|
| 76 |
+
-i competition/instructions.md \\
|
| 77 |
+
-d competition/description.md \\
|
| 78 |
+
--data competition/data \\
|
| 79 |
+
-w workspace
|
| 80 |
+
|
| 81 |
+
# With custom settings
|
| 82 |
+
python -m bench.mlebench_workflow \\
|
| 83 |
+
-i competition/instructions.md \\
|
| 84 |
+
-d competition/description.md \\
|
| 85 |
+
--data competition/data \\
|
| 86 |
+
-w workspace \\
|
| 87 |
+
--max-revisions 10 \\
|
| 88 |
+
--session-name my_experiment
|
| 89 |
+
""",
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Required arguments
|
| 93 |
+
parser.add_argument(
|
| 94 |
+
"--instructions",
|
| 95 |
+
"-i",
|
| 96 |
+
required=True,
|
| 97 |
+
help="Path to instructions.md (task instructions)",
|
| 98 |
+
)
|
| 99 |
+
parser.add_argument(
|
| 100 |
+
"--description",
|
| 101 |
+
"-d",
|
| 102 |
+
required=True,
|
| 103 |
+
help="Path to description.md (task background)",
|
| 104 |
+
)
|
| 105 |
+
parser.add_argument(
|
| 106 |
+
"--data",
|
| 107 |
+
required=True,
|
| 108 |
+
help="Path to the data directory or file",
|
| 109 |
+
)
|
| 110 |
+
parser.add_argument(
|
| 111 |
+
"--workspace",
|
| 112 |
+
"-w",
|
| 113 |
+
required=True,
|
| 114 |
+
help="Workspace directory for the experiment",
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# Optional arguments
|
| 118 |
+
parser.add_argument(
|
| 119 |
+
"--repo-source",
|
| 120 |
+
default=None,
|
| 121 |
+
help="Optional repository source (local path or git URL)",
|
| 122 |
+
)
|
| 123 |
+
parser.add_argument(
|
| 124 |
+
"--max-revisions",
|
| 125 |
+
type=int,
|
| 126 |
+
default=3,
|
| 127 |
+
help="Maximum revision loops (default: 3)",
|
| 128 |
+
)
|
| 129 |
+
parser.add_argument(
|
| 130 |
+
"--data-recursion-limit",
|
| 131 |
+
type=int,
|
| 132 |
+
default=512,
|
| 133 |
+
help="Recursion limit for DataAgent (default: 512)",
|
| 134 |
+
)
|
| 135 |
+
parser.add_argument(
|
| 136 |
+
"--experiment-recursion-limit",
|
| 137 |
+
type=int,
|
| 138 |
+
default=512,
|
| 139 |
+
help="Recursion limit for ExperimentAgent (default: 512)",
|
| 140 |
+
)
|
| 141 |
+
parser.add_argument(
|
| 142 |
+
"--session-name",
|
| 143 |
+
default=None,
|
| 144 |
+
help="Custom session name (otherwise uses timestamp)",
|
| 145 |
+
)
|
| 146 |
+
parser.add_argument(
|
| 147 |
+
"--models",
|
| 148 |
+
choices=[
|
| 149 |
+
"gpt-low-medium",
|
| 150 |
+
"gpt-medium-high",
|
| 151 |
+
"gemini-low-medium",
|
| 152 |
+
"gemini-medium-high",
|
| 153 |
+
],
|
| 154 |
+
default="gemini-low-medium",
|
| 155 |
+
help="Model configuration to use (default: gemini-low-medium)",
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
args = parser.parse_args()
|
| 159 |
+
|
| 160 |
+
# Register models based on choice
|
| 161 |
+
logger.info(f"Registering models: {args.models}")
|
| 162 |
+
match args.models:
|
| 163 |
+
case "gpt-low-medium":
|
| 164 |
+
register_gpt_low_medium_models()
|
| 165 |
+
case "gpt-medium-high":
|
| 166 |
+
register_gpt_medium_high_models()
|
| 167 |
+
case "gemini-low-medium":
|
| 168 |
+
register_gemini_low_medium_models()
|
| 169 |
+
case "gemini-medium-high":
|
| 170 |
+
register_gemini_medium_high_models()
|
| 171 |
+
|
| 172 |
+
# Build user query and data description from MLE-Bench files
|
| 173 |
+
logger.info("Building user query from MLE-Bench task files...")
|
| 174 |
+
user_query, data_desc = build_mlebench_user_query(
|
| 175 |
+
instructions_path=Path(args.instructions),
|
| 176 |
+
description_path=Path(args.description),
|
| 177 |
+
)
|
| 178 |
+
logger.info(f"User query built: {len(user_query)} chars")
|
| 179 |
+
logger.info(f"Data description built: {len(data_desc)} chars")
|
| 180 |
+
|
| 181 |
+
# Run FullWorkflow
|
| 182 |
+
result = run_full_workflow(
|
| 183 |
+
data_path=args.data,
|
| 184 |
+
workspace_path=args.workspace,
|
| 185 |
+
user_query=user_query,
|
| 186 |
+
data_desc=data_desc,
|
| 187 |
+
repo_source=args.repo_source,
|
| 188 |
+
max_revisions=args.max_revisions,
|
| 189 |
+
data_agent_recursion_limit=args.data_recursion_limit,
|
| 190 |
+
experiment_agent_recursion_limit=args.experiment_recursion_limit,
|
| 191 |
+
session_name=args.session_name,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# Save summary
|
| 195 |
+
result.save_summary()
|
| 196 |
+
|
| 197 |
+
print(f"\nStatus: {result.final_status}")
|
bench/register_models/gemini.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
|
| 5 |
+
load_dotenv()
|
| 6 |
+
from pydantic import BaseModel
|
| 7 |
+
|
| 8 |
+
from scievo.core.llms import ModelRegistry
|
| 9 |
+
|
| 10 |
+
LOW_COST_MODEL = "gemini/gemini-2.5-flash-lite"
|
| 11 |
+
MEDIUM_COST_MODEL = "gemini/gemini-2.5-flash"
|
| 12 |
+
HIGH_COST_MODEL = "gemini/gemini-2.5-pro"
|
| 13 |
+
|
| 14 |
+
OPENAI_KEY = os.getenv("OPENAI_API_KEY")
|
| 15 |
+
GEMINI_KEY = os.getenv("GEMINI_API_KEY")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def register_gemini_low_medium_models(reasoning: str = "low"):
|
| 19 |
+
"""Register Gemini low and medium cost models in the ModelRegistry."""
|
| 20 |
+
ModelRegistry.register(
|
| 21 |
+
name="data",
|
| 22 |
+
model=LOW_COST_MODEL,
|
| 23 |
+
api_key=GEMINI_KEY,
|
| 24 |
+
reasoning_effort=reasoning,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
ModelRegistry.register(
|
| 28 |
+
name="plan",
|
| 29 |
+
model=MEDIUM_COST_MODEL,
|
| 30 |
+
api_key=GEMINI_KEY,
|
| 31 |
+
reasoning_effort=reasoning,
|
| 32 |
+
temperature=0.3,
|
| 33 |
+
top_p=0.9,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
ModelRegistry.register(
|
| 37 |
+
name="critic",
|
| 38 |
+
model=LOW_COST_MODEL,
|
| 39 |
+
api_key=GEMINI_KEY,
|
| 40 |
+
reasoning_effort=reasoning,
|
| 41 |
+
temperature=0.3,
|
| 42 |
+
top_p=0.9,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
ModelRegistry.register(
|
| 46 |
+
name="mem",
|
| 47 |
+
model=LOW_COST_MODEL,
|
| 48 |
+
api_key=GEMINI_KEY,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# NOTE: Use OpenAI embeddings for better performance
|
| 52 |
+
ModelRegistry.register(
|
| 53 |
+
name="embed",
|
| 54 |
+
model="text-embedding-3-small",
|
| 55 |
+
api_key=OPENAI_KEY,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
ModelRegistry.register(
|
| 59 |
+
name="history",
|
| 60 |
+
model=LOW_COST_MODEL,
|
| 61 |
+
api_key=GEMINI_KEY,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
ModelRegistry.register(
|
| 65 |
+
name="experiment_agent",
|
| 66 |
+
model=MEDIUM_COST_MODEL,
|
| 67 |
+
api_key=GEMINI_KEY,
|
| 68 |
+
reasoning_effort=reasoning,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
ModelRegistry.register(
|
| 72 |
+
name="experiment_coding",
|
| 73 |
+
model=MEDIUM_COST_MODEL,
|
| 74 |
+
api_key=GEMINI_KEY,
|
| 75 |
+
reasoning_effort=reasoning,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
ModelRegistry.register(
|
| 79 |
+
name="experiment_execute",
|
| 80 |
+
model=MEDIUM_COST_MODEL,
|
| 81 |
+
api_key=GEMINI_KEY,
|
| 82 |
+
reasoning_effort=reasoning,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
ModelRegistry.register(
|
| 86 |
+
name="experiment_monitor",
|
| 87 |
+
model=LOW_COST_MODEL,
|
| 88 |
+
api_key=GEMINI_KEY,
|
| 89 |
+
temperature=0.3,
|
| 90 |
+
top_p=0.9,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
ModelRegistry.register(
|
| 94 |
+
name="experiment_summary",
|
| 95 |
+
model=LOW_COST_MODEL,
|
| 96 |
+
api_key=GEMINI_KEY,
|
| 97 |
+
reasoning_effort="low",
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def register_gemini_medium_high_models(reasoning: str = "low"):
|
| 102 |
+
"""Register Gemini medium and high cost models in the ModelRegistry."""
|
| 103 |
+
ModelRegistry.register(
|
| 104 |
+
name="data",
|
| 105 |
+
model=MEDIUM_COST_MODEL,
|
| 106 |
+
api_key=GEMINI_KEY,
|
| 107 |
+
reasoning_effort=reasoning,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
ModelRegistry.register(
|
| 111 |
+
name="plan",
|
| 112 |
+
model=HIGH_COST_MODEL,
|
| 113 |
+
api_key=GEMINI_KEY,
|
| 114 |
+
reasoning_effort=reasoning,
|
| 115 |
+
temperature=0.3,
|
| 116 |
+
top_p=0.9,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
ModelRegistry.register(
|
| 120 |
+
name="critic",
|
| 121 |
+
model=HIGH_COST_MODEL,
|
| 122 |
+
api_key=GEMINI_KEY,
|
| 123 |
+
reasoning_effort=reasoning,
|
| 124 |
+
temperature=0.3,
|
| 125 |
+
top_p=0.9,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
ModelRegistry.register(
|
| 129 |
+
name="mem",
|
| 130 |
+
model=MEDIUM_COST_MODEL,
|
| 131 |
+
api_key=GEMINI_KEY,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
ModelRegistry.register(
|
| 135 |
+
name="embed",
|
| 136 |
+
model="text-embedding-3-small",
|
| 137 |
+
api_key=OPENAI_KEY,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
ModelRegistry.register(
|
| 141 |
+
name="history",
|
| 142 |
+
model=MEDIUM_COST_MODEL,
|
| 143 |
+
api_key=GEMINI_KEY,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
ModelRegistry.register(
|
| 147 |
+
name="experiment_agent",
|
| 148 |
+
model=HIGH_COST_MODEL,
|
| 149 |
+
api_key=GEMINI_KEY,
|
| 150 |
+
reasoning_effort=reasoning,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
ModelRegistry.register(
|
| 154 |
+
name="experiment_coding",
|
| 155 |
+
model=HIGH_COST_MODEL,
|
| 156 |
+
api_key=GEMINI_KEY,
|
| 157 |
+
reasoning_effort=reasoning,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
ModelRegistry.register(
|
| 161 |
+
name="experiment_execute",
|
| 162 |
+
model=HIGH_COST_MODEL,
|
| 163 |
+
api_key=GEMINI_KEY,
|
| 164 |
+
reasoning_effort=reasoning,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
ModelRegistry.register(
|
| 168 |
+
name="experiment_monitor",
|
| 169 |
+
model=MEDIUM_COST_MODEL,
|
| 170 |
+
api_key=GEMINI_KEY,
|
| 171 |
+
temperature=0.3,
|
| 172 |
+
top_p=0.9,
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
ModelRegistry.register(
|
| 176 |
+
name="experiment_summary",
|
| 177 |
+
model=HIGH_COST_MODEL,
|
| 178 |
+
api_key=GEMINI_KEY,
|
| 179 |
+
reasoning_effort="low",
|
| 180 |
+
)
|
bench/register_models/gpt.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
|
| 5 |
+
load_dotenv()
|
| 6 |
+
from pydantic import BaseModel
|
| 7 |
+
|
| 8 |
+
from scievo.core.llms import ModelRegistry
|
| 9 |
+
|
| 10 |
+
LOW_COST_MODEL = "gpt-5-nano"
|
| 11 |
+
MEDIUM_COST_MODEL = "gpt-5-mini"
|
| 12 |
+
HIGH_COST_MODEL = "gpt-5.2"
|
| 13 |
+
|
| 14 |
+
OPENAI_KEY = os.getenv("OPENAI_API_KEY")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def register_gpt_low_medium_models(reasoning: str = "low"):
|
| 18 |
+
"""Register GPT low and medium cost models in the ModelRegistry."""
|
| 19 |
+
ModelRegistry.register(
|
| 20 |
+
name="data",
|
| 21 |
+
model=LOW_COST_MODEL,
|
| 22 |
+
api_key=OPENAI_KEY,
|
| 23 |
+
reasoning={
|
| 24 |
+
"effort": reasoning,
|
| 25 |
+
"summary": "detailed",
|
| 26 |
+
},
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
ModelRegistry.register(
|
| 30 |
+
name="plan",
|
| 31 |
+
model=MEDIUM_COST_MODEL,
|
| 32 |
+
api_key=OPENAI_KEY,
|
| 33 |
+
reasoning={
|
| 34 |
+
"effort": reasoning,
|
| 35 |
+
},
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
ModelRegistry.register(
|
| 39 |
+
name="critic",
|
| 40 |
+
model=LOW_COST_MODEL,
|
| 41 |
+
api_key=OPENAI_KEY,
|
| 42 |
+
reasoning={
|
| 43 |
+
"effort": reasoning,
|
| 44 |
+
},
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
ModelRegistry.register(
|
| 48 |
+
name="mem",
|
| 49 |
+
model=LOW_COST_MODEL,
|
| 50 |
+
api_key=OPENAI_KEY,
|
| 51 |
+
reasoning={
|
| 52 |
+
"effort": "minimal",
|
| 53 |
+
},
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# NOTE: Use OpenAI embeddings for better performance
|
| 57 |
+
ModelRegistry.register(
|
| 58 |
+
name="embed",
|
| 59 |
+
model="text-embedding-3-small",
|
| 60 |
+
api_key=OPENAI_KEY,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
ModelRegistry.register(
|
| 64 |
+
name="history",
|
| 65 |
+
model=LOW_COST_MODEL,
|
| 66 |
+
api_key=OPENAI_KEY,
|
| 67 |
+
reasoning={
|
| 68 |
+
"effort": "minimal",
|
| 69 |
+
},
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
ModelRegistry.register(
|
| 73 |
+
name="experiment_agent",
|
| 74 |
+
model=MEDIUM_COST_MODEL,
|
| 75 |
+
api_key=OPENAI_KEY,
|
| 76 |
+
reasoning={
|
| 77 |
+
"effort": reasoning,
|
| 78 |
+
},
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
ModelRegistry.register(
|
| 82 |
+
name="experiment_coding",
|
| 83 |
+
model=MEDIUM_COST_MODEL,
|
| 84 |
+
api_key=OPENAI_KEY,
|
| 85 |
+
reasoning={
|
| 86 |
+
"effort": reasoning,
|
| 87 |
+
},
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
ModelRegistry.register(
|
| 91 |
+
name="experiment_execute",
|
| 92 |
+
model=MEDIUM_COST_MODEL,
|
| 93 |
+
api_key=OPENAI_KEY,
|
| 94 |
+
reasoning={
|
| 95 |
+
"effort": reasoning,
|
| 96 |
+
},
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
ModelRegistry.register(
|
| 100 |
+
name="experiment_monitor",
|
| 101 |
+
model=LOW_COST_MODEL,
|
| 102 |
+
api_key=OPENAI_KEY,
|
| 103 |
+
reasoning={
|
| 104 |
+
"effort": "minimal",
|
| 105 |
+
},
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
ModelRegistry.register(
|
| 109 |
+
name="experiment_summary",
|
| 110 |
+
model=LOW_COST_MODEL,
|
| 111 |
+
api_key=OPENAI_KEY,
|
| 112 |
+
reasoning={
|
| 113 |
+
"effort": "low",
|
| 114 |
+
},
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def register_gpt_medium_high_models(reasoning: str = "low"):
|
| 119 |
+
"""Register GPT medium and high cost models in the ModelRegistry."""
|
| 120 |
+
ModelRegistry.register(
|
| 121 |
+
name="data",
|
| 122 |
+
model=MEDIUM_COST_MODEL,
|
| 123 |
+
api_key=OPENAI_KEY,
|
| 124 |
+
reasoning={
|
| 125 |
+
"effort": reasoning,
|
| 126 |
+
"summary": "detailed",
|
| 127 |
+
},
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
ModelRegistry.register(
|
| 131 |
+
name="plan",
|
| 132 |
+
model=HIGH_COST_MODEL,
|
| 133 |
+
api_key=OPENAI_KEY,
|
| 134 |
+
reasoning={
|
| 135 |
+
"effort": reasoning,
|
| 136 |
+
},
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
ModelRegistry.register(
|
| 140 |
+
name="critic",
|
| 141 |
+
model=MEDIUM_COST_MODEL,
|
| 142 |
+
api_key=OPENAI_KEY,
|
| 143 |
+
reasoning={
|
| 144 |
+
"effort": reasoning,
|
| 145 |
+
},
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
ModelRegistry.register(
|
| 149 |
+
name="mem",
|
| 150 |
+
model=MEDIUM_COST_MODEL,
|
| 151 |
+
api_key=OPENAI_KEY,
|
| 152 |
+
reasoning={
|
| 153 |
+
"effort": "minimal",
|
| 154 |
+
},
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
ModelRegistry.register(
|
| 158 |
+
name="embed",
|
| 159 |
+
model="text-embedding-3-small",
|
| 160 |
+
api_key=OPENAI_KEY,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
ModelRegistry.register(
|
| 164 |
+
name="history",
|
| 165 |
+
model=MEDIUM_COST_MODEL,
|
| 166 |
+
api_key=OPENAI_KEY,
|
| 167 |
+
reasoning={
|
| 168 |
+
"effort": "minimal",
|
| 169 |
+
},
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
ModelRegistry.register(
|
| 173 |
+
name="experiment_agent",
|
| 174 |
+
model=HIGH_COST_MODEL,
|
| 175 |
+
api_key=OPENAI_KEY,
|
| 176 |
+
reasoning={
|
| 177 |
+
"effort": reasoning,
|
| 178 |
+
},
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
ModelRegistry.register(
|
| 182 |
+
name="experiment_coding",
|
| 183 |
+
model=HIGH_COST_MODEL,
|
| 184 |
+
api_key=OPENAI_KEY,
|
| 185 |
+
reasoning={
|
| 186 |
+
"effort": reasoning,
|
| 187 |
+
},
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
ModelRegistry.register(
|
| 191 |
+
name="experiment_execute",
|
| 192 |
+
model=HIGH_COST_MODEL,
|
| 193 |
+
api_key=OPENAI_KEY,
|
| 194 |
+
reasoning={
|
| 195 |
+
"effort": reasoning,
|
| 196 |
+
},
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
ModelRegistry.register(
|
| 200 |
+
name="experiment_monitor",
|
| 201 |
+
model=MEDIUM_COST_MODEL,
|
| 202 |
+
api_key=OPENAI_KEY,
|
| 203 |
+
reasoning={
|
| 204 |
+
"effort": "minimal",
|
| 205 |
+
},
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
ModelRegistry.register(
|
| 209 |
+
name="experiment_summary",
|
| 210 |
+
model=MEDIUM_COST_MODEL,
|
| 211 |
+
api_key=OPENAI_KEY,
|
| 212 |
+
reasoning={
|
| 213 |
+
"effort": "low",
|
| 214 |
+
},
|
| 215 |
+
)
|
case-studies/case1/task.md
ADDED
|
File without changes
|
case-studies/case2/task.md
ADDED
|
File without changes
|
pyproject.toml
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "scievo"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "Add your description here"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.13"
|
| 11 |
+
dependencies = [
|
| 12 |
+
"beautifulsoup4>=4.14.2",
|
| 13 |
+
"ddgs>=9.6.1",
|
| 14 |
+
"epam-indigo==1.35.0",
|
| 15 |
+
"feedparser>=6.0.12",
|
| 16 |
+
"filetype>=1.2.0",
|
| 17 |
+
"jinja2>=3.1.6",
|
| 18 |
+
"json-repair>=0.53.0",
|
| 19 |
+
"langchain-text-splitters>=1.0.0",
|
| 20 |
+
"langgraph>=1.0.2",
|
| 21 |
+
"litellm>=1.79.0,<1.80.0",
|
| 22 |
+
"loguru>=0.7.3",
|
| 23 |
+
"numpy>=2.3.4",
|
| 24 |
+
"openhands-sdk==1.3.0",
|
| 25 |
+
"openhands-tools==1.3.0",
|
| 26 |
+
"pandas>=2.3.3",
|
| 27 |
+
"pexpect>=4.9.0",
|
| 28 |
+
"pillow>=12.0.0",
|
| 29 |
+
"pydantic>=2.12.3",
|
| 30 |
+
"pyfunctional>=1.5.0",
|
| 31 |
+
"python-toon>=0.1.2",
|
| 32 |
+
"pyyaml>=6.0.3",
|
| 33 |
+
"rich>=14.2.0",
|
| 34 |
+
"scikit-learn>=1.8.0",
|
| 35 |
+
"tiktoken>=0.12.0",
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
[dependency-groups]
|
| 39 |
+
dev = [
|
| 40 |
+
"jupyterlab>=4.4.10",
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
[project.optional-dependencies]
|
| 44 |
+
cpu = [
|
| 45 |
+
"torch>=2.9.0",
|
| 46 |
+
"torchvision",
|
| 47 |
+
]
|
| 48 |
+
cu128 = [
|
| 49 |
+
"torch>=2.9.0",
|
| 50 |
+
"torchvision",
|
| 51 |
+
]
|
| 52 |
+
mac = [
|
| 53 |
+
"torch>=2.9.0",
|
| 54 |
+
"torchvision",
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
[tool.setuptools]
|
| 58 |
+
packages = { find = { include = ["scievo", "scievo.*"] } }
|
| 59 |
+
|
| 60 |
+
[tool.uv]
|
| 61 |
+
conflicts = [
|
| 62 |
+
[
|
| 63 |
+
{ extra = "cpu" },
|
| 64 |
+
{ extra = "cu128" },
|
| 65 |
+
{ extra = "mac" },
|
| 66 |
+
],
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
[tool.uv.sources]
|
| 70 |
+
torch = [
|
| 71 |
+
{ index = "pytorch-cpu", extra = "cpu" },
|
| 72 |
+
{ index = "pytorch-cu128", extra = "cu128" },
|
| 73 |
+
{ index = "pytorch-mac", extra = "mac" },
|
| 74 |
+
]
|
| 75 |
+
torchvision = [
|
| 76 |
+
{ index = "pytorch-cpu", extra = "cpu" },
|
| 77 |
+
{ index = "pytorch-cu128", extra = "cu128" },
|
| 78 |
+
{ index = "pytorch-mac", extra = "mac" },
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
[[tool.uv.index]]
|
| 82 |
+
name = "pytorch-cpu"
|
| 83 |
+
url = "https://download.pytorch.org/whl/cpu"
|
| 84 |
+
explicit = true
|
| 85 |
+
|
| 86 |
+
[[tool.uv.index]]
|
| 87 |
+
name = "pytorch-cu128"
|
| 88 |
+
url = "https://download.pytorch.org/whl/cu128"
|
| 89 |
+
explicit = true
|
| 90 |
+
|
| 91 |
+
[[tool.uv.index]]
|
| 92 |
+
name = "pytorch-mac"
|
| 93 |
+
url = "https://pypi.org/simple"
|
| 94 |
+
explicit = true
|
reasoning_bank/README.md
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ReasoningBank Core (Reference Copy)
|
| 2 |
+
|
| 3 |
+
- This directory contains a minimal copy of the core component of Google's ReasoningBank.
|
| 4 |
+
- It is included only as a reminder/reference of what their core looks like.
|
| 5 |
+
- For the complete, authoritative source and updates, please refer to the original Google project.
|
reasoning_bank/mem_induction.py
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
"""Run mini-SWE-agent on SWE-bench instances in batch mode."""
|
| 4 |
+
# Read this first: https://mini-swe-agent.com/latest/usage/swebench/ (usage docs)
|
| 5 |
+
|
| 6 |
+
import concurrent.futures
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import random
|
| 10 |
+
import re
|
| 11 |
+
import threading
|
| 12 |
+
import time
|
| 13 |
+
import traceback
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
import typer
|
| 17 |
+
import yaml
|
| 18 |
+
from datasets import load_dataset
|
| 19 |
+
from google import genai
|
| 20 |
+
from google.genai.types import GenerateContentConfig, HttpOptions
|
| 21 |
+
from jinja2 import Template
|
| 22 |
+
from minisweagent import Environment
|
| 23 |
+
from minisweagent.agents.default import DefaultAgent
|
| 24 |
+
from minisweagent.config import builtin_config_dir, get_config_path
|
| 25 |
+
from minisweagent.environments import get_environment
|
| 26 |
+
from minisweagent.memory.instruction import FAILED_SI, SUCCESSFUL_SI
|
| 27 |
+
from minisweagent.memory.memory_management import select_memory
|
| 28 |
+
from minisweagent.models import get_model
|
| 29 |
+
from minisweagent.run.extra.utils.batch_progress import RunBatchProgressManager
|
| 30 |
+
from minisweagent.run.utils.save import save_traj
|
| 31 |
+
from minisweagent.utils.log import add_file_handler, logger
|
| 32 |
+
from rich.live import Live
|
| 33 |
+
|
| 34 |
+
client = genai.Client(http_options=HttpOptions(api_version="v1"))
|
| 35 |
+
|
| 36 |
+
_HELP_TEXT = """Run mini-SWE-agent on SWEBench instances.
|
| 37 |
+
|
| 38 |
+
[not dim]
|
| 39 |
+
More information about the usage: [bold green]https://mini-swe-agent.com/latest/usage/swebench/[/bold green]
|
| 40 |
+
[/not dim]
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
app = typer.Typer(rich_markup_mode="rich", add_completion=False)
|
| 44 |
+
|
| 45 |
+
DATASET_MAPPING = {
|
| 46 |
+
"full": "princeton-nlp/SWE-Bench",
|
| 47 |
+
"verified": "princeton-nlp/SWE-Bench_Verified",
|
| 48 |
+
"lite": "princeton-nlp/SWE-Bench_Lite",
|
| 49 |
+
"multimodal": "princeton-nlp/SWE-Bench_Multimodal",
|
| 50 |
+
"multilingual": "swe-bench/SWE-Bench_Multilingual",
|
| 51 |
+
"smith": "SWE-bench/SWE-smith",
|
| 52 |
+
"_test": "klieret/swe-bench-dummy-test-dataset",
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
_OUTPUT_FILE_LOCK = threading.Lock()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class ProgressTrackingAgent(DefaultAgent):
|
| 60 |
+
"""Simple wrapper around DefaultAgent that provides progress updates."""
|
| 61 |
+
|
| 62 |
+
def __init__(
|
| 63 |
+
self, *args, progress_manager: RunBatchProgressManager, instance_id: str = "", **kwargs
|
| 64 |
+
):
|
| 65 |
+
super().__init__(*args, **kwargs)
|
| 66 |
+
self.progress_manager: RunBatchProgressManager = progress_manager
|
| 67 |
+
self.instance_id = instance_id
|
| 68 |
+
|
| 69 |
+
def step(self) -> dict:
|
| 70 |
+
"""Override step to provide progress updates."""
|
| 71 |
+
self.progress_manager.update_instance_status(
|
| 72 |
+
self.instance_id, f"Step {self.model.n_calls + 1:3d} (${self.model.cost:.2f})"
|
| 73 |
+
)
|
| 74 |
+
return super().step()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def get_swebench_docker_image_name(instance: dict) -> str:
|
| 78 |
+
"""Get the image name for a SWEBench instance."""
|
| 79 |
+
image_name = instance.get("image_name", None)
|
| 80 |
+
if image_name is None:
|
| 81 |
+
# Docker doesn't allow double underscore, so we replace them with a magic token
|
| 82 |
+
iid = instance["instance_id"]
|
| 83 |
+
id_docker_compatible = iid.replace("__", "_1776_")
|
| 84 |
+
image_name = f"swebench/sweb.eval.x86_64.{id_docker_compatible}:latest".lower()
|
| 85 |
+
return image_name
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def get_sb_environment(config: dict, instance: dict) -> Environment:
|
| 89 |
+
env_config = config.setdefault("environment", {})
|
| 90 |
+
env_config["environment_class"] = env_config.get("environment_class", "docker")
|
| 91 |
+
image_name = get_swebench_docker_image_name(instance)
|
| 92 |
+
if env_config["environment_class"] == "docker":
|
| 93 |
+
env_config["image"] = image_name
|
| 94 |
+
elif env_config["environment_class"] == "singularity":
|
| 95 |
+
env_config["image"] = "docker://" + image_name
|
| 96 |
+
env = get_environment(env_config)
|
| 97 |
+
if startup_command := config.get("run", {}).get("env_startup_command"):
|
| 98 |
+
startup_command = Template(startup_command).render(**instance)
|
| 99 |
+
out = env.execute(startup_command)
|
| 100 |
+
if out["returncode"] != 0:
|
| 101 |
+
raise RuntimeError(f"Error executing startup command: {out}")
|
| 102 |
+
return env
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def update_preds_file(output_path: Path, instance_id: str, model_name: str, result: str):
|
| 106 |
+
"""Update the output JSON file with results from a single instance."""
|
| 107 |
+
with _OUTPUT_FILE_LOCK:
|
| 108 |
+
output_data = {}
|
| 109 |
+
if output_path.exists():
|
| 110 |
+
output_data = json.loads(output_path.read_text())
|
| 111 |
+
output_data[instance_id] = {
|
| 112 |
+
"model_name_or_path": model_name,
|
| 113 |
+
"instance_id": instance_id,
|
| 114 |
+
"model_patch": result,
|
| 115 |
+
}
|
| 116 |
+
output_path.write_text(json.dumps(output_data, indent=2))
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def remove_from_preds_file(output_path: Path, instance_id: str):
|
| 120 |
+
"""Remove an instance from the predictions file."""
|
| 121 |
+
if not output_path.exists():
|
| 122 |
+
return
|
| 123 |
+
with _OUTPUT_FILE_LOCK:
|
| 124 |
+
output_data = json.loads(output_path.read_text())
|
| 125 |
+
if instance_id in output_data:
|
| 126 |
+
del output_data[instance_id]
|
| 127 |
+
output_path.write_text(json.dumps(output_data, indent=2))
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def llm_generate(prompt: list[dict], model, verbose: bool = False, si: str = None) -> str:
|
| 131 |
+
"""Call gpt model to generate memories."""
|
| 132 |
+
if verbose:
|
| 133 |
+
print("Prompt:\n", prompt, "\n\n")
|
| 134 |
+
response = client.models.generate_content(
|
| 135 |
+
model=model,
|
| 136 |
+
contents=prompt,
|
| 137 |
+
config=GenerateContentConfig(
|
| 138 |
+
temperature=1.0,
|
| 139 |
+
max_output_tokens=65536,
|
| 140 |
+
system_instruction=si.strip() if si else None,
|
| 141 |
+
),
|
| 142 |
+
)
|
| 143 |
+
response = response.text
|
| 144 |
+
if verbose:
|
| 145 |
+
print(response)
|
| 146 |
+
return response.split("\n\n")
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def process_instance(
|
| 150 |
+
instance: dict,
|
| 151 |
+
output_dir: Path,
|
| 152 |
+
config: dict,
|
| 153 |
+
progress_manager: RunBatchProgressManager,
|
| 154 |
+
) -> None:
|
| 155 |
+
"""Process a single SWEBench instance."""
|
| 156 |
+
instance_id = instance["instance_id"]
|
| 157 |
+
instance_dir = output_dir / instance_id
|
| 158 |
+
# avoid inconsistent state if something here fails and there's leftover previous files
|
| 159 |
+
remove_from_preds_file(output_dir / "preds.json", instance_id)
|
| 160 |
+
(instance_dir / f"{instance_id}.traj.json").unlink(missing_ok=True)
|
| 161 |
+
model = get_model(config=config.get("model", {}))
|
| 162 |
+
task = instance["problem_statement"]
|
| 163 |
+
|
| 164 |
+
if not os.path.exists(f"./memory/{model.config.model_name}.jsonl"):
|
| 165 |
+
open(f"./memory/{model.config.model_name}.jsonl", "w").close() # create an empty file
|
| 166 |
+
|
| 167 |
+
with open(f"./memory/{model.config.model_name}.jsonl", "r") as f:
|
| 168 |
+
memory_bank = [json.loads(line) for line in f.readlines()]
|
| 169 |
+
|
| 170 |
+
res = select_memory(
|
| 171 |
+
1,
|
| 172 |
+
reasoning_bank=memory_bank,
|
| 173 |
+
cur_query=task,
|
| 174 |
+
task_id=instance_id,
|
| 175 |
+
cache_path=f"./memory/{model.config.model_name}_embeddings.jsonl",
|
| 176 |
+
prefer_model="gemini",
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
if not res:
|
| 180 |
+
selected_memory = ""
|
| 181 |
+
else:
|
| 182 |
+
mem_items = []
|
| 183 |
+
for item in res:
|
| 184 |
+
for i in item["memory_items"]:
|
| 185 |
+
mem_items.append(i)
|
| 186 |
+
selected_memory = "\n\n".join(mem_items)
|
| 187 |
+
|
| 188 |
+
progress_manager.on_instance_start(instance_id)
|
| 189 |
+
progress_manager.update_instance_status(instance_id, "Pulling/starting docker")
|
| 190 |
+
|
| 191 |
+
agent = None
|
| 192 |
+
extra_info = None
|
| 193 |
+
|
| 194 |
+
try:
|
| 195 |
+
env = get_sb_environment(config, instance)
|
| 196 |
+
agent = ProgressTrackingAgent(
|
| 197 |
+
model,
|
| 198 |
+
env,
|
| 199 |
+
progress_manager=progress_manager,
|
| 200 |
+
instance_id=instance_id,
|
| 201 |
+
**config.get("agent", {}),
|
| 202 |
+
)
|
| 203 |
+
exit_status, result = agent.run(task, selected_memory=selected_memory)
|
| 204 |
+
except Exception as e:
|
| 205 |
+
logger.error(f"Error processing instance {instance_id}: {e}", exc_info=True)
|
| 206 |
+
exit_status, result = type(e).__name__, str(e)
|
| 207 |
+
extra_info = {"traceback": traceback.format_exc()}
|
| 208 |
+
finally:
|
| 209 |
+
save_traj(
|
| 210 |
+
agent,
|
| 211 |
+
instance_dir / f"{instance_id}.traj.json",
|
| 212 |
+
exit_status=exit_status,
|
| 213 |
+
result=result,
|
| 214 |
+
extra_info=extra_info,
|
| 215 |
+
instance_id=instance_id,
|
| 216 |
+
print_fct=logger.info,
|
| 217 |
+
)
|
| 218 |
+
update_preds_file(output_dir / "preds.json", instance_id, model.config.model_name, result)
|
| 219 |
+
progress_manager.on_instance_end(instance_id, exit_status)
|
| 220 |
+
|
| 221 |
+
# read trajectory and extract memory
|
| 222 |
+
with open(instance_dir / f"{instance_id}.traj.json", "r") as f:
|
| 223 |
+
messages = json.load(f)["messages"]
|
| 224 |
+
trajectory = "\n".join([m["content"] for m in messages if m["role"] != "system"])
|
| 225 |
+
status = llm_judge_status(task, trajectory, model.config.model_name)
|
| 226 |
+
|
| 227 |
+
trajectory = f"**Query:** {task}\n\n**Trajectory:**\n{trajectory}"
|
| 228 |
+
if status:
|
| 229 |
+
generated_memory_item = llm_generate(
|
| 230 |
+
trajectory, model.config.model_name, True, si=SUCCESSFUL_SI
|
| 231 |
+
)
|
| 232 |
+
else:
|
| 233 |
+
generated_memory_item = llm_generate(
|
| 234 |
+
trajectory, model.config.model_name, True, si=FAILED_SI
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
with open(f"./memory/{model.config.model_name}.jsonl", "a") as f:
|
| 238 |
+
f.write(
|
| 239 |
+
json.dumps(
|
| 240 |
+
{
|
| 241 |
+
"task_id": instance_id,
|
| 242 |
+
"query": task,
|
| 243 |
+
"memory_items": generated_memory_item,
|
| 244 |
+
"status": "success" if status else "fail",
|
| 245 |
+
}
|
| 246 |
+
)
|
| 247 |
+
+ "\n"
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def llm_judge_status(task: str, trajectory: str, model: str) -> str:
|
| 252 |
+
prompt = f"Task: {task}\n\nTrajectory:\n{trajectory}\n\nDid the agent successfully complete the task? Answer with 'success' or 'fail' only."
|
| 253 |
+
response = client.models.generate_content(
|
| 254 |
+
model=model,
|
| 255 |
+
contents=prompt,
|
| 256 |
+
config=GenerateContentConfig(
|
| 257 |
+
temperature=0.0,
|
| 258 |
+
max_output_tokens=65536,
|
| 259 |
+
system_instruction="You are a helpful assistant that judges whether the agent successfully completed the task.",
|
| 260 |
+
),
|
| 261 |
+
)
|
| 262 |
+
response = response.text.strip().lower()
|
| 263 |
+
if "success" in response:
|
| 264 |
+
return True
|
| 265 |
+
else:
|
| 266 |
+
return False
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def filter_instances(
|
| 270 |
+
instances: list[dict], *, filter_spec: str, slice_spec: str = "", shuffle: bool = False
|
| 271 |
+
) -> list[dict]:
|
| 272 |
+
"""Filter and slice a list of SWEBench instances."""
|
| 273 |
+
if shuffle:
|
| 274 |
+
instances = sorted(instances.copy(), key=lambda x: x["instance_id"])
|
| 275 |
+
random.seed(42)
|
| 276 |
+
random.shuffle(instances)
|
| 277 |
+
before_filter = len(instances)
|
| 278 |
+
instances = [
|
| 279 |
+
instance for instance in instances if re.match(filter_spec, instance["instance_id"])
|
| 280 |
+
]
|
| 281 |
+
if (after_filter := len(instances)) != before_filter:
|
| 282 |
+
logger.info(f"Instance filter: {before_filter} -> {after_filter} instances")
|
| 283 |
+
if slice_spec:
|
| 284 |
+
values = [int(x) if x else None for x in slice_spec.split(":")]
|
| 285 |
+
instances = instances[slice(*values)]
|
| 286 |
+
if (after_slice := len(instances)) != before_filter:
|
| 287 |
+
logger.info(f"Instance slice: {before_filter} -> {after_slice} instances")
|
| 288 |
+
return instances
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
# fmt: off
|
| 292 |
+
@app.command(help=_HELP_TEXT)
|
| 293 |
+
def main(
|
| 294 |
+
subset: str = typer.Option("lite", "--subset", help="SWEBench subset to use or path to a dataset", rich_help_panel="Data selection"),
|
| 295 |
+
split: str = typer.Option("dev", "--split", help="Dataset split", rich_help_panel="Data selection"),
|
| 296 |
+
slice_spec: str = typer.Option("", "--slice", help="Slice specification (e.g., '0:5' for first 5 instances)", rich_help_panel="Data selection"),
|
| 297 |
+
filter_spec: str = typer.Option("", "--filter", help="Filter instance IDs by regex", rich_help_panel="Data selection"),
|
| 298 |
+
shuffle: bool = typer.Option(False, "--shuffle", help="Shuffle instances", rich_help_panel="Data selection"),
|
| 299 |
+
output: str = typer.Option("", "-o", "--output", help="Output directory", rich_help_panel="Basic"),
|
| 300 |
+
workers: int = typer.Option(1, "-w", "--workers", help="Number of worker threads for parallel processing", rich_help_panel="Basic"),
|
| 301 |
+
model: str | None = typer.Option(None, "-m", "--model", help="Model to use", rich_help_panel="Basic"),
|
| 302 |
+
model_class: str | None = typer.Option(None, "-c", "--model-class", help="Model class to use (e.g., 'anthropic' or 'minisweagent.models.anthropic.AnthropicModel')", rich_help_panel="Advanced"),
|
| 303 |
+
redo_existing: bool = typer.Option(False, "--redo-existing", help="Redo existing instances", rich_help_panel="Data selection"),
|
| 304 |
+
config_spec: Path = typer.Option( builtin_config_dir / "extra" / "swebench.yaml", "-c", "--config", help="Path to a config file", rich_help_panel="Basic"),
|
| 305 |
+
environment_class: str | None = typer.Option( None, "--environment-class", help="Environment type to use. Recommended are docker or singularity", rich_help_panel="Advanced"),
|
| 306 |
+
) -> None:
|
| 307 |
+
# fmt: on
|
| 308 |
+
output_path = Path(output)
|
| 309 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 310 |
+
logger.info(f"Results will be saved to {output_path}")
|
| 311 |
+
add_file_handler(output_path / "minisweagent.log")
|
| 312 |
+
|
| 313 |
+
dataset_path = DATASET_MAPPING.get(subset, subset)
|
| 314 |
+
logger.info(f"Loading dataset {dataset_path}, split {split}...")
|
| 315 |
+
instances = list(load_dataset(dataset_path, split=split))
|
| 316 |
+
|
| 317 |
+
instances = filter_instances(instances, filter_spec=filter_spec, slice_spec=slice_spec, shuffle=shuffle)
|
| 318 |
+
if not redo_existing and (output_path / "preds.json").exists():
|
| 319 |
+
existing_instances = list(json.loads((output_path / "preds.json").read_text()).keys())
|
| 320 |
+
logger.info(f"Skipping {len(existing_instances)} existing instances")
|
| 321 |
+
instances = [instance for instance in instances if instance["instance_id"] not in existing_instances]
|
| 322 |
+
logger.info(f"Running on {len(instances)} instances...")
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
config = yaml.safe_load(get_config_path(config_spec).read_text())
|
| 326 |
+
if environment_class is not None:
|
| 327 |
+
config.setdefault("environment", {})["environment_class"] = environment_class
|
| 328 |
+
if model is not None:
|
| 329 |
+
config.setdefault("model", {})["model_name"] = model
|
| 330 |
+
if model_class is not None:
|
| 331 |
+
config.setdefault("model", {})["model_class"] = model_class
|
| 332 |
+
|
| 333 |
+
progress_manager = RunBatchProgressManager(len(instances), output_path / f"exit_statuses_{time.time()}.yaml")
|
| 334 |
+
|
| 335 |
+
def process_futures(futures: dict[concurrent.futures.Future, str]):
|
| 336 |
+
for future in concurrent.futures.as_completed(futures):
|
| 337 |
+
try:
|
| 338 |
+
future.result()
|
| 339 |
+
except concurrent.futures.CancelledError:
|
| 340 |
+
pass
|
| 341 |
+
except Exception as e:
|
| 342 |
+
instance_id = futures[future]
|
| 343 |
+
logger.error(f"Error in future for instance {instance_id}: {e}", exc_info=True)
|
| 344 |
+
progress_manager.on_uncaught_exception(instance_id, e)
|
| 345 |
+
|
| 346 |
+
with Live(progress_manager.render_group, refresh_per_second=4):
|
| 347 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
|
| 348 |
+
futures = {
|
| 349 |
+
executor.submit(process_instance, instance, output_path, config, progress_manager): instance[
|
| 350 |
+
"instance_id"
|
| 351 |
+
]
|
| 352 |
+
for instance in instances
|
| 353 |
+
}
|
| 354 |
+
try:
|
| 355 |
+
process_futures(futures)
|
| 356 |
+
except KeyboardInterrupt:
|
| 357 |
+
logger.info("Cancelling all pending jobs. Press ^C again to exit immediately.")
|
| 358 |
+
for future in futures:
|
| 359 |
+
if not future.running() and not future.done():
|
| 360 |
+
future.cancel()
|
| 361 |
+
process_futures(futures)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
if __name__ == "__main__":
|
| 365 |
+
app()
|
reasoning_bank/mem_manage.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
# from transformers import AutoTokenizer, AutoModel
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Dict, List, Tuple
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
from google import genai
|
| 16 |
+
from google.genai.types import EmbedContentConfig
|
| 17 |
+
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
|
| 18 |
+
|
| 19 |
+
client = genai.Client()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_embeddings(texts: list) -> list:
|
| 23 |
+
"""
|
| 24 |
+
Get embeddings for a list of texts using Google GenAI.
|
| 25 |
+
"""
|
| 26 |
+
response = client.models.embed_content(
|
| 27 |
+
model="gemini-embedding-001",
|
| 28 |
+
contents=texts,
|
| 29 |
+
config=EmbedContentConfig(
|
| 30 |
+
task_type="RETRIEVAL_DOCUMENT",
|
| 31 |
+
output_dimensionality=3072,
|
| 32 |
+
title="Memory Embeddings",
|
| 33 |
+
),
|
| 34 |
+
)
|
| 35 |
+
return [item.embedding for item in response.embeddings]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def l2_normalize(x: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
| 39 |
+
return F.normalize(x, p=2, dim=dim)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def embed_query_with_qwen(query: str) -> Tuple[torch.Tensor, str, int]:
|
| 43 |
+
"""Returns (1, D) torch tensor (on CPU), model_name, dim."""
|
| 44 |
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Embedding-8B", padding_side="left")
|
| 45 |
+
model = AutoModel.from_pretrained("Qwen/Qwen3-Embedding-8B")
|
| 46 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 47 |
+
model = model.to(device)
|
| 48 |
+
|
| 49 |
+
batch = tokenizer([query], max_length=1024, padding=True, truncation=True, return_tensors="pt")
|
| 50 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
| 51 |
+
with torch.no_grad():
|
| 52 |
+
out = model(**batch)
|
| 53 |
+
last_hidden = out.last_hidden_state # (1, L, D)
|
| 54 |
+
masked = last_hidden.masked_fill(~batch["attention_mask"][..., None].bool(), 0.0)
|
| 55 |
+
pooled = masked.sum(dim=1) / batch["attention_mask"].sum(dim=1)[..., None] # (1, D)
|
| 56 |
+
pooled = pooled.to("cpu")
|
| 57 |
+
pooled = l2_normalize(pooled, dim=1)
|
| 58 |
+
return pooled
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def embed_query_with_gemini(
|
| 62 |
+
query: str, dimensionality: int = 3072
|
| 63 |
+
) -> Tuple[torch.Tensor, str, int]:
|
| 64 |
+
"""Returns (1, D) torch tensor (on CPU), model_name, dim."""
|
| 65 |
+
|
| 66 |
+
model_name = "gemini-embedding-001"
|
| 67 |
+
model = TextEmbeddingModel.from_pretrained(model_name)
|
| 68 |
+
text_input = TextEmbeddingInput(query, "RETRIEVAL_DOCUMENT")
|
| 69 |
+
|
| 70 |
+
resp = model.get_embeddings([text_input], output_dimensionality=dimensionality)
|
| 71 |
+
|
| 72 |
+
# vertexai returns a list of TextEmbedding objects with .values
|
| 73 |
+
vec = torch.tensor([resp[0].values], dtype=torch.float32) # (1, D)
|
| 74 |
+
|
| 75 |
+
return vec
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def load_cached_embeddings(path: str) -> Tuple[List[str], List[str], torch.Tensor]:
|
| 79 |
+
"""
|
| 80 |
+
Load cached embeddings from JSONL.
|
| 81 |
+
Returns: ids, texts, torch.Tensor (N, D) normalized
|
| 82 |
+
Each line must contain keys: id, text, embedding
|
| 83 |
+
"""
|
| 84 |
+
ids, texts, vecs = [], [], []
|
| 85 |
+
if not os.path.exists(path):
|
| 86 |
+
logger.warning(f"Cache file not found: {path}, creating an empty cache.")
|
| 87 |
+
open(path, "w").close() # create an empty file
|
| 88 |
+
return ids, texts, torch.empty(0)
|
| 89 |
+
|
| 90 |
+
with open(path, "r") as f:
|
| 91 |
+
for line in f:
|
| 92 |
+
if not line.strip():
|
| 93 |
+
continue
|
| 94 |
+
obj = json.loads(line)
|
| 95 |
+
ids.append(obj["id"])
|
| 96 |
+
texts.append(obj.get("text", ""))
|
| 97 |
+
vecs.append(obj["embedding"])
|
| 98 |
+
|
| 99 |
+
if len(vecs) == 0:
|
| 100 |
+
return ids, texts, torch.empty(0)
|
| 101 |
+
|
| 102 |
+
emb = torch.tensor(vecs, dtype=torch.float32) # (N, D)
|
| 103 |
+
emb = l2_normalize(emb, dim=1)
|
| 104 |
+
|
| 105 |
+
return ids, texts, emb
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
| 109 |
+
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
| 110 |
+
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def get_detailed_instruct(task_description: str, query: str) -> str:
|
| 114 |
+
return f"Instruct: {task_description}\nQuery: {query}"
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def formalize(queries):
|
| 118 |
+
tmp = []
|
| 119 |
+
ids = []
|
| 120 |
+
for id, data in enumerate(queries):
|
| 121 |
+
ids.append(id)
|
| 122 |
+
tmp.append(data)
|
| 123 |
+
return tmp, ids
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def select_memory(
|
| 127 |
+
n: int,
|
| 128 |
+
reasoning_bank: List[Dict],
|
| 129 |
+
cur_query: str,
|
| 130 |
+
task_id: str = None,
|
| 131 |
+
cache_path: str = "./memories/embeddings.jsonl",
|
| 132 |
+
prefer_model: str = "gemini",
|
| 133 |
+
) -> Dict:
|
| 134 |
+
"""
|
| 135 |
+
Returns a dict of top-n items by ID -> (optionally) original metadata.
|
| 136 |
+
This uses ONLY the cached embeddings; it does not recompute them.
|
| 137 |
+
"""
|
| 138 |
+
if n > 10:
|
| 139 |
+
logger.error("the number of return experiences shouldn't be greater than 10")
|
| 140 |
+
|
| 141 |
+
id2score, ordered_ids = screening(
|
| 142 |
+
cur_query=cur_query, task_id=task_id, cache_path=cache_path, prefer_model=prefer_model
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
if not ordered_ids:
|
| 146 |
+
return {}
|
| 147 |
+
|
| 148 |
+
top_ids = ordered_ids[:n]
|
| 149 |
+
|
| 150 |
+
# optional: map back to your in-memory store if you have it
|
| 151 |
+
# below assumes your cache ids correspond 1:1 to indices in reasoning_bank
|
| 152 |
+
out = []
|
| 153 |
+
for sid in top_ids:
|
| 154 |
+
# find the corresponding reasoning bank entry, with reasoning_bank["task_id"] == sid
|
| 155 |
+
for i, item in enumerate(reasoning_bank):
|
| 156 |
+
if item["task_id"] == sid:
|
| 157 |
+
out.append(reasoning_bank[i])
|
| 158 |
+
break
|
| 159 |
+
return out
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def screening(
|
| 163 |
+
cur_query: str,
|
| 164 |
+
cache_path: str,
|
| 165 |
+
task_id: str = None,
|
| 166 |
+
prefer_model: str = "",
|
| 167 |
+
) -> Tuple[List[Tuple[str, float]], List[str]]:
|
| 168 |
+
"""
|
| 169 |
+
Compute similarity of current query against cached embeddings, optionally append the query embedding to cache.
|
| 170 |
+
"""
|
| 171 |
+
cache_ids, cache_texts, cache_emb = load_cached_embeddings(cache_path)
|
| 172 |
+
|
| 173 |
+
# choose embedding method to match the cache
|
| 174 |
+
use_qwen = "Qwen" in prefer_model
|
| 175 |
+
|
| 176 |
+
if use_qwen:
|
| 177 |
+
q_vec = embed_query_with_qwen(cur_query)
|
| 178 |
+
else:
|
| 179 |
+
q_vec = embed_query_with_gemini(cur_query, dimensionality=3072)
|
| 180 |
+
|
| 181 |
+
# write current query embeddings to cache
|
| 182 |
+
record = {
|
| 183 |
+
"id": task_id,
|
| 184 |
+
"text": cur_query,
|
| 185 |
+
"embedding": q_vec.squeeze(0).tolist(),
|
| 186 |
+
}
|
| 187 |
+
with open(cache_path, "a") as f:
|
| 188 |
+
f.write(json.dumps(record) + "\n")
|
| 189 |
+
logger.info(f"Appended new query embedding to cache: webarena.{task_id}")
|
| 190 |
+
|
| 191 |
+
if len(cache_emb) == 0:
|
| 192 |
+
logger.warning(f"No cached embeddings found in {cache_path}.")
|
| 193 |
+
return [], []
|
| 194 |
+
|
| 195 |
+
# add instruction-aware embedding for calculation
|
| 196 |
+
task = "Given the prior software engineering queries, your task is to analyze a current query's intent and select relevant prior queries that could help resolve it."
|
| 197 |
+
|
| 198 |
+
instruction_query = get_detailed_instruct(task, cur_query)
|
| 199 |
+
instruct_vec = embed_query_with_gemini(instruction_query, dimensionality=3072)
|
| 200 |
+
instruct_vec = l2_normalize(instruct_vec, dim=1)
|
| 201 |
+
|
| 202 |
+
# Calculate similarity scores for embeddings and current query
|
| 203 |
+
scores = (instruct_vec @ cache_emb.T).squeeze(0) * 100.0 # (N,)
|
| 204 |
+
id2score = list(zip(cache_ids, scores.tolist()))
|
| 205 |
+
id2score.sort(key=lambda x: x[1], reverse=True)
|
| 206 |
+
|
| 207 |
+
return id2score, [str(i) for i, _ in id2score]
|
requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
altair
|
| 2 |
+
pandas
|
| 3 |
+
streamlit
|
scievo/__init__.py
ADDED
|
File without changes
|
scievo/agents/__init__.py
ADDED
|
File without changes
|
scievo/agents/critic_agent/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .build import build
|
| 2 |
+
from .state import CriticAgentState
|
scievo/agents/critic_agent/build.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langgraph.graph import END, START, StateGraph
|
| 2 |
+
from loguru import logger
|
| 3 |
+
|
| 4 |
+
from . import execute
|
| 5 |
+
from .state import CriticAgentState
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@logger.catch
|
| 9 |
+
def build():
|
| 10 |
+
g = StateGraph(CriticAgentState)
|
| 11 |
+
|
| 12 |
+
# nodes
|
| 13 |
+
g.add_node("create_first_user_msg", execute.create_first_user_msg_node)
|
| 14 |
+
g.add_node("gateway", execute.gateway_node)
|
| 15 |
+
g.add_node("llm_chat", execute.llm_chat_node)
|
| 16 |
+
g.add_node("tool_calling", execute.tool_calling_node)
|
| 17 |
+
g.add_node("summary", execute.summary_node)
|
| 18 |
+
|
| 19 |
+
# edges
|
| 20 |
+
g.add_edge(START, "create_first_user_msg")
|
| 21 |
+
g.add_edge("create_first_user_msg", "gateway")
|
| 22 |
+
g.add_conditional_edges(
|
| 23 |
+
"gateway",
|
| 24 |
+
execute.gateway_conditional,
|
| 25 |
+
[
|
| 26 |
+
"llm_chat",
|
| 27 |
+
"tool_calling",
|
| 28 |
+
"summary",
|
| 29 |
+
],
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
# edges from nodes back to gateway
|
| 33 |
+
g.add_edge("llm_chat", "gateway")
|
| 34 |
+
g.add_edge("tool_calling", "gateway")
|
| 35 |
+
|
| 36 |
+
# edge from summary to end
|
| 37 |
+
g.add_edge("summary", END)
|
| 38 |
+
|
| 39 |
+
return g
|
scievo/agents/critic_agent/execute.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Agent for criticizing and giving feedback on the agent's actions
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import TYPE_CHECKING, TypeVar
|
| 6 |
+
|
| 7 |
+
from loguru import logger
|
| 8 |
+
|
| 9 |
+
from scievo.core import constant
|
| 10 |
+
from scievo.core.llms import ModelRegistry
|
| 11 |
+
from scievo.core.types import Message
|
| 12 |
+
from scievo.core.utils import wrap_dict_to_toon, wrap_text_with_block
|
| 13 |
+
from scievo.prompts import PROMPTS
|
| 14 |
+
from scievo.rbank.subgraph import mem_retrieval
|
| 15 |
+
from scievo.tools import Tool, ToolRegistry
|
| 16 |
+
|
| 17 |
+
from .state import CriticAgentState
|
| 18 |
+
|
| 19 |
+
if TYPE_CHECKING:
|
| 20 |
+
from scievo.core.types import HistoryState, RBankState
|
| 21 |
+
from scievo.rbank.memo import Memo
|
| 22 |
+
|
| 23 |
+
MemHistoryMixin = TypeVar("MemHistoryMixin", HistoryState, RBankState)
|
| 24 |
+
|
| 25 |
+
LLM_NAME = "critic"
|
| 26 |
+
AGENT_NAME = "critic"
|
| 27 |
+
|
| 28 |
+
BUILTIN_TOOLSETS = [
|
| 29 |
+
# "todo",
|
| 30 |
+
"state",
|
| 31 |
+
"history",
|
| 32 |
+
"web",
|
| 33 |
+
]
|
| 34 |
+
ALLOWED_TOOLSETS = ["fs", "web"]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def create_first_user_msg_node(agent_state: CriticAgentState) -> CriticAgentState:
|
| 38 |
+
logger.debug("create_first_user_msg_node of Agent {}", AGENT_NAME)
|
| 39 |
+
agent_state.add_node_history("create_first_user_msg")
|
| 40 |
+
|
| 41 |
+
# Stringify all input messages
|
| 42 |
+
input_msgs_texts = []
|
| 43 |
+
for i, msg in enumerate(agent_state.input_msgs):
|
| 44 |
+
plain = msg.to_plain_text()
|
| 45 |
+
input_msgs_texts.append(f"--- Message {i} Begin ---\n{plain}\n--- Message {i} End ---")
|
| 46 |
+
trajectory_text: str = "\n".join(input_msgs_texts)
|
| 47 |
+
|
| 48 |
+
# Format using user_prompt template
|
| 49 |
+
user_prompt_content = PROMPTS.critic.user_prompt.render(
|
| 50 |
+
plan_text=agent_state.plan,
|
| 51 |
+
trajectory_text=trajectory_text,
|
| 52 |
+
is_data_agent=agent_state.is_data_agent,
|
| 53 |
+
is_exp_agent=agent_state.is_exp_agent,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# Add as first user message
|
| 57 |
+
agent_state.add_message(
|
| 58 |
+
Message(role="user", content=user_prompt_content, agent_sender=AGENT_NAME)
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
return agent_state
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def gateway_node(agent_state: CriticAgentState) -> CriticAgentState:
|
| 65 |
+
# NOTE: Same as data agent
|
| 66 |
+
logger.trace("gateway_node of Agent {}", AGENT_NAME)
|
| 67 |
+
return agent_state
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def gateway_conditional(agent_state: CriticAgentState) -> str:
|
| 71 |
+
# NOTE: Same as data agent
|
| 72 |
+
last_msg = agent_state.patched_history[-1]
|
| 73 |
+
if (tool_calls := last_msg.tool_calls) and len(tool_calls) > 0:
|
| 74 |
+
return "tool_calling"
|
| 75 |
+
|
| 76 |
+
match last_msg.role:
|
| 77 |
+
case "user" | "tool":
|
| 78 |
+
return "llm_chat"
|
| 79 |
+
case "assistant":
|
| 80 |
+
# finish this round of critic, go to "summary" node
|
| 81 |
+
return "summary"
|
| 82 |
+
case _:
|
| 83 |
+
raise ValueError(f"Unknown message role: {last_msg.role}")
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
mem_retrieval_subgraph = mem_retrieval.build()
|
| 87 |
+
mem_retrieval_subgraph_compiled = mem_retrieval_subgraph.compile()
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def llm_chat_node(agent_state: CriticAgentState) -> CriticAgentState:
|
| 91 |
+
logger.debug("llm_chat_node of Agent {}", AGENT_NAME)
|
| 92 |
+
agent_state.add_node_history("llm_chat")
|
| 93 |
+
|
| 94 |
+
selected_state = {
|
| 95 |
+
"current_activated_toolsets": agent_state.toolsets,
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
# retrieve memos
|
| 99 |
+
if constant.REASONING_BANK_ENABLED:
|
| 100 |
+
try:
|
| 101 |
+
mem_dirs = [agent_state.sess_dir / "short_term"]
|
| 102 |
+
if d := agent_state.long_term_mem_dir:
|
| 103 |
+
mem_dirs.append(d)
|
| 104 |
+
if d := agent_state.project_mem_dir:
|
| 105 |
+
mem_dirs.append(d)
|
| 106 |
+
res = mem_retrieval_subgraph_compiled.invoke(
|
| 107 |
+
mem_retrieval.MemRetrievalState(
|
| 108 |
+
input_msgs=agent_state.patched_history,
|
| 109 |
+
mem_dirs=mem_dirs,
|
| 110 |
+
max_num_memos=constant.MEM_RETRIEVAL_MAX_NUM_MEMOS,
|
| 111 |
+
)
|
| 112 |
+
)
|
| 113 |
+
memos: list[Memo] = res.get("output_memos", [])
|
| 114 |
+
from scievo.agents.data_agent.execute import _memos_to_markdown
|
| 115 |
+
|
| 116 |
+
memory_text = _memos_to_markdown(memos)
|
| 117 |
+
except Exception:
|
| 118 |
+
logger.exception("mem_retrieval_error")
|
| 119 |
+
memory_text = None
|
| 120 |
+
else:
|
| 121 |
+
memory_text = None
|
| 122 |
+
|
| 123 |
+
# update system prompt
|
| 124 |
+
system_prompt = PROMPTS.critic.system_prompt.render(
|
| 125 |
+
state_text=wrap_dict_to_toon(selected_state),
|
| 126 |
+
toolsets_desc=ToolRegistry.get_toolsets_desc(BUILTIN_TOOLSETS + ALLOWED_TOOLSETS),
|
| 127 |
+
memory_text=wrap_text_with_block(memory_text, "markdown"),
|
| 128 |
+
is_data_agent=agent_state.is_data_agent,
|
| 129 |
+
is_exp_agent=agent_state.is_exp_agent,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# construct tools
|
| 133 |
+
tools: dict[str, Tool] = {}
|
| 134 |
+
for toolset in agent_state.toolsets:
|
| 135 |
+
tools.update(ToolRegistry.get_toolset(toolset))
|
| 136 |
+
for toolset in BUILTIN_TOOLSETS:
|
| 137 |
+
tools.update(ToolRegistry.get_toolset(toolset))
|
| 138 |
+
|
| 139 |
+
msg = ModelRegistry.completion(
|
| 140 |
+
LLM_NAME,
|
| 141 |
+
agent_state.patched_history,
|
| 142 |
+
system_prompt=(
|
| 143 |
+
Message(role="system", content=system_prompt)
|
| 144 |
+
.with_log(cond=constant.LOG_SYSTEM_PROMPT)
|
| 145 |
+
.content
|
| 146 |
+
),
|
| 147 |
+
agent_sender=AGENT_NAME,
|
| 148 |
+
tools=[tool.name for tool in tools.values()],
|
| 149 |
+
).with_log()
|
| 150 |
+
agent_state.add_message(msg)
|
| 151 |
+
|
| 152 |
+
return agent_state
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def tool_calling_node(agent_state: CriticAgentState) -> CriticAgentState:
|
| 156 |
+
"""Execute tool calls from the last message and update the graph state"""
|
| 157 |
+
logger.debug("tool_calling_node of Agent {}", AGENT_NAME)
|
| 158 |
+
agent_state.add_node_history("tool_calling")
|
| 159 |
+
# Get the last message which contains tool calls
|
| 160 |
+
last_msg = agent_state.patched_history[-1]
|
| 161 |
+
|
| 162 |
+
if not last_msg.tool_calls:
|
| 163 |
+
raise ValueError("No tool calls found in the last message")
|
| 164 |
+
|
| 165 |
+
# construct tools
|
| 166 |
+
tools: dict[str, Tool] = {}
|
| 167 |
+
for toolset in agent_state.toolsets:
|
| 168 |
+
tools.update(ToolRegistry.get_toolset(toolset))
|
| 169 |
+
for toolset in BUILTIN_TOOLSETS:
|
| 170 |
+
tools.update(ToolRegistry.get_toolset(toolset))
|
| 171 |
+
|
| 172 |
+
function_map = {tool.name: tool.func for tool in tools.values()}
|
| 173 |
+
|
| 174 |
+
# Execute each tool call
|
| 175 |
+
for tool_call in last_msg.tool_calls:
|
| 176 |
+
tool_name = tool_call.function.name
|
| 177 |
+
|
| 178 |
+
# Check if tool exists in function map
|
| 179 |
+
if tool_name not in function_map:
|
| 180 |
+
error_msg = f"Tool {tool_name} not found"
|
| 181 |
+
tool_response = {
|
| 182 |
+
"role": "tool",
|
| 183 |
+
"tool_name": tool_name,
|
| 184 |
+
"tool_call_id": tool_call.id,
|
| 185 |
+
"content": error_msg,
|
| 186 |
+
}
|
| 187 |
+
agent_state.add_message(Message(**tool_response).with_log())
|
| 188 |
+
continue
|
| 189 |
+
|
| 190 |
+
# Parse tool arguments
|
| 191 |
+
try:
|
| 192 |
+
args = json.loads(tool_call.function.arguments)
|
| 193 |
+
assert isinstance(args, dict)
|
| 194 |
+
except json.JSONDecodeError as e:
|
| 195 |
+
error_msg = f"Invalid JSON in tool arguments: {e}"
|
| 196 |
+
tool_response = {
|
| 197 |
+
"role": "tool",
|
| 198 |
+
"tool_name": tool_name,
|
| 199 |
+
"tool_call_id": tool_call.id,
|
| 200 |
+
"content": error_msg,
|
| 201 |
+
}
|
| 202 |
+
agent_state.add_message(Message(**tool_response).with_log())
|
| 203 |
+
continue
|
| 204 |
+
except AssertionError as e:
|
| 205 |
+
error_msg = f"Invalid tool arguments: {e}"
|
| 206 |
+
tool_response = {
|
| 207 |
+
"role": "tool",
|
| 208 |
+
"tool_name": tool_name,
|
| 209 |
+
"tool_call_id": tool_call.id,
|
| 210 |
+
"content": error_msg,
|
| 211 |
+
}
|
| 212 |
+
agent_state.add_message(Message(**tool_response).with_log())
|
| 213 |
+
continue
|
| 214 |
+
|
| 215 |
+
# Execute the tool
|
| 216 |
+
try:
|
| 217 |
+
# Pass the graph state to the tool function
|
| 218 |
+
func = function_map[tool_name]
|
| 219 |
+
|
| 220 |
+
# Check if function expects agent_state parameter
|
| 221 |
+
import inspect
|
| 222 |
+
|
| 223 |
+
sig = inspect.signature(func)
|
| 224 |
+
if constant.__AGENT_STATE_NAME__ in sig.parameters:
|
| 225 |
+
args.update({constant.__AGENT_STATE_NAME__: agent_state})
|
| 226 |
+
if constant.__CTX_NAME__ in sig.parameters:
|
| 227 |
+
args.update({constant.__CTX_NAME__: {"current_agent": AGENT_NAME}})
|
| 228 |
+
|
| 229 |
+
# Execute the tool in the agent's local environment
|
| 230 |
+
with agent_state.local_env:
|
| 231 |
+
result = func(**args)
|
| 232 |
+
|
| 233 |
+
# Create tool response message
|
| 234 |
+
tool_response = {
|
| 235 |
+
"role": "tool",
|
| 236 |
+
"tool_call_id": tool_call.id,
|
| 237 |
+
"tool_name": tool_name,
|
| 238 |
+
"content": str(result), # Ensure result is string
|
| 239 |
+
}
|
| 240 |
+
agent_state.add_message(Message(**tool_response).with_log())
|
| 241 |
+
|
| 242 |
+
except Exception as e:
|
| 243 |
+
error_msg = f"Tool {tool_name} execution failed: {e}"
|
| 244 |
+
tool_response = {
|
| 245 |
+
"role": "tool",
|
| 246 |
+
"tool_call_id": tool_call.id,
|
| 247 |
+
"tool_name": tool_name,
|
| 248 |
+
"content": error_msg,
|
| 249 |
+
}
|
| 250 |
+
agent_state.add_message(Message(**tool_response).with_log())
|
| 251 |
+
|
| 252 |
+
return agent_state
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def summary_node(agent_state: CriticAgentState) -> CriticAgentState:
|
| 256 |
+
logger.debug("summary_node of Agent {}", AGENT_NAME)
|
| 257 |
+
agent_state.add_node_history("summary")
|
| 258 |
+
|
| 259 |
+
# update system prompt
|
| 260 |
+
system_prompt = PROMPTS.critic.system_prompt.render(
|
| 261 |
+
toolsets_desc={},
|
| 262 |
+
is_data_agent=agent_state.is_data_agent,
|
| 263 |
+
is_exp_agent=agent_state.is_exp_agent,
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
# Render the summary prompt
|
| 267 |
+
summary_prompt_content = PROMPTS.critic.user_prompt_summary.render(
|
| 268 |
+
is_data_agent=agent_state.is_data_agent,
|
| 269 |
+
is_exp_agent=agent_state.is_exp_agent,
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
# Add summary request as user message
|
| 273 |
+
agent_state.add_message(
|
| 274 |
+
Message(role="user", content=summary_prompt_content, agent_sender=AGENT_NAME)
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# Get AI summary response
|
| 278 |
+
msg = ModelRegistry.completion(
|
| 279 |
+
LLM_NAME,
|
| 280 |
+
agent_state.patched_history,
|
| 281 |
+
system_prompt=system_prompt,
|
| 282 |
+
agent_sender=AGENT_NAME,
|
| 283 |
+
).with_log()
|
| 284 |
+
agent_state.add_message(msg)
|
| 285 |
+
|
| 286 |
+
# Set the summary message as the output
|
| 287 |
+
agent_state.critic_msg = msg
|
| 288 |
+
|
| 289 |
+
return agent_state
|
scievo/agents/critic_agent/state.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import model_validator
|
| 2 |
+
|
| 3 |
+
from scievo.core.types import HistoryState, Message, RBankState, ToolsetState
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class CriticAgentState(HistoryState, ToolsetState, RBankState):
|
| 7 |
+
# messages to be criticized (input)
|
| 8 |
+
input_msgs: list[Message]
|
| 9 |
+
# current plan of the caller (input)
|
| 10 |
+
plan: str | None = None
|
| 11 |
+
# whether the input messages are from data agent (input)
|
| 12 |
+
is_data_agent: bool = False
|
| 13 |
+
# whether the input messages are from experiment agent (input)
|
| 14 |
+
is_exp_agent: bool = False
|
| 15 |
+
# critics (output)
|
| 16 |
+
critic_msg: Message | None = None
|
| 17 |
+
|
| 18 |
+
@model_validator(mode="after")
|
| 19 |
+
def check_agent_source(self):
|
| 20 |
+
if self.is_data_agent and self.is_exp_agent:
|
| 21 |
+
raise ValueError("CriticAgentState: both is_data_agent and is_exp_agent are True")
|
| 22 |
+
if not self.is_data_agent and not self.is_exp_agent:
|
| 23 |
+
raise ValueError("CriticAgentState: both is_data_agent and is_exp_agent are False")
|
| 24 |
+
return self
|
scievo/agents/data_agent/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .build import build
|
| 2 |
+
from .state import DataAgentState
|
scievo/agents/data_agent/build.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langgraph.graph import END, START, StateGraph
|
| 2 |
+
from loguru import logger
|
| 3 |
+
|
| 4 |
+
from scievo.core import constant
|
| 5 |
+
from scievo.core.types import Message
|
| 6 |
+
from scievo.rbank.subgraph import mem_consolidation
|
| 7 |
+
|
| 8 |
+
from . import execute, plan
|
| 9 |
+
from .paper_subagent import build as paper_subagent_build
|
| 10 |
+
from .paper_subagent.state import PaperSearchAgentState
|
| 11 |
+
from .state import DataAgentState
|
| 12 |
+
|
| 13 |
+
mem_consolidation_subgraph = mem_consolidation.build()
|
| 14 |
+
mem_consolidation_subgraph_compiled = mem_consolidation_subgraph.compile()
|
| 15 |
+
|
| 16 |
+
paper_subagent_graph = paper_subagent_build()
|
| 17 |
+
paper_subagent_graph_compiled = paper_subagent_graph.compile()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def finialize_node(agent_state: DataAgentState) -> DataAgentState:
|
| 21 |
+
"""A finalization node to do any final processing before ending the graph."""
|
| 22 |
+
agent_state.intermediate_state.append(
|
| 23 |
+
{
|
| 24 |
+
"node_name": "finalize",
|
| 25 |
+
"output": f"Finalization complete. Plans completed: {len(agent_state.past_plans)}, Remaining: {len(agent_state.remaining_plans)}",
|
| 26 |
+
}
|
| 27 |
+
)
|
| 28 |
+
return agent_state
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def run_paper_subagent(agent_state: DataAgentState) -> DataAgentState:
|
| 32 |
+
"""Run paper subagent to search for relevant papers, datasets, and metrics."""
|
| 33 |
+
logger.debug("run_paper_subagent of Agent data")
|
| 34 |
+
|
| 35 |
+
paper_state = PaperSearchAgentState(
|
| 36 |
+
user_query=agent_state.user_query,
|
| 37 |
+
data_summary=agent_state.data_desc,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
result_state = paper_subagent_graph_compiled.invoke(paper_state)
|
| 42 |
+
result_state = PaperSearchAgentState(**result_state)
|
| 43 |
+
|
| 44 |
+
agent_state.papers = result_state.papers
|
| 45 |
+
agent_state.datasets = result_state.datasets
|
| 46 |
+
agent_state.metrics = result_state.metrics
|
| 47 |
+
agent_state.paper_search_summary = result_state.output_summary
|
| 48 |
+
|
| 49 |
+
agent_state.intermediate_state.append(
|
| 50 |
+
{
|
| 51 |
+
"node_name": "paper_subagent",
|
| 52 |
+
"output": f"Paper subagent completed. Found {len(result_state.papers)} papers, {len(result_state.datasets)} datasets, {len(result_state.metrics)} metrics.\n\nSummary: {result_state.output_summary or 'No summary'}",
|
| 53 |
+
}
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
if result_state.output_summary:
|
| 57 |
+
agent_state.add_message(
|
| 58 |
+
Message(
|
| 59 |
+
role="assistant",
|
| 60 |
+
content=f"[Paper Search Results]\n{result_state.output_summary}",
|
| 61 |
+
agent="paper_subagent",
|
| 62 |
+
).with_log()
|
| 63 |
+
)
|
| 64 |
+
except Exception as e:
|
| 65 |
+
logger.exception("paper_subagent_error")
|
| 66 |
+
error_msg = f"Paper subagent error: {e}"
|
| 67 |
+
agent_state.add_message(
|
| 68 |
+
Message(
|
| 69 |
+
role="assistant",
|
| 70 |
+
content=error_msg,
|
| 71 |
+
agent="paper_subagent",
|
| 72 |
+
).with_log()
|
| 73 |
+
)
|
| 74 |
+
agent_state.intermediate_state.append(
|
| 75 |
+
{
|
| 76 |
+
"node_name": "paper_subagent",
|
| 77 |
+
"output": error_msg,
|
| 78 |
+
}
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
return agent_state
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def prepare_for_talk_mode(agent_state: DataAgentState) -> DataAgentState:
|
| 85 |
+
assert agent_state.talk_mode
|
| 86 |
+
agent_state.remaining_plans = ["Response to users' query."]
|
| 87 |
+
|
| 88 |
+
mem_output = "Memory consolidation skipped"
|
| 89 |
+
# consolidate mems
|
| 90 |
+
if constant.REASONING_BANK_ENABLED:
|
| 91 |
+
try:
|
| 92 |
+
mem_consolidation_subgraph_compiled.invoke(
|
| 93 |
+
mem_consolidation.MemConsolidationState(
|
| 94 |
+
mem_dir=agent_state.sess_dir / "short_term",
|
| 95 |
+
long_term_mem_dir=agent_state.long_term_mem_dir,
|
| 96 |
+
project_mem_dir=agent_state.project_mem_dir,
|
| 97 |
+
)
|
| 98 |
+
)
|
| 99 |
+
mem_output = "Memory consolidation completed"
|
| 100 |
+
except Exception as e:
|
| 101 |
+
error_msg = f"mem_consolidation_error: {e}"
|
| 102 |
+
agent_state.add_message(
|
| 103 |
+
Message(
|
| 104 |
+
role="assistant",
|
| 105 |
+
content=error_msg,
|
| 106 |
+
agent="noname",
|
| 107 |
+
).with_log()
|
| 108 |
+
)
|
| 109 |
+
mem_output = error_msg
|
| 110 |
+
|
| 111 |
+
agent_state.intermediate_state.append(
|
| 112 |
+
{
|
| 113 |
+
"node_name": "prepare_for_talk_mode",
|
| 114 |
+
"output": mem_output,
|
| 115 |
+
}
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
return agent_state
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@logger.catch
|
| 122 |
+
def build():
|
| 123 |
+
g = StateGraph(DataAgentState)
|
| 124 |
+
|
| 125 |
+
# nodes
|
| 126 |
+
g.add_node("paper_subagent", run_paper_subagent)
|
| 127 |
+
g.add_node("planner", plan.planner_node)
|
| 128 |
+
g.add_node("replanner", plan.replanner_node)
|
| 129 |
+
|
| 130 |
+
g.add_node("gateway", execute.gateway_node)
|
| 131 |
+
g.add_node("llm_chat", execute.llm_chat_node)
|
| 132 |
+
g.add_node("tool_calling", execute.tool_calling_node)
|
| 133 |
+
g.add_node("mem_extraction", execute.mem_extraction_node)
|
| 134 |
+
g.add_node("history_compression", execute.history_compression_node)
|
| 135 |
+
# g.add_node("critic", execute.critic_node) # not used for now
|
| 136 |
+
g.add_node("critic_before_replan", execute.critic_node)
|
| 137 |
+
g.add_node("finalize", finialize_node)
|
| 138 |
+
g.add_node("generate_summary", execute.generate_summary_node)
|
| 139 |
+
g.add_node("prepare_for_talk_mode", prepare_for_talk_mode)
|
| 140 |
+
|
| 141 |
+
# edges from gateway to nodes
|
| 142 |
+
g.add_edge(START, "paper_subagent")
|
| 143 |
+
g.add_edge("paper_subagent", "planner")
|
| 144 |
+
g.add_edge("planner", "gateway")
|
| 145 |
+
g.add_conditional_edges(
|
| 146 |
+
"gateway",
|
| 147 |
+
execute.gateway_conditional,
|
| 148 |
+
[
|
| 149 |
+
"llm_chat",
|
| 150 |
+
"tool_calling",
|
| 151 |
+
"mem_extraction",
|
| 152 |
+
"history_compression",
|
| 153 |
+
"critic_before_replan", # plan END
|
| 154 |
+
],
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# edges from nodes to gateway
|
| 158 |
+
g.add_edge("llm_chat", "gateway")
|
| 159 |
+
g.add_edge("tool_calling", "gateway")
|
| 160 |
+
g.add_edge("mem_extraction", "gateway")
|
| 161 |
+
g.add_edge("history_compression", "gateway")
|
| 162 |
+
|
| 163 |
+
g.add_edge("critic_before_replan", "replanner")
|
| 164 |
+
|
| 165 |
+
# edges from gateway to replanner
|
| 166 |
+
g.add_conditional_edges(
|
| 167 |
+
"replanner",
|
| 168 |
+
plan.should_replan,
|
| 169 |
+
[
|
| 170 |
+
"gateway",
|
| 171 |
+
"finalize",
|
| 172 |
+
],
|
| 173 |
+
)
|
| 174 |
+
# edges from nodes to end
|
| 175 |
+
g.add_edge("finalize", "generate_summary")
|
| 176 |
+
g.add_edge("generate_summary", "prepare_for_talk_mode")
|
| 177 |
+
g.add_edge("prepare_for_talk_mode", END)
|
| 178 |
+
return g
|
scievo/agents/data_agent/execute.py
ADDED
|
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Agent for data understanding and processing
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import TYPE_CHECKING, TypeVar
|
| 8 |
+
|
| 9 |
+
from loguru import logger
|
| 10 |
+
|
| 11 |
+
from scievo import history_compression
|
| 12 |
+
from scievo.agents import critic_agent
|
| 13 |
+
from scievo.core import constant
|
| 14 |
+
from scievo.core.errors import sprint_chained_exception
|
| 15 |
+
from scievo.core.llms import ModelRegistry
|
| 16 |
+
from scievo.core.types import HistoryState, Message, RBankState
|
| 17 |
+
from scievo.core.utils import wrap_dict_to_toon, wrap_text_with_block
|
| 18 |
+
from scievo.prompts import PROMPTS
|
| 19 |
+
from scievo.rbank.subgraph import mem_extraction, mem_retrieval
|
| 20 |
+
from scievo.tools import Tool, ToolRegistry
|
| 21 |
+
|
| 22 |
+
from .state import DataAgentState
|
| 23 |
+
|
| 24 |
+
if TYPE_CHECKING:
|
| 25 |
+
from scievo.rbank.memo import Memo
|
| 26 |
+
|
| 27 |
+
MemHistoryMixin = TypeVar("MemHistoryMixin", HistoryState, RBankState)
|
| 28 |
+
|
| 29 |
+
LLM_NAME = "data"
|
| 30 |
+
AGENT_NAME = "data"
|
| 31 |
+
|
| 32 |
+
BUILTIN_TOOLSETS = [
|
| 33 |
+
# "todo",
|
| 34 |
+
"state",
|
| 35 |
+
"history",
|
| 36 |
+
"fs",
|
| 37 |
+
]
|
| 38 |
+
ALLOWED_TOOLSETS = ["web"]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def gateway_node(agent_state: DataAgentState) -> DataAgentState:
|
| 42 |
+
# NOTE: this node does nothing, it's just a placeholder for the conditional edges
|
| 43 |
+
# Check `gateway_conditional` for the actual logic
|
| 44 |
+
logger.trace("gateway_node of Agent {}", AGENT_NAME)
|
| 45 |
+
return agent_state
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def gateway_conditional(agent_state: DataAgentState) -> str:
|
| 49 |
+
# compress history if needed
|
| 50 |
+
if (
|
| 51 |
+
constant.HISTORY_AUTO_COMPRESSION
|
| 52 |
+
and "history_compression" not in agent_state.node_history[-2:]
|
| 53 |
+
and agent_state.total_patched_tokens > constant.HISTORY_AUTO_COMPRESSION_TOKEN_THRESHOLD
|
| 54 |
+
):
|
| 55 |
+
return "history_compression"
|
| 56 |
+
|
| 57 |
+
if (
|
| 58 |
+
constant.REASONING_BANK_ENABLED
|
| 59 |
+
and len(agent_state.node_history) > 0
|
| 60 |
+
and agent_state.node_history[-1] != "mem_extraction"
|
| 61 |
+
and agent_state.round > 0
|
| 62 |
+
and agent_state.round % constant.MEM_EXTRACTION_ROUND_FREQ == 0
|
| 63 |
+
):
|
| 64 |
+
return "mem_extraction"
|
| 65 |
+
|
| 66 |
+
if len(agent_state.patched_history) == 0:
|
| 67 |
+
logger.warning("patched_history is empty, returning llm_chat")
|
| 68 |
+
return "llm_chat"
|
| 69 |
+
|
| 70 |
+
last_msg = agent_state.patched_history[-1]
|
| 71 |
+
if (tool_calls := last_msg.tool_calls) and len(tool_calls) > 0:
|
| 72 |
+
return "tool_calling"
|
| 73 |
+
|
| 74 |
+
match last_msg.role:
|
| 75 |
+
case "user" | "tool":
|
| 76 |
+
return "llm_chat"
|
| 77 |
+
case "assistant":
|
| 78 |
+
return "critic_before_replan"
|
| 79 |
+
case _:
|
| 80 |
+
raise ValueError(f"Unknown message role: {last_msg.role}")
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
mem_retrieval_subgraph = mem_retrieval.build()
|
| 84 |
+
mem_retrieval_subgraph_compiled = mem_retrieval_subgraph.compile()
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _memos_to_markdown(memos: list["Memo"]) -> str:
|
| 88 |
+
ret = ""
|
| 89 |
+
if len(memos) == 0:
|
| 90 |
+
return "No memory retrieved."
|
| 91 |
+
for i, memo in enumerate(memos):
|
| 92 |
+
ret += f"# Memo {i + 1}\n\n{memo.to_markdown()}\n\n"
|
| 93 |
+
return ret
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def llm_chat_node(agent_state: DataAgentState) -> DataAgentState:
|
| 97 |
+
logger.debug("llm_chat_node of Agent {}", AGENT_NAME)
|
| 98 |
+
agent_state.add_node_history("llm_chat")
|
| 99 |
+
|
| 100 |
+
selected_state = {
|
| 101 |
+
"workspace": agent_state.workspace.working_dir,
|
| 102 |
+
"current_activated_toolsets": agent_state.toolsets,
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
# retrieve memos
|
| 106 |
+
if constant.REASONING_BANK_ENABLED:
|
| 107 |
+
try:
|
| 108 |
+
mem_dirs = [agent_state.sess_dir / "short_term"]
|
| 109 |
+
if d := agent_state.long_term_mem_dir:
|
| 110 |
+
mem_dirs.append(d)
|
| 111 |
+
if d := agent_state.project_mem_dir:
|
| 112 |
+
mem_dirs.append(d)
|
| 113 |
+
res = mem_retrieval_subgraph_compiled.invoke(
|
| 114 |
+
mem_retrieval.MemRetrievalState(
|
| 115 |
+
input_msgs=agent_state.patched_history,
|
| 116 |
+
mem_dirs=mem_dirs,
|
| 117 |
+
max_num_memos=constant.MEM_RETRIEVAL_MAX_NUM_MEMOS,
|
| 118 |
+
)
|
| 119 |
+
)
|
| 120 |
+
memos: list[Memo] = res.get("output_memos", [])
|
| 121 |
+
memory_text = _memos_to_markdown(memos)
|
| 122 |
+
except Exception:
|
| 123 |
+
logger.exception("mem_retrieval_error")
|
| 124 |
+
memory_text = None
|
| 125 |
+
else:
|
| 126 |
+
memory_text = None
|
| 127 |
+
|
| 128 |
+
# update system prompt
|
| 129 |
+
system_prompt = PROMPTS.data.system_prompt.render(
|
| 130 |
+
state_text=wrap_dict_to_toon(selected_state),
|
| 131 |
+
toolsets_desc=ToolRegistry.get_toolsets_desc(BUILTIN_TOOLSETS + ALLOWED_TOOLSETS),
|
| 132 |
+
memory_text=wrap_text_with_block(memory_text, "markdown"),
|
| 133 |
+
current_plan=(
|
| 134 |
+
agent_state.remaining_plans[0] if len(agent_state.remaining_plans) > 0 else None
|
| 135 |
+
),
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# construct tools
|
| 139 |
+
tools: dict[str, Tool] = {}
|
| 140 |
+
for toolset in agent_state.toolsets:
|
| 141 |
+
tools.update(ToolRegistry.get_toolset(toolset))
|
| 142 |
+
for toolset in BUILTIN_TOOLSETS:
|
| 143 |
+
tools.update(ToolRegistry.get_toolset(toolset))
|
| 144 |
+
|
| 145 |
+
# Ensure there's at least one non-system message for Anthropic API
|
| 146 |
+
history = agent_state.patched_history
|
| 147 |
+
if len(history) == 0 or all(msg.role == "system" for msg in history):
|
| 148 |
+
# Add a dummy user message if history is empty or only contains system messages
|
| 149 |
+
logger.warning(
|
| 150 |
+
"patched_history is empty or only contains system messages, adding dummy user message"
|
| 151 |
+
)
|
| 152 |
+
history = [
|
| 153 |
+
Message(
|
| 154 |
+
role="user",
|
| 155 |
+
content="Please continue with the task.",
|
| 156 |
+
agent_sender=AGENT_NAME,
|
| 157 |
+
)
|
| 158 |
+
]
|
| 159 |
+
|
| 160 |
+
msg = ModelRegistry.completion(
|
| 161 |
+
LLM_NAME,
|
| 162 |
+
history,
|
| 163 |
+
system_prompt=(
|
| 164 |
+
Message(role="system", content=system_prompt)
|
| 165 |
+
.with_log(cond=constant.LOG_SYSTEM_PROMPT)
|
| 166 |
+
.content
|
| 167 |
+
),
|
| 168 |
+
agent_sender=AGENT_NAME,
|
| 169 |
+
tools=[tool.name for tool in tools.values()],
|
| 170 |
+
).with_log()
|
| 171 |
+
agent_state.add_message(msg)
|
| 172 |
+
|
| 173 |
+
llm_output = (
|
| 174 |
+
msg.content
|
| 175 |
+
if msg.content
|
| 176 |
+
else ("[Tool calls: " + str(len(msg.tool_calls)) + "]" if msg.tool_calls else "[No output]")
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
agent_state.intermediate_state.append(
|
| 180 |
+
{
|
| 181 |
+
"node_name": "llm_chat",
|
| 182 |
+
"output": llm_output,
|
| 183 |
+
}
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
return agent_state
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def tool_calling_node(agent_state: DataAgentState) -> DataAgentState:
|
| 190 |
+
"""Execute tool calls from the last message and update the graph state"""
|
| 191 |
+
logger.debug("tool_calling_node of Agent {}", AGENT_NAME)
|
| 192 |
+
agent_state.add_node_history("tool_calling")
|
| 193 |
+
# Get the last message which contains tool calls
|
| 194 |
+
last_msg = agent_state.patched_history[-1]
|
| 195 |
+
|
| 196 |
+
if not last_msg.tool_calls:
|
| 197 |
+
raise ValueError("No tool calls found in the last message")
|
| 198 |
+
|
| 199 |
+
# construct tools
|
| 200 |
+
tools: dict[str, Tool] = {}
|
| 201 |
+
for toolset in agent_state.toolsets:
|
| 202 |
+
tools.update(ToolRegistry.get_toolset(toolset))
|
| 203 |
+
for toolset in BUILTIN_TOOLSETS:
|
| 204 |
+
tools.update(ToolRegistry.get_toolset(toolset))
|
| 205 |
+
|
| 206 |
+
function_map = {tool.name: tool.func for tool in tools.values()}
|
| 207 |
+
|
| 208 |
+
tool_results = []
|
| 209 |
+
|
| 210 |
+
# Execute each tool call
|
| 211 |
+
for tool_call in last_msg.tool_calls:
|
| 212 |
+
tool_name = tool_call.function.name
|
| 213 |
+
|
| 214 |
+
# Check if tool exists in function map
|
| 215 |
+
if tool_name not in function_map:
|
| 216 |
+
error_msg = f"Tool {tool_name} not found"
|
| 217 |
+
tool_response = {
|
| 218 |
+
"role": "tool",
|
| 219 |
+
"tool_name": tool_name,
|
| 220 |
+
"tool_call_id": tool_call.id,
|
| 221 |
+
"content": error_msg,
|
| 222 |
+
}
|
| 223 |
+
agent_state.add_message(Message(**tool_response).with_log())
|
| 224 |
+
tool_results.append({"tool": tool_name, "result": error_msg})
|
| 225 |
+
continue
|
| 226 |
+
|
| 227 |
+
# Parse tool arguments
|
| 228 |
+
try:
|
| 229 |
+
args = json.loads(tool_call.function.arguments)
|
| 230 |
+
assert isinstance(args, dict)
|
| 231 |
+
except json.JSONDecodeError as e:
|
| 232 |
+
error_msg = f"Invalid JSON in tool arguments: {e}"
|
| 233 |
+
tool_response = {
|
| 234 |
+
"role": "tool",
|
| 235 |
+
"tool_name": tool_name,
|
| 236 |
+
"tool_call_id": tool_call.id,
|
| 237 |
+
"content": error_msg,
|
| 238 |
+
}
|
| 239 |
+
agent_state.add_message(Message(**tool_response).with_log())
|
| 240 |
+
tool_results.append({"tool": tool_name, "result": error_msg})
|
| 241 |
+
continue
|
| 242 |
+
except AssertionError as e:
|
| 243 |
+
error_msg = f"Invalid tool arguments: {e}"
|
| 244 |
+
tool_response = {
|
| 245 |
+
"role": "tool",
|
| 246 |
+
"tool_name": tool_name,
|
| 247 |
+
"tool_call_id": tool_call.id,
|
| 248 |
+
"content": error_msg,
|
| 249 |
+
}
|
| 250 |
+
agent_state.add_message(Message(**tool_response).with_log())
|
| 251 |
+
tool_results.append({"tool": tool_name, "result": error_msg})
|
| 252 |
+
continue
|
| 253 |
+
|
| 254 |
+
# Execute the tool
|
| 255 |
+
result = None
|
| 256 |
+
try:
|
| 257 |
+
# Pass the graph state to the tool function
|
| 258 |
+
func = function_map[tool_name]
|
| 259 |
+
|
| 260 |
+
# Check if function expects agent_state parameter
|
| 261 |
+
import inspect
|
| 262 |
+
|
| 263 |
+
sig = inspect.signature(func)
|
| 264 |
+
if constant.__AGENT_STATE_NAME__ in sig.parameters:
|
| 265 |
+
args.update({constant.__AGENT_STATE_NAME__: agent_state})
|
| 266 |
+
if constant.__CTX_NAME__ in sig.parameters:
|
| 267 |
+
args.update({constant.__CTX_NAME__: {"current_agent": AGENT_NAME}})
|
| 268 |
+
|
| 269 |
+
# Execute the tool in the agent's local environment
|
| 270 |
+
with agent_state.workspace:
|
| 271 |
+
result = func(**args)
|
| 272 |
+
|
| 273 |
+
# Create tool response message
|
| 274 |
+
tool_response = {
|
| 275 |
+
"role": "tool",
|
| 276 |
+
"tool_call_id": tool_call.id,
|
| 277 |
+
"tool_name": tool_name,
|
| 278 |
+
"content": str(result), # Ensure result is string
|
| 279 |
+
}
|
| 280 |
+
tool_results.append(
|
| 281 |
+
{"tool": tool_name, "result": str(result)[:1000] if result else "No result"}
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
except Exception as e:
|
| 285 |
+
error_msg = f"Tool {tool_name} execution failed: {e}"
|
| 286 |
+
tool_response = {
|
| 287 |
+
"role": "tool",
|
| 288 |
+
"tool_call_id": tool_call.id,
|
| 289 |
+
"tool_name": tool_name,
|
| 290 |
+
"content": error_msg,
|
| 291 |
+
}
|
| 292 |
+
tool_results.append({"tool": tool_name, "result": error_msg})
|
| 293 |
+
|
| 294 |
+
tool_response_msg = Message(**tool_response).with_log()
|
| 295 |
+
agent_state.add_message(tool_response_msg)
|
| 296 |
+
|
| 297 |
+
tool_output_parts = []
|
| 298 |
+
for tr in tool_results:
|
| 299 |
+
tool_output_parts.append(f"Tool: {tr['tool']}\nResult: {tr['result']}")
|
| 300 |
+
|
| 301 |
+
tool_output = "\n\n".join(tool_output_parts) if tool_output_parts else "No tool calls executed"
|
| 302 |
+
|
| 303 |
+
agent_state.intermediate_state.append(
|
| 304 |
+
{
|
| 305 |
+
"node_name": "tool_calling",
|
| 306 |
+
"output": tool_output,
|
| 307 |
+
}
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
return agent_state
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
mem_extraction_subgraph = mem_extraction.build()
|
| 314 |
+
mem_extraction_subgraph_compiled = mem_extraction_subgraph.compile()
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def mem_extraction_node(agent_state: MemHistoryMixin) -> MemHistoryMixin:
|
| 318 |
+
logger.debug("mem_extraction_node of Agent {}", AGENT_NAME)
|
| 319 |
+
agent_state.add_node_history("mem_extraction")
|
| 320 |
+
context_window = agent_state.patched_history[-constant.MEM_EXTRACTION_CONTEXT_WINDOW :]
|
| 321 |
+
logger.info("Agent {} begins to Memory Extraction", AGENT_NAME)
|
| 322 |
+
mem_output = "Memory extraction completed"
|
| 323 |
+
try:
|
| 324 |
+
result = mem_extraction_subgraph_compiled.invoke(
|
| 325 |
+
mem_extraction.MemExtractionState(
|
| 326 |
+
mem_dir=Path(agent_state.sess_dir) / f"short_term",
|
| 327 |
+
input_msgs=context_window,
|
| 328 |
+
input_agent_name=AGENT_NAME,
|
| 329 |
+
)
|
| 330 |
+
)
|
| 331 |
+
if isinstance(result, dict) and "output_memos" in result:
|
| 332 |
+
mem_output = f"Extracted {len(result.get('output_memos', []))} memory entries"
|
| 333 |
+
except Exception as e:
|
| 334 |
+
error_msg = f"mem_extraction_error: {sprint_chained_exception(e)}"
|
| 335 |
+
agent_state.add_message(
|
| 336 |
+
Message(
|
| 337 |
+
role="assistant",
|
| 338 |
+
content=error_msg,
|
| 339 |
+
agent_sender=AGENT_NAME,
|
| 340 |
+
).with_log()
|
| 341 |
+
)
|
| 342 |
+
mem_output = error_msg
|
| 343 |
+
|
| 344 |
+
if isinstance(agent_state, DataAgentState):
|
| 345 |
+
agent_state.intermediate_state.append(
|
| 346 |
+
{
|
| 347 |
+
"node_name": "mem_extraction",
|
| 348 |
+
"output": mem_output,
|
| 349 |
+
}
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
return agent_state
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def history_compression_node(agent_state: DataAgentState) -> DataAgentState:
|
| 356 |
+
logger.debug("history_compression_node of Agent {}", AGENT_NAME)
|
| 357 |
+
|
| 358 |
+
history_before = len(agent_state.history)
|
| 359 |
+
agent_state = history_compression.invoke_history_compression(agent_state)
|
| 360 |
+
history_after = len(agent_state.history)
|
| 361 |
+
|
| 362 |
+
compression_output = f"Compressed history: {history_before} -> {history_after} messages"
|
| 363 |
+
if agent_state.history_patches:
|
| 364 |
+
last_patch = agent_state.history_patches[-1]
|
| 365 |
+
if last_patch.patched_message and last_patch.patched_message.content:
|
| 366 |
+
compression_output = f"Compressed {last_patch.n_messages} messages into:\n{last_patch.patched_message.content[:500]}"
|
| 367 |
+
|
| 368 |
+
agent_state.intermediate_state.append(
|
| 369 |
+
{
|
| 370 |
+
"node_name": "history_compression",
|
| 371 |
+
"output": compression_output,
|
| 372 |
+
}
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
return agent_state
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def generate_summary_node(agent_state: DataAgentState) -> DataAgentState:
|
| 379 |
+
"""Generate analysis summary and store it in agent state"""
|
| 380 |
+
logger.debug("generate_summary_node of Agent {}", AGENT_NAME)
|
| 381 |
+
agent_state.add_node_history("generate_summary")
|
| 382 |
+
|
| 383 |
+
try:
|
| 384 |
+
# Construct a summary request message
|
| 385 |
+
summary_system_prompt = PROMPTS.data.summary_system_prompt
|
| 386 |
+
summary_user_prompt = PROMPTS.data.summary_user_prompt
|
| 387 |
+
|
| 388 |
+
agent_state.add_message(
|
| 389 |
+
Message(
|
| 390 |
+
role="user",
|
| 391 |
+
content=summary_user_prompt.render(),
|
| 392 |
+
).with_log(cond=constant.LOG_SYSTEM_PROMPT)
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
# Call LLM to generate summary
|
| 396 |
+
summary_msg = ModelRegistry.completion(
|
| 397 |
+
LLM_NAME,
|
| 398 |
+
agent_state.patched_history,
|
| 399 |
+
system_prompt=summary_system_prompt.render(),
|
| 400 |
+
agent_sender=AGENT_NAME,
|
| 401 |
+
).with_log()
|
| 402 |
+
|
| 403 |
+
agent_state.add_message(summary_msg)
|
| 404 |
+
|
| 405 |
+
# Extract summary content
|
| 406 |
+
if summary_msg.role != "assistant" or not summary_msg.content:
|
| 407 |
+
raise ValueError("Failed to get summary from LLM")
|
| 408 |
+
|
| 409 |
+
# Store summary in state
|
| 410 |
+
agent_state.output_summary = summary_msg.content
|
| 411 |
+
logger.info("Analysis summary generated successfully")
|
| 412 |
+
|
| 413 |
+
except Exception as e:
|
| 414 |
+
error_msg = f"Failed to generate analysis summary: {sprint_chained_exception(e)}"
|
| 415 |
+
agent_state.add_message(
|
| 416 |
+
Message(
|
| 417 |
+
role="assistant",
|
| 418 |
+
content=error_msg,
|
| 419 |
+
agent_sender=AGENT_NAME,
|
| 420 |
+
).with_log()
|
| 421 |
+
)
|
| 422 |
+
logger.error("generate_summary_node failed: {}", error_msg)
|
| 423 |
+
|
| 424 |
+
summary_output = (
|
| 425 |
+
summary_msg.content
|
| 426 |
+
if "summary_msg" in locals() and summary_msg.content
|
| 427 |
+
else (error_msg if "error_msg" in locals() else "No summary generated")
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
agent_state.intermediate_state.append(
|
| 431 |
+
{
|
| 432 |
+
"node_name": "generate_summary",
|
| 433 |
+
"output": summary_output,
|
| 434 |
+
}
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
return agent_state
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
critic_subgraph = critic_agent.build()
|
| 441 |
+
critic_subgraph_compiled = critic_subgraph.compile()
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def critic_node(agent_state: DataAgentState) -> DataAgentState:
|
| 445 |
+
logger.trace("critic_node of Agent {}", AGENT_NAME)
|
| 446 |
+
|
| 447 |
+
if not constant.CRITIC_ENABLED:
|
| 448 |
+
return agent_state
|
| 449 |
+
|
| 450 |
+
try:
|
| 451 |
+
current_plan = (
|
| 452 |
+
agent_state.remaining_plans[0] if len(agent_state.remaining_plans) > 0 else "N/A"
|
| 453 |
+
)
|
| 454 |
+
res = critic_subgraph_compiled.invoke(
|
| 455 |
+
critic_agent.CriticAgentState(
|
| 456 |
+
input_msgs=agent_state.patched_history[-constant.CRITIC_CONTEXT_WINDOW :],
|
| 457 |
+
plan=agent_state.remaining_plans[0],
|
| 458 |
+
is_data_agent=True,
|
| 459 |
+
# RBankState
|
| 460 |
+
sess_dir=agent_state.sess_dir,
|
| 461 |
+
long_term_mem_dir=agent_state.long_term_mem_dir,
|
| 462 |
+
project_mem_dir=agent_state.project_mem_dir,
|
| 463 |
+
)
|
| 464 |
+
)
|
| 465 |
+
assert res.get("critic_msg", None) is not None, "critic_msg is None"
|
| 466 |
+
critic_msg: Message = res.get("critic_msg")
|
| 467 |
+
agent_state.add_message(critic_msg.with_log())
|
| 468 |
+
critic_output = critic_msg.content if critic_msg.content else "No critic feedback"
|
| 469 |
+
except Exception as e:
|
| 470 |
+
error_msg = f"critic_error: {sprint_chained_exception(e)}"
|
| 471 |
+
agent_state.add_message(
|
| 472 |
+
Message(
|
| 473 |
+
role="assistant",
|
| 474 |
+
content=error_msg,
|
| 475 |
+
agent_sender=AGENT_NAME,
|
| 476 |
+
).with_log()
|
| 477 |
+
)
|
| 478 |
+
critic_output = error_msg
|
| 479 |
+
|
| 480 |
+
agent_state.intermediate_state.append(
|
| 481 |
+
{
|
| 482 |
+
"node_name": "critic",
|
| 483 |
+
"output": critic_output if "critic_output" in locals() else "No critic output",
|
| 484 |
+
}
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
return agent_state
|
scievo/agents/data_agent/paper_subagent/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Paper Search Subagent
|
| 3 |
+
|
| 4 |
+
A minimal agent for searching academic papers using arxiv_tool.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from .build import build
|
| 8 |
+
from .state import PaperSearchAgentState
|
| 9 |
+
|
| 10 |
+
__all__ = ["build", "PaperSearchAgentState"]
|
scievo/agents/data_agent/paper_subagent/build.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langgraph.graph import END, START, StateGraph
|
| 2 |
+
from loguru import logger
|
| 3 |
+
|
| 4 |
+
from . import execute
|
| 5 |
+
from .state import PaperSearchAgentState
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@logger.catch
|
| 9 |
+
def build():
|
| 10 |
+
"""Build paper search agent graph with iterative query refinement.
|
| 11 |
+
|
| 12 |
+
Flow:
|
| 13 |
+
START -> optimize_query -> search -> check_results ->
|
| 14 |
+
(if insufficient results) -> optimize_query -> search -> check_results -> ...
|
| 15 |
+
(if sufficient results) -> dataset -> metric -> summary -> END
|
| 16 |
+
"""
|
| 17 |
+
g = StateGraph(PaperSearchAgentState)
|
| 18 |
+
|
| 19 |
+
# Nodes
|
| 20 |
+
g.add_node("optimize_query", execute.optimize_query_node)
|
| 21 |
+
g.add_node("search", execute.search_node)
|
| 22 |
+
g.add_node("check_results", execute.check_results_node)
|
| 23 |
+
g.add_node("dataset", execute.dataset_node)
|
| 24 |
+
g.add_node("metric", execute.metric_node)
|
| 25 |
+
g.add_node("summary", execute.summary_node)
|
| 26 |
+
|
| 27 |
+
# Flow with iteration support
|
| 28 |
+
g.add_edge(START, "optimize_query")
|
| 29 |
+
g.add_edge("optimize_query", "search")
|
| 30 |
+
g.add_edge("search", "check_results")
|
| 31 |
+
|
| 32 |
+
# Conditional edge: continue searching or proceed
|
| 33 |
+
g.add_conditional_edges(
|
| 34 |
+
"check_results",
|
| 35 |
+
execute.should_continue_search,
|
| 36 |
+
{
|
| 37 |
+
"continue_search": "optimize_query", # Iterate: optimize query and search again
|
| 38 |
+
"proceed": "dataset", # Proceed with current results
|
| 39 |
+
},
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# Continue with dataset, metric, and summary
|
| 43 |
+
g.add_edge("dataset", "metric")
|
| 44 |
+
g.add_edge("metric", "summary")
|
| 45 |
+
g.add_edge("summary", END)
|
| 46 |
+
|
| 47 |
+
return g
|
scievo/agents/data_agent/paper_subagent/execute.py
ADDED
|
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Execution nodes for the Paper Search Agent
|
| 3 |
+
|
| 4 |
+
This module provides a minimal execution flow that searches for papers, datasets,
|
| 5 |
+
extracts metrics, and generates a summary.
|
| 6 |
+
Flow: START -> search_node -> dataset_node -> metric_node -> summary_node -> END
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from loguru import logger
|
| 10 |
+
|
| 11 |
+
from scievo.core.llms import ModelRegistry
|
| 12 |
+
from scievo.core.types import Message
|
| 13 |
+
from scievo.core.utils import unwrap_dict_from_toon
|
| 14 |
+
from scievo.prompts.prompt_data import PROMPTS
|
| 15 |
+
from scievo.tools.arxiv_tool import search_papers
|
| 16 |
+
from scievo.tools.dataset_search_tool import search_datasets
|
| 17 |
+
from scievo.tools.metric_search_tool import extract_metrics_from_papers
|
| 18 |
+
|
| 19 |
+
from .state import PaperSearchAgentState
|
| 20 |
+
|
| 21 |
+
LLM_NAME = "paper_search"
|
| 22 |
+
AGENT_NAME = "paper_search"
|
| 23 |
+
|
| 24 |
+
# Minimum thresholds for considering search successful
|
| 25 |
+
MIN_PAPERS_THRESHOLD = 3
|
| 26 |
+
MIN_DATASETS_THRESHOLD = 2
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def optimize_query_node(agent_state: PaperSearchAgentState) -> PaperSearchAgentState:
|
| 30 |
+
"""Optimize the search query using LLM to improve search results."""
|
| 31 |
+
logger.debug("optimize_query_node of Agent {}", AGENT_NAME)
|
| 32 |
+
agent_state.add_node_history("optimize_query")
|
| 33 |
+
|
| 34 |
+
# Initialize current_query if not set
|
| 35 |
+
if agent_state.current_query is None:
|
| 36 |
+
agent_state.current_query = agent_state.user_query
|
| 37 |
+
agent_state.query_history = [agent_state.user_query]
|
| 38 |
+
|
| 39 |
+
# If we've already tried multiple queries, use the best one or stop
|
| 40 |
+
if agent_state.search_iteration >= agent_state.max_search_iterations:
|
| 41 |
+
logger.info("Reached max iterations, using current query")
|
| 42 |
+
return agent_state
|
| 43 |
+
|
| 44 |
+
# Build optimization prompt
|
| 45 |
+
previous_results = ""
|
| 46 |
+
if agent_state.search_iteration > 0:
|
| 47 |
+
previous_results = f"""
|
| 48 |
+
Previous search results:
|
| 49 |
+
- Papers found: {len(agent_state.papers)}
|
| 50 |
+
- Datasets found: {len(agent_state.datasets)}
|
| 51 |
+
- Previous queries tried: {', '.join(agent_state.query_history[-3:])}
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
optimization_prompt = f"""You are a research assistant helping to optimize academic paper search queries.
|
| 55 |
+
|
| 56 |
+
Original user query: "{agent_state.user_query}"
|
| 57 |
+
{previous_results}
|
| 58 |
+
|
| 59 |
+
Your task is to generate an improved search query that is more likely to find relevant academic papers on arXiv.
|
| 60 |
+
|
| 61 |
+
Guidelines:
|
| 62 |
+
1. If previous search found few/no results, make the query MORE GENERAL (remove specific details, use broader terms)
|
| 63 |
+
2. If previous search found too many irrelevant results, make the query MORE SPECIFIC (add key terms, use domain-specific vocabulary)
|
| 64 |
+
3. Use standard academic terminology and keywords
|
| 65 |
+
4. Keep the query concise (2-5 key terms)
|
| 66 |
+
5. Consider synonyms and related terms
|
| 67 |
+
6. Focus on the core research topic, not implementation details
|
| 68 |
+
|
| 69 |
+
Generate ONLY the optimized search query (no explanation, just the query text):"""
|
| 70 |
+
|
| 71 |
+
try:
|
| 72 |
+
msg = ModelRegistry.completion(
|
| 73 |
+
LLM_NAME,
|
| 74 |
+
[Message(role="user", content=optimization_prompt)],
|
| 75 |
+
system_prompt="You are an expert at crafting effective academic search queries. Return only the optimized query text.",
|
| 76 |
+
agent_sender=AGENT_NAME,
|
| 77 |
+
tools=None,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
optimized_query = msg.content.strip()
|
| 81 |
+
# Remove quotes if present
|
| 82 |
+
optimized_query = optimized_query.strip('"').strip("'").strip()
|
| 83 |
+
|
| 84 |
+
if optimized_query and optimized_query != agent_state.current_query:
|
| 85 |
+
agent_state.current_query = optimized_query
|
| 86 |
+
agent_state.query_history.append(optimized_query)
|
| 87 |
+
logger.info(
|
| 88 |
+
f"Optimized query (iteration {agent_state.search_iteration + 1}): {optimized_query}"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
agent_state.add_message(
|
| 92 |
+
Message(
|
| 93 |
+
role="assistant",
|
| 94 |
+
content=f"[Query Optimization] Optimized search query: '{optimized_query}'",
|
| 95 |
+
agent_sender=AGENT_NAME,
|
| 96 |
+
).with_log()
|
| 97 |
+
)
|
| 98 |
+
else:
|
| 99 |
+
logger.info("Query optimization did not produce a new query, using current query")
|
| 100 |
+
|
| 101 |
+
except Exception as e:
|
| 102 |
+
logger.exception("Query optimization error")
|
| 103 |
+
# Continue with current query if optimization fails
|
| 104 |
+
if not agent_state.current_query:
|
| 105 |
+
agent_state.current_query = agent_state.user_query
|
| 106 |
+
|
| 107 |
+
return agent_state
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def check_results_node(agent_state: PaperSearchAgentState) -> PaperSearchAgentState:
|
| 111 |
+
"""Check if paper search results are sufficient, decide whether to iterate."""
|
| 112 |
+
logger.debug("check_results_node of Agent {}", AGENT_NAME)
|
| 113 |
+
agent_state.add_node_history("check_results")
|
| 114 |
+
|
| 115 |
+
papers_count = len(agent_state.papers)
|
| 116 |
+
|
| 117 |
+
# Check if we have sufficient papers
|
| 118 |
+
has_sufficient_papers = papers_count >= MIN_PAPERS_THRESHOLD
|
| 119 |
+
|
| 120 |
+
# Decision: continue if we don't have enough papers and haven't exceeded max iterations
|
| 121 |
+
should_continue = (
|
| 122 |
+
not has_sufficient_papers
|
| 123 |
+
and agent_state.search_iteration < agent_state.max_search_iterations
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
logger.info(
|
| 127 |
+
f"Results check: {papers_count} papers found. "
|
| 128 |
+
f"Sufficient: {has_sufficient_papers} (threshold: {MIN_PAPERS_THRESHOLD}). "
|
| 129 |
+
f"Should continue: {should_continue} (iteration {agent_state.search_iteration}/{agent_state.max_search_iterations})"
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Store decision in state (we'll use this in conditional edge)
|
| 133 |
+
agent_state.add_message(
|
| 134 |
+
Message(
|
| 135 |
+
role="assistant",
|
| 136 |
+
content=f"[Results Check] Found {papers_count} papers. "
|
| 137 |
+
f"{'Continuing search iteration' if should_continue else 'Proceeding with current results'}.",
|
| 138 |
+
agent_sender=AGENT_NAME,
|
| 139 |
+
).with_log()
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
return agent_state
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def should_continue_search(agent_state: PaperSearchAgentState) -> str:
|
| 146 |
+
"""Conditional function to decide whether to continue searching or proceed.
|
| 147 |
+
|
| 148 |
+
Only iterates on paper search. Dataset search happens once after paper search is done.
|
| 149 |
+
"""
|
| 150 |
+
papers_count = len(agent_state.papers)
|
| 151 |
+
|
| 152 |
+
has_sufficient_papers = papers_count >= MIN_PAPERS_THRESHOLD
|
| 153 |
+
|
| 154 |
+
should_continue = (
|
| 155 |
+
not has_sufficient_papers
|
| 156 |
+
and agent_state.search_iteration < agent_state.max_search_iterations
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
return "continue_search" if should_continue else "proceed"
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def search_node(agent_state: PaperSearchAgentState) -> PaperSearchAgentState:
|
| 163 |
+
"""Execute paper search using the search_papers tool."""
|
| 164 |
+
logger.debug("search_node of Agent {}", AGENT_NAME)
|
| 165 |
+
agent_state.add_node_history("search")
|
| 166 |
+
|
| 167 |
+
# Increment iteration count
|
| 168 |
+
agent_state.search_iteration += 1
|
| 169 |
+
|
| 170 |
+
# Use current_query if available, otherwise use user_query
|
| 171 |
+
query_to_use = agent_state.current_query or agent_state.user_query
|
| 172 |
+
|
| 173 |
+
try:
|
| 174 |
+
# Call the search_papers tool directly
|
| 175 |
+
# Use only arxiv by default to avoid rate limiting issues with Semantic Scholar
|
| 176 |
+
# Semantic Scholar has strict rate limits (429 errors)
|
| 177 |
+
result = search_papers(
|
| 178 |
+
query=query_to_use,
|
| 179 |
+
sources=["arxiv"], # Use arxiv only to avoid rate limiting
|
| 180 |
+
max_results=10,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# Parse the result (tool returns TOON format)
|
| 184 |
+
try:
|
| 185 |
+
papers = unwrap_dict_from_toon(result)
|
| 186 |
+
if isinstance(papers, list):
|
| 187 |
+
agent_state.papers = papers
|
| 188 |
+
else:
|
| 189 |
+
logger.warning("Unexpected result format from search_papers")
|
| 190 |
+
agent_state.papers = []
|
| 191 |
+
except Exception as parse_error:
|
| 192 |
+
logger.warning("Failed to parse search results: {}", parse_error)
|
| 193 |
+
agent_state.papers = []
|
| 194 |
+
|
| 195 |
+
logger.info("Found {} papers", len(agent_state.papers))
|
| 196 |
+
|
| 197 |
+
# Add search results to history
|
| 198 |
+
agent_state.add_message(
|
| 199 |
+
Message(
|
| 200 |
+
role="assistant",
|
| 201 |
+
content=f"[Search Results] Found {len(agent_state.papers)} papers for query: '{query_to_use}' (iteration {agent_state.search_iteration})",
|
| 202 |
+
agent_sender=AGENT_NAME,
|
| 203 |
+
).with_log()
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
except Exception as e:
|
| 207 |
+
logger.exception("Paper search error")
|
| 208 |
+
agent_state.add_message(
|
| 209 |
+
Message(
|
| 210 |
+
role="assistant",
|
| 211 |
+
content=f"[Search Error] {str(e)}",
|
| 212 |
+
agent_sender=AGENT_NAME,
|
| 213 |
+
).with_log()
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
return agent_state
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def dataset_node(agent_state: PaperSearchAgentState) -> PaperSearchAgentState:
|
| 220 |
+
"""Execute dataset search using the search_datasets tool."""
|
| 221 |
+
logger.debug("dataset_node of Agent {}", AGENT_NAME)
|
| 222 |
+
agent_state.add_node_history("dataset")
|
| 223 |
+
|
| 224 |
+
try:
|
| 225 |
+
# Call the search_datasets tool directly
|
| 226 |
+
# Use current_query if available, otherwise use user_query
|
| 227 |
+
query_to_use = agent_state.current_query or agent_state.user_query
|
| 228 |
+
|
| 229 |
+
# Pass data_summary if available to search for similar datasets
|
| 230 |
+
result = search_datasets(
|
| 231 |
+
query=query_to_use,
|
| 232 |
+
sources=["paperswithcode", "huggingface"], # Default sources
|
| 233 |
+
max_results=10,
|
| 234 |
+
data_summary=agent_state.data_summary, # Pass data analysis summary
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# Parse the result (tool returns TOON format)
|
| 238 |
+
try:
|
| 239 |
+
datasets = unwrap_dict_from_toon(result)
|
| 240 |
+
if isinstance(datasets, list):
|
| 241 |
+
agent_state.datasets = datasets
|
| 242 |
+
else:
|
| 243 |
+
logger.warning("Unexpected result format from search_datasets")
|
| 244 |
+
agent_state.datasets = []
|
| 245 |
+
except Exception as parse_error:
|
| 246 |
+
logger.warning("Failed to parse dataset search results: {}", parse_error)
|
| 247 |
+
agent_state.datasets = []
|
| 248 |
+
|
| 249 |
+
logger.info("Found {} datasets", len(agent_state.datasets))
|
| 250 |
+
|
| 251 |
+
# Add search results to history
|
| 252 |
+
agent_state.add_message(
|
| 253 |
+
Message(
|
| 254 |
+
role="assistant",
|
| 255 |
+
content=f"[Dataset Search Results] Found {len(agent_state.datasets)} datasets for query: '{agent_state.user_query}'",
|
| 256 |
+
agent_sender=AGENT_NAME,
|
| 257 |
+
).with_log()
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
except Exception as e:
|
| 261 |
+
logger.exception("Dataset search error")
|
| 262 |
+
agent_state.add_message(
|
| 263 |
+
Message(
|
| 264 |
+
role="assistant",
|
| 265 |
+
content=f"[Dataset Search Error] {str(e)}",
|
| 266 |
+
agent_sender=AGENT_NAME,
|
| 267 |
+
).with_log()
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
return agent_state
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def metric_node(agent_state: PaperSearchAgentState) -> PaperSearchAgentState:
|
| 274 |
+
"""Extract evaluation metrics from the searched papers."""
|
| 275 |
+
logger.debug("metric_node of Agent {}", AGENT_NAME)
|
| 276 |
+
agent_state.add_node_history("metric")
|
| 277 |
+
|
| 278 |
+
try:
|
| 279 |
+
# Extract metrics even if we don't have papers (fallback to common metrics)
|
| 280 |
+
if not agent_state.papers:
|
| 281 |
+
logger.info("No papers available for metric extraction, using fallback")
|
| 282 |
+
# Still call the tool - it has fallback logic to suggest common metrics
|
| 283 |
+
result = extract_metrics_from_papers(
|
| 284 |
+
papers=[], # Empty list triggers fallback
|
| 285 |
+
task_query=agent_state.user_query,
|
| 286 |
+
max_results=20,
|
| 287 |
+
)
|
| 288 |
+
else:
|
| 289 |
+
# Call the extract_metrics_from_papers tool with actual papers
|
| 290 |
+
result = extract_metrics_from_papers(
|
| 291 |
+
papers=agent_state.papers,
|
| 292 |
+
task_query=agent_state.user_query,
|
| 293 |
+
max_results=20,
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
# Parse the result (tool returns TOON format)
|
| 297 |
+
try:
|
| 298 |
+
metrics = unwrap_dict_from_toon(result)
|
| 299 |
+
if isinstance(metrics, list):
|
| 300 |
+
agent_state.metrics = metrics
|
| 301 |
+
else:
|
| 302 |
+
logger.warning("Unexpected result format from extract_metrics_from_papers")
|
| 303 |
+
agent_state.metrics = []
|
| 304 |
+
except Exception as parse_error:
|
| 305 |
+
logger.warning("Failed to parse metric extraction results: {}", parse_error)
|
| 306 |
+
agent_state.metrics = []
|
| 307 |
+
|
| 308 |
+
logger.info("Extracted {} metrics", len(agent_state.metrics))
|
| 309 |
+
|
| 310 |
+
# Add extraction results to history
|
| 311 |
+
agent_state.add_message(
|
| 312 |
+
Message(
|
| 313 |
+
role="assistant",
|
| 314 |
+
content=f"[Metric Extraction Results] Extracted {len(agent_state.metrics)} evaluation metrics from {len(agent_state.papers)} papers.",
|
| 315 |
+
agent_sender=AGENT_NAME,
|
| 316 |
+
).with_log()
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
except Exception as e:
|
| 320 |
+
logger.exception("Metric extraction error")
|
| 321 |
+
agent_state.add_message(
|
| 322 |
+
Message(
|
| 323 |
+
role="assistant",
|
| 324 |
+
content=f"[Metric Extraction Error] {str(e)}",
|
| 325 |
+
agent_sender=AGENT_NAME,
|
| 326 |
+
).with_log()
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
return agent_state
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def summary_node(agent_state: PaperSearchAgentState) -> PaperSearchAgentState:
|
| 333 |
+
"""Generate summary of search results."""
|
| 334 |
+
logger.debug("summary_node of Agent {}", AGENT_NAME)
|
| 335 |
+
agent_state.add_node_history("summary")
|
| 336 |
+
|
| 337 |
+
# Build summary prompt with paper, dataset, and metric details
|
| 338 |
+
if not agent_state.papers and not agent_state.datasets and not agent_state.metrics:
|
| 339 |
+
agent_state.output_summary = (
|
| 340 |
+
f"No papers, datasets, or metrics found for query: '{agent_state.user_query}'"
|
| 341 |
+
)
|
| 342 |
+
agent_state.add_message(
|
| 343 |
+
Message(
|
| 344 |
+
role="assistant",
|
| 345 |
+
content=agent_state.output_summary,
|
| 346 |
+
agent_sender=AGENT_NAME,
|
| 347 |
+
).with_log()
|
| 348 |
+
)
|
| 349 |
+
return agent_state
|
| 350 |
+
|
| 351 |
+
# Format papers for summary
|
| 352 |
+
papers_text = ""
|
| 353 |
+
if agent_state.papers:
|
| 354 |
+
papers_text = "\n\n".join(
|
| 355 |
+
[
|
| 356 |
+
f"**{i+1}. {p.get('title', 'N/A')}**\n"
|
| 357 |
+
f"- Authors: {', '.join(p.get('authors', [])[:5])}{'...' if len(p.get('authors', [])) > 5 else ''}\n"
|
| 358 |
+
f"- Published: {p.get('published', 'N/A')}\n"
|
| 359 |
+
f"- Source: {p.get('source', 'N/A')}\n"
|
| 360 |
+
f"- Summary: {p.get('summary', 'N/A')[:300]}...\n"
|
| 361 |
+
f"- URL: {p.get('url', 'N/A')}"
|
| 362 |
+
for i, p in enumerate(agent_state.papers[:10])
|
| 363 |
+
]
|
| 364 |
+
)
|
| 365 |
+
else:
|
| 366 |
+
papers_text = "No papers found."
|
| 367 |
+
|
| 368 |
+
# Format datasets for summary (more detailed)
|
| 369 |
+
datasets_text = ""
|
| 370 |
+
if agent_state.datasets:
|
| 371 |
+
datasets_text = "\n\n".join(
|
| 372 |
+
[
|
| 373 |
+
f"**Dataset {i+1}: {d.get('name', 'N/A')}**\n"
|
| 374 |
+
f"- **Source**: {d.get('source', 'N/A')}\n"
|
| 375 |
+
f"- **Description**: {d.get('description', 'N/A')[:500]}{'...' if len(d.get('description', '')) > 500 else ''}\n"
|
| 376 |
+
f"- **Domain**: {d.get('domain', 'N/A')}\n"
|
| 377 |
+
f"- **Size**: {d.get('size', 'N/A')}\n"
|
| 378 |
+
f"- **URL**: {d.get('url', 'N/A')}\n"
|
| 379 |
+
f"- **Download URL**: {d.get('download_url', 'N/A') if d.get('download_url') else 'N/A'}\n"
|
| 380 |
+
f"- **License**: {d.get('license', 'N/A') if d.get('license') else 'Not specified'}\n"
|
| 381 |
+
f"- **Paper URL**: {d.get('paper_url', 'N/A') if d.get('paper_url') else 'N/A'}"
|
| 382 |
+
for i, d in enumerate(agent_state.datasets[:15]) # Show more datasets
|
| 383 |
+
]
|
| 384 |
+
)
|
| 385 |
+
else:
|
| 386 |
+
datasets_text = "No datasets found."
|
| 387 |
+
|
| 388 |
+
# Format metrics for summary (more detailed with formulas)
|
| 389 |
+
metrics_text = ""
|
| 390 |
+
if agent_state.metrics:
|
| 391 |
+
metrics_text = "\n\n".join(
|
| 392 |
+
[
|
| 393 |
+
f"**Metric {i+1}: {m.get('name', 'N/A')}**\n"
|
| 394 |
+
f"- **Description**: {m.get('description', 'N/A')}\n"
|
| 395 |
+
f"- **Domain**: {m.get('domain', 'N/A')}\n"
|
| 396 |
+
f"- **Source Paper**: {m.get('paper_title', 'N/A')}\n"
|
| 397 |
+
f"- **Paper URL**: {m.get('paper_url', 'N/A') if m.get('paper_url') else 'N/A'}\n"
|
| 398 |
+
f"- **Reported Value**: {m.get('value', 'N/A') if m.get('value') else 'Not specified'}\n"
|
| 399 |
+
f"- **Formula**: {m.get('formula', 'N/A') if m.get('formula') else 'Not provided'}"
|
| 400 |
+
for i, m in enumerate(agent_state.metrics[:20]) # Show more metrics
|
| 401 |
+
]
|
| 402 |
+
)
|
| 403 |
+
else:
|
| 404 |
+
metrics_text = "No metrics extracted."
|
| 405 |
+
|
| 406 |
+
# Render summary prompt from template
|
| 407 |
+
summary_prompt_content = PROMPTS.paper_subagent.summary_prompt.render(
|
| 408 |
+
user_query=agent_state.user_query,
|
| 409 |
+
papers_text=papers_text,
|
| 410 |
+
datasets_text=datasets_text,
|
| 411 |
+
metrics_text=metrics_text,
|
| 412 |
+
)
|
| 413 |
+
summary_prompt = Message(
|
| 414 |
+
role="user",
|
| 415 |
+
content=summary_prompt_content,
|
| 416 |
+
agent_sender=AGENT_NAME,
|
| 417 |
+
)
|
| 418 |
+
agent_state.add_message(summary_prompt)
|
| 419 |
+
|
| 420 |
+
# Get summary from LLM
|
| 421 |
+
system_prompt = PROMPTS.paper_subagent.summary_system_prompt.render()
|
| 422 |
+
msg = ModelRegistry.completion(
|
| 423 |
+
LLM_NAME,
|
| 424 |
+
agent_state.patched_history,
|
| 425 |
+
system_prompt=system_prompt,
|
| 426 |
+
agent_sender=AGENT_NAME,
|
| 427 |
+
tools=None, # No tools needed for summary
|
| 428 |
+
).with_log()
|
| 429 |
+
|
| 430 |
+
# Store the summary text
|
| 431 |
+
agent_state.output_summary = msg.content or ""
|
| 432 |
+
agent_state.add_message(msg)
|
| 433 |
+
|
| 434 |
+
logger.info(f"Summary generated: {len(agent_state.output_summary)} characters")
|
| 435 |
+
|
| 436 |
+
return agent_state
|
scievo/agents/data_agent/paper_subagent/state.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from scievo.core.types import HistoryState, ToolsetState
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class PaperSearchAgentState(ToolsetState, HistoryState):
|
| 5 |
+
"""Minimal state for Paper Search Agent.
|
| 6 |
+
|
| 7 |
+
This agent searches for academic papers and datasets using the paper_search and dataset_search toolsets.
|
| 8 |
+
Supports iterative query refinement to improve search results.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
# Input
|
| 12 |
+
user_query: str # User's original search query
|
| 13 |
+
data_summary: str | None = (
|
| 14 |
+
None # Data analysis summary from data agent (for dataset similarity search)
|
| 15 |
+
)
|
| 16 |
+
current_query: str | None = None # Current optimized query (for iteration)
|
| 17 |
+
max_search_iterations: int = 3 # Maximum number of search iterations
|
| 18 |
+
|
| 19 |
+
# Iteration tracking
|
| 20 |
+
search_iteration: int = 0 # Current search iteration count
|
| 21 |
+
query_history: list[str] = [] # History of queries tried
|
| 22 |
+
|
| 23 |
+
# Output
|
| 24 |
+
papers: list[dict] = [] # Paper search results
|
| 25 |
+
datasets: list[dict] = [] # Dataset search results
|
| 26 |
+
metrics: list[dict] = [] # Extracted metrics from papers
|
| 27 |
+
output_summary: str | None = None # Final summary
|
scievo/agents/data_agent/plan.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from loguru import logger
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
|
| 4 |
+
from scievo.core import constant
|
| 5 |
+
from scievo.core.llms import ModelRegistry
|
| 6 |
+
from scievo.core.plan import Plan
|
| 7 |
+
from scievo.core.types import Message
|
| 8 |
+
from scievo.core.utils import parse_json_from_llm_response
|
| 9 |
+
from scievo.prompts import PROMPTS
|
| 10 |
+
|
| 11 |
+
from .state import DataAgentState
|
| 12 |
+
|
| 13 |
+
LLM_NAME = "plan"
|
| 14 |
+
AGENT_NAME = "data_planner"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@logger.catch
|
| 18 |
+
def planner_node(agent_state: DataAgentState) -> DataAgentState:
|
| 19 |
+
logger.trace("planner_node of Agent {}", AGENT_NAME)
|
| 20 |
+
|
| 21 |
+
user_query_msg = Message(
|
| 22 |
+
role="user",
|
| 23 |
+
content=agent_state.user_query,
|
| 24 |
+
agent_sender=AGENT_NAME,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
agent_state.add_message(user_query_msg)
|
| 28 |
+
|
| 29 |
+
msg = ModelRegistry.completion(
|
| 30 |
+
LLM_NAME,
|
| 31 |
+
agent_state.patched_history,
|
| 32 |
+
system_prompt=(
|
| 33 |
+
Message(
|
| 34 |
+
role="system",
|
| 35 |
+
content=PROMPTS.data.planner_system_prompt.render(is_replanner=False),
|
| 36 |
+
)
|
| 37 |
+
.with_log(cond=constant.LOG_SYSTEM_PROMPT)
|
| 38 |
+
.content
|
| 39 |
+
),
|
| 40 |
+
agent_sender=AGENT_NAME,
|
| 41 |
+
).with_log()
|
| 42 |
+
|
| 43 |
+
agent_state.add_message(msg)
|
| 44 |
+
|
| 45 |
+
# NOTE: we don't add the message to the history
|
| 46 |
+
plans = parse_json_from_llm_response(msg, Plan)
|
| 47 |
+
|
| 48 |
+
# NOTE:
|
| 49 |
+
agent_state.add_message(
|
| 50 |
+
Message(
|
| 51 |
+
role="user",
|
| 52 |
+
content="Follow the current plan.",
|
| 53 |
+
agent_sender=AGENT_NAME,
|
| 54 |
+
)
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
agent_state.plans = plans
|
| 58 |
+
agent_state.remaining_plans = plans.steps
|
| 59 |
+
agent_state.past_plans = []
|
| 60 |
+
|
| 61 |
+
# dummy user response, just for logging
|
| 62 |
+
if len(agent_state.remaining_plans) > 0:
|
| 63 |
+
Message(
|
| 64 |
+
role="user",
|
| 65 |
+
content=PROMPTS.data.replanner_user_response.render(
|
| 66 |
+
next_step=agent_state.remaining_plans[0],
|
| 67 |
+
),
|
| 68 |
+
agent_sender=AGENT_NAME,
|
| 69 |
+
).with_log()
|
| 70 |
+
else:
|
| 71 |
+
logger.warning("No plans generated by planner - remaining_plans is empty")
|
| 72 |
+
|
| 73 |
+
planner_output = msg.content if "msg" in locals() and msg.content else "No planner output"
|
| 74 |
+
|
| 75 |
+
agent_state.intermediate_state.append(
|
| 76 |
+
{
|
| 77 |
+
"node_name": "planner",
|
| 78 |
+
"output": planner_output,
|
| 79 |
+
}
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
return agent_state
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def replanner_node(agent_state: DataAgentState) -> DataAgentState:
|
| 86 |
+
logger.trace("replanner_node of Agent {}", AGENT_NAME)
|
| 87 |
+
|
| 88 |
+
# NOTE: when all the plans are done, go into the talk mode
|
| 89 |
+
if len(agent_state.remaining_plans) == 0:
|
| 90 |
+
logger.debug("All plans are done, going into talk mode")
|
| 91 |
+
agent_state.talk_mode = True
|
| 92 |
+
# agent_state.remaining_plans = ["Response to users' query."]
|
| 93 |
+
return agent_state
|
| 94 |
+
|
| 95 |
+
# Move current plan to past_plans
|
| 96 |
+
agent_state.past_plans.append(agent_state.remaining_plans.pop(0))
|
| 97 |
+
|
| 98 |
+
user_query = agent_state.user_query
|
| 99 |
+
|
| 100 |
+
user_msg = Message(
|
| 101 |
+
role="user",
|
| 102 |
+
content=PROMPTS.data.replanner_user_prompt.render(
|
| 103 |
+
user_query=user_query,
|
| 104 |
+
plan=agent_state.plans.steps,
|
| 105 |
+
past_steps=agent_state.past_plans,
|
| 106 |
+
),
|
| 107 |
+
agent_sender=AGENT_NAME,
|
| 108 |
+
).with_log()
|
| 109 |
+
|
| 110 |
+
agent_state.add_message(user_msg)
|
| 111 |
+
|
| 112 |
+
msg = ModelRegistry.completion(
|
| 113 |
+
LLM_NAME,
|
| 114 |
+
agent_state.patched_history,
|
| 115 |
+
system_prompt=(
|
| 116 |
+
Message(
|
| 117 |
+
role="system",
|
| 118 |
+
content=PROMPTS.data.planner_system_prompt.render(is_replanner=True),
|
| 119 |
+
)
|
| 120 |
+
.with_log(cond=constant.LOG_SYSTEM_PROMPT)
|
| 121 |
+
.content
|
| 122 |
+
),
|
| 123 |
+
agent_sender=AGENT_NAME,
|
| 124 |
+
).with_log()
|
| 125 |
+
|
| 126 |
+
agent_state.add_message(msg)
|
| 127 |
+
|
| 128 |
+
class Replan(BaseModel):
|
| 129 |
+
continued: bool = False
|
| 130 |
+
modified: list[str] = []
|
| 131 |
+
|
| 132 |
+
# NOTE: we don't add the message to the history
|
| 133 |
+
plans = parse_json_from_llm_response(msg, Replan)
|
| 134 |
+
|
| 135 |
+
if plans.continued is True:
|
| 136 |
+
pass # No changes to plan
|
| 137 |
+
elif plans.continued is False:
|
| 138 |
+
# plans done
|
| 139 |
+
logger.debug("Replanner indicates all plans are done, going into talk mode")
|
| 140 |
+
agent_state.talk_mode = True
|
| 141 |
+
return agent_state
|
| 142 |
+
else:
|
| 143 |
+
agent_state.plans = Plan(steps=plans.modified)
|
| 144 |
+
agent_state.remaining_plans = plans.modified
|
| 145 |
+
|
| 146 |
+
if len(agent_state.remaining_plans) > 0:
|
| 147 |
+
agent_state.add_message(
|
| 148 |
+
Message(
|
| 149 |
+
role="user",
|
| 150 |
+
content=PROMPTS.data.replanner_user_response.render(
|
| 151 |
+
next_step=agent_state.remaining_plans[0],
|
| 152 |
+
),
|
| 153 |
+
agent_sender=AGENT_NAME,
|
| 154 |
+
)
|
| 155 |
+
)
|
| 156 |
+
else:
|
| 157 |
+
logger.warning("No remaining plans after replan - going to talk mode")
|
| 158 |
+
agent_state.talk_mode = True
|
| 159 |
+
|
| 160 |
+
replanner_output = msg.content if "msg" in locals() and msg.content else "No replanner output"
|
| 161 |
+
|
| 162 |
+
agent_state.intermediate_state.append(
|
| 163 |
+
{
|
| 164 |
+
"node_name": "replanner",
|
| 165 |
+
"output": replanner_output,
|
| 166 |
+
}
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
return agent_state
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def should_replan(agent_state: DataAgentState) -> str:
|
| 173 |
+
if agent_state.talk_mode:
|
| 174 |
+
return "finalize"
|
| 175 |
+
else:
|
| 176 |
+
return "gateway"
|
scievo/agents/data_agent/state.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from scievo.core.code_env import LocalEnv
|
| 2 |
+
from scievo.core.plan import PlanState
|
| 3 |
+
from scievo.core.types import HistoryState, RBankState, ToolsetState
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class DataAgentState(ToolsetState, PlanState, HistoryState, RBankState):
|
| 7 |
+
"""State of an agent"""
|
| 8 |
+
|
| 9 |
+
user_query: str
|
| 10 |
+
# Local environment for the agent
|
| 11 |
+
workspace: LocalEnv
|
| 12 |
+
|
| 13 |
+
# Optional additional description of the data (input)
|
| 14 |
+
data_desc: str | None = None
|
| 15 |
+
|
| 16 |
+
# talking mode
|
| 17 |
+
talk_mode: bool = False
|
| 18 |
+
|
| 19 |
+
# output summary generated by the agent (output)
|
| 20 |
+
output_summary: str | None = None
|
| 21 |
+
|
| 22 |
+
# Paper subagent results
|
| 23 |
+
papers: list[dict] = []
|
| 24 |
+
datasets: list[dict] = []
|
| 25 |
+
metrics: list[dict] = []
|
| 26 |
+
paper_search_summary: str | None = None
|
| 27 |
+
|
| 28 |
+
# Intermediate states
|
| 29 |
+
intermediate_state: list[dict] = []
|
| 30 |
+
|
| 31 |
+
def __init__(self, *args, **kwargs):
|
| 32 |
+
super().__init__(*args, **kwargs)
|
| 33 |
+
self.toolsets.append("fs")
|
scievo/agents/dummy_agent.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langgraph.graph import END, START, StateGraph
|
| 2 |
+
from loguru import logger
|
| 3 |
+
|
| 4 |
+
from scievo.core.types import GraphState, Message
|
| 5 |
+
from scievo.prompts import PROMPTS
|
| 6 |
+
|
| 7 |
+
LLM_NAME = "dummy"
|
| 8 |
+
AGENT_NAME = "dummy"
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def say_hello(graph_state: GraphState) -> GraphState:
|
| 12 |
+
logger.debug("say_hello of Agent {}", AGENT_NAME)
|
| 13 |
+
msg = Message(
|
| 14 |
+
role="assistant",
|
| 15 |
+
content="Hello",
|
| 16 |
+
llm_sender=None,
|
| 17 |
+
agent_sender=AGENT_NAME,
|
| 18 |
+
).with_log()
|
| 19 |
+
graph_state.agents[AGENT_NAME].data_msgs.append(msg)
|
| 20 |
+
return graph_state
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@logger.catch
|
| 24 |
+
def build():
|
| 25 |
+
g = StateGraph(GraphState)
|
| 26 |
+
g.add_node("dummy1", say_hello)
|
| 27 |
+
g.add_node("dummy2", say_hello)
|
| 28 |
+
|
| 29 |
+
g.add_edge(START, "dummy1")
|
| 30 |
+
g.add_edge("dummy1", "dummy2")
|
| 31 |
+
g.add_edge("dummy2", END)
|
| 32 |
+
|
| 33 |
+
return g
|
scievo/agents/experiment_agent/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Experiment Agent - High-level orchestrator for code modification experiments.
|
| 3 |
+
|
| 4 |
+
This agent coordinates three sub-agents:
|
| 5 |
+
1. Coding Subagent V2 - Plans and executes code modifications
|
| 6 |
+
2. Exec Subagent - Runs experiments/commands in a local shell
|
| 7 |
+
3. Summary Subagent - Generates comprehensive experiment summaries
|
| 8 |
+
|
| 9 |
+
The agent runs in a revision loop until the experiment succeeds or max revisions is reached.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from .build import build
|
| 13 |
+
from .state import ExperimentAgentState
|
| 14 |
+
|
| 15 |
+
__all__ = ["build", "ExperimentAgentState"]
|
scievo/agents/experiment_agent/build.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Build the Experiment Agent graph.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from langgraph.graph import END, START, StateGraph
|
| 6 |
+
from loguru import logger
|
| 7 |
+
|
| 8 |
+
from . import execute
|
| 9 |
+
from .state import ExperimentAgentState
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@logger.catch
|
| 13 |
+
def build():
|
| 14 |
+
"""Build the Experiment Agent graph with sub-agent composition."""
|
| 15 |
+
g = StateGraph(ExperimentAgentState)
|
| 16 |
+
|
| 17 |
+
# ==================== NODES ====================
|
| 18 |
+
# Initialization node - prepares initial context
|
| 19 |
+
g.add_node("init", execute.init_node)
|
| 20 |
+
|
| 21 |
+
# Sub-agent nodes - invoke compiled sub-graphs
|
| 22 |
+
g.add_node("coding", execute.run_coding_subagent)
|
| 23 |
+
g.add_node("exec", execute.run_exec_subagent)
|
| 24 |
+
g.add_node("summary", execute.run_summary_subagent)
|
| 25 |
+
|
| 26 |
+
# Analysis node - analyzes loop results and generates insights
|
| 27 |
+
g.add_node("analysis", execute.analysis_node)
|
| 28 |
+
|
| 29 |
+
# Revision judge node - decides whether to continue or complete
|
| 30 |
+
g.add_node("revision_judge", execute.revision_judge_node)
|
| 31 |
+
|
| 32 |
+
# Finalize node - prepares final output
|
| 33 |
+
g.add_node("finalize", execute.finalize_node)
|
| 34 |
+
|
| 35 |
+
# ==================== EDGES ====================
|
| 36 |
+
# Start -> Init
|
| 37 |
+
g.add_edge(START, "init")
|
| 38 |
+
|
| 39 |
+
# Init -> Coding
|
| 40 |
+
g.add_edge("init", "coding")
|
| 41 |
+
|
| 42 |
+
# Coding -> Exec
|
| 43 |
+
g.add_edge("coding", "exec")
|
| 44 |
+
|
| 45 |
+
# Exec -> Summary
|
| 46 |
+
g.add_edge("exec", "summary")
|
| 47 |
+
|
| 48 |
+
# Summary -> Analysis
|
| 49 |
+
g.add_edge("summary", "analysis")
|
| 50 |
+
|
| 51 |
+
# Analysis -> Revision Judge
|
| 52 |
+
g.add_edge("analysis", "revision_judge")
|
| 53 |
+
|
| 54 |
+
# Revision Judge -> Conditional (Continue loop or Complete)
|
| 55 |
+
g.add_conditional_edges(
|
| 56 |
+
"revision_judge",
|
| 57 |
+
execute.should_continue_revision,
|
| 58 |
+
{
|
| 59 |
+
"continue": "coding", # Go back to coding for next revision
|
| 60 |
+
"complete": "finalize", # Exit the loop
|
| 61 |
+
},
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Finalize -> END
|
| 65 |
+
g.add_edge("finalize", END)
|
| 66 |
+
|
| 67 |
+
return g
|
scievo/agents/experiment_agent/coding_subagent_v2/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Coding Subagent V2
|
| 3 |
+
|
| 4 |
+
This agent follows the plan-and-execute paradigm for coding tasks.
|
| 5 |
+
It integrates with OpenHands SDK for external code manipulation.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from .build import build
|
| 9 |
+
from .state import CodingAgentState
|
| 10 |
+
|
| 11 |
+
__all__ = ["build", "CodingAgentState"]
|
scievo/agents/experiment_agent/coding_subagent_v2/build.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langgraph.graph import END, START, StateGraph
|
| 2 |
+
from loguru import logger
|
| 3 |
+
|
| 4 |
+
from . import execute
|
| 5 |
+
from .state import CodingAgentState
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@logger.catch
|
| 9 |
+
def build():
|
| 10 |
+
"""Build the coding agent graph.
|
| 11 |
+
|
| 12 |
+
This is a minimal graph that delegates all coding work to OpenHands SDK.
|
| 13 |
+
Flow: START -> openhands_node -> summary_node -> END
|
| 14 |
+
|
| 15 |
+
OpenHands has its own internal planning and execution, so no external
|
| 16 |
+
LLM chat loop or tool calling is needed.
|
| 17 |
+
"""
|
| 18 |
+
g = StateGraph(CodingAgentState)
|
| 19 |
+
|
| 20 |
+
# Nodes - minimal: just OpenHands execution and summary
|
| 21 |
+
g.add_node("openhands", execute.openhands_node)
|
| 22 |
+
g.add_node("summary", execute.summary_node)
|
| 23 |
+
|
| 24 |
+
# Simple linear flow
|
| 25 |
+
g.add_edge(START, "openhands")
|
| 26 |
+
g.add_edge("openhands", "summary")
|
| 27 |
+
g.add_edge("summary", END)
|
| 28 |
+
|
| 29 |
+
return g
|
scievo/agents/experiment_agent/coding_subagent_v2/execute.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Execution nodes for the Coding Subagent V2
|
| 3 |
+
|
| 4 |
+
This module provides a minimal execution flow that delegates all coding work
|
| 5 |
+
to OpenHands SDK. The flow is: START -> openhands_node -> summary_node -> END
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
from loguru import logger
|
| 11 |
+
from openhands.sdk.event import ActionEvent
|
| 12 |
+
|
| 13 |
+
from scievo.core import constant
|
| 14 |
+
from scievo.core.llms import ModelRegistry
|
| 15 |
+
from scievo.core.types import Message
|
| 16 |
+
from scievo.prompts import PROMPTS
|
| 17 |
+
|
| 18 |
+
from .state import CodingAgentState
|
| 19 |
+
|
| 20 |
+
LLM_NAME = "experiment_coding"
|
| 21 |
+
AGENT_NAME = "experiment_coding"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def openhands_node(agent_state: CodingAgentState) -> CodingAgentState:
|
| 25 |
+
"""
|
| 26 |
+
Execute the coding task using OpenHands sub-agent.
|
| 27 |
+
|
| 28 |
+
This node directly invokes the OpenHands conversation to handle
|
| 29 |
+
the entire coding workflow. OpenHands has its own internal planning,
|
| 30 |
+
tool calling, and execution mechanisms.
|
| 31 |
+
"""
|
| 32 |
+
logger.debug("openhands_node of Agent {}", AGENT_NAME)
|
| 33 |
+
agent_state.add_node_history("openhands")
|
| 34 |
+
|
| 35 |
+
conversation = agent_state.openhands_conversation
|
| 36 |
+
if conversation is None:
|
| 37 |
+
logger.error("OpenHands conversation not initialized")
|
| 38 |
+
agent_state.output_summary = "Error: OpenHands conversation not initialized."
|
| 39 |
+
return agent_state
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
# Construct the message for OpenHands
|
| 43 |
+
instruction = agent_state.user_query or "No specific coding task provided."
|
| 44 |
+
bg_info = agent_state.data_summary or "No background information available."
|
| 45 |
+
# prefix with `> ` for markdown blockquote
|
| 46 |
+
instruction = "\n".join([f"> {line}" for line in instruction.splitlines()])
|
| 47 |
+
bg_info = "\n".join([f"> {line}" for line in bg_info.splitlines()])
|
| 48 |
+
workspace_dir = os.path.abspath(agent_state.workspace.working_dir)
|
| 49 |
+
|
| 50 |
+
message = f"""\
|
| 51 |
+
# Requirements:
|
| 52 |
+
- At the end of your response, provide a detailed explanation of what you did and why.
|
| 53 |
+
- Ensure that all changes are made in a way that maintains the integrity of the codebase.
|
| 54 |
+
- Avoid long-running executions of training or data processing; focus on code changes. If needed for code testing, design some simple test code instead.
|
| 55 |
+
|
| 56 |
+
# Important Notes:
|
| 57 |
+
- DO NOT train the full model. Just train a demo if needed for testing code changes.
|
| 58 |
+
- DO NOT run large data processing tasks. Just simulate with small data if needed for testing code
|
| 59 |
+
- Always ensure that the code runs without errors after your changes.
|
| 60 |
+
- I would run the full experiments later after getting your code changes.
|
| 61 |
+
|
| 62 |
+
# Workspace
|
| 63 |
+
{workspace_dir}
|
| 64 |
+
|
| 65 |
+
# Task:
|
| 66 |
+
{instruction}
|
| 67 |
+
|
| 68 |
+
# Background information:
|
| 69 |
+
```
|
| 70 |
+
{bg_info}
|
| 71 |
+
```
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
logger.info("Sending task to OpenHands sub-agent: {}", instruction[:100])
|
| 75 |
+
|
| 76 |
+
# Send message to the OpenHands agent
|
| 77 |
+
conversation.send_message(message)
|
| 78 |
+
|
| 79 |
+
# Run the agent until completion
|
| 80 |
+
with agent_state.workspace:
|
| 81 |
+
conversation.run()
|
| 82 |
+
|
| 83 |
+
# Extract the last response from OpenHands
|
| 84 |
+
if conversation.state.events:
|
| 85 |
+
for e in reversed(conversation.state.events):
|
| 86 |
+
if isinstance(e, ActionEvent) and e.source == "agent":
|
| 87 |
+
if hasattr(e, "llm_message") and e.llm_message:
|
| 88 |
+
content = e.llm_message.content
|
| 89 |
+
elif (m := getattr(e, "to_llm_message", None)) is not None and callable(m):
|
| 90 |
+
content = m().content
|
| 91 |
+
else:
|
| 92 |
+
# Unable to extract content from this event
|
| 93 |
+
continue
|
| 94 |
+
last_response = "\n".join([c.text for c in content])
|
| 95 |
+
break
|
| 96 |
+
else:
|
| 97 |
+
last_response = "Coding task completed (no detailed response available)."
|
| 98 |
+
else:
|
| 99 |
+
last_response = "Coding task completed (no detailed response available)."
|
| 100 |
+
|
| 101 |
+
# Log the result
|
| 102 |
+
logger.info("OpenHands sub-agent completed task")
|
| 103 |
+
|
| 104 |
+
# Store the response in history for summary generation
|
| 105 |
+
agent_state.add_message(
|
| 106 |
+
Message(
|
| 107 |
+
role="assistant",
|
| 108 |
+
content=f"[OpenHands Sub-Agent Result]\n{last_response}",
|
| 109 |
+
agent_sender="openhands",
|
| 110 |
+
).with_log()
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
except Exception as e:
|
| 114 |
+
logger.exception("OpenHands agent error")
|
| 115 |
+
agent_state.add_message(
|
| 116 |
+
Message(
|
| 117 |
+
role="assistant",
|
| 118 |
+
content=f"[OpenHands Error] {str(e)}",
|
| 119 |
+
agent_sender="openhands",
|
| 120 |
+
).with_log()
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
return agent_state
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def summary_node(agent_state: CodingAgentState) -> CodingAgentState:
|
| 127 |
+
"""Generate summary of the coding workflow and results."""
|
| 128 |
+
logger.debug("summary_node of Agent {}", AGENT_NAME)
|
| 129 |
+
agent_state.add_node_history("summary")
|
| 130 |
+
|
| 131 |
+
# Add summary generation prompt from PROMPTS
|
| 132 |
+
summary_prompt = Message(
|
| 133 |
+
role="user",
|
| 134 |
+
content=PROMPTS.experiment_coding_v2.summary_prompt.render(),
|
| 135 |
+
agent_sender=AGENT_NAME,
|
| 136 |
+
)
|
| 137 |
+
agent_state.add_message(summary_prompt)
|
| 138 |
+
|
| 139 |
+
# Get summary from LLM
|
| 140 |
+
msg = ModelRegistry.completion(
|
| 141 |
+
LLM_NAME,
|
| 142 |
+
agent_state.patched_history,
|
| 143 |
+
system_prompt=(
|
| 144 |
+
Message(
|
| 145 |
+
role="system",
|
| 146 |
+
content=PROMPTS.experiment_coding_v2.summary_system_prompt.render(),
|
| 147 |
+
)
|
| 148 |
+
.with_log(cond=constant.LOG_SYSTEM_PROMPT)
|
| 149 |
+
.content
|
| 150 |
+
),
|
| 151 |
+
agent_sender=AGENT_NAME,
|
| 152 |
+
tools=None, # No tools needed for final summary
|
| 153 |
+
).with_log()
|
| 154 |
+
|
| 155 |
+
# Store the summary text
|
| 156 |
+
agent_state.output_summary = msg.content or ""
|
| 157 |
+
agent_state.add_message(msg)
|
| 158 |
+
|
| 159 |
+
logger.info(f"Coding task summary generated: {len(agent_state.output_summary)} characters")
|
| 160 |
+
|
| 161 |
+
return agent_state
|
scievo/agents/experiment_agent/coding_subagent_v2/state.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import uuid
|
| 3 |
+
from typing import TYPE_CHECKING, Optional
|
| 4 |
+
|
| 5 |
+
if TYPE_CHECKING:
|
| 6 |
+
from openhands.sdk import Conversation
|
| 7 |
+
|
| 8 |
+
from pydantic import PrivateAttr
|
| 9 |
+
|
| 10 |
+
from scievo.core.code_env import LocalEnv
|
| 11 |
+
from scievo.core.types import HistoryState, ToolsetState
|
| 12 |
+
from scievo.prompts import SKILLS
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class CodingAgentState(ToolsetState, HistoryState):
|
| 16 |
+
"""State of the Coding Subagent V2.
|
| 17 |
+
|
| 18 |
+
This agent delegates coding tasks to OpenHands SDK which has its own
|
| 19 |
+
internal planning mechanism. No external planning is needed.
|
| 20 |
+
|
| 21 |
+
Note: No RBankState - memory extraction is not used in this agent.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
# Summary of the data from data agent, providing background info for the coding task (input)
|
| 25 |
+
data_summary: str
|
| 26 |
+
|
| 27 |
+
# User's coding task description (input, optional)
|
| 28 |
+
user_query: str | None = None
|
| 29 |
+
|
| 30 |
+
# Local environment for the agent (input)
|
| 31 |
+
workspace: LocalEnv
|
| 32 |
+
|
| 33 |
+
# OpenHands Conversation object - persists throughout the execution (private)
|
| 34 |
+
# This maintains the conversation history with the external coding agent
|
| 35 |
+
_openhands_conversation: Optional["Conversation"] = PrivateAttr(default=None)
|
| 36 |
+
|
| 37 |
+
# Output summary (output)
|
| 38 |
+
output_summary: str | None = None
|
| 39 |
+
|
| 40 |
+
def __init__(self, _openhands_conversation: Optional["Conversation"] = None, *args, **kwargs):
|
| 41 |
+
super().__init__(*args, **kwargs)
|
| 42 |
+
# Create a default empty conversation if not provided
|
| 43 |
+
if _openhands_conversation is None:
|
| 44 |
+
enable_openhands = os.getenv("SCIEVO_ENABLE_OPENHANDS", "").strip().lower() in {
|
| 45 |
+
"1",
|
| 46 |
+
"true",
|
| 47 |
+
"yes",
|
| 48 |
+
"y",
|
| 49 |
+
}
|
| 50 |
+
if not enable_openhands:
|
| 51 |
+
raise RuntimeError(
|
| 52 |
+
"OpenHands coding subagent (v2) is disabled. "
|
| 53 |
+
"Set env `SCIEVO_ENABLE_OPENHANDS=1` to enable it, or use the Claude coding subagent "
|
| 54 |
+
"(`CODING_AGENT_VERSION=v3`)."
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# Setup openhands paths first (must be before any openhands imports)
|
| 58 |
+
# Local imports so importing this module doesn't require OpenHands unless v2 is used.
|
| 59 |
+
from openhands.sdk import LLM, Agent, AgentContext, Conversation, Tool
|
| 60 |
+
from openhands.sdk.context.skills import Skill
|
| 61 |
+
|
| 62 |
+
from scievo.core import openhands_import # noqa: F401
|
| 63 |
+
|
| 64 |
+
# Try to import LLMSummarizingCondenser if available
|
| 65 |
+
try:
|
| 66 |
+
from openhands.sdk.context.condenser import LLMSummarizingCondenser
|
| 67 |
+
except ImportError:
|
| 68 |
+
# Fallback: LLMSummarizingCondenser is not available in this version
|
| 69 |
+
LLMSummarizingCondenser = None
|
| 70 |
+
|
| 71 |
+
api_key = os.getenv("OPENHANDS_API_KEY") or os.getenv("LLM_API_KEY")
|
| 72 |
+
model = os.getenv("OPENHANDS_MODEL", "anthropic/claude-sonnet-4-5-20250929")
|
| 73 |
+
|
| 74 |
+
llm = LLM(
|
| 75 |
+
model=model,
|
| 76 |
+
api_key=api_key,
|
| 77 |
+
usage_id=f"openhands-coding-agent-{uuid.uuid4().hex[:8]}",
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
from openhands.tools.file_editor import FileEditorTool
|
| 81 |
+
from openhands.tools.glob import GlobTool
|
| 82 |
+
from openhands.tools.grep import GrepTool
|
| 83 |
+
from openhands.tools.task_tracker import TaskTrackerTool
|
| 84 |
+
from openhands.tools.terminal import TerminalTool
|
| 85 |
+
|
| 86 |
+
tools = [
|
| 87 |
+
Tool(name=FileEditorTool.name),
|
| 88 |
+
Tool(name=TaskTrackerTool.name),
|
| 89 |
+
Tool(name=TerminalTool.name),
|
| 90 |
+
Tool(name=GlobTool.name),
|
| 91 |
+
Tool(name=GrepTool.name),
|
| 92 |
+
]
|
| 93 |
+
agent_context = AgentContext(
|
| 94 |
+
skills=[
|
| 95 |
+
Skill(
|
| 96 |
+
name="Python Dependency Management by `uv` instead of `pip`",
|
| 97 |
+
content="For Python projects: Always prioritize using 'uv' for managing dependencies and virtual environments. "
|
| 98 |
+
"Avoid using 'pip' or other package managers that directly affect the native system environment. "
|
| 99 |
+
"Use 'uv sync' to install dependencies from lock files, 'uv venv' to create isolated environments, "
|
| 100 |
+
"and 'uv add' to add new packages. This approach ensures project isolation and reproducibility. "
|
| 101 |
+
"This skill applies only to Python projects.",
|
| 102 |
+
),
|
| 103 |
+
Skill(
|
| 104 |
+
name="Avoid Long Time Operations",
|
| 105 |
+
content="Avoid using tools or commands that may lead to long wait times or blocking operations, "
|
| 106 |
+
"such as training the model directly within this environment. ",
|
| 107 |
+
),
|
| 108 |
+
Skill(
|
| 109 |
+
name="File Operations Should Use Absolute Paths as Much as Possible",
|
| 110 |
+
content="When using the File Editor tool and other file-related tools, always refer to files using their absolute paths. "
|
| 111 |
+
"This ensures that file operations are unambiguous and correctly targeted within the workspace. ",
|
| 112 |
+
),
|
| 113 |
+
Skill(
|
| 114 |
+
name="UV - Python Package Manager Skill",
|
| 115 |
+
content=SKILLS.uv_skill,
|
| 116 |
+
),
|
| 117 |
+
],
|
| 118 |
+
system_message_suffix="""\
|
| 119 |
+
<CLI_MODE>
|
| 120 |
+
You are operating in CLI mode, so all file paths should be absolute paths as much as possible.
|
| 121 |
+
Besides, try to avoid long time operations that may block the process, e.g., training the deep learning model directly.
|
| 122 |
+
</CLI_MODE>
|
| 123 |
+
|
| 124 |
+
<SHORT_RUNNING>
|
| 125 |
+
- DO NOT train the full model. Just train a demo if needed for testing code changes.
|
| 126 |
+
- DO NOT run large data processing tasks. Just simulate with small data if needed for testing code
|
| 127 |
+
- The full experiments will be run later by the user after getting the code changes.
|
| 128 |
+
- IMPORTANT: If a command takes longer than 10 minutes (a.k.a. 600 seconds), you should leave it to the user to run later.
|
| 129 |
+
</SHORT_RUNNING>
|
| 130 |
+
""",
|
| 131 |
+
)
|
| 132 |
+
# Build agent kwargs - only include condenser if available
|
| 133 |
+
agent_kwargs = {
|
| 134 |
+
"llm": llm,
|
| 135 |
+
"tools": tools,
|
| 136 |
+
"system_prompt_kwargs": {"cli_mode": True},
|
| 137 |
+
"agent_context": agent_context,
|
| 138 |
+
}
|
| 139 |
+
# Add condenser only if LLMSummarizingCondenser is available
|
| 140 |
+
if LLMSummarizingCondenser is not None:
|
| 141 |
+
agent_kwargs["condenser"] = LLMSummarizingCondenser(
|
| 142 |
+
llm=llm.model_copy(update={"usage_id": "condenser"}),
|
| 143 |
+
max_size=48,
|
| 144 |
+
keep_first=4,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
agent = Agent(**agent_kwargs)
|
| 148 |
+
_openhands_conversation = Conversation(
|
| 149 |
+
agent=agent, workspace=self.workspace.working_dir
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
self._openhands_conversation = _openhands_conversation
|
| 153 |
+
|
| 154 |
+
# Ensure the openhands toolset is included initially
|
| 155 |
+
self.toolsets.append("openhands")
|
| 156 |
+
|
| 157 |
+
@property
|
| 158 |
+
def openhands_conversation(self) -> "Conversation":
|
| 159 |
+
"""Get the OpenHands Conversation object."""
|
| 160 |
+
return self._openhands_conversation
|
scievo/agents/experiment_agent/coding_subagent_v3_claude/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Coding Subagent V3 Claude
|
| 3 |
+
|
| 4 |
+
This agent delegates coding tasks to Claude Agent SDK for external code manipulation.
|
| 5 |
+
Claude Agent SDK has its own internal planning and execution mechanisms.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from .build import build
|
| 9 |
+
from .state import ClaudeCodingAgentState, CodingAgentState
|
| 10 |
+
|
| 11 |
+
__all__ = ["build", "ClaudeCodingAgentState", "CodingAgentState"]
|
scievo/agents/experiment_agent/coding_subagent_v3_claude/build.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langgraph.graph import END, START, StateGraph
|
| 2 |
+
from loguru import logger
|
| 3 |
+
|
| 4 |
+
from . import execute
|
| 5 |
+
from .state import ClaudeCodingAgentState
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@logger.catch
|
| 9 |
+
def build():
|
| 10 |
+
"""Build the Claude coding agent graph.
|
| 11 |
+
|
| 12 |
+
This is a minimal graph that delegates all coding work to Claude Agent SDK.
|
| 13 |
+
Flow: START -> claude_node -> summary_node -> END
|
| 14 |
+
|
| 15 |
+
Claude Agent SDK has its own internal planning and execution, so no external
|
| 16 |
+
LLM chat loop or tool calling is needed.
|
| 17 |
+
"""
|
| 18 |
+
g = StateGraph(ClaudeCodingAgentState)
|
| 19 |
+
|
| 20 |
+
# Nodes - minimal: just Claude execution and summary
|
| 21 |
+
g.add_node("claude", execute.claude_node)
|
| 22 |
+
g.add_node("summary", execute.summary_node)
|
| 23 |
+
|
| 24 |
+
# Simple linear flow
|
| 25 |
+
g.add_edge(START, "claude")
|
| 26 |
+
g.add_edge("claude", "summary")
|
| 27 |
+
g.add_edge("summary", END)
|
| 28 |
+
|
| 29 |
+
return g
|
scievo/agents/experiment_agent/coding_subagent_v3_claude/execute.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Execution nodes for the Coding Subagent V3 Claude
|
| 3 |
+
|
| 4 |
+
This module provides a minimal execution flow that delegates all coding work
|
| 5 |
+
to Claude Agent SDK. The flow is: START -> claude_node -> summary_node -> END
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
from loguru import logger
|
| 11 |
+
|
| 12 |
+
from scievo.core import constant
|
| 13 |
+
from scievo.core.llms import ModelRegistry
|
| 14 |
+
from scievo.core.types import Message
|
| 15 |
+
from scievo.prompts import PROMPTS
|
| 16 |
+
from scievo.tools.claude_agent_sdk_tool import run_claude_agent_sdk
|
| 17 |
+
from scievo.tools.claude_code_tool import run_claude_code
|
| 18 |
+
|
| 19 |
+
from .state import ClaudeCodingAgentState
|
| 20 |
+
|
| 21 |
+
LLM_NAME = "experiment_coding"
|
| 22 |
+
AGENT_NAME = "experiment_coding"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def claude_node(agent_state: ClaudeCodingAgentState) -> ClaudeCodingAgentState:
|
| 26 |
+
"""
|
| 27 |
+
Execute the coding task using Claude Agent SDK.
|
| 28 |
+
|
| 29 |
+
This node directly invokes the Claude Agent SDK to handle
|
| 30 |
+
the entire coding workflow. Claude Agent SDK has its own internal planning,
|
| 31 |
+
tool calling, and execution mechanisms.
|
| 32 |
+
"""
|
| 33 |
+
logger.debug("claude_node of Agent {}", AGENT_NAME)
|
| 34 |
+
agent_state.add_node_history("claude")
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
# Construct the message for Claude Agent SDK
|
| 38 |
+
instruction = agent_state.user_query or "No specific coding task provided."
|
| 39 |
+
bg_info = agent_state.data_summary or "No background information available."
|
| 40 |
+
# prefix with `> ` for markdown blockquote
|
| 41 |
+
instruction = "\n".join([f"> {line}" for line in instruction.splitlines()])
|
| 42 |
+
bg_info = "\n".join([f"> {line}" for line in bg_info.splitlines()])
|
| 43 |
+
workspace_dir = os.path.abspath(agent_state.workspace.working_dir)
|
| 44 |
+
|
| 45 |
+
prompt = f"""\
|
| 46 |
+
# Requirements:
|
| 47 |
+
- At the end of your response, provide a detailed explanation of what you did and why.
|
| 48 |
+
- Ensure that all changes are made in a way that maintains the integrity of the codebase.
|
| 49 |
+
- Avoid long-running executions of training or data processing; focus on code changes. If needed for code testing, design some simple test code instead.
|
| 50 |
+
|
| 51 |
+
# Important Notes:
|
| 52 |
+
- DO NOT train the full model. Just train a demo if needed for testing code changes.
|
| 53 |
+
- DO NOT run large data processing tasks. Just simulate with small data if needed for testing code
|
| 54 |
+
- Always ensure that the code runs without errors after your changes.
|
| 55 |
+
- I would run the full experiments later after getting your code changes.
|
| 56 |
+
|
| 57 |
+
# Workspace
|
| 58 |
+
{workspace_dir}
|
| 59 |
+
|
| 60 |
+
# Task:
|
| 61 |
+
{instruction}
|
| 62 |
+
|
| 63 |
+
# Background information:
|
| 64 |
+
```
|
| 65 |
+
{bg_info}
|
| 66 |
+
```
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
logger.info("Sending task to Claude Agent SDK: {}", instruction[:100])
|
| 70 |
+
|
| 71 |
+
# Call Claude Agent SDK tool (preferred)
|
| 72 |
+
sdk_result = run_claude_agent_sdk(
|
| 73 |
+
prompt=prompt,
|
| 74 |
+
cwd=workspace_dir,
|
| 75 |
+
allowed_tools=["Read", "Write", "Edit", "Bash", "Glob", "Grep"],
|
| 76 |
+
permission_mode="acceptEdits",
|
| 77 |
+
**{constant.__AGENT_STATE_NAME__: agent_state},
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
sdk_text = str(sdk_result)
|
| 81 |
+
has_error = any(
|
| 82 |
+
(line.strip().startswith("error:") and "error=None" not in line)
|
| 83 |
+
for line in sdk_text.splitlines()[:20]
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
if not has_error:
|
| 87 |
+
logger.info("Claude Agent SDK completed task")
|
| 88 |
+
agent_state.add_message(
|
| 89 |
+
Message(
|
| 90 |
+
role="assistant",
|
| 91 |
+
content=(
|
| 92 |
+
"[Claude Agent SDK Result]\n"
|
| 93 |
+
"Claude Agent SDK has completed the coding task. The changes have been applied to the workspace.\n\n"
|
| 94 |
+
f"{sdk_text}"
|
| 95 |
+
),
|
| 96 |
+
agent_sender="claude_agent_sdk",
|
| 97 |
+
).with_log()
|
| 98 |
+
)
|
| 99 |
+
else:
|
| 100 |
+
# Fallback to Claude Code CLI (still Claude-based, but doesn't require SDK install)
|
| 101 |
+
logger.warning("Claude Agent SDK returned an error; falling back to Claude Code CLI")
|
| 102 |
+
cli_result = run_claude_code(
|
| 103 |
+
instruction=prompt,
|
| 104 |
+
cwd=workspace_dir,
|
| 105 |
+
timeout=1800,
|
| 106 |
+
**{constant.__AGENT_STATE_NAME__: agent_state},
|
| 107 |
+
)
|
| 108 |
+
agent_state.add_message(
|
| 109 |
+
Message(
|
| 110 |
+
role="assistant",
|
| 111 |
+
content=(
|
| 112 |
+
"[Claude Agent SDK Error]\n"
|
| 113 |
+
f"{sdk_text}\n\n"
|
| 114 |
+
"[Claude Code CLI Fallback Result]\n"
|
| 115 |
+
f"{str(cli_result)}"
|
| 116 |
+
),
|
| 117 |
+
agent_sender="claude_code",
|
| 118 |
+
).with_log()
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
except Exception as e:
|
| 122 |
+
logger.exception("Claude Agent SDK error")
|
| 123 |
+
agent_state.add_message(
|
| 124 |
+
Message(
|
| 125 |
+
role="assistant",
|
| 126 |
+
content=f"[Claude Agent SDK Error] {str(e)}",
|
| 127 |
+
agent_sender="claude_agent_sdk",
|
| 128 |
+
).with_log()
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
claude_output = "Claude Agent SDK execution completed"
|
| 132 |
+
if agent_state.history:
|
| 133 |
+
last_msg = agent_state.history[-1]
|
| 134 |
+
if last_msg.role == "assistant" and last_msg.content:
|
| 135 |
+
claude_output = last_msg.content[:2000]
|
| 136 |
+
|
| 137 |
+
agent_state.intermediate_state.append(
|
| 138 |
+
{
|
| 139 |
+
"node_name": "claude",
|
| 140 |
+
"output": claude_output,
|
| 141 |
+
}
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
return agent_state
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def summary_node(agent_state: ClaudeCodingAgentState) -> ClaudeCodingAgentState:
|
| 148 |
+
"""Generate summary of the coding workflow and results."""
|
| 149 |
+
logger.debug("summary_node of Agent {}", AGENT_NAME)
|
| 150 |
+
agent_state.add_node_history("summary")
|
| 151 |
+
|
| 152 |
+
# Add summary generation prompt from PROMPTS
|
| 153 |
+
summary_prompt = Message(
|
| 154 |
+
role="user",
|
| 155 |
+
content=PROMPTS.experiment_coding_v2.summary_prompt.render(),
|
| 156 |
+
agent_sender=AGENT_NAME,
|
| 157 |
+
)
|
| 158 |
+
agent_state.add_message(summary_prompt)
|
| 159 |
+
|
| 160 |
+
# Get summary from LLM
|
| 161 |
+
msg = ModelRegistry.completion(
|
| 162 |
+
LLM_NAME,
|
| 163 |
+
agent_state.patched_history,
|
| 164 |
+
system_prompt=(
|
| 165 |
+
Message(
|
| 166 |
+
role="system",
|
| 167 |
+
content=PROMPTS.experiment_coding_v2.summary_system_prompt.render(),
|
| 168 |
+
)
|
| 169 |
+
.with_log(cond=constant.LOG_SYSTEM_PROMPT)
|
| 170 |
+
.content
|
| 171 |
+
),
|
| 172 |
+
agent_sender=AGENT_NAME,
|
| 173 |
+
tools=None, # No tools needed for final summary
|
| 174 |
+
).with_log()
|
| 175 |
+
|
| 176 |
+
# Store the summary text
|
| 177 |
+
agent_state.output_summary = msg.content or ""
|
| 178 |
+
agent_state.add_message(msg)
|
| 179 |
+
|
| 180 |
+
logger.info(f"Coding task summary generated: {len(agent_state.output_summary)} characters")
|
| 181 |
+
|
| 182 |
+
agent_state.intermediate_state.append(
|
| 183 |
+
{
|
| 184 |
+
"node_name": "summary",
|
| 185 |
+
"output": agent_state.output_summary or "No summary generated",
|
| 186 |
+
}
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
return agent_state
|
scievo/agents/experiment_agent/coding_subagent_v3_claude/state.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from scievo.core.code_env import LocalEnv
|
| 2 |
+
from scievo.core.types import HistoryState, ToolsetState
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ClaudeCodingAgentState(ToolsetState, HistoryState):
|
| 6 |
+
"""State of the Coding Subagent V3 Claude.
|
| 7 |
+
|
| 8 |
+
This agent delegates coding tasks to Claude Agent SDK which has its own
|
| 9 |
+
internal planning mechanism. No external planning is needed.
|
| 10 |
+
|
| 11 |
+
Note: No RBankState - memory extraction is not used in this agent.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
# Summary of the data from data agent, providing background info for the coding task (input)
|
| 15 |
+
data_summary: str
|
| 16 |
+
|
| 17 |
+
# User's coding task description (input, optional)
|
| 18 |
+
user_query: str | None = None
|
| 19 |
+
|
| 20 |
+
# Local environment for the agent (input)
|
| 21 |
+
workspace: LocalEnv
|
| 22 |
+
|
| 23 |
+
# Output summary (output)
|
| 24 |
+
output_summary: str | None = None
|
| 25 |
+
|
| 26 |
+
# Intermediate states
|
| 27 |
+
intermediate_state: list[dict] = []
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# Alias for consistency with v2 (CodingAgentState)
|
| 31 |
+
CodingAgentState = ClaudeCodingAgentState
|
scievo/agents/experiment_agent/exec_subagent/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Experiment Execution Agent
|
| 3 |
+
|
| 4 |
+
This agent is responsible for executing experiments in local shell sessions.
|
| 5 |
+
It parses natural language queries to determine commands to execute and manages
|
| 6 |
+
the execution using LocalShellSession.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from .build import build
|
| 10 |
+
from .state import ExecAgentState
|
| 11 |
+
|
| 12 |
+
__all__ = ["build", "ExecAgentState"]
|
scievo/agents/experiment_agent/exec_subagent/build.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Build the Experiment Execution Agent graph
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from langgraph.graph import END, START, StateGraph
|
| 6 |
+
from loguru import logger
|
| 7 |
+
|
| 8 |
+
from . import execute
|
| 9 |
+
from .state import ExecAgentState
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def init_node(agent_state: ExecAgentState) -> ExecAgentState:
|
| 13 |
+
"""Initialize the agent with the user query as the first message"""
|
| 14 |
+
logger.trace("init_node of ExecAgent")
|
| 15 |
+
|
| 16 |
+
# Add the initial user query message if history is empty
|
| 17 |
+
if not agent_state.history or len(agent_state.history) == 0:
|
| 18 |
+
from scievo.core.types import Message
|
| 19 |
+
from scievo.prompts import PROMPTS
|
| 20 |
+
|
| 21 |
+
user_msg = Message(
|
| 22 |
+
role="user",
|
| 23 |
+
content=PROMPTS.experiment_exec.exec_user_prompt.render(
|
| 24 |
+
user_query=agent_state.user_query,
|
| 25 |
+
working_dir=agent_state.workspace,
|
| 26 |
+
current_coding_summary=(
|
| 27 |
+
agent_state.coding_summaries[-1]
|
| 28 |
+
if agent_state.coding_summaries is not None
|
| 29 |
+
and len(agent_state.coding_summaries) > 0
|
| 30 |
+
else None
|
| 31 |
+
),
|
| 32 |
+
coding_summaries=agent_state.coding_summaries,
|
| 33 |
+
),
|
| 34 |
+
)
|
| 35 |
+
agent_state.add_message(user_msg)
|
| 36 |
+
else:
|
| 37 |
+
logger.warning("Agent history is not empty during init_node; skipping adding user query.")
|
| 38 |
+
|
| 39 |
+
agent_state.intermediate_state.append(
|
| 40 |
+
{
|
| 41 |
+
"node_name": "init",
|
| 42 |
+
"output": user_msg.content if "user_msg" in locals() else "Initialization complete",
|
| 43 |
+
}
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
return agent_state
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@logger.catch
|
| 50 |
+
def build():
|
| 51 |
+
"""Build and return the Experiment Execution Agent graph"""
|
| 52 |
+
g = StateGraph(ExecAgentState)
|
| 53 |
+
|
| 54 |
+
# Add nodes
|
| 55 |
+
g.add_node("init", init_node)
|
| 56 |
+
g.add_node("gateway", execute.gateway_node)
|
| 57 |
+
g.add_node("llm_chat", execute.llm_chat_node)
|
| 58 |
+
g.add_node("tool_calling", execute.tool_calling_node)
|
| 59 |
+
g.add_node("monitoring", execute.monitoring_node)
|
| 60 |
+
g.add_node("summary", execute.summary_node)
|
| 61 |
+
g.add_node("history_compression", execute.history_compression_node)
|
| 62 |
+
|
| 63 |
+
# Add edges
|
| 64 |
+
# Start -> Init -> Gateway
|
| 65 |
+
g.add_edge(START, "init")
|
| 66 |
+
g.add_edge("init", "gateway")
|
| 67 |
+
|
| 68 |
+
# Gateway -> conditional routing
|
| 69 |
+
g.add_conditional_edges(
|
| 70 |
+
"gateway",
|
| 71 |
+
execute.gateway_conditional,
|
| 72 |
+
[
|
| 73 |
+
"llm_chat",
|
| 74 |
+
"tool_calling",
|
| 75 |
+
"monitoring",
|
| 76 |
+
"summary",
|
| 77 |
+
"history_compression",
|
| 78 |
+
],
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# LLM chat -> Gateway
|
| 82 |
+
g.add_edge("llm_chat", "gateway")
|
| 83 |
+
|
| 84 |
+
# Tool calling -> Gateway
|
| 85 |
+
g.add_edge("tool_calling", "gateway")
|
| 86 |
+
|
| 87 |
+
# Monitoring -> Gateway (after checking/interrupting)
|
| 88 |
+
g.add_edge("monitoring", "gateway")
|
| 89 |
+
|
| 90 |
+
# History compression -> Gateway
|
| 91 |
+
g.add_edge("history_compression", "gateway")
|
| 92 |
+
|
| 93 |
+
# Summary -> END
|
| 94 |
+
g.add_edge("summary", END)
|
| 95 |
+
|
| 96 |
+
return g
|
scievo/agents/experiment_agent/exec_subagent/execute.py
ADDED
|
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Experiment Execution Agent - handles running experiments in local shell sessions
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import inspect
|
| 6 |
+
import json
|
| 7 |
+
import time
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
from loguru import logger
|
| 11 |
+
from pydantic import BaseModel
|
| 12 |
+
|
| 13 |
+
from scievo import history_compression
|
| 14 |
+
from scievo.core import constant
|
| 15 |
+
from scievo.core.llms import ModelRegistry
|
| 16 |
+
from scievo.core.types import Message
|
| 17 |
+
from scievo.core.utils import parse_json_from_llm_response, wrap_dict_to_toon
|
| 18 |
+
from scievo.prompts import PROMPTS, SKILLS
|
| 19 |
+
from scievo.tools import Tool, ToolRegistry
|
| 20 |
+
|
| 21 |
+
from .state import ExecAgentState
|
| 22 |
+
|
| 23 |
+
LLM_NAME = "experiment_execute"
|
| 24 |
+
LLM_MONITOR_NAME = "experiment_monitor"
|
| 25 |
+
AGENT_NAME = "experiment_exec"
|
| 26 |
+
|
| 27 |
+
BUILTIN_TOOLSETS = [
|
| 28 |
+
"state",
|
| 29 |
+
"exec", # The exec toolset is built-in for this agent
|
| 30 |
+
"fs",
|
| 31 |
+
]
|
| 32 |
+
ALLOWED_TOOLSETS = [
|
| 33 |
+
"history",
|
| 34 |
+
] # Can be extended if needed
|
| 35 |
+
|
| 36 |
+
MONITORING_INTERVALS = [5, 10, 10, 20, 20, 30, 45, 60, 60, 120, 120, 180] # in seconds
|
| 37 |
+
|
| 38 |
+
# load uv skill md
|
| 39 |
+
UV_SKILL = Path(__file__).parent.parent.parent.parent / "tools" / "skills" / "uv_venv_management.md"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def gateway_node(agent_state: ExecAgentState) -> ExecAgentState:
|
| 43 |
+
"""Gateway node - placeholder for conditional routing logic"""
|
| 44 |
+
logger.trace("gateway_node of Agent {}", AGENT_NAME)
|
| 45 |
+
return agent_state
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def gateway_conditional(agent_state: ExecAgentState) -> str:
|
| 49 |
+
"""Determine the next node based on the last message"""
|
| 50 |
+
# compress history if needed
|
| 51 |
+
if (
|
| 52 |
+
constant.HISTORY_AUTO_COMPRESSION
|
| 53 |
+
and "history_compression" not in agent_state.node_history[-2:]
|
| 54 |
+
and agent_state.total_patched_tokens > constant.HISTORY_AUTO_COMPRESSION_TOKEN_THRESHOLD
|
| 55 |
+
):
|
| 56 |
+
return "history_compression"
|
| 57 |
+
|
| 58 |
+
# Check if there's a command currently running in the session
|
| 59 |
+
if agent_state.is_monitor_mode:
|
| 60 |
+
# A command is running -> go to monitoring node
|
| 61 |
+
time2sleep = MONITORING_INTERVALS[
|
| 62 |
+
min(agent_state.monitoring_attempts, len(MONITORING_INTERVALS) - 1)
|
| 63 |
+
]
|
| 64 |
+
logger.debug(
|
| 65 |
+
f"A command is currently running. Waiting for {time2sleep} seconds before monitoring again."
|
| 66 |
+
)
|
| 67 |
+
time.sleep(time2sleep)
|
| 68 |
+
return "monitoring"
|
| 69 |
+
|
| 70 |
+
last_msg = agent_state.patched_history[-1]
|
| 71 |
+
|
| 72 |
+
# If the last message contains tool calls, execute them
|
| 73 |
+
if (tool_calls := last_msg.tool_calls) and len(tool_calls) > 0:
|
| 74 |
+
return "tool_calling"
|
| 75 |
+
|
| 76 |
+
# Route based on message role
|
| 77 |
+
match last_msg.role:
|
| 78 |
+
case "user" | "tool":
|
| 79 |
+
# User or tool message -> call LLM
|
| 80 |
+
return "llm_chat"
|
| 81 |
+
case "assistant":
|
| 82 |
+
# Assistant responded without tool calls -> execution is complete, go to summary
|
| 83 |
+
return "summary"
|
| 84 |
+
case _:
|
| 85 |
+
raise ValueError(f"Unknown message role: {last_msg.role}")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def llm_chat_node(agent_state: ExecAgentState) -> ExecAgentState:
|
| 89 |
+
"""LLM chat node - gets next action from the model"""
|
| 90 |
+
logger.debug("llm_chat_node of Agent {}", AGENT_NAME)
|
| 91 |
+
agent_state.add_node_history("llm_chat")
|
| 92 |
+
|
| 93 |
+
selected_state = {
|
| 94 |
+
"workspace": agent_state.workspace.working_dir,
|
| 95 |
+
"current_activated_toolsets": agent_state.toolsets,
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
# Update system prompt
|
| 99 |
+
system_prompt = PROMPTS.experiment_exec.exec_system_prompt.render(
|
| 100 |
+
state_text=wrap_dict_to_toon(selected_state),
|
| 101 |
+
toolsets_desc=ToolRegistry.get_toolsets_desc(BUILTIN_TOOLSETS + ALLOWED_TOOLSETS),
|
| 102 |
+
uv_skill=SKILLS.uv_skill,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# Construct tools
|
| 106 |
+
tools: dict[str, Tool] = {}
|
| 107 |
+
for toolset in agent_state.toolsets:
|
| 108 |
+
tools.update(ToolRegistry.get_toolset(toolset))
|
| 109 |
+
for toolset in BUILTIN_TOOLSETS:
|
| 110 |
+
tools.update(ToolRegistry.get_toolset(toolset))
|
| 111 |
+
|
| 112 |
+
# Get completion from LLM
|
| 113 |
+
msg = ModelRegistry.completion(
|
| 114 |
+
LLM_NAME,
|
| 115 |
+
agent_state.patched_history,
|
| 116 |
+
system_prompt=(
|
| 117 |
+
Message(role="system", content=system_prompt)
|
| 118 |
+
.with_log(cond=constant.LOG_SYSTEM_PROMPT)
|
| 119 |
+
.content
|
| 120 |
+
),
|
| 121 |
+
agent_sender=AGENT_NAME,
|
| 122 |
+
tools=[tool.name for tool in tools.values()],
|
| 123 |
+
).with_log()
|
| 124 |
+
|
| 125 |
+
agent_state.add_message(msg)
|
| 126 |
+
|
| 127 |
+
llm_output = (
|
| 128 |
+
msg.content
|
| 129 |
+
if msg.content
|
| 130 |
+
else ("[Tool calls: " + str(len(msg.tool_calls)) + "]" if msg.tool_calls else "[No output]")
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
agent_state.intermediate_state.append(
|
| 134 |
+
{
|
| 135 |
+
"node_name": "llm_chat",
|
| 136 |
+
"output": llm_output,
|
| 137 |
+
}
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
return agent_state
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def monitoring_node(agent_state: ExecAgentState) -> ExecAgentState:
|
| 144 |
+
"""Monitor a running command and decide whether to continue waiting or interrupt it"""
|
| 145 |
+
logger.debug("monitoring_node of Agent {}", AGENT_NAME)
|
| 146 |
+
agent_state.add_node_history("monitoring")
|
| 147 |
+
agent_state.monitoring_attempts += 1
|
| 148 |
+
|
| 149 |
+
if agent_state.monitoring_attempts <= len(MONITORING_INTERVALS):
|
| 150 |
+
total_monitoring_seconds = sum(MONITORING_INTERVALS[: agent_state.monitoring_attempts])
|
| 151 |
+
else:
|
| 152 |
+
total_monitoring_seconds = (
|
| 153 |
+
sum(MONITORING_INTERVALS)
|
| 154 |
+
+ (agent_state.monitoring_attempts - len(MONITORING_INTERVALS))
|
| 155 |
+
* MONITORING_INTERVALS[-1]
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# Get the current running command context
|
| 159 |
+
ctx = agent_state.session.get_current_context()
|
| 160 |
+
if ctx is None:
|
| 161 |
+
# No command running, this shouldn't happen but handle it gracefully
|
| 162 |
+
logger.warning("monitoring_node called but no command is running")
|
| 163 |
+
agent_state.monitoring_attempts = 0
|
| 164 |
+
agent_state.is_monitor_mode = False
|
| 165 |
+
|
| 166 |
+
agent_state.intermediate_state.append(
|
| 167 |
+
{
|
| 168 |
+
"node_name": "monitoring",
|
| 169 |
+
"output": "No command running - monitoring stopped",
|
| 170 |
+
}
|
| 171 |
+
)
|
| 172 |
+
return agent_state
|
| 173 |
+
|
| 174 |
+
# Get current output from the running command
|
| 175 |
+
current_output = ctx.get_input_output(max_length=32000)
|
| 176 |
+
|
| 177 |
+
if not agent_state.session.is_running_command():
|
| 178 |
+
# Command has completed while we were waiting
|
| 179 |
+
logger.debug("The monitored command has completed.")
|
| 180 |
+
agent_state.monitoring_attempts = 0
|
| 181 |
+
agent_state.is_monitor_mode = False
|
| 182 |
+
|
| 183 |
+
# Add monitoring end user prompt message
|
| 184 |
+
monitoring_end_user_msg = Message(
|
| 185 |
+
role="user",
|
| 186 |
+
content=PROMPTS.experiment_exec.monitoring_end_user_prompt.render(
|
| 187 |
+
command=ctx.command,
|
| 188 |
+
final_output=current_output,
|
| 189 |
+
error_text=ctx.get_error(),
|
| 190 |
+
total_monitoring_seconds=total_monitoring_seconds,
|
| 191 |
+
),
|
| 192 |
+
agent_sender=AGENT_NAME,
|
| 193 |
+
).with_log()
|
| 194 |
+
agent_state.add_message(monitoring_end_user_msg)
|
| 195 |
+
|
| 196 |
+
return agent_state
|
| 197 |
+
|
| 198 |
+
history = agent_state.patched_history.copy()
|
| 199 |
+
# Prepare monitoring prompt
|
| 200 |
+
monitoring_user_msg = Message(
|
| 201 |
+
role="user",
|
| 202 |
+
content=PROMPTS.experiment_exec.monitoring_user_prompt.render(
|
| 203 |
+
command=ctx.command,
|
| 204 |
+
monitoring_attempts=agent_state.monitoring_attempts,
|
| 205 |
+
current_output=current_output,
|
| 206 |
+
total_monitoring_seconds=total_monitoring_seconds,
|
| 207 |
+
),
|
| 208 |
+
agent_sender=AGENT_NAME,
|
| 209 |
+
)
|
| 210 |
+
history.append(monitoring_user_msg)
|
| 211 |
+
|
| 212 |
+
# Ask monitoring LLM to decide
|
| 213 |
+
msg = ModelRegistry.completion(
|
| 214 |
+
LLM_MONITOR_NAME,
|
| 215 |
+
history,
|
| 216 |
+
system_prompt=(
|
| 217 |
+
Message(
|
| 218 |
+
role="system",
|
| 219 |
+
content=PROMPTS.experiment_exec.monitoring_system_prompt.render(),
|
| 220 |
+
)
|
| 221 |
+
.with_log(cond=constant.LOG_SYSTEM_PROMPT)
|
| 222 |
+
.content
|
| 223 |
+
),
|
| 224 |
+
agent_sender=AGENT_NAME,
|
| 225 |
+
tools=None,
|
| 226 |
+
).with_log()
|
| 227 |
+
|
| 228 |
+
class MonitorDecisionModel(BaseModel):
|
| 229 |
+
action: str
|
| 230 |
+
|
| 231 |
+
r = parse_json_from_llm_response(msg, MonitorDecisionModel) # just to validate JSON format
|
| 232 |
+
|
| 233 |
+
if "wait" in r.action.lower():
|
| 234 |
+
logger.debug("Monitoring decision: continue waiting for the command to complete.")
|
| 235 |
+
agent_state.is_monitor_mode = True
|
| 236 |
+
elif "ctrlc" in r.action.lower():
|
| 237 |
+
logger.debug("Monitoring decision: interrupting the running command.")
|
| 238 |
+
ctx.cancel()
|
| 239 |
+
logger.debug("Monitoring is interrupted. Command is cancelled.")
|
| 240 |
+
monitoring_ctrlc_user_msg = Message(
|
| 241 |
+
role="user",
|
| 242 |
+
content=PROMPTS.experiment_exec.monitoring_ctrlc_user_prompt.render(
|
| 243 |
+
command=ctx.command,
|
| 244 |
+
output_before_interrupt=current_output,
|
| 245 |
+
total_monitoring_seconds=total_monitoring_seconds,
|
| 246 |
+
),
|
| 247 |
+
agent_sender=AGENT_NAME,
|
| 248 |
+
)
|
| 249 |
+
agent_state.add_message(monitoring_ctrlc_user_msg)
|
| 250 |
+
agent_state.is_monitor_mode = False
|
| 251 |
+
else:
|
| 252 |
+
logger.warning(
|
| 253 |
+
f"Unknown monitoring action '{r.action}' received. Continuing to wait by default."
|
| 254 |
+
)
|
| 255 |
+
agent_state.is_monitor_mode = True
|
| 256 |
+
|
| 257 |
+
monitoring_output = f"Monitoring attempt {agent_state.monitoring_attempts}, total time: {total_monitoring_seconds}s"
|
| 258 |
+
if ctx:
|
| 259 |
+
monitoring_output += f"\nCommand: {ctx.command if hasattr(ctx, 'command') else 'Unknown'}\nAction: {r.action}"
|
| 260 |
+
|
| 261 |
+
agent_state.intermediate_state.append(
|
| 262 |
+
{
|
| 263 |
+
"node_name": "monitoring",
|
| 264 |
+
"output": monitoring_output,
|
| 265 |
+
}
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
return agent_state
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def summary_node(agent_state: ExecAgentState) -> ExecAgentState:
|
| 272 |
+
"""Generate a summary of the experiment execution"""
|
| 273 |
+
logger.debug("summary_node of Agent {}", AGENT_NAME)
|
| 274 |
+
agent_state.add_node_history("summary")
|
| 275 |
+
|
| 276 |
+
# Construct a prompt to generate the summary
|
| 277 |
+
summary_prompt = Message(
|
| 278 |
+
role="user",
|
| 279 |
+
content=PROMPTS.experiment_exec.summary_user_prompt.render(),
|
| 280 |
+
agent_sender=AGENT_NAME,
|
| 281 |
+
)
|
| 282 |
+
agent_state.add_message(summary_prompt)
|
| 283 |
+
|
| 284 |
+
# Get summary from LLM
|
| 285 |
+
msg = ModelRegistry.completion(
|
| 286 |
+
LLM_NAME,
|
| 287 |
+
agent_state.patched_history,
|
| 288 |
+
system_prompt=(
|
| 289 |
+
Message(
|
| 290 |
+
role="system",
|
| 291 |
+
content=PROMPTS.experiment_exec.summary_system_prompt.render(),
|
| 292 |
+
)
|
| 293 |
+
.with_log(cond=constant.LOG_SYSTEM_PROMPT)
|
| 294 |
+
.content
|
| 295 |
+
),
|
| 296 |
+
agent_sender=AGENT_NAME,
|
| 297 |
+
tools=None, # No tools needed for summary
|
| 298 |
+
).with_log()
|
| 299 |
+
|
| 300 |
+
# Store the summary text
|
| 301 |
+
agent_state.execution_summary = msg.content or ""
|
| 302 |
+
agent_state.add_message(msg)
|
| 303 |
+
|
| 304 |
+
# Parse JSON summary from the response
|
| 305 |
+
try:
|
| 306 |
+
|
| 307 |
+
class ExecutionSummary(BaseModel):
|
| 308 |
+
status: str
|
| 309 |
+
commands_executed: list[str]
|
| 310 |
+
key_outputs: str
|
| 311 |
+
errors_issues: str
|
| 312 |
+
|
| 313 |
+
summary_dict = parse_json_from_llm_response(msg, ExecutionSummary)
|
| 314 |
+
agent_state.execution_summary_dict = summary_dict.model_dump()
|
| 315 |
+
except Exception as e:
|
| 316 |
+
logger.warning(f"Failed to parse execution summary as JSON: {e}")
|
| 317 |
+
# If JSON parsing fails, store the text response in a basic dict structure
|
| 318 |
+
agent_state.execution_summary_dict = {
|
| 319 |
+
"status": "Unknown",
|
| 320 |
+
"commands_executed": [],
|
| 321 |
+
"key_outputs": agent_state.execution_summary,
|
| 322 |
+
"errors_issues": str(e),
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
summary_output = (
|
| 326 |
+
json.dumps(agent_state.execution_summary_dict, indent=2)
|
| 327 |
+
if agent_state.execution_summary_dict
|
| 328 |
+
else agent_state.execution_summary
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
agent_state.intermediate_state.append(
|
| 332 |
+
{
|
| 333 |
+
"node_name": "summary",
|
| 334 |
+
"output": summary_output,
|
| 335 |
+
}
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
return agent_state
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def tool_calling_node(agent_state: ExecAgentState) -> ExecAgentState:
|
| 342 |
+
"""Execute tool calls from the last message"""
|
| 343 |
+
logger.debug("tool_calling_node of Agent {}", AGENT_NAME)
|
| 344 |
+
agent_state.add_node_history("tool_calling")
|
| 345 |
+
|
| 346 |
+
# Get the last message which contains tool calls
|
| 347 |
+
last_msg = agent_state.patched_history[-1]
|
| 348 |
+
|
| 349 |
+
if not last_msg.tool_calls:
|
| 350 |
+
raise ValueError("No tool calls found in the last message")
|
| 351 |
+
|
| 352 |
+
# Construct tools
|
| 353 |
+
tools: dict[str, Tool] = {}
|
| 354 |
+
for toolset in agent_state.toolsets:
|
| 355 |
+
tools.update(ToolRegistry.get_toolset(toolset))
|
| 356 |
+
for toolset in BUILTIN_TOOLSETS:
|
| 357 |
+
tools.update(ToolRegistry.get_toolset(toolset))
|
| 358 |
+
|
| 359 |
+
function_map = {tool.name: tool.func for tool in tools.values()}
|
| 360 |
+
|
| 361 |
+
# Execute each tool call
|
| 362 |
+
for tool_call in last_msg.tool_calls:
|
| 363 |
+
tool_name = tool_call.function.name
|
| 364 |
+
|
| 365 |
+
# Check if tool exists in function map
|
| 366 |
+
if tool_name not in function_map:
|
| 367 |
+
error_msg = f"Tool {tool_name} not found"
|
| 368 |
+
tool_response = {
|
| 369 |
+
"role": "tool",
|
| 370 |
+
"tool_name": tool_name,
|
| 371 |
+
"tool_call_id": tool_call.id,
|
| 372 |
+
"content": error_msg,
|
| 373 |
+
}
|
| 374 |
+
agent_state.add_message(Message(**tool_response).with_log())
|
| 375 |
+
continue
|
| 376 |
+
|
| 377 |
+
# Parse tool arguments
|
| 378 |
+
try:
|
| 379 |
+
args = json.loads(tool_call.function.arguments)
|
| 380 |
+
assert isinstance(args, dict)
|
| 381 |
+
except json.JSONDecodeError as e:
|
| 382 |
+
error_msg = f"Invalid JSON in tool arguments: {e}"
|
| 383 |
+
tool_response = {
|
| 384 |
+
"role": "tool",
|
| 385 |
+
"tool_name": tool_name,
|
| 386 |
+
"tool_call_id": tool_call.id,
|
| 387 |
+
"content": error_msg,
|
| 388 |
+
}
|
| 389 |
+
agent_state.add_message(Message(**tool_response).with_log())
|
| 390 |
+
continue
|
| 391 |
+
except AssertionError as e:
|
| 392 |
+
error_msg = f"Invalid tool arguments: {e}"
|
| 393 |
+
tool_response = {
|
| 394 |
+
"role": "tool",
|
| 395 |
+
"tool_name": tool_name,
|
| 396 |
+
"tool_call_id": tool_call.id,
|
| 397 |
+
"content": error_msg,
|
| 398 |
+
}
|
| 399 |
+
agent_state.add_message(Message(**tool_response).with_log())
|
| 400 |
+
continue
|
| 401 |
+
|
| 402 |
+
# Execute the tool
|
| 403 |
+
try:
|
| 404 |
+
func = function_map[tool_name]
|
| 405 |
+
|
| 406 |
+
# Check if function expects agent_state parameter
|
| 407 |
+
sig = inspect.signature(func)
|
| 408 |
+
if constant.__AGENT_STATE_NAME__ in sig.parameters:
|
| 409 |
+
args.update({constant.__AGENT_STATE_NAME__: agent_state})
|
| 410 |
+
if constant.__CTX_NAME__ in sig.parameters:
|
| 411 |
+
args.update({constant.__CTX_NAME__: {"current_agent": AGENT_NAME}})
|
| 412 |
+
|
| 413 |
+
# Execute the tool
|
| 414 |
+
result = func(**args)
|
| 415 |
+
|
| 416 |
+
# Create tool response message
|
| 417 |
+
tool_response = {
|
| 418 |
+
"role": "tool",
|
| 419 |
+
"tool_call_id": tool_call.id,
|
| 420 |
+
"tool_name": tool_name,
|
| 421 |
+
"content": str(result), # Ensure result is string
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
# if this is a long-running exec_command, check for monitoring flag
|
| 425 |
+
flag_text = "Try to check the execution status later."
|
| 426 |
+
if tool_name == "exec_command" and flag_text in tool_response["content"]:
|
| 427 |
+
logger.debug("The executed command is still running, entering monitor mode.")
|
| 428 |
+
assert (
|
| 429 |
+
agent_state.session.get_current_context() is not None
|
| 430 |
+
), "Expected a current context when entering monitor mode"
|
| 431 |
+
# The command is still running, go into monitor mode in the next step
|
| 432 |
+
agent_state.is_monitor_mode = True
|
| 433 |
+
|
| 434 |
+
except Exception as e:
|
| 435 |
+
logger.exception(f"Tool {tool_name} execution failed")
|
| 436 |
+
error_msg = f"Tool {tool_name} execution failed: {e}"
|
| 437 |
+
tool_response = {
|
| 438 |
+
"role": "tool",
|
| 439 |
+
"tool_call_id": tool_call.id,
|
| 440 |
+
"tool_name": tool_name,
|
| 441 |
+
"content": error_msg,
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
agent_state.add_message(Message(**tool_response).with_log())
|
| 445 |
+
|
| 446 |
+
# Reset monitoring attempts after tool execution
|
| 447 |
+
agent_state.monitoring_attempts = 0
|
| 448 |
+
|
| 449 |
+
tool_results = []
|
| 450 |
+
for tool_call in last_msg.tool_calls:
|
| 451 |
+
tool_name = tool_call.function.name
|
| 452 |
+
for msg in reversed(agent_state.history):
|
| 453 |
+
if (
|
| 454 |
+
msg.role == "tool"
|
| 455 |
+
and hasattr(msg, "tool_call_id")
|
| 456 |
+
and msg.tool_call_id == tool_call.id
|
| 457 |
+
):
|
| 458 |
+
tool_results.append(
|
| 459 |
+
{
|
| 460 |
+
"tool": tool_name,
|
| 461 |
+
"result": msg.content[:1000] if msg.content else "No result",
|
| 462 |
+
}
|
| 463 |
+
)
|
| 464 |
+
break
|
| 465 |
+
|
| 466 |
+
tool_output_parts = []
|
| 467 |
+
for tr in tool_results:
|
| 468 |
+
tool_output_parts.append(f"Tool: {tr['tool']}\nResult: {tr['result']}")
|
| 469 |
+
|
| 470 |
+
tool_output = "\n\n".join(tool_output_parts) if tool_output_parts else "No tool calls executed"
|
| 471 |
+
|
| 472 |
+
agent_state.intermediate_state.append(
|
| 473 |
+
{
|
| 474 |
+
"node_name": "tool_calling",
|
| 475 |
+
"output": tool_output,
|
| 476 |
+
}
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
return agent_state
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def history_compression_node(agent_state: ExecAgentState) -> ExecAgentState:
|
| 483 |
+
logger.debug("history_compression_node of Agent {}", AGENT_NAME)
|
| 484 |
+
|
| 485 |
+
history_before = len(agent_state.history)
|
| 486 |
+
agent_state = history_compression.invoke_history_compression(agent_state)
|
| 487 |
+
history_after = len(agent_state.history)
|
| 488 |
+
|
| 489 |
+
compression_output = f"Compressed history: {history_before} -> {history_after} messages"
|
| 490 |
+
if agent_state.history_patches:
|
| 491 |
+
last_patch = agent_state.history_patches[-1]
|
| 492 |
+
if last_patch.patched_message and last_patch.patched_message.content:
|
| 493 |
+
compression_output = f"Compressed {last_patch.n_messages} messages into:\n{last_patch.patched_message.content[:500]}"
|
| 494 |
+
|
| 495 |
+
agent_state.intermediate_state.append(
|
| 496 |
+
{
|
| 497 |
+
"node_name": "history_compression",
|
| 498 |
+
"output": compression_output,
|
| 499 |
+
}
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
return agent_state
|
scievo/agents/experiment_agent/exec_subagent/state.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from scievo.core.code_env import LocalEnv
|
| 2 |
+
from scievo.core.exec.manager import SessionManager
|
| 3 |
+
from scievo.core.exec.pty_session import LocalShellSession
|
| 4 |
+
from scievo.core.types import ExecState, HistoryState, ToolsetState
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ExecAgentState(ExecState, ToolsetState, HistoryState):
|
| 8 |
+
"""State of the Experiment Execution Agent.
|
| 9 |
+
|
| 10 |
+
This agent is responsible for executing experiments in local shell sessions.
|
| 11 |
+
It combines:
|
| 12 |
+
- ToolsetState: for managing available toolsets
|
| 13 |
+
- HistoryState: for managing conversation history
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
# The natural language query describing what experiment to run (input)
|
| 17 |
+
user_query: str
|
| 18 |
+
|
| 19 |
+
# Current working directory where experiments are executed (input)
|
| 20 |
+
workspace: LocalEnv
|
| 21 |
+
|
| 22 |
+
# Coding summaries from previous revisions (input, optional)
|
| 23 |
+
# Used to provide context about code changes made in each revision
|
| 24 |
+
coding_summaries: list[str] | None = None
|
| 25 |
+
|
| 26 |
+
# Raw summary of the experiment execution, try to use `execution_summary_dict` instead (output)
|
| 27 |
+
execution_summary: str = ""
|
| 28 |
+
|
| 29 |
+
# Structured summary of the experiment execution (output)
|
| 30 |
+
# Should be:
|
| 31 |
+
# ```json
|
| 32 |
+
# {
|
| 33 |
+
# "status": "Success" or "Failed",
|
| 34 |
+
# "commands_executed": ["command 1", "command 2", ...],
|
| 35 |
+
# "key_outputs": "Highlight any important output or results",
|
| 36 |
+
# "errors_issues": "Note any errors or issues encountered, or 'None' if successful"
|
| 37 |
+
# }
|
| 38 |
+
# ```
|
| 39 |
+
execution_summary_dict: dict = {}
|
| 40 |
+
|
| 41 |
+
# Number of monitoring attempts for the current running command (internal use)
|
| 42 |
+
monitoring_attempts: int = 0
|
| 43 |
+
|
| 44 |
+
# Whether to force monitoring in the next step (internal use)
|
| 45 |
+
is_monitor_mode: bool = False
|
| 46 |
+
|
| 47 |
+
# Intermediate states
|
| 48 |
+
intermediate_state: list[dict] = []
|
| 49 |
+
|
| 50 |
+
def __init__(self, *args, **kwargs):
|
| 51 |
+
super().__init__(*args, **kwargs)
|
| 52 |
+
if self.session_id is None:
|
| 53 |
+
s = LocalShellSession(cwd=self.workspace.working_dir)
|
| 54 |
+
# Store session ID instead of the session instance
|
| 55 |
+
self.session_id = s.session_id
|
| 56 |
+
# add initial toolset
|
| 57 |
+
self.toolsets.append("exec")
|
scievo/agents/experiment_agent/execute.py
ADDED
|
@@ -0,0 +1,513 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Execution nodes for the Experiment Agent.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
from typing import Literal
|
| 8 |
+
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
from loguru import logger
|
| 11 |
+
from pydantic import BaseModel
|
| 12 |
+
|
| 13 |
+
from scievo.core import constant
|
| 14 |
+
from scievo.core.llms import ModelRegistry
|
| 15 |
+
from scievo.core.types import Message
|
| 16 |
+
from scievo.core.utils import parse_json_from_llm_response
|
| 17 |
+
from scievo.prompts import PROMPTS
|
| 18 |
+
|
| 19 |
+
from .exec_subagent import build as exec_build
|
| 20 |
+
from .exec_subagent.state import ExecAgentState
|
| 21 |
+
from .state import ExperimentAgentState
|
| 22 |
+
from .summary_subagent import build as summary_build
|
| 23 |
+
from .summary_subagent.state import SummaryAgentState
|
| 24 |
+
|
| 25 |
+
AGENT_NAME = "experiment_agent"
|
| 26 |
+
LLM_NAME = "experiment_agent"
|
| 27 |
+
|
| 28 |
+
load_dotenv()
|
| 29 |
+
CODING_AGENT_VERSION = os.getenv("CODING_AGENT_VERSION", "v3") # default to Claude (v3)
|
| 30 |
+
_OPENHANDS_ENABLED = os.getenv("SCIEVO_ENABLE_OPENHANDS", "").strip().lower() in {
|
| 31 |
+
"1",
|
| 32 |
+
"true",
|
| 33 |
+
"yes",
|
| 34 |
+
"y",
|
| 35 |
+
}
|
| 36 |
+
match CODING_AGENT_VERSION:
|
| 37 |
+
case "v2":
|
| 38 |
+
if not _OPENHANDS_ENABLED:
|
| 39 |
+
raise RuntimeError(
|
| 40 |
+
"CODING_AGENT_VERSION=v2 requires OpenHands, but OpenHands is disabled.\n"
|
| 41 |
+
"Hint: set `CODING_AGENT_VERSION=v3` to use the Claude coding agent, or enable OpenHands with "
|
| 42 |
+
"`SCIEVO_ENABLE_OPENHANDS=1`."
|
| 43 |
+
)
|
| 44 |
+
from .coding_subagent_v2 import build as coding_build
|
| 45 |
+
from .coding_subagent_v2.state import CodingAgentState
|
| 46 |
+
case "v3":
|
| 47 |
+
from .coding_subagent_v3_claude import build as coding_build
|
| 48 |
+
from .coding_subagent_v3_claude.state import CodingAgentState
|
| 49 |
+
case _:
|
| 50 |
+
raise ValueError(f"Unsupported CODING_AGENT_VERSION: {CODING_AGENT_VERSION}")
|
| 51 |
+
|
| 52 |
+
# Compile sub-agent graphs as global variables
|
| 53 |
+
coding_graph = coding_build().compile()
|
| 54 |
+
exec_graph = exec_build().compile()
|
| 55 |
+
summary_graph = summary_build().compile()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def init_node(agent_state: ExperimentAgentState) -> ExperimentAgentState:
|
| 59 |
+
"""Initialize the experiment agent.
|
| 60 |
+
|
| 61 |
+
Prepares the initial context message with data summary and user query.
|
| 62 |
+
The repo_source will be passed to the coding subagent which handles
|
| 63 |
+
git cloning and workspace setup.
|
| 64 |
+
"""
|
| 65 |
+
logger.info("Initializing Experiment Agent")
|
| 66 |
+
agent_state.current_phase = "init"
|
| 67 |
+
|
| 68 |
+
# Add initial message to history
|
| 69 |
+
init_msg = Message(
|
| 70 |
+
role="user",
|
| 71 |
+
content=PROMPTS.experiment_agent.init_prompt.render(
|
| 72 |
+
data_summary=agent_state.data_summary,
|
| 73 |
+
user_query=agent_state.user_query,
|
| 74 |
+
repo_source=agent_state.repo_source or "Not specified",
|
| 75 |
+
),
|
| 76 |
+
agent_sender=AGENT_NAME,
|
| 77 |
+
).with_log()
|
| 78 |
+
agent_state.add_message(init_msg)
|
| 79 |
+
|
| 80 |
+
agent_state.intermediate_state.append(
|
| 81 |
+
{
|
| 82 |
+
"node_name": "init",
|
| 83 |
+
"output": init_msg.content,
|
| 84 |
+
}
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
return agent_state
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def run_coding_subagent(agent_state: ExperimentAgentState) -> ExperimentAgentState:
|
| 91 |
+
"""Run the Coding Subagent (stateless invocation).
|
| 92 |
+
|
| 93 |
+
The coding subagent receives repo_source and handles git cloning
|
| 94 |
+
and workspace setup internally. By default this uses the Claude Agent SDK/Claude Code path (v3).
|
| 95 |
+
"""
|
| 96 |
+
logger.info(f"Running Coding Subagent (revision {agent_state.current_revision})")
|
| 97 |
+
agent_state.current_phase = "coding"
|
| 98 |
+
|
| 99 |
+
# Build revision feedback context if available
|
| 100 |
+
revision_feedback_list = []
|
| 101 |
+
if agent_state.revision_summaries:
|
| 102 |
+
for i, summary in enumerate(agent_state.revision_summaries):
|
| 103 |
+
revision_feedback_list.append({"revision_number": i + 1, "summary": summary})
|
| 104 |
+
|
| 105 |
+
# Collect all previous coding summaries
|
| 106 |
+
previous_coding_summaries = []
|
| 107 |
+
for i, loop in enumerate(agent_state.loop_results):
|
| 108 |
+
prev_summary = loop.get("coding_summary", "")
|
| 109 |
+
if prev_summary:
|
| 110 |
+
previous_coding_summaries.append({"revision": i, "summary": prev_summary})
|
| 111 |
+
|
| 112 |
+
# Also include accumulated analysis
|
| 113 |
+
revision_analysis_text = (
|
| 114 |
+
agent_state.revision_analysis
|
| 115 |
+
if agent_state.revision_analysis
|
| 116 |
+
else "No previous analysis yet."
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Build user query using prompt template
|
| 120 |
+
coding_query = PROMPTS.experiment_agent.coding_subagent_query_prompt.render(
|
| 121 |
+
user_query=agent_state.user_query,
|
| 122 |
+
repo_source=agent_state.repo_source or "Not specified",
|
| 123 |
+
# TODO: limit to last revision and coding summary for now
|
| 124 |
+
revision_feedback_list=revision_feedback_list[-1:],
|
| 125 |
+
previous_coding_summaries=previous_coding_summaries[-1:],
|
| 126 |
+
revision_analysis=revision_analysis_text,
|
| 127 |
+
current_revision=agent_state.current_revision,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
coding_state = CodingAgentState(
|
| 131 |
+
data_summary=agent_state.data_summary, # Keep data_summary separate
|
| 132 |
+
user_query=coding_query,
|
| 133 |
+
workspace=agent_state.workspace,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# Invoke coding subagent (stateless call)
|
| 137 |
+
result_state = coding_graph.invoke(coding_state)
|
| 138 |
+
|
| 139 |
+
# Extract only needed data from result - don't store full state (graph.invoke returns dict)
|
| 140 |
+
agent_state.history = result_state["history"] # Merge back history
|
| 141 |
+
|
| 142 |
+
# Store coding summary for this loop (for later analysis)
|
| 143 |
+
# Use .get() for safe access in case output_summary is not set
|
| 144 |
+
coding_summary = result_state.get("output_summary") or "No summary available"
|
| 145 |
+
|
| 146 |
+
if (
|
| 147 |
+
not agent_state.loop_results
|
| 148 |
+
or agent_state.loop_results[-1].get("revision") != agent_state.current_revision
|
| 149 |
+
):
|
| 150 |
+
agent_state.loop_results.append(
|
| 151 |
+
{
|
| 152 |
+
"revision": agent_state.current_revision,
|
| 153 |
+
"coding_summary": coding_summary,
|
| 154 |
+
}
|
| 155 |
+
)
|
| 156 |
+
else:
|
| 157 |
+
agent_state.loop_results[-1]["coding_summary"] = coding_summary
|
| 158 |
+
|
| 159 |
+
coding_output = coding_summary
|
| 160 |
+
if isinstance(result_state, dict) and "intermediate_state" in result_state:
|
| 161 |
+
coding_intermediate = result_state.get("intermediate_state", [])
|
| 162 |
+
if coding_intermediate:
|
| 163 |
+
coding_output = (
|
| 164 |
+
f"{coding_summary}\n\n[Coding Subagent Intermediate States]\n"
|
| 165 |
+
+ "\n".join(
|
| 166 |
+
[
|
| 167 |
+
f"{item.get('node_name', 'unknown')}: {item.get('output', '')[:200]}"
|
| 168 |
+
for item in coding_intermediate
|
| 169 |
+
]
|
| 170 |
+
)
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
agent_state.intermediate_state.append(
|
| 174 |
+
{
|
| 175 |
+
"node_name": "run_coding_subagent",
|
| 176 |
+
"output": coding_output,
|
| 177 |
+
}
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
return agent_state
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def run_exec_subagent(agent_state: ExperimentAgentState) -> ExperimentAgentState:
|
| 184 |
+
"""Run the Exec Subagent (stateless invocation).
|
| 185 |
+
|
| 186 |
+
The workspace path should be extracted from the conversation history
|
| 187 |
+
left by the coding subagent.
|
| 188 |
+
"""
|
| 189 |
+
logger.info(f"Running Exec Subagent (revision {agent_state.current_revision})")
|
| 190 |
+
agent_state.current_phase = "exec"
|
| 191 |
+
|
| 192 |
+
# Collect all coding summaries from loop results
|
| 193 |
+
coding_summaries = [
|
| 194 |
+
loop.get("coding_summary", "")
|
| 195 |
+
for loop in agent_state.loop_results
|
| 196 |
+
if loop.get("coding_summary")
|
| 197 |
+
]
|
| 198 |
+
|
| 199 |
+
exec_state = ExecAgentState(
|
| 200 |
+
user_query="Run the modified code/experiments and verify the output.",
|
| 201 |
+
workspace=agent_state.workspace,
|
| 202 |
+
coding_summaries=coding_summaries if coding_summaries else None,
|
| 203 |
+
toolsets=["exec"],
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Invoke exec subagent (stateless call)
|
| 207 |
+
result_state = exec_graph.invoke(exec_state)
|
| 208 |
+
|
| 209 |
+
# Extract only needed data from result - don't store full state (graph.invoke returns dict)
|
| 210 |
+
agent_state.history = result_state["history"]
|
| 211 |
+
agent_state.all_execution_results.append(result_state["execution_summary_dict"])
|
| 212 |
+
|
| 213 |
+
# Store exec results for this loop
|
| 214 |
+
if (
|
| 215 |
+
agent_state.loop_results
|
| 216 |
+
and agent_state.loop_results[-1].get("revision") == agent_state.current_revision
|
| 217 |
+
):
|
| 218 |
+
agent_state.loop_results[-1]["exec_result"] = result_state["execution_summary_dict"]
|
| 219 |
+
|
| 220 |
+
exec_output = json.dumps(result_state.get("execution_summary_dict", {}), indent=2)
|
| 221 |
+
if isinstance(result_state, dict) and "intermediate_state" in result_state:
|
| 222 |
+
exec_intermediate = result_state.get("intermediate_state", [])
|
| 223 |
+
if exec_intermediate:
|
| 224 |
+
exec_output = f"{exec_output}\n\n[Exec Subagent Intermediate States]\n" + "\n".join(
|
| 225 |
+
[
|
| 226 |
+
f"{item.get('node_name', 'unknown')}: {item.get('output', '')[:200]}"
|
| 227 |
+
for item in exec_intermediate
|
| 228 |
+
]
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
agent_state.intermediate_state.append(
|
| 232 |
+
{
|
| 233 |
+
"node_name": "run_exec_subagent",
|
| 234 |
+
"output": exec_output,
|
| 235 |
+
}
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
return agent_state
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def run_summary_subagent(agent_state: ExperimentAgentState) -> ExperimentAgentState:
|
| 242 |
+
"""Run the Summary Subagent (stateless invocation).
|
| 243 |
+
|
| 244 |
+
The workspace path should be extracted from the conversation history.
|
| 245 |
+
"""
|
| 246 |
+
logger.info(f"Running Summary Subagent (revision {agent_state.current_revision})")
|
| 247 |
+
agent_state.current_phase = "summary"
|
| 248 |
+
|
| 249 |
+
summary_state = SummaryAgentState(
|
| 250 |
+
workspace=agent_state.workspace,
|
| 251 |
+
history=agent_state.history.copy(),
|
| 252 |
+
output_path=None, # Or specify a path for saving
|
| 253 |
+
toolsets=["fs"],
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
# Invoke summary subagent (stateless call)
|
| 257 |
+
result_state = summary_graph.invoke(summary_state)
|
| 258 |
+
|
| 259 |
+
# Extract only needed data from result - don't store full state (graph.invoke returns dict)
|
| 260 |
+
agent_state.history = result_state["history"]
|
| 261 |
+
agent_state.revision_summaries.append(result_state["summary_text"])
|
| 262 |
+
|
| 263 |
+
# Store summary for this loop
|
| 264 |
+
if (
|
| 265 |
+
agent_state.loop_results
|
| 266 |
+
and agent_state.loop_results[-1].get("revision") == agent_state.current_revision
|
| 267 |
+
):
|
| 268 |
+
agent_state.loop_results[-1]["summary"] = result_state["summary_text"]
|
| 269 |
+
|
| 270 |
+
summary_output = result_state.get("summary_text", "No summary generated")
|
| 271 |
+
if isinstance(result_state, dict) and "intermediate_state" in result_state:
|
| 272 |
+
summary_intermediate = result_state.get("intermediate_state", [])
|
| 273 |
+
if summary_intermediate:
|
| 274 |
+
summary_output = (
|
| 275 |
+
f"{summary_output}\n\n[Summary Subagent Intermediate States]\n"
|
| 276 |
+
+ "\n".join(
|
| 277 |
+
[
|
| 278 |
+
f"{item.get('node_name', 'unknown')}: {item.get('output', '')[:200]}"
|
| 279 |
+
for item in summary_intermediate
|
| 280 |
+
]
|
| 281 |
+
)
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
agent_state.intermediate_state.append(
|
| 285 |
+
{
|
| 286 |
+
"node_name": "run_summary_subagent",
|
| 287 |
+
"output": summary_output,
|
| 288 |
+
}
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
return agent_state
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def analysis_node(agent_state: ExperimentAgentState) -> ExperimentAgentState:
|
| 295 |
+
"""Analyze the current loop results and generate insights.
|
| 296 |
+
|
| 297 |
+
This node uses an LLM to analyze what went wrong, what succeeded,
|
| 298 |
+
and what needs improvement. The analysis is accumulated across revisions.
|
| 299 |
+
"""
|
| 300 |
+
logger.info(f"Analyzing loop results for revision {agent_state.current_revision}")
|
| 301 |
+
agent_state.current_phase = "analysis"
|
| 302 |
+
|
| 303 |
+
# Get current loop results
|
| 304 |
+
current_loop = agent_state.loop_results[-1] if agent_state.loop_results else {}
|
| 305 |
+
|
| 306 |
+
# Use LLM to analyze the loop
|
| 307 |
+
analysis_prompt = PROMPTS.experiment_agent.analysis_prompt.render(
|
| 308 |
+
revision_number=agent_state.current_revision,
|
| 309 |
+
coding_summary=current_loop.get("coding_summary", "No coding summary available"),
|
| 310 |
+
exec_result=json.dumps(current_loop.get("exec_result", {}), indent=2),
|
| 311 |
+
summary=current_loop.get("summary", "No summary available"),
|
| 312 |
+
previous_analysis=agent_state.revision_analysis or "No previous analysis.",
|
| 313 |
+
user_query=agent_state.user_query,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
response = ModelRegistry.completion(
|
| 317 |
+
LLM_NAME,
|
| 318 |
+
[Message(role="user", content=analysis_prompt)],
|
| 319 |
+
system_prompt=(
|
| 320 |
+
Message(
|
| 321 |
+
role="system",
|
| 322 |
+
content=PROMPTS.experiment_agent.analysis_system_prompt.render(),
|
| 323 |
+
)
|
| 324 |
+
.with_log(cond=constant.LOG_SYSTEM_PROMPT)
|
| 325 |
+
.content
|
| 326 |
+
),
|
| 327 |
+
agent_sender=AGENT_NAME,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
# Accumulate analysis
|
| 331 |
+
analysis_text = response.content
|
| 332 |
+
if agent_state.revision_analysis:
|
| 333 |
+
agent_state.revision_analysis += (
|
| 334 |
+
f"\n\n---\n\n## Revision {agent_state.current_revision} Analysis\n{analysis_text}"
|
| 335 |
+
)
|
| 336 |
+
else:
|
| 337 |
+
agent_state.revision_analysis = (
|
| 338 |
+
f"## Revision {agent_state.current_revision} Analysis\n{analysis_text}"
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
# Save analysis result to file
|
| 342 |
+
try:
|
| 343 |
+
import os
|
| 344 |
+
|
| 345 |
+
analysis_dir = os.path.join(agent_state.workspace.working_dir, "experiment_analyses")
|
| 346 |
+
os.makedirs(analysis_dir, exist_ok=True)
|
| 347 |
+
|
| 348 |
+
analysis_file = os.path.join(
|
| 349 |
+
analysis_dir, f"revision_{agent_state.current_revision}_analysis.md"
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
with open(analysis_file, "w", encoding="utf-8") as f:
|
| 353 |
+
f.write(f"# Revision {agent_state.current_revision} Analysis\n\n")
|
| 354 |
+
f.write(analysis_text)
|
| 355 |
+
|
| 356 |
+
logger.info(f"Analysis saved to {analysis_file}")
|
| 357 |
+
except Exception as e:
|
| 358 |
+
logger.warning(f"Failed to save analysis to file: {e}")
|
| 359 |
+
|
| 360 |
+
logger.debug(f"Analysis for revision {agent_state.current_revision} completed")
|
| 361 |
+
|
| 362 |
+
agent_state.intermediate_state.append(
|
| 363 |
+
{
|
| 364 |
+
"node_name": "analysis",
|
| 365 |
+
"output": analysis_text,
|
| 366 |
+
}
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
return agent_state
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def revision_judge_node(agent_state: ExperimentAgentState) -> ExperimentAgentState:
|
| 373 |
+
"""Judge whether a revision is needed based on the summary.
|
| 374 |
+
|
| 375 |
+
This node analyzes the experiment summary and decides:
|
| 376 |
+
1. COMPLETE - Experiment succeeded, no more revisions needed
|
| 377 |
+
2. CONTINUE - Issues found, need another revision loop
|
| 378 |
+
3. COMPLETE (max_revisions) - Hit max revisions limit
|
| 379 |
+
"""
|
| 380 |
+
logger.info("Revision Judge evaluating results")
|
| 381 |
+
agent_state.current_phase = "judge"
|
| 382 |
+
|
| 383 |
+
# Check max revisions
|
| 384 |
+
if agent_state.current_revision >= agent_state.max_revisions - 1:
|
| 385 |
+
logger.warning("Max revisions reached")
|
| 386 |
+
agent_state.final_status = "max_revisions_reached"
|
| 387 |
+
judge_output = "Max revisions reached - stopping"
|
| 388 |
+
agent_state.intermediate_state.append(
|
| 389 |
+
{
|
| 390 |
+
"node_name": "revision_judge",
|
| 391 |
+
"output": judge_output,
|
| 392 |
+
}
|
| 393 |
+
)
|
| 394 |
+
return agent_state
|
| 395 |
+
|
| 396 |
+
# Get the latest summary
|
| 397 |
+
latest_summary = (
|
| 398 |
+
agent_state.revision_summaries[-1]
|
| 399 |
+
if agent_state.revision_summaries
|
| 400 |
+
else "No summary available"
|
| 401 |
+
)
|
| 402 |
+
exec_result = agent_state.all_execution_results[-1] if agent_state.all_execution_results else {}
|
| 403 |
+
|
| 404 |
+
# Use LLM to judge whether revision is needed (with accumulated analysis)
|
| 405 |
+
judge_prompt = PROMPTS.experiment_agent.judge_prompt.render(
|
| 406 |
+
latest_summary=latest_summary,
|
| 407 |
+
exec_result=json.dumps(exec_result, indent=2),
|
| 408 |
+
user_query=agent_state.user_query,
|
| 409 |
+
revision_analysis=agent_state.revision_analysis or "No analysis available.",
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
response = ModelRegistry.completion(
|
| 413 |
+
LLM_NAME,
|
| 414 |
+
[Message(role="user", content=judge_prompt)],
|
| 415 |
+
system_prompt=(
|
| 416 |
+
Message(
|
| 417 |
+
role="system",
|
| 418 |
+
content=PROMPTS.experiment_agent.judge_system_prompt.render(),
|
| 419 |
+
)
|
| 420 |
+
.with_log(cond=constant.LOG_SYSTEM_PROMPT)
|
| 421 |
+
.content
|
| 422 |
+
),
|
| 423 |
+
agent_sender=AGENT_NAME,
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
class JudgeDecisionModel(BaseModel):
|
| 427 |
+
"""Model for revision judge decision"""
|
| 428 |
+
|
| 429 |
+
decision: str # "COMPLETE" or "CONTINUE"
|
| 430 |
+
reason: str
|
| 431 |
+
issues_to_fix: list[str] = []
|
| 432 |
+
|
| 433 |
+
# Parse the response using utility function
|
| 434 |
+
judge_output = response.content
|
| 435 |
+
try:
|
| 436 |
+
result = parse_json_from_llm_response(response, JudgeDecisionModel)
|
| 437 |
+
|
| 438 |
+
if result.decision == "COMPLETE":
|
| 439 |
+
logger.info("Revision judge decided: COMPLETE")
|
| 440 |
+
agent_state.final_status = "success"
|
| 441 |
+
judge_output = f"Decision: COMPLETE\nReason: {result.reason}"
|
| 442 |
+
else:
|
| 443 |
+
logger.info(f"Revision judge decided: CONTINUE - {result.reason}")
|
| 444 |
+
# Prepare for next revision
|
| 445 |
+
agent_state.current_revision += 1
|
| 446 |
+
# Add feedback to history for next coding iteration
|
| 447 |
+
feedback_msg = Message(
|
| 448 |
+
role="user",
|
| 449 |
+
content=PROMPTS.experiment_agent.revision_feedback_prompt.render(
|
| 450 |
+
attempt_number=agent_state.current_revision + 1,
|
| 451 |
+
reason=result.reason,
|
| 452 |
+
issues_to_fix=result.issues_to_fix,
|
| 453 |
+
),
|
| 454 |
+
agent_sender=AGENT_NAME,
|
| 455 |
+
)
|
| 456 |
+
agent_state.add_message(feedback_msg)
|
| 457 |
+
judge_output = f"Decision: CONTINUE\nReason: {result.reason}\nIssues to fix: {result.issues_to_fix}"
|
| 458 |
+
except Exception as e:
|
| 459 |
+
logger.error(f"Error parsing judge response: {e}")
|
| 460 |
+
agent_state.final_status = "success"
|
| 461 |
+
judge_output = f"Error parsing judge response: {e}"
|
| 462 |
+
|
| 463 |
+
agent_state.intermediate_state.append(
|
| 464 |
+
{
|
| 465 |
+
"node_name": "revision_judge",
|
| 466 |
+
"output": judge_output,
|
| 467 |
+
}
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
return agent_state
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
def should_continue_revision(
|
| 474 |
+
agent_state: ExperimentAgentState,
|
| 475 |
+
) -> Literal["continue", "complete"]:
|
| 476 |
+
"""Conditional edge function to determine next step after revision judge."""
|
| 477 |
+
if agent_state.final_status is None:
|
| 478 |
+
return "continue"
|
| 479 |
+
return "complete"
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def finalize_node(agent_state: ExperimentAgentState) -> ExperimentAgentState:
|
| 483 |
+
"""Finalize the experiment and prepare output."""
|
| 484 |
+
logger.info("Finalizing Experiment Agent")
|
| 485 |
+
agent_state.current_phase = "complete"
|
| 486 |
+
|
| 487 |
+
# Compile final summary
|
| 488 |
+
exec_results_text = json.dumps(agent_state.all_execution_results, indent=2)
|
| 489 |
+
|
| 490 |
+
agent_state.final_summary = f"""# Experiment Complete
|
| 491 |
+
|
| 492 |
+
## Status: {agent_state.final_status}
|
| 493 |
+
|
| 494 |
+
## Total Revisions: {agent_state.current_revision + 1}
|
| 495 |
+
|
| 496 |
+
## Final Summary
|
| 497 |
+
{agent_state.revision_summaries[-1] if agent_state.revision_summaries else 'No summary generated'}
|
| 498 |
+
|
| 499 |
+
## Accumulated Analysis
|
| 500 |
+
{agent_state.revision_analysis or 'No analysis available'}
|
| 501 |
+
|
| 502 |
+
## All Execution Results
|
| 503 |
+
{exec_results_text}
|
| 504 |
+
"""
|
| 505 |
+
|
| 506 |
+
agent_state.intermediate_state.append(
|
| 507 |
+
{
|
| 508 |
+
"node_name": "finalize",
|
| 509 |
+
"output": agent_state.final_summary,
|
| 510 |
+
}
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
return agent_state
|