Spaces:
Running
Running
Commit ·
f4dca43
1
Parent(s): c02193c
Deploy 2026-02-09 15:08:25
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- pyproject.toml +31 -4
- src/flow/__init__.py +1 -1
- src/flow/cli/app.py +22 -12
- src/flow/cli/hf_import.py +159 -0
- src/flow/cli/optimize.py +193 -112
- src/flow/cli/repl.py +3 -1
- src/flow/experiments/__init__.py +79 -15
- src/flow/experiments/ablation.py +6 -2
- src/flow/experiments/agent_api.py +305 -0
- src/flow/experiments/evaluators/heuristic.py +1 -1
- src/flow/experiments/evaluators/llm.py +138 -37
- src/flow/experiments/expansion.py +250 -0
- src/flow/experiments/gaia_converter.py +216 -0
- src/flow/experiments/hf_datasets.py +354 -0
- src/flow/experiments/models.py +635 -83
- src/flow/experiments/optimizer.py +270 -12
- src/flow/experiments/presets.py +123 -0
- src/flow/experiments/results.py +118 -0
- src/flow/experiments/runner.py +63 -50
- src/flow/experiments/strategies/__init__.py +103 -0
- src/flow/experiments/strategies/llm_rewriter.py +357 -0
- src/flow/experiments/strategies/tool_selector.py +426 -0
- src/flow/experiments/trace_collector.py +92 -49
- src/flow/experiments/types.py +7 -2
- src/flow/harness/__init__.py +5 -4
- src/flow/harness/base.py +19 -7
- src/flow/harness/compaction/__init__.py +38 -0
- src/flow/harness/compaction/strategies.py +502 -0
- src/flow/harness/compaction/tokenizer.py +131 -0
- src/flow/harness/langgraph/__init__.py +7 -1
- src/flow/harness/langgraph/compaction.py +187 -19
- src/flow/harness/langgraph/harness.py +11 -4
- src/flow/harness/maf/__init__.py +1 -1
- src/flow/harness/maf/agent.py +7 -3
- src/flow/harness/maf/harness.py +11 -5
- src/flow/harness/maf/message_store.py +247 -69
- src/flow/harness/maf/tools/__init__.py +55 -20
- src/flow/harness/maf/wrappers.py +1 -1
- src/flow/harness/miniagent/__init__.py +19 -19
- src/flow/harness/miniagent/agent.py +13 -12
- src/flow/harness/miniagent/client.py +2 -2
- src/flow/harness/miniagent/context.py +5 -4
- src/flow/harness/miniagent/harness.py +54 -30
- src/flow/harness/miniagent/hooks.py +2 -1
- src/flow/harness/miniagent/instructions.py +23 -64
- src/flow/harness/miniagent/otel.py +8 -8
- src/flow/harness/miniagent/tool.py +8 -7
- src/flow/harness/miniagent/tools/__init__.py +27 -28
- src/flow/harness/miniagent/workspace.py +23 -20
- src/flow/harness/registry.py +43 -6
pyproject.toml
CHANGED
|
@@ -38,21 +38,46 @@ dependencies = [
|
|
| 38 |
"uvicorn>=0.27.0",
|
| 39 |
"sqlmodel>=0.0.14",
|
| 40 |
"aiosqlite>=0.19.0",
|
|
|
|
|
|
|
|
|
|
| 41 |
"tiktoken>=0.12.0",
|
| 42 |
]
|
| 43 |
|
| 44 |
[project.optional-dependencies]
|
| 45 |
# Optional features
|
| 46 |
research = ["beautifulsoup4>=4.12.0", "html2text>=2024.2.26"]
|
|
|
|
| 47 |
langgraph = [
|
| 48 |
"langgraph>=0.2.0",
|
| 49 |
"langchain-core>=0.3.0",
|
| 50 |
"langchain-openai>=0.2.0",
|
| 51 |
]
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
# Bundles
|
| 55 |
-
all = ["flow-agent[research,langgraph,optimizer]"]
|
| 56 |
dev = [
|
| 57 |
"pytest>=8.0.0",
|
| 58 |
"pytest-asyncio>=0.23.0",
|
|
@@ -85,7 +110,7 @@ packages = ["src/flow"]
|
|
| 85 |
|
| 86 |
[tool.pyright]
|
| 87 |
include = ["src"]
|
| 88 |
-
exclude = ["**/tests/**", "**/.venv/**"]
|
| 89 |
typeCheckingMode = "strict"
|
| 90 |
pythonVersion = "3.10"
|
| 91 |
reportMissingTypeStubs = false
|
|
@@ -108,6 +133,7 @@ show_error_codes = true
|
|
| 108 |
warn_unused_ignores = false
|
| 109 |
disallow_incomplete_defs = true
|
| 110 |
disallow_untyped_decorators = true
|
|
|
|
| 111 |
|
| 112 |
# ============================================================================
|
| 113 |
# Linting - Ruff
|
|
@@ -119,7 +145,7 @@ target-version = "py310"
|
|
| 119 |
src = ["src"]
|
| 120 |
fix = true
|
| 121 |
include = ["*.py", "*.pyi", "**/pyproject.toml"]
|
| 122 |
-
exclude = ["docs/*"]
|
| 123 |
|
| 124 |
[tool.ruff.lint]
|
| 125 |
select = [
|
|
@@ -140,6 +166,7 @@ ignore = [
|
|
| 140 |
"D107", # allow missing docstring in __init__
|
| 141 |
"ANN401", # allow Any type (needed for generic tool/event handling)
|
| 142 |
"S101", # allow assert statements (used in tests)
|
|
|
|
| 143 |
]
|
| 144 |
|
| 145 |
[tool.ruff.lint.per-file-ignores]
|
|
|
|
| 38 |
"uvicorn>=0.27.0",
|
| 39 |
"sqlmodel>=0.0.14",
|
| 40 |
"aiosqlite>=0.19.0",
|
| 41 |
+
"greenlet>=3.0.0", # Required for SQLAlchemy async support
|
| 42 |
+
# Logging dependencies
|
| 43 |
+
"loguru>=0.7.3",
|
| 44 |
"tiktoken>=0.12.0",
|
| 45 |
]
|
| 46 |
|
| 47 |
[project.optional-dependencies]
|
| 48 |
# Optional features
|
| 49 |
research = ["beautifulsoup4>=4.12.0", "html2text>=2024.2.26"]
|
| 50 |
+
|
| 51 |
langgraph = [
|
| 52 |
"langgraph>=0.2.0",
|
| 53 |
"langchain-core>=0.3.0",
|
| 54 |
"langchain-openai>=0.2.0",
|
| 55 |
]
|
| 56 |
+
|
| 57 |
+
smolagents = [
|
| 58 |
+
"smolagents[toolkit]>=1.24.0",
|
| 59 |
+
"pdfminer.six>=20240706",
|
| 60 |
+
"cffi>=1.16.0",
|
| 61 |
+
"cryptography>=42.0.0",
|
| 62 |
+
"Pillow>=11.0.0",
|
| 63 |
+
"puremagic>=1.28",
|
| 64 |
+
"pypdf>=5.1.0",
|
| 65 |
+
"youtube_transcript_api>=0.6.2",
|
| 66 |
+
"python_pptx>=1.0.2",
|
| 67 |
+
"serpapi>=0.1.5",
|
| 68 |
+
"mammoth>=1.8.0",
|
| 69 |
+
"markdownify>=0.13.1",
|
| 70 |
+
"pandas>=2.2.3",
|
| 71 |
+
"openpyxl>=3.1.0",
|
| 72 |
+
"wikipedia-api>=0.9.0",
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
# Bundles
|
| 76 |
+
optimizer = ["gepa>=0.0.27", "litellm>=1.0.0"]
|
| 77 |
+
hf-datasets = ["datasets>=2.0.0"]
|
| 78 |
|
| 79 |
# Bundles
|
| 80 |
+
all = ["flow-agent[research,langgraph,optimizer,smolagents,hf-datasets]"]
|
| 81 |
dev = [
|
| 82 |
"pytest>=8.0.0",
|
| 83 |
"pytest-asyncio>=0.23.0",
|
|
|
|
| 110 |
|
| 111 |
[tool.pyright]
|
| 112 |
include = ["src"]
|
| 113 |
+
exclude = ["**/tests/**", "**/.venv/**", "**/skills/**"]
|
| 114 |
typeCheckingMode = "strict"
|
| 115 |
pythonVersion = "3.10"
|
| 116 |
reportMissingTypeStubs = false
|
|
|
|
| 133 |
warn_unused_ignores = false
|
| 134 |
disallow_incomplete_defs = true
|
| 135 |
disallow_untyped_decorators = true
|
| 136 |
+
exclude = ["src/flow/skills/"]
|
| 137 |
|
| 138 |
# ============================================================================
|
| 139 |
# Linting - Ruff
|
|
|
|
| 145 |
src = ["src"]
|
| 146 |
fix = true
|
| 147 |
include = ["*.py", "*.pyi", "**/pyproject.toml"]
|
| 148 |
+
exclude = ["docs/*", "src/flow/skills/*"]
|
| 149 |
|
| 150 |
[tool.ruff.lint]
|
| 151 |
select = [
|
|
|
|
| 166 |
"D107", # allow missing docstring in __init__
|
| 167 |
"ANN401", # allow Any type (needed for generic tool/event handling)
|
| 168 |
"S101", # allow assert statements (used in tests)
|
| 169 |
+
"B008", # allow Depends() in function defaults (FastAPI pattern)
|
| 170 |
]
|
| 171 |
|
| 172 |
[tool.ruff.lint.per-file-ignores]
|
src/flow/__init__.py
CHANGED
|
@@ -21,6 +21,6 @@ __version__ = "0.1.0"
|
|
| 21 |
|
| 22 |
__all__ = [
|
| 23 |
"MAFHarness",
|
| 24 |
-
"create_agent",
|
| 25 |
"__version__",
|
|
|
|
| 26 |
]
|
|
|
|
| 21 |
|
| 22 |
__all__ = [
|
| 23 |
"MAFHarness",
|
|
|
|
| 24 |
"__version__",
|
| 25 |
+
"create_agent",
|
| 26 |
]
|
src/flow/cli/app.py
CHANGED
|
@@ -65,7 +65,7 @@ def run(
|
|
| 65 |
framework: Annotated[
|
| 66 |
str,
|
| 67 |
typer.Option("--framework", "-f", help="Agent framework: 'maf', 'miniagent', or 'langgraph'"),
|
| 68 |
-
] = "
|
| 69 |
interactive: Annotated[
|
| 70 |
bool,
|
| 71 |
typer.Option("--interactive/--no-interactive", "-i", help="Interactive mode"),
|
|
@@ -110,26 +110,34 @@ async def _run_single_task(
|
|
| 110 |
memory_path: Path,
|
| 111 |
task: str,
|
| 112 |
config_path: Path | None = None,
|
| 113 |
-
framework: str = "
|
| 114 |
) -> None:
|
| 115 |
"""Run a single task and print the result."""
|
|
|
|
|
|
|
|
|
|
| 116 |
from flow.cli.output import print_event
|
| 117 |
from flow.harness.base import EventType
|
| 118 |
|
| 119 |
-
|
| 120 |
-
import flow.harness.maf # noqa: F401
|
| 121 |
-
import flow.harness.miniagent # noqa: F401 # pyright: ignore[reportUnusedImport]
|
| 122 |
|
| 123 |
if framework == "langgraph":
|
| 124 |
try:
|
| 125 |
-
import flow.harness.langgraph
|
|
|
|
|
|
|
| 126 |
except ImportError:
|
| 127 |
console.print("[red]Error:[/] LangGraph dependencies not installed.")
|
| 128 |
console.print("[dim]Install with: pip install flow-agent[langgraph][/]")
|
| 129 |
raise typer.Exit(1)
|
| 130 |
|
|
|
|
|
|
|
|
|
|
| 131 |
from flow.harness import create_harness
|
| 132 |
-
|
|
|
|
|
|
|
| 133 |
|
| 134 |
if config_path:
|
| 135 |
# Load agent config from optimization result
|
|
@@ -137,19 +145,19 @@ async def _run_single_task(
|
|
| 137 |
|
| 138 |
agent_config = load_agent(config_path)
|
| 139 |
# Override framework if specified
|
| 140 |
-
if framework != "
|
| 141 |
agent_config = Agent(
|
| 142 |
name=agent_config.name,
|
| 143 |
-
framework=
|
| 144 |
tools=agent_config.tools,
|
| 145 |
-
|
| 146 |
instructions=agent_config.instructions,
|
| 147 |
compaction=agent_config.compaction,
|
| 148 |
)
|
| 149 |
console.print(f"[dim]Using agent config: {agent_config.name} ({framework})[/]")
|
| 150 |
harness = create_harness(agent_config, workspace)
|
| 151 |
else:
|
| 152 |
-
agent = Agent(name="flow-cli", framework=
|
| 153 |
harness = create_harness(agent, workspace)
|
| 154 |
|
| 155 |
try:
|
|
@@ -167,10 +175,12 @@ async def _run_single_task(
|
|
| 167 |
await harness.close()
|
| 168 |
|
| 169 |
|
| 170 |
-
# Import and register
|
|
|
|
| 171 |
from flow.cli.optimize import optimize as optimize_cmd
|
| 172 |
|
| 173 |
app.command()(optimize_cmd)
|
|
|
|
| 174 |
|
| 175 |
|
| 176 |
@app.command()
|
|
|
|
| 65 |
framework: Annotated[
|
| 66 |
str,
|
| 67 |
typer.Option("--framework", "-f", help="Agent framework: 'maf', 'miniagent', or 'langgraph'"),
|
| 68 |
+
] = "miniagent",
|
| 69 |
interactive: Annotated[
|
| 70 |
bool,
|
| 71 |
typer.Option("--interactive/--no-interactive", "-i", help="Interactive mode"),
|
|
|
|
| 110 |
memory_path: Path,
|
| 111 |
task: str,
|
| 112 |
config_path: Path | None = None,
|
| 113 |
+
framework: str = "miniagent",
|
| 114 |
) -> None:
|
| 115 |
"""Run a single task and print the result."""
|
| 116 |
+
# Import harness modules to register them (side-effect imports)
|
| 117 |
+
import flow.harness.maf as _maf
|
| 118 |
+
import flow.harness.miniagent as _miniagent
|
| 119 |
from flow.cli.output import print_event
|
| 120 |
from flow.harness.base import EventType
|
| 121 |
|
| 122 |
+
_ = (_maf, _miniagent)
|
|
|
|
|
|
|
| 123 |
|
| 124 |
if framework == "langgraph":
|
| 125 |
try:
|
| 126 |
+
import flow.harness.langgraph as _lg
|
| 127 |
+
|
| 128 |
+
_ = _lg
|
| 129 |
except ImportError:
|
| 130 |
console.print("[red]Error:[/] LangGraph dependencies not installed.")
|
| 131 |
console.print("[dim]Install with: pip install flow-agent[langgraph][/]")
|
| 132 |
raise typer.Exit(1)
|
| 133 |
|
| 134 |
+
from typing import cast
|
| 135 |
+
|
| 136 |
+
from flow.experiments.models import Agent, Framework
|
| 137 |
from flow.harness import create_harness
|
| 138 |
+
|
| 139 |
+
# Cast the validated framework string to Framework literal type
|
| 140 |
+
framework_typed = cast(Framework, framework)
|
| 141 |
|
| 142 |
if config_path:
|
| 143 |
# Load agent config from optimization result
|
|
|
|
| 145 |
|
| 146 |
agent_config = load_agent(config_path)
|
| 147 |
# Override framework if specified
|
| 148 |
+
if framework != "miniagent":
|
| 149 |
agent_config = Agent(
|
| 150 |
name=agent_config.name,
|
| 151 |
+
framework=framework_typed,
|
| 152 |
tools=agent_config.tools,
|
| 153 |
+
llm_config=agent_config.llm_config,
|
| 154 |
instructions=agent_config.instructions,
|
| 155 |
compaction=agent_config.compaction,
|
| 156 |
)
|
| 157 |
console.print(f"[dim]Using agent config: {agent_config.name} ({framework})[/]")
|
| 158 |
harness = create_harness(agent_config, workspace)
|
| 159 |
else:
|
| 160 |
+
agent = Agent(name="flow-cli", framework=framework_typed)
|
| 161 |
harness = create_harness(agent, workspace)
|
| 162 |
|
| 163 |
try:
|
|
|
|
| 175 |
await harness.close()
|
| 176 |
|
| 177 |
|
| 178 |
+
# Import and register commands
|
| 179 |
+
from flow.cli.hf_import import hf_import as hf_import_cmd
|
| 180 |
from flow.cli.optimize import optimize as optimize_cmd
|
| 181 |
|
| 182 |
app.command()(optimize_cmd)
|
| 183 |
+
app.command(name="hf-import")(hf_import_cmd)
|
| 184 |
|
| 185 |
|
| 186 |
@app.command()
|
src/flow/cli/hf_import.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CLI command to import Hugging Face datasets."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Annotated
|
| 7 |
+
|
| 8 |
+
import typer
|
| 9 |
+
from rich.console import Console
|
| 10 |
+
|
| 11 |
+
from flow.experiments.hf_datasets import (
|
| 12 |
+
DATASET_CONVERTERS,
|
| 13 |
+
import_hf_dataset,
|
| 14 |
+
save_tasks_to_jsonl,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
console = Console()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def hf_import(
|
| 21 |
+
dataset: Annotated[
|
| 22 |
+
str,
|
| 23 |
+
typer.Argument(help="Hugging Face dataset name (e.g., 'openai/gsm8k')"),
|
| 24 |
+
],
|
| 25 |
+
output: Annotated[
|
| 26 |
+
Path,
|
| 27 |
+
typer.Option(
|
| 28 |
+
"--output",
|
| 29 |
+
"-o",
|
| 30 |
+
help="Output JSONL file path",
|
| 31 |
+
),
|
| 32 |
+
] = Path("tasks/imported.jsonl"),
|
| 33 |
+
config: Annotated[
|
| 34 |
+
str | None,
|
| 35 |
+
typer.Option(
|
| 36 |
+
"--config",
|
| 37 |
+
"-c",
|
| 38 |
+
help="Dataset configuration/subset (e.g., 'main' for gsm8k)",
|
| 39 |
+
),
|
| 40 |
+
] = None,
|
| 41 |
+
split: Annotated[
|
| 42 |
+
str,
|
| 43 |
+
typer.Option(
|
| 44 |
+
"--split",
|
| 45 |
+
"-s",
|
| 46 |
+
help="Dataset split to use",
|
| 47 |
+
),
|
| 48 |
+
] = "train",
|
| 49 |
+
limit: Annotated[
|
| 50 |
+
int | None,
|
| 51 |
+
typer.Option(
|
| 52 |
+
"--limit",
|
| 53 |
+
"-n",
|
| 54 |
+
help="Maximum number of examples to import",
|
| 55 |
+
),
|
| 56 |
+
] = None,
|
| 57 |
+
local_path: Annotated[
|
| 58 |
+
Path | None,
|
| 59 |
+
typer.Option(
|
| 60 |
+
"--local-path",
|
| 61 |
+
"-l",
|
| 62 |
+
help="Path to download dataset snapshot to. Uses huggingface_hub.snapshot_download(). "
|
| 63 |
+
"For private datasets, set HF_TOKEN env variable.",
|
| 64 |
+
),
|
| 65 |
+
] = None,
|
| 66 |
+
list_supported: Annotated[
|
| 67 |
+
bool,
|
| 68 |
+
typer.Option(
|
| 69 |
+
"--list",
|
| 70 |
+
help="List supported datasets and exit",
|
| 71 |
+
),
|
| 72 |
+
] = False,
|
| 73 |
+
) -> None:
|
| 74 |
+
"""Import a Hugging Face dataset for use with GEPA optimization.
|
| 75 |
+
|
| 76 |
+
Converts HF datasets into Flow's task format with evaluation criteria.
|
| 77 |
+
|
| 78 |
+
Examples:
|
| 79 |
+
# Import 100 GSM8K math problems
|
| 80 |
+
flow hf-import openai/gsm8k --config main --output tasks/gsm8k.jsonl --limit 100
|
| 81 |
+
|
| 82 |
+
# Import HumanEval coding problems
|
| 83 |
+
flow hf-import openai_humaneval --output tasks/humaneval.jsonl
|
| 84 |
+
|
| 85 |
+
# Download to local path first (useful for caching or private datasets)
|
| 86 |
+
flow hf-import openai/gsm8k --local-path /data/gsm8k --output tasks/gsm8k.jsonl
|
| 87 |
+
|
| 88 |
+
# For private datasets, set HF_TOKEN env variable
|
| 89 |
+
HF_TOKEN=hf_... flow hf-import org/private-dataset --local-path /data/private
|
| 90 |
+
|
| 91 |
+
# Use with GEPA
|
| 92 |
+
flow optimize \\
|
| 93 |
+
--config examples/gepa_strategy.yaml \\
|
| 94 |
+
--agent examples/base_agent.yaml \\
|
| 95 |
+
--tasks tasks/gsm8k.jsonl \\
|
| 96 |
+
--budget 10
|
| 97 |
+
"""
|
| 98 |
+
if list_supported:
|
| 99 |
+
console.print("\n[bold]Supported Datasets:[/]")
|
| 100 |
+
console.print("\n[dim]You can add custom converters via register_converter()[/]\n")
|
| 101 |
+
for name in sorted(DATASET_CONVERTERS.keys()):
|
| 102 |
+
console.print(f" • {name}")
|
| 103 |
+
return
|
| 104 |
+
|
| 105 |
+
console.print(f"\n[bold]Importing dataset:[/] {dataset}")
|
| 106 |
+
if config:
|
| 107 |
+
console.print(f"[dim]Config:[/] {config}")
|
| 108 |
+
console.print(f"[dim]Split:[/] {split}")
|
| 109 |
+
if limit:
|
| 110 |
+
console.print(f"[dim]Limit:[/] {limit} examples")
|
| 111 |
+
if local_path:
|
| 112 |
+
console.print(f"[dim]Local path:[/] {local_path}")
|
| 113 |
+
console.print()
|
| 114 |
+
|
| 115 |
+
try:
|
| 116 |
+
# Import dataset
|
| 117 |
+
tasks = import_hf_dataset(
|
| 118 |
+
dataset_name=dataset,
|
| 119 |
+
config=config,
|
| 120 |
+
split=split,
|
| 121 |
+
limit=limit,
|
| 122 |
+
local_path=local_path,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
if not tasks:
|
| 126 |
+
console.print("[red]Error:[/] No tasks were converted")
|
| 127 |
+
raise typer.Exit(1)
|
| 128 |
+
|
| 129 |
+
# Save to file
|
| 130 |
+
save_tasks_to_jsonl(tasks, output)
|
| 131 |
+
|
| 132 |
+
console.print(f"\n[green]Success![/] Imported {len(tasks)} tasks")
|
| 133 |
+
console.print(f"[dim]Output:[/] {output}")
|
| 134 |
+
console.print("\n[bold]Sample task:[/]")
|
| 135 |
+
console.print(f" Name: {tasks[0].name}")
|
| 136 |
+
console.print(f" Prompt: {tasks[0].prompt[:100]}...")
|
| 137 |
+
console.print(f" Criteria: {len(tasks[0].criteria)} evaluation criteria")
|
| 138 |
+
|
| 139 |
+
console.print("\n[bold]Next steps:[/]")
|
| 140 |
+
console.print(" [dim]# Run GEPA optimization[/]")
|
| 141 |
+
console.print(" flow optimize \\")
|
| 142 |
+
console.print(" --config examples/gepa_strategy.yaml \\")
|
| 143 |
+
console.print(" --agent examples/base_agent.yaml \\")
|
| 144 |
+
console.print(f" --tasks {output} \\")
|
| 145 |
+
console.print(" --budget 10")
|
| 146 |
+
|
| 147 |
+
except ImportError:
|
| 148 |
+
console.print("[red]Error:[/] Hugging Face datasets library not installed")
|
| 149 |
+
console.print("[dim]Install with:[/] pip install datasets")
|
| 150 |
+
raise typer.Exit(1)
|
| 151 |
+
except ValueError as e:
|
| 152 |
+
console.print(f"[red]Error:[/] {e}")
|
| 153 |
+
raise typer.Exit(1)
|
| 154 |
+
except Exception as e:
|
| 155 |
+
console.print(f"[red]Error:[/] {e}")
|
| 156 |
+
import traceback
|
| 157 |
+
|
| 158 |
+
traceback.print_exc()
|
| 159 |
+
raise typer.Exit(1)
|
src/flow/cli/optimize.py
CHANGED
|
@@ -18,7 +18,6 @@ from flow.experiments.models import (
|
|
| 18 |
Agent,
|
| 19 |
Candidate,
|
| 20 |
CompactionConfig,
|
| 21 |
-
Experiment,
|
| 22 |
ExperimentResult,
|
| 23 |
GridSearchStrategy,
|
| 24 |
load_experiment,
|
|
@@ -86,6 +85,13 @@ def optimize(
|
|
| 86 |
help="Output directory for results",
|
| 87 |
),
|
| 88 |
] = None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
no_llm_eval: Annotated[
|
| 90 |
bool,
|
| 91 |
typer.Option(
|
|
@@ -107,7 +113,6 @@ def optimize(
|
|
| 107 |
ranks via Pareto analysis, and exports winning agent configs.
|
| 108 |
|
| 109 |
Examples:
|
| 110 |
-
|
| 111 |
# Use experiment YAML (recommended - defines agent, tasks, and variations)
|
| 112 |
flow optimize --experiment experiment.yaml
|
| 113 |
|
|
@@ -140,6 +145,7 @@ def optimize(
|
|
| 140 |
output_dir=output,
|
| 141 |
use_llm_eval=not no_llm_eval,
|
| 142 |
budget=budget,
|
|
|
|
| 143 |
))
|
| 144 |
|
| 145 |
|
|
@@ -154,15 +160,16 @@ async def _run_optimize(
|
|
| 154 |
output_dir: Path | None,
|
| 155 |
use_llm_eval: bool,
|
| 156 |
budget: int,
|
|
|
|
| 157 |
) -> None:
|
| 158 |
"""Run the optimization."""
|
| 159 |
# If experiment YAML provided, use it as the source of truth
|
| 160 |
if experiment_path:
|
| 161 |
-
await _run_from_experiment(experiment_path, output_dir)
|
| 162 |
return
|
| 163 |
|
| 164 |
# Load tasks
|
| 165 |
-
tasks = _load_tasks(tasks_path, suite)
|
| 166 |
if not tasks:
|
| 167 |
console.print("[red]Error:[/] No tasks specified. Use --tasks or --suite")
|
| 168 |
raise typer.Exit(1)
|
|
@@ -171,7 +178,7 @@ async def _run_optimize(
|
|
| 171 |
base = _load_base_agent(agent_path)
|
| 172 |
|
| 173 |
# Load candidates and check if a strategy is defined in config
|
| 174 |
-
candidates, strategy_instance = _load_candidates_and_strategy(config_path, vary, base, budget)
|
| 175 |
|
| 176 |
# If a strategy was provided (like GepaStrategy), run it directly
|
| 177 |
if strategy_instance is not None:
|
|
@@ -221,7 +228,7 @@ async def _run_optimize(
|
|
| 221 |
raise typer.Exit(1)
|
| 222 |
|
| 223 |
|
| 224 |
-
async def _run_from_experiment(experiment_path: Path, output_dir: Path | None) -> None:
|
| 225 |
"""Run optimization from an experiment YAML file.
|
| 226 |
|
| 227 |
The experiment YAML defines:
|
|
@@ -270,10 +277,13 @@ async def _run_from_experiment(experiment_path: Path, output_dir: Path | None) -
|
|
| 270 |
console.print("[red]Error:[/] Experiment must specify 'suite' or 'tasks'")
|
| 271 |
raise typer.Exit(1)
|
| 272 |
|
|
|
|
|
|
|
|
|
|
| 273 |
# Generate candidates from variations
|
| 274 |
if exp.variations:
|
| 275 |
strategy = GridSearchStrategy(exp.variations)
|
| 276 |
-
candidates = strategy.generate(base, exp.budget)
|
| 277 |
else:
|
| 278 |
candidates = [Candidate(agent=base, mutations={}, rationale="baseline")]
|
| 279 |
|
|
@@ -283,7 +293,7 @@ async def _run_from_experiment(experiment_path: Path, output_dir: Path | None) -
|
|
| 283 |
for t in tasks:
|
| 284 |
console.print(f" - {t.name}")
|
| 285 |
|
| 286 |
-
console.print(
|
| 287 |
for key, values in exp.variations.items():
|
| 288 |
console.print(f" - {key}: {len(values)} variants")
|
| 289 |
|
|
@@ -309,27 +319,31 @@ async def _run_from_experiment(experiment_path: Path, output_dir: Path | None) -
|
|
| 309 |
raise typer.Exit(1)
|
| 310 |
|
| 311 |
|
| 312 |
-
def _load_tasks(tasks_path: Path | None, suite: str | None) -> list[Task]:
|
| 313 |
"""Load tasks from file or built-in suite."""
|
|
|
|
| 314 |
if tasks_path:
|
| 315 |
if not tasks_path.exists():
|
| 316 |
console.print(f"[red]Error:[/] Tasks file not found: {tasks_path}")
|
| 317 |
raise typer.Exit(1)
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
if suite:
|
| 321 |
try:
|
| 322 |
-
|
| 323 |
except ValueError as e:
|
| 324 |
console.print(f"[red]Error:[/] {e}")
|
| 325 |
raise typer.Exit(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
except ValueError:
|
| 331 |
-
console.print("[red]Error:[/] No built-in suites available. Use --tasks to specify a JSONL file.")
|
| 332 |
-
raise typer.Exit(1)
|
| 333 |
|
| 334 |
|
| 335 |
def _load_base_agent(agent_path: Path | None) -> Agent:
|
|
@@ -344,18 +358,18 @@ def _load_base_agent(agent_path: Path | None) -> Agent:
|
|
| 344 |
return Agent(name="flow_agent")
|
| 345 |
|
| 346 |
|
| 347 |
-
def _load_candidates_and_strategy(
|
| 348 |
config_path: Path | None,
|
| 349 |
vary: str | None,
|
| 350 |
base: Agent,
|
| 351 |
budget: int,
|
| 352 |
) -> tuple[list[Candidate], Any | None]:
|
| 353 |
"""Load candidates from file or generate from variations.
|
| 354 |
-
|
| 355 |
Supports both YAML and Python config files:
|
| 356 |
- YAML: strategy configuration (strategy_type, config)
|
| 357 |
- Python: STRATEGY object, CANDIDATES list, or VARIATIONS dict
|
| 358 |
-
|
| 359 |
Returns:
|
| 360 |
Tuple of (candidates, strategy_instance)
|
| 361 |
- If a STRATEGY is defined in config, returns ([], strategy_instance)
|
|
@@ -374,17 +388,17 @@ def _load_candidates_and_strategy(
|
|
| 374 |
# YAML files currently only support strategy definitions
|
| 375 |
console.print("[red]Error:[/] YAML config must define a strategy")
|
| 376 |
raise typer.Exit(1)
|
| 377 |
-
|
| 378 |
# Python config file
|
| 379 |
candidates, variations, strategy_obj = _load_python_config(config_path)
|
| 380 |
|
| 381 |
# If a strategy object was provided (e.g., GepaStrategy), return it
|
| 382 |
if strategy_obj is not None:
|
| 383 |
return [], strategy_obj
|
| 384 |
-
|
| 385 |
if variations:
|
| 386 |
strategy = GridSearchStrategy(variations)
|
| 387 |
-
return strategy.generate(base, budget), None
|
| 388 |
elif candidates:
|
| 389 |
return candidates, None
|
| 390 |
else:
|
|
@@ -394,7 +408,7 @@ def _load_candidates_and_strategy(
|
|
| 394 |
if vary:
|
| 395 |
variations = _parse_vary_flag(vary)
|
| 396 |
strategy = GridSearchStrategy(variations)
|
| 397 |
-
return strategy.generate(base, budget), None
|
| 398 |
|
| 399 |
# Default: explore context engineering dimensions
|
| 400 |
strategy = GridSearchStrategy(variations={
|
|
@@ -402,9 +416,8 @@ def _load_candidates_and_strategy(
|
|
| 402 |
CompactionConfig.head_tail(10, 40),
|
| 403 |
CompactionConfig.none(),
|
| 404 |
],
|
| 405 |
-
"tools": ["minimal", "standard"],
|
| 406 |
})
|
| 407 |
-
return strategy.generate(base, budget), None
|
| 408 |
|
| 409 |
|
| 410 |
def _load_yaml_strategy(path: Path) -> Any | None:
|
|
@@ -442,9 +455,12 @@ def _load_yaml_strategy(path: Path) -> Any | None:
|
|
| 442 |
console.print("[red]Error:[/] GEPA optimizer not available.")
|
| 443 |
console.print("[dim]Install with: pip install flow-agent[optimizer][/]")
|
| 444 |
raise typer.Exit(1)
|
|
|
|
|
|
|
|
|
|
| 445 |
else:
|
| 446 |
console.print(f"[red]Error:[/] Unknown strategy type: {strategy_type}")
|
| 447 |
-
console.print("[dim]Supported: gepa[/]")
|
| 448 |
raise typer.Exit(1)
|
| 449 |
|
| 450 |
|
|
@@ -526,128 +542,193 @@ async def _run_active_strategy(
|
|
| 526 |
use_llm_eval: bool,
|
| 527 |
budget: int,
|
| 528 |
) -> None:
|
| 529 |
-
"""Run an active optimization strategy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 530 |
logger = logging.getLogger(__name__)
|
| 531 |
-
|
| 532 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 533 |
optimizer_runner = FlowOptimizer(
|
| 534 |
parallel=parallel,
|
| 535 |
use_llm_evaluator=use_llm_eval,
|
| 536 |
-
output_dir=
|
| 537 |
)
|
| 538 |
|
| 539 |
-
|
| 540 |
main_loop = asyncio.get_running_loop()
|
| 541 |
|
| 542 |
-
# Define evaluator function to inject into strategy
|
| 543 |
def evaluator(candidate: Candidate, minibatch: list[Task] | None = None) -> ExperimentResult:
|
| 544 |
-
"""Evaluate a candidate on a minibatch of tasks."""
|
|
|
|
|
|
|
| 545 |
eval_tasks = minibatch if minibatch else tasks
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 551 |
try:
|
| 552 |
-
# Run async evaluation on the main loop and wait for result
|
| 553 |
-
# This is safe because strategy.generate (which calls this)
|
| 554 |
-
# is running in an executor thread.
|
| 555 |
future = asyncio.run_coroutine_threadsafe(
|
| 556 |
-
optimizer_runner.optimize([candidate], eval_tasks),
|
| 557 |
main_loop
|
| 558 |
)
|
| 559 |
optimization_result = future.result()
|
| 560 |
-
|
| 561 |
-
# Check if we got any results
|
| 562 |
if not optimization_result.summaries:
|
| 563 |
-
logger.warning(f"[EVALUATOR] Optimization produced no summaries for candidate '{candidate.agent.name}'")
|
| 564 |
-
# Return a fallback result with zero score instead of raising
|
| 565 |
return ExperimentResult(
|
| 566 |
-
candidate=candidate,
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
eval_score=0.0,
|
| 570 |
-
eval_passed=False,
|
| 571 |
-
eval_reasoning="Evaluation failed to produce results",
|
| 572 |
-
traces={}
|
| 573 |
)
|
| 574 |
-
|
| 575 |
summary = optimization_result.summaries[0]
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
# Log individual task results for debugging
|
| 579 |
-
if summary.task_results:
|
| 580 |
-
for tr in summary.task_results:
|
| 581 |
-
logger.info(f"[EVALUATOR] Task '{tr.task_name}': score={tr.eval_score:.3f}, passed={tr.eval_passed}")
|
| 582 |
-
logger.debug(f"[EVALUATOR] Reasoning: '{tr.eval_reasoning[:150]}'")
|
| 583 |
-
logger.debug(f"[EVALUATOR] Metrics: tokens={tr.metrics.total_tokens}, duration={tr.run_result.duration_seconds if tr.run_result else 0:.2f}s")
|
| 584 |
-
|
| 585 |
-
# Convert CandidateSummary to ExperimentResult for GEPA
|
| 586 |
-
|
| 587 |
if summary.task_results:
|
| 588 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 589 |
return ExperimentResult(
|
| 590 |
candidate=candidate,
|
| 591 |
-
run_result=
|
| 592 |
-
metrics=
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
|
|
|
| 597 |
)
|
| 598 |
-
|
| 599 |
-
# Fallback to aggregate metrics if no individual task results
|
| 600 |
return ExperimentResult(
|
| 601 |
-
candidate=candidate,
|
| 602 |
-
|
| 603 |
-
metrics={"score": summary.avg_score},
|
| 604 |
-
eval_score=summary.avg_score,
|
| 605 |
eval_passed=summary.pass_rate > 0.5,
|
| 606 |
-
eval_reasoning=f"Aggregate pass rate: {summary.pass_rate}",
|
| 607 |
-
traces={}
|
| 608 |
)
|
| 609 |
-
|
| 610 |
except Exception as e:
|
| 611 |
logger.error(f"Error evaluating candidate '{candidate.agent.name}': {e}", exc_info=True)
|
| 612 |
-
# Return a fallback result instead of propagating the exception
|
| 613 |
return ExperimentResult(
|
| 614 |
-
candidate=candidate,
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
eval_score=0.0,
|
| 618 |
-
eval_passed=False,
|
| 619 |
-
eval_reasoning=f"Evaluation error: {str(e)}",
|
| 620 |
-
traces={}
|
| 621 |
)
|
| 622 |
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
# GepaStrategy accepts them in __init__, but we might have loaded it from config
|
| 626 |
-
# without them.
|
| 627 |
-
if hasattr(strategy, "evaluator") and strategy.evaluator is None:
|
| 628 |
strategy.evaluator = evaluator
|
| 629 |
if hasattr(strategy, "dataset") and strategy.dataset is None:
|
| 630 |
strategy.dataset = tasks
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
candidates = await loop.run_in_executor(None, strategy.generate, base_agent, budget)
|
| 637 |
|
| 638 |
console.print("\n[bold green]Optimization complete![/]")
|
| 639 |
console.print(f"Generated {len(candidates)} candidates.")
|
| 640 |
|
| 641 |
-
|
| 642 |
-
if
|
|
|
|
|
|
|
| 643 |
from flow.experiments.models import export_agent
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
|
|
|
| 647 |
for i, cand in enumerate(candidates):
|
| 648 |
-
# Basic export
|
| 649 |
name = cand.agent.name or f"candidate_{i}"
|
| 650 |
-
export_agent(cand.agent,
|
| 651 |
-
|
| 652 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 653 |
|
|
|
|
| 18 |
Agent,
|
| 19 |
Candidate,
|
| 20 |
CompactionConfig,
|
|
|
|
| 21 |
ExperimentResult,
|
| 22 |
GridSearchStrategy,
|
| 23 |
load_experiment,
|
|
|
|
| 85 |
help="Output directory for results",
|
| 86 |
),
|
| 87 |
] = None,
|
| 88 |
+
limit: Annotated[
|
| 89 |
+
int | None,
|
| 90 |
+
typer.Option(
|
| 91 |
+
"--limit", "-l",
|
| 92 |
+
help="Max number of tasks to run (use first N tasks from suite/file)",
|
| 93 |
+
),
|
| 94 |
+
] = None,
|
| 95 |
no_llm_eval: Annotated[
|
| 96 |
bool,
|
| 97 |
typer.Option(
|
|
|
|
| 113 |
ranks via Pareto analysis, and exports winning agent configs.
|
| 114 |
|
| 115 |
Examples:
|
|
|
|
| 116 |
# Use experiment YAML (recommended - defines agent, tasks, and variations)
|
| 117 |
flow optimize --experiment experiment.yaml
|
| 118 |
|
|
|
|
| 145 |
output_dir=output,
|
| 146 |
use_llm_eval=not no_llm_eval,
|
| 147 |
budget=budget,
|
| 148 |
+
limit=limit,
|
| 149 |
))
|
| 150 |
|
| 151 |
|
|
|
|
| 160 |
output_dir: Path | None,
|
| 161 |
use_llm_eval: bool,
|
| 162 |
budget: int,
|
| 163 |
+
limit: int | None = None,
|
| 164 |
) -> None:
|
| 165 |
"""Run the optimization."""
|
| 166 |
# If experiment YAML provided, use it as the source of truth
|
| 167 |
if experiment_path:
|
| 168 |
+
await _run_from_experiment(experiment_path, output_dir, limit=limit)
|
| 169 |
return
|
| 170 |
|
| 171 |
# Load tasks
|
| 172 |
+
tasks = _load_tasks(tasks_path, suite, limit=limit)
|
| 173 |
if not tasks:
|
| 174 |
console.print("[red]Error:[/] No tasks specified. Use --tasks or --suite")
|
| 175 |
raise typer.Exit(1)
|
|
|
|
| 178 |
base = _load_base_agent(agent_path)
|
| 179 |
|
| 180 |
# Load candidates and check if a strategy is defined in config
|
| 181 |
+
candidates, strategy_instance = await _load_candidates_and_strategy(config_path, vary, base, budget)
|
| 182 |
|
| 183 |
# If a strategy was provided (like GepaStrategy), run it directly
|
| 184 |
if strategy_instance is not None:
|
|
|
|
| 228 |
raise typer.Exit(1)
|
| 229 |
|
| 230 |
|
| 231 |
+
async def _run_from_experiment(experiment_path: Path, output_dir: Path | None, limit: int | None = None) -> None:
|
| 232 |
"""Run optimization from an experiment YAML file.
|
| 233 |
|
| 234 |
The experiment YAML defines:
|
|
|
|
| 277 |
console.print("[red]Error:[/] Experiment must specify 'suite' or 'tasks'")
|
| 278 |
raise typer.Exit(1)
|
| 279 |
|
| 280 |
+
if limit is not None and limit > 0:
|
| 281 |
+
tasks = tasks[:limit]
|
| 282 |
+
|
| 283 |
# Generate candidates from variations
|
| 284 |
if exp.variations:
|
| 285 |
strategy = GridSearchStrategy(exp.variations)
|
| 286 |
+
candidates = await strategy.generate(base, exp.budget)
|
| 287 |
else:
|
| 288 |
candidates = [Candidate(agent=base, mutations={}, rationale="baseline")]
|
| 289 |
|
|
|
|
| 293 |
for t in tasks:
|
| 294 |
console.print(f" - {t.name}")
|
| 295 |
|
| 296 |
+
console.print("\n[bold]Variations:[/]")
|
| 297 |
for key, values in exp.variations.items():
|
| 298 |
console.print(f" - {key}: {len(values)} variants")
|
| 299 |
|
|
|
|
| 319 |
raise typer.Exit(1)
|
| 320 |
|
| 321 |
|
| 322 |
+
def _load_tasks(tasks_path: Path | None, suite: str | None, limit: int | None = None) -> list[Task]:
|
| 323 |
"""Load tasks from file or built-in suite."""
|
| 324 |
+
tasks: list[Task] = []
|
| 325 |
if tasks_path:
|
| 326 |
if not tasks_path.exists():
|
| 327 |
console.print(f"[red]Error:[/] Tasks file not found: {tasks_path}")
|
| 328 |
raise typer.Exit(1)
|
| 329 |
+
tasks = load_tasks_from_jsonl(tasks_path)
|
| 330 |
+
elif suite:
|
|
|
|
| 331 |
try:
|
| 332 |
+
tasks = get_task_suite(suite)
|
| 333 |
except ValueError as e:
|
| 334 |
console.print(f"[red]Error:[/] {e}")
|
| 335 |
raise typer.Exit(1)
|
| 336 |
+
else:
|
| 337 |
+
# Default: quick suite
|
| 338 |
+
try:
|
| 339 |
+
tasks = get_task_suite("quick")
|
| 340 |
+
except ValueError:
|
| 341 |
+
console.print("[red]Error:[/] No built-in suites available. Use --tasks to specify a JSONL file.")
|
| 342 |
+
raise typer.Exit(1)
|
| 343 |
|
| 344 |
+
if limit is not None and limit > 0:
|
| 345 |
+
tasks = tasks[:limit]
|
| 346 |
+
return tasks
|
|
|
|
|
|
|
|
|
|
| 347 |
|
| 348 |
|
| 349 |
def _load_base_agent(agent_path: Path | None) -> Agent:
|
|
|
|
| 358 |
return Agent(name="flow_agent")
|
| 359 |
|
| 360 |
|
| 361 |
+
async def _load_candidates_and_strategy(
|
| 362 |
config_path: Path | None,
|
| 363 |
vary: str | None,
|
| 364 |
base: Agent,
|
| 365 |
budget: int,
|
| 366 |
) -> tuple[list[Candidate], Any | None]:
|
| 367 |
"""Load candidates from file or generate from variations.
|
| 368 |
+
|
| 369 |
Supports both YAML and Python config files:
|
| 370 |
- YAML: strategy configuration (strategy_type, config)
|
| 371 |
- Python: STRATEGY object, CANDIDATES list, or VARIATIONS dict
|
| 372 |
+
|
| 373 |
Returns:
|
| 374 |
Tuple of (candidates, strategy_instance)
|
| 375 |
- If a STRATEGY is defined in config, returns ([], strategy_instance)
|
|
|
|
| 388 |
# YAML files currently only support strategy definitions
|
| 389 |
console.print("[red]Error:[/] YAML config must define a strategy")
|
| 390 |
raise typer.Exit(1)
|
| 391 |
+
|
| 392 |
# Python config file
|
| 393 |
candidates, variations, strategy_obj = _load_python_config(config_path)
|
| 394 |
|
| 395 |
# If a strategy object was provided (e.g., GepaStrategy), return it
|
| 396 |
if strategy_obj is not None:
|
| 397 |
return [], strategy_obj
|
| 398 |
+
|
| 399 |
if variations:
|
| 400 |
strategy = GridSearchStrategy(variations)
|
| 401 |
+
return await strategy.generate(base, budget), None
|
| 402 |
elif candidates:
|
| 403 |
return candidates, None
|
| 404 |
else:
|
|
|
|
| 408 |
if vary:
|
| 409 |
variations = _parse_vary_flag(vary)
|
| 410 |
strategy = GridSearchStrategy(variations)
|
| 411 |
+
return await strategy.generate(base, budget), None
|
| 412 |
|
| 413 |
# Default: explore context engineering dimensions
|
| 414 |
strategy = GridSearchStrategy(variations={
|
|
|
|
| 416 |
CompactionConfig.head_tail(10, 40),
|
| 417 |
CompactionConfig.none(),
|
| 418 |
],
|
|
|
|
| 419 |
})
|
| 420 |
+
return await strategy.generate(base, budget), None
|
| 421 |
|
| 422 |
|
| 423 |
def _load_yaml_strategy(path: Path) -> Any | None:
|
|
|
|
| 455 |
console.print("[red]Error:[/] GEPA optimizer not available.")
|
| 456 |
console.print("[dim]Install with: pip install flow-agent[optimizer][/]")
|
| 457 |
raise typer.Exit(1)
|
| 458 |
+
elif strategy_type == "llm_rewriter":
|
| 459 |
+
from flow.experiments.strategies.llm_rewriter import LLMRewriterStrategy
|
| 460 |
+
return LLMRewriterStrategy(config=strategy_config)
|
| 461 |
else:
|
| 462 |
console.print(f"[red]Error:[/] Unknown strategy type: {strategy_type}")
|
| 463 |
+
console.print("[dim]Supported: gepa, llm_rewriter[/]")
|
| 464 |
raise typer.Exit(1)
|
| 465 |
|
| 466 |
|
|
|
|
| 542 |
use_llm_eval: bool,
|
| 543 |
budget: int,
|
| 544 |
) -> None:
|
| 545 |
+
"""Run an active optimization strategy.
|
| 546 |
+
|
| 547 |
+
For strategies that use the ExperimentRunner protocol (LLMRewriterStrategy),
|
| 548 |
+
delegates to FlowOptimizer.optimize_with_strategy() which handles setup,
|
| 549 |
+
evaluation, Pareto analysis, and export.
|
| 550 |
+
|
| 551 |
+
For GEPA (which uses its own evaluator callback), uses the legacy path
|
| 552 |
+
with a bridging evaluator function.
|
| 553 |
+
"""
|
| 554 |
+
# Check if strategy uses GEPA's evaluator pattern (legacy path)
|
| 555 |
+
is_gepa = hasattr(strategy, "evaluator")
|
| 556 |
+
|
| 557 |
+
if is_gepa:
|
| 558 |
+
await _run_gepa_strategy(strategy, base_agent, tasks, output_dir, parallel, use_llm_eval, budget)
|
| 559 |
+
else:
|
| 560 |
+
# Modern path: use optimize_with_strategy which passes self as runner
|
| 561 |
+
optimizer = FlowOptimizer(
|
| 562 |
+
parallel=parallel,
|
| 563 |
+
use_llm_evaluator=use_llm_eval,
|
| 564 |
+
output_dir=output_dir,
|
| 565 |
+
)
|
| 566 |
+
result = await optimizer.optimize_with_strategy(
|
| 567 |
+
strategy=strategy,
|
| 568 |
+
base=base_agent,
|
| 569 |
+
tasks=tasks,
|
| 570 |
+
budget=budget,
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
console.print("\n[bold green]Optimization complete![/]")
|
| 574 |
+
console.print(f"\nBest agents exported to: [cyan]{result.output_dir / 'agents'}[/]")
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
async def _run_gepa_strategy(
|
| 578 |
+
strategy: Any,
|
| 579 |
+
base_agent: Agent,
|
| 580 |
+
tasks: list[Task],
|
| 581 |
+
output_dir: Path | None,
|
| 582 |
+
parallel: int,
|
| 583 |
+
use_llm_eval: bool,
|
| 584 |
+
budget: int,
|
| 585 |
+
) -> None:
|
| 586 |
+
"""Run GEPA strategy with its custom evaluator callback bridge."""
|
| 587 |
logger = logging.getLogger(__name__)
|
| 588 |
+
|
| 589 |
+
import threading
|
| 590 |
+
from datetime import datetime
|
| 591 |
+
|
| 592 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 593 |
+
if output_dir is None:
|
| 594 |
+
base_output_dir = Path.home() / ".flow" / "optimizations"
|
| 595 |
+
else:
|
| 596 |
+
base_output_dir = output_dir
|
| 597 |
+
run_dir = base_output_dir / f"gepa_{timestamp}"
|
| 598 |
+
run_dir.mkdir(parents=True, exist_ok=True)
|
| 599 |
+
|
| 600 |
+
eval_counter = 0
|
| 601 |
+
counter_lock = threading.Lock()
|
| 602 |
+
|
| 603 |
optimizer_runner = FlowOptimizer(
|
| 604 |
parallel=parallel,
|
| 605 |
use_llm_evaluator=use_llm_eval,
|
| 606 |
+
output_dir=run_dir,
|
| 607 |
)
|
| 608 |
|
|
|
|
| 609 |
main_loop = asyncio.get_running_loop()
|
| 610 |
|
|
|
|
| 611 |
def evaluator(candidate: Candidate, minibatch: list[Task] | None = None) -> ExperimentResult:
|
| 612 |
+
"""Evaluate a candidate on a minibatch of tasks (GEPA bridge)."""
|
| 613 |
+
nonlocal eval_counter
|
| 614 |
+
|
| 615 |
eval_tasks = minibatch if minibatch else tasks
|
| 616 |
+
candidate_id = candidate.mutations.get("_candidate_id", "unknown")
|
| 617 |
+
|
| 618 |
+
with counter_lock:
|
| 619 |
+
rollout_num = eval_counter
|
| 620 |
+
eval_counter += 1
|
| 621 |
+
|
| 622 |
+
rollout_dir = run_dir / f"rollout_{rollout_num}_{candidate_id}"
|
| 623 |
+
logger.debug(f"[EVALUATOR] Evaluating {candidate_id} on {len(eval_tasks)} tasks (rollout {rollout_num})")
|
| 624 |
+
|
| 625 |
try:
|
|
|
|
|
|
|
|
|
|
| 626 |
future = asyncio.run_coroutine_threadsafe(
|
| 627 |
+
optimizer_runner.optimize([candidate], eval_tasks, run_dir=rollout_dir),
|
| 628 |
main_loop
|
| 629 |
)
|
| 630 |
optimization_result = future.result()
|
| 631 |
+
|
|
|
|
| 632 |
if not optimization_result.summaries:
|
|
|
|
|
|
|
| 633 |
return ExperimentResult(
|
| 634 |
+
candidate=candidate, run_result=None,
|
| 635 |
+
metrics={"score": 0.0}, eval_score=0.0,
|
| 636 |
+
eval_passed=False, eval_reasoning="No results", traces={}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 637 |
)
|
| 638 |
+
|
| 639 |
summary = optimization_result.summaries[0]
|
| 640 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 641 |
if summary.task_results:
|
| 642 |
+
total_tokens = sum(tr.metrics.total_tokens for tr in summary.task_results)
|
| 643 |
+
avg_duration = sum(
|
| 644 |
+
tr.run_result.duration_seconds for tr in summary.task_results if tr.run_result
|
| 645 |
+
) / max(len(summary.task_results), 1)
|
| 646 |
+
combined_reasoning = "\n".join(
|
| 647 |
+
f"Task {tr.task_name}: {tr.eval_reasoning}" for tr in summary.task_results
|
| 648 |
+
)
|
| 649 |
return ExperimentResult(
|
| 650 |
candidate=candidate,
|
| 651 |
+
run_result=summary.task_results[0].run_result,
|
| 652 |
+
metrics={"total_tokens": total_tokens, "avg_duration": avg_duration,
|
| 653 |
+
"pass_rate": summary.pass_rate, "num_tasks": len(summary.task_results)},
|
| 654 |
+
eval_score=summary.avg_score,
|
| 655 |
+
eval_passed=summary.pass_rate > 0.5,
|
| 656 |
+
eval_reasoning=combined_reasoning,
|
| 657 |
+
traces=summary.task_results[0].run_result.trace if summary.task_results[0].run_result else {},
|
| 658 |
)
|
| 659 |
+
|
|
|
|
| 660 |
return ExperimentResult(
|
| 661 |
+
candidate=candidate, run_result=None,
|
| 662 |
+
metrics={"score": summary.avg_score}, eval_score=summary.avg_score,
|
|
|
|
|
|
|
| 663 |
eval_passed=summary.pass_rate > 0.5,
|
| 664 |
+
eval_reasoning=f"Aggregate pass rate: {summary.pass_rate}", traces={}
|
|
|
|
| 665 |
)
|
| 666 |
+
|
| 667 |
except Exception as e:
|
| 668 |
logger.error(f"Error evaluating candidate '{candidate.agent.name}': {e}", exc_info=True)
|
|
|
|
| 669 |
return ExperimentResult(
|
| 670 |
+
candidate=candidate, run_result=None,
|
| 671 |
+
metrics={"score": 0.0, "error": str(e)}, eval_score=0.0,
|
| 672 |
+
eval_passed=False, eval_reasoning=f"Evaluation error: {e!s}", traces={}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 673 |
)
|
| 674 |
|
| 675 |
+
# Inject GEPA-specific dependencies
|
| 676 |
+
if strategy.evaluator is None:
|
|
|
|
|
|
|
|
|
|
| 677 |
strategy.evaluator = evaluator
|
| 678 |
if hasattr(strategy, "dataset") and strategy.dataset is None:
|
| 679 |
strategy.dataset = tasks
|
| 680 |
+
|
| 681 |
+
candidates = await strategy.generate(base_agent, budget, tasks=tasks, runner=None)
|
| 682 |
+
|
| 683 |
+
if hasattr(strategy, "print_report"):
|
| 684 |
+
strategy.print_report()
|
|
|
|
| 685 |
|
| 686 |
console.print("\n[bold green]Optimization complete![/]")
|
| 687 |
console.print(f"Generated {len(candidates)} candidates.")
|
| 688 |
|
| 689 |
+
output_path = output_dir if output_dir else run_dir
|
| 690 |
+
if output_path:
|
| 691 |
+
import json
|
| 692 |
+
|
| 693 |
from flow.experiments.models import export_agent
|
| 694 |
+
|
| 695 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 696 |
+
(output_path / "agents").mkdir(exist_ok=True)
|
| 697 |
+
|
| 698 |
for i, cand in enumerate(candidates):
|
|
|
|
| 699 |
name = cand.agent.name or f"candidate_{i}"
|
| 700 |
+
export_agent(cand.agent, output_path / "agents" / f"{name}.yaml", metrics={"rationale": cand.rationale})
|
| 701 |
+
|
| 702 |
+
if hasattr(strategy, "get_report") and strategy.get_report():
|
| 703 |
+
report = strategy.get_report()
|
| 704 |
+
report_data = {
|
| 705 |
+
"baseline_prompt": report.baseline_prompt,
|
| 706 |
+
"baseline_score": report.baseline_score,
|
| 707 |
+
"final_prompt": report.final_prompt,
|
| 708 |
+
"final_score": report.final_score,
|
| 709 |
+
"best_candidate_id": report.best_candidate_id,
|
| 710 |
+
"improvement": report.improvement,
|
| 711 |
+
"improvement_percent": (report.improvement / max(report.baseline_score, 0.001)) * 100,
|
| 712 |
+
"total_candidates_evaluated": report.total_candidates_evaluated,
|
| 713 |
+
"total_generations": report.total_generations,
|
| 714 |
+
"candidate_history": [
|
| 715 |
+
{
|
| 716 |
+
"generation": r.generation,
|
| 717 |
+
"candidate_id": r.candidate_id,
|
| 718 |
+
"avg_score": r.avg_score,
|
| 719 |
+
"best_score": r.best_score,
|
| 720 |
+
"best_eval_num": r.best_eval_num,
|
| 721 |
+
"eval_count": r.eval_count,
|
| 722 |
+
"pass_rate": r.pass_rate,
|
| 723 |
+
"is_selected": r.is_selected,
|
| 724 |
+
"instructions_preview": r.instructions_preview,
|
| 725 |
+
}
|
| 726 |
+
for r in report.candidate_history
|
| 727 |
+
]
|
| 728 |
+
}
|
| 729 |
+
with open(output_path / "optimization_report.json", "w") as f:
|
| 730 |
+
json.dump(report_data, f, indent=2)
|
| 731 |
+
console.print(f"Optimization report saved to: [cyan]{output_path / 'optimization_report.json'}[/]")
|
| 732 |
+
|
| 733 |
+
console.print(f"\nAgents exported to: [cyan]{output_path / 'agents'}[/]")
|
| 734 |
|
src/flow/cli/repl.py
CHANGED
|
@@ -47,7 +47,9 @@ class FlowREPL:
|
|
| 47 |
"""Get or create the harness instance."""
|
| 48 |
if self._harness is None:
|
| 49 |
# Import maf module to register the harness, then use registry
|
| 50 |
-
import flow.harness.maf
|
|
|
|
|
|
|
| 51 |
from flow.harness import create_harness
|
| 52 |
|
| 53 |
agent = Agent(name="flow-repl")
|
|
|
|
| 47 |
"""Get or create the harness instance."""
|
| 48 |
if self._harness is None:
|
| 49 |
# Import maf module to register the harness, then use registry
|
| 50 |
+
import flow.harness.maf as _maf
|
| 51 |
+
|
| 52 |
+
_ = _maf
|
| 53 |
from flow.harness import create_harness
|
| 54 |
|
| 55 |
agent = Agent(name="flow-repl")
|
src/flow/experiments/__init__.py
CHANGED
|
@@ -27,7 +27,7 @@ Example usage:
|
|
| 27 |
strategy = GridSearchStrategy(variations={
|
| 28 |
"enable_memory": [True, False],
|
| 29 |
})
|
| 30 |
-
candidates = strategy.generate(base, budget=10)
|
| 31 |
|
| 32 |
# Run optimization
|
| 33 |
optimizer = FlowOptimizer(parallel=4)
|
|
@@ -37,18 +37,6 @@ Example usage:
|
|
| 37 |
"""
|
| 38 |
|
| 39 |
# Core models
|
| 40 |
-
from .models import (
|
| 41 |
-
Agent,
|
| 42 |
-
Candidate,
|
| 43 |
-
CandidateStrategy,
|
| 44 |
-
CompactionConfig,
|
| 45 |
-
ExperimentResult,
|
| 46 |
-
GridSearchStrategy,
|
| 47 |
-
export_agent,
|
| 48 |
-
export_optimization_results,
|
| 49 |
-
load_agent,
|
| 50 |
-
)
|
| 51 |
-
|
| 52 |
# Experiment runner + Pareto analysis
|
| 53 |
from .ablation import (
|
| 54 |
compute_pareto_frontier,
|
|
@@ -66,6 +54,16 @@ from .evaluators import (
|
|
| 66 |
TraceEvaluator,
|
| 67 |
)
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
# Metrics
|
| 70 |
from .metrics import (
|
| 71 |
LLMCallInfo,
|
|
@@ -75,6 +73,26 @@ from .metrics import (
|
|
| 75 |
format_metrics_summary,
|
| 76 |
metrics_to_dict,
|
| 77 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
# Optimizer
|
| 80 |
from .optimizer import (
|
|
@@ -82,6 +100,7 @@ from .optimizer import (
|
|
| 82 |
FlowOptimizer,
|
| 83 |
OptimizationResult,
|
| 84 |
TaskResult,
|
|
|
|
| 85 |
load_tasks_from_jsonl,
|
| 86 |
)
|
| 87 |
|
|
@@ -96,11 +115,24 @@ from .reporters import (
|
|
| 96 |
)
|
| 97 |
|
| 98 |
# Runner
|
| 99 |
-
from .runner import FlowExperimentRunner, setup_tracing
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
# Trace collection
|
| 102 |
from .trace_collector import FlowTraceCollector
|
| 103 |
-
from .types import CriterionResult, EvalCriterion, EvalResult, RunResult, Task
|
| 104 |
|
| 105 |
__all__ = [ # noqa: RUF022 # Intentionally grouped by category
|
| 106 |
# Core models
|
|
@@ -108,11 +140,27 @@ __all__ = [ # noqa: RUF022 # Intentionally grouped by category
|
|
| 108 |
"Candidate",
|
| 109 |
"CandidateStrategy",
|
| 110 |
"CompactionConfig",
|
|
|
|
| 111 |
"ExperimentResult",
|
|
|
|
|
|
|
| 112 |
"GridSearchStrategy",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
"export_agent",
|
| 114 |
"load_agent",
|
|
|
|
| 115 |
"export_optimization_results",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
# Types
|
| 117 |
"Task",
|
| 118 |
"EvalCriterion",
|
|
@@ -130,6 +178,7 @@ __all__ = [ # noqa: RUF022 # Intentionally grouped by category
|
|
| 130 |
"metrics_to_dict",
|
| 131 |
# Runner
|
| 132 |
"FlowExperimentRunner",
|
|
|
|
| 133 |
"setup_tracing",
|
| 134 |
# Evaluators
|
| 135 |
"Evaluator",
|
|
@@ -154,5 +203,20 @@ __all__ = [ # noqa: RUF022 # Intentionally grouped by category
|
|
| 154 |
"OptimizationResult",
|
| 155 |
"CandidateSummary",
|
| 156 |
"TaskResult",
|
|
|
|
| 157 |
"load_tasks_from_jsonl",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
]
|
|
|
|
| 27 |
strategy = GridSearchStrategy(variations={
|
| 28 |
"enable_memory": [True, False],
|
| 29 |
})
|
| 30 |
+
candidates = await strategy.generate(base, budget=10)
|
| 31 |
|
| 32 |
# Run optimization
|
| 33 |
optimizer = FlowOptimizer(parallel=4)
|
|
|
|
| 37 |
"""
|
| 38 |
|
| 39 |
# Core models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
# Experiment runner + Pareto analysis
|
| 41 |
from .ablation import (
|
| 42 |
compute_pareto_frontier,
|
|
|
|
| 54 |
TraceEvaluator,
|
| 55 |
)
|
| 56 |
|
| 57 |
+
# Expansion pipeline
|
| 58 |
+
from .expansion import expand_variations, generate_candidates
|
| 59 |
+
|
| 60 |
+
# HF Dataset Integration
|
| 61 |
+
from .hf_datasets import (
|
| 62 |
+
import_hf_dataset,
|
| 63 |
+
register_converter,
|
| 64 |
+
save_tasks_to_jsonl,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
# Metrics
|
| 68 |
from .metrics import (
|
| 69 |
LLMCallInfo,
|
|
|
|
| 73 |
format_metrics_summary,
|
| 74 |
metrics_to_dict,
|
| 75 |
)
|
| 76 |
+
from .models import (
|
| 77 |
+
Agent,
|
| 78 |
+
Candidate,
|
| 79 |
+
CandidateStrategy,
|
| 80 |
+
CompactionConfig,
|
| 81 |
+
Experiment,
|
| 82 |
+
ExperimentResult,
|
| 83 |
+
ExperimentRunner,
|
| 84 |
+
Framework,
|
| 85 |
+
GridSearchStrategy,
|
| 86 |
+
LiteralVariation,
|
| 87 |
+
StrategyIteration,
|
| 88 |
+
StrategyVariation,
|
| 89 |
+
VariationItem,
|
| 90 |
+
compute_max_experiments,
|
| 91 |
+
export_agent,
|
| 92 |
+
export_optimization_results,
|
| 93 |
+
load_agent,
|
| 94 |
+
load_experiment,
|
| 95 |
+
)
|
| 96 |
|
| 97 |
# Optimizer
|
| 98 |
from .optimizer import (
|
|
|
|
| 100 |
FlowOptimizer,
|
| 101 |
OptimizationResult,
|
| 102 |
TaskResult,
|
| 103 |
+
evaluate_agent,
|
| 104 |
load_tasks_from_jsonl,
|
| 105 |
)
|
| 106 |
|
|
|
|
| 115 |
)
|
| 116 |
|
| 117 |
# Runner
|
| 118 |
+
from .runner import FlowExperimentRunner, get_shared_collector, setup_tracing
|
| 119 |
+
|
| 120 |
+
# Strategy registry
|
| 121 |
+
from .strategies import get_registered_strategies, get_strategy, register_strategy
|
| 122 |
+
|
| 123 |
+
# Presets
|
| 124 |
+
from .presets import AgentPreset, get_all_presets, get_preset
|
| 125 |
+
|
| 126 |
+
# Results (simple API)
|
| 127 |
+
from .results import (
|
| 128 |
+
AgentOptimizationResult,
|
| 129 |
+
EvaluationResult,
|
| 130 |
+
ImprovementMetrics,
|
| 131 |
+
)
|
| 132 |
|
| 133 |
# Trace collection
|
| 134 |
from .trace_collector import FlowTraceCollector
|
| 135 |
+
from .types import CriterionResult, EvalCriterion, EvalResult, RunResult, Task, get_task_suite
|
| 136 |
|
| 137 |
__all__ = [ # noqa: RUF022 # Intentionally grouped by category
|
| 138 |
# Core models
|
|
|
|
| 140 |
"Candidate",
|
| 141 |
"CandidateStrategy",
|
| 142 |
"CompactionConfig",
|
| 143 |
+
"Experiment",
|
| 144 |
"ExperimentResult",
|
| 145 |
+
"ExperimentRunner",
|
| 146 |
+
"Framework",
|
| 147 |
"GridSearchStrategy",
|
| 148 |
+
"LiteralVariation",
|
| 149 |
+
"StrategyIteration",
|
| 150 |
+
"StrategyVariation",
|
| 151 |
+
"VariationItem",
|
| 152 |
+
"compute_max_experiments",
|
| 153 |
"export_agent",
|
| 154 |
"load_agent",
|
| 155 |
+
"load_experiment",
|
| 156 |
"export_optimization_results",
|
| 157 |
+
# Expansion pipeline
|
| 158 |
+
"expand_variations",
|
| 159 |
+
"generate_candidates",
|
| 160 |
+
# Strategy registry
|
| 161 |
+
"get_strategy",
|
| 162 |
+
"register_strategy",
|
| 163 |
+
"get_registered_strategies",
|
| 164 |
# Types
|
| 165 |
"Task",
|
| 166 |
"EvalCriterion",
|
|
|
|
| 178 |
"metrics_to_dict",
|
| 179 |
# Runner
|
| 180 |
"FlowExperimentRunner",
|
| 181 |
+
"get_shared_collector",
|
| 182 |
"setup_tracing",
|
| 183 |
# Evaluators
|
| 184 |
"Evaluator",
|
|
|
|
| 203 |
"OptimizationResult",
|
| 204 |
"CandidateSummary",
|
| 205 |
"TaskResult",
|
| 206 |
+
"evaluate_agent",
|
| 207 |
"load_tasks_from_jsonl",
|
| 208 |
+
# Presets
|
| 209 |
+
"AgentPreset",
|
| 210 |
+
"get_preset",
|
| 211 |
+
"get_all_presets",
|
| 212 |
+
# Results (simple API)
|
| 213 |
+
"EvaluationResult",
|
| 214 |
+
"AgentOptimizationResult",
|
| 215 |
+
"ImprovementMetrics",
|
| 216 |
+
# Task suites
|
| 217 |
+
"get_task_suite",
|
| 218 |
+
# HF Datasets
|
| 219 |
+
"import_hf_dataset",
|
| 220 |
+
"register_converter",
|
| 221 |
+
"save_tasks_to_jsonl",
|
| 222 |
]
|
src/flow/experiments/ablation.py
CHANGED
|
@@ -46,9 +46,13 @@ async def run_single_experiment(
|
|
| 46 |
ExperimentResult with metrics and evaluation
|
| 47 |
"""
|
| 48 |
# Import harness modules to register them, then use registry
|
| 49 |
-
import flow.harness.maf
|
|
|
|
|
|
|
| 50 |
try:
|
| 51 |
-
import flow.harness.miniagent
|
|
|
|
|
|
|
| 52 |
except ImportError:
|
| 53 |
pass # miniagent harness is optional
|
| 54 |
from flow.harness import create_harness
|
|
|
|
| 46 |
ExperimentResult with metrics and evaluation
|
| 47 |
"""
|
| 48 |
# Import harness modules to register them, then use registry
|
| 49 |
+
import flow.harness.maf as _maf
|
| 50 |
+
|
| 51 |
+
_ = _maf
|
| 52 |
try:
|
| 53 |
+
import flow.harness.miniagent as _miniagent
|
| 54 |
+
|
| 55 |
+
_ = _miniagent
|
| 56 |
except ImportError:
|
| 57 |
pass # miniagent harness is optional
|
| 58 |
from flow.harness import create_harness
|
src/flow/experiments/agent_api.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft. All rights reserved.
|
| 2 |
+
|
| 3 |
+
"""Implementation of Agent.evaluate() and Agent.optimize() methods.
|
| 4 |
+
|
| 5 |
+
This module contains the implementation details, keeping the Agent class
|
| 6 |
+
itself clean and focused on configuration.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import contextlib
|
| 12 |
+
import io
|
| 13 |
+
import sys
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import TYPE_CHECKING, Any
|
| 16 |
+
|
| 17 |
+
from .models import Candidate, CompactionConfig, GridSearchStrategy
|
| 18 |
+
from .optimizer import FlowOptimizer, OptimizationResult, evaluate_agent
|
| 19 |
+
from .results import AgentOptimizationResult, EvaluationResult, ImprovementMetrics
|
| 20 |
+
from .types import Task, get_task_suite, load_tasks_from_jsonl
|
| 21 |
+
|
| 22 |
+
if TYPE_CHECKING:
|
| 23 |
+
from .models import Agent
|
| 24 |
+
from .optimizer import CandidateSummary
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# Default variations for optimize() when none provided
|
| 28 |
+
DEFAULT_VARIATIONS: dict[str, list[Any]] = {
|
| 29 |
+
"tools": ["minimal", "standard"],
|
| 30 |
+
"compaction": [
|
| 31 |
+
CompactionConfig.none(),
|
| 32 |
+
CompactionConfig.head_tail(10, 40),
|
| 33 |
+
CompactionConfig.sliding_window(100_000),
|
| 34 |
+
],
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
# Known active strategy names and their classes
|
| 38 |
+
_STRATEGY_MAP: dict[str, str] = {
|
| 39 |
+
"tools": "flow.experiments.strategies.tool_selector.ToolSelectorStrategy",
|
| 40 |
+
"instructions": "flow.experiments.strategies.llm_rewriter.LLMRewriterStrategy",
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _resolve_tasks(tasks: str | list[Task] | Path) -> list[Task]:
|
| 45 |
+
"""Resolve tasks specification to list of Task objects.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
tasks: One of:
|
| 49 |
+
- str: Suite name (e.g., "quick", "coding", "gaia_level1")
|
| 50 |
+
- list[Task]: Already resolved tasks
|
| 51 |
+
- Path: Path to JSONL file
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
List of Task objects
|
| 55 |
+
"""
|
| 56 |
+
if isinstance(tasks, str):
|
| 57 |
+
return get_task_suite(tasks)
|
| 58 |
+
elif isinstance(tasks, Path):
|
| 59 |
+
return load_tasks_from_jsonl(tasks)
|
| 60 |
+
else:
|
| 61 |
+
return tasks
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _summary_to_eval_result(summary: CandidateSummary) -> EvaluationResult:
|
| 65 |
+
"""Convert internal CandidateSummary to user-friendly EvaluationResult."""
|
| 66 |
+
return EvaluationResult(
|
| 67 |
+
score=summary.avg_score,
|
| 68 |
+
tokens=summary.total_tokens,
|
| 69 |
+
pass_rate=summary.pass_rate,
|
| 70 |
+
duration=summary.avg_duration * summary.task_count,
|
| 71 |
+
task_count=summary.task_count,
|
| 72 |
+
_details=summary,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@contextlib.contextmanager
|
| 77 |
+
def _suppress_output():
|
| 78 |
+
"""Context manager to suppress stdout/stderr."""
|
| 79 |
+
old_stdout, old_stderr = sys.stdout, sys.stderr
|
| 80 |
+
sys.stdout = io.StringIO()
|
| 81 |
+
sys.stderr = io.StringIO()
|
| 82 |
+
try:
|
| 83 |
+
yield
|
| 84 |
+
finally:
|
| 85 |
+
sys.stdout, sys.stderr = old_stdout, old_stderr
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
async def _evaluate_agent_impl(
|
| 89 |
+
agent: Agent,
|
| 90 |
+
tasks: str | list[Task] | Path,
|
| 91 |
+
parallel: int,
|
| 92 |
+
use_llm_eval: bool,
|
| 93 |
+
quiet: bool,
|
| 94 |
+
agent_id: str | None = None,
|
| 95 |
+
) -> EvaluationResult:
|
| 96 |
+
"""Implementation of Agent.evaluate().
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
agent_id: If set (from deploy()), results are auto-persisted to DB.
|
| 100 |
+
"""
|
| 101 |
+
resolved_tasks = _resolve_tasks(tasks)
|
| 102 |
+
|
| 103 |
+
if quiet:
|
| 104 |
+
with _suppress_output():
|
| 105 |
+
summary = await evaluate_agent(
|
| 106 |
+
agent,
|
| 107 |
+
resolved_tasks,
|
| 108 |
+
parallel=parallel,
|
| 109 |
+
use_llm_evaluator=use_llm_eval,
|
| 110 |
+
)
|
| 111 |
+
else:
|
| 112 |
+
summary = await evaluate_agent(
|
| 113 |
+
agent,
|
| 114 |
+
resolved_tasks,
|
| 115 |
+
parallel=parallel,
|
| 116 |
+
use_llm_evaluator=use_llm_eval,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
result = _summary_to_eval_result(summary)
|
| 120 |
+
|
| 121 |
+
# Auto-persist if agent was deployed
|
| 122 |
+
if agent_id is not None:
|
| 123 |
+
try:
|
| 124 |
+
from flow.ui.services.persistence_adapter import PersistenceAdapter
|
| 125 |
+
|
| 126 |
+
adapter = PersistenceAdapter()
|
| 127 |
+
result.job_id = await adapter.persist_evaluation(summary, agent_id)
|
| 128 |
+
except ImportError:
|
| 129 |
+
pass # DB not available, skip persistence
|
| 130 |
+
|
| 131 |
+
return result
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def _resolve_strategy(name: str) -> Any:
|
| 135 |
+
"""Import and instantiate a named strategy.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
name: Strategy name ("tools", "instructions")
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
Strategy instance
|
| 142 |
+
|
| 143 |
+
Raises:
|
| 144 |
+
ValueError: If name is not a known strategy
|
| 145 |
+
"""
|
| 146 |
+
if name not in _STRATEGY_MAP:
|
| 147 |
+
available = ["grid"] + list(_STRATEGY_MAP.keys())
|
| 148 |
+
raise ValueError(f"Unknown strategy: {name!r}. Available: {available}")
|
| 149 |
+
|
| 150 |
+
module_path, class_name = _STRATEGY_MAP[name].rsplit(".", 1)
|
| 151 |
+
import importlib
|
| 152 |
+
mod = importlib.import_module(module_path)
|
| 153 |
+
cls = getattr(mod, class_name)
|
| 154 |
+
return cls(config={
|
| 155 |
+
"max_iterations": 3,
|
| 156 |
+
"min_improvement": 0.01,
|
| 157 |
+
})
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def _opt_result_to_agent_result(
|
| 161 |
+
opt_result: OptimizationResult,
|
| 162 |
+
baseline_agent: Agent,
|
| 163 |
+
) -> AgentOptimizationResult:
|
| 164 |
+
"""Convert internal OptimizationResult to user-friendly AgentOptimizationResult."""
|
| 165 |
+
# Find baseline: look for the original agent name with no mutations, else first summary
|
| 166 |
+
baseline_summary = next(
|
| 167 |
+
(s for s in opt_result.summaries if s.name == baseline_agent.name and s.candidate.mutations == {}),
|
| 168 |
+
None,
|
| 169 |
+
)
|
| 170 |
+
best_summary = opt_result.get_best_candidate("score") or opt_result.summaries[0]
|
| 171 |
+
|
| 172 |
+
if baseline_summary is not None:
|
| 173 |
+
baseline_result = _summary_to_eval_result(baseline_summary)
|
| 174 |
+
else:
|
| 175 |
+
# For active strategies, the baseline is the first iteration in optimization_history
|
| 176 |
+
history = best_summary.candidate.optimization_history
|
| 177 |
+
if history:
|
| 178 |
+
baseline_result = EvaluationResult(
|
| 179 |
+
score=history[0].avg_score,
|
| 180 |
+
tokens=0,
|
| 181 |
+
pass_rate=history[0].pass_rate,
|
| 182 |
+
duration=0.0,
|
| 183 |
+
task_count=best_summary.task_count or len(history[0].change_description),
|
| 184 |
+
)
|
| 185 |
+
else:
|
| 186 |
+
baseline_result = _summary_to_eval_result(best_summary)
|
| 187 |
+
|
| 188 |
+
best_result = _summary_to_eval_result(best_summary)
|
| 189 |
+
|
| 190 |
+
score_delta = best_result.score - baseline_result.score
|
| 191 |
+
token_reduction_pct = 0.0
|
| 192 |
+
if baseline_result.tokens > 0:
|
| 193 |
+
token_reduction_pct = (baseline_result.tokens - best_result.tokens) / baseline_result.tokens * 100
|
| 194 |
+
|
| 195 |
+
return AgentOptimizationResult(
|
| 196 |
+
baseline=baseline_result,
|
| 197 |
+
best=best_result,
|
| 198 |
+
improvement=ImprovementMetrics(
|
| 199 |
+
score_delta=score_delta,
|
| 200 |
+
token_reduction_pct=token_reduction_pct,
|
| 201 |
+
),
|
| 202 |
+
best_agent=best_summary.candidate.agent,
|
| 203 |
+
candidates_tested=len(opt_result.summaries),
|
| 204 |
+
pareto_frontier=opt_result.pareto_frontier,
|
| 205 |
+
output_dir=opt_result.output_dir,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
async def _optimize_agent_impl(
|
| 210 |
+
agent: Agent,
|
| 211 |
+
tasks: str | list[Task] | Path,
|
| 212 |
+
variations: dict[str, list[Any]] | None,
|
| 213 |
+
parallel: int,
|
| 214 |
+
budget: int,
|
| 215 |
+
use_llm_eval: bool,
|
| 216 |
+
quiet: bool,
|
| 217 |
+
agent_id: str | None = None,
|
| 218 |
+
strategy: str | list[str] | None = None,
|
| 219 |
+
) -> AgentOptimizationResult:
|
| 220 |
+
"""Implementation of Agent.optimize().
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
agent_id: If set (from deploy()), results are auto-persisted to DB.
|
| 224 |
+
strategy: Active optimization strategy name(s). None or "grid" uses
|
| 225 |
+
grid search. A string like "tools" or "instructions" runs that
|
| 226 |
+
strategy. A list runs them sequentially, each starting from the
|
| 227 |
+
previous best.
|
| 228 |
+
"""
|
| 229 |
+
resolved_tasks = _resolve_tasks(tasks)
|
| 230 |
+
|
| 231 |
+
# Normalize strategy to a list (or None for grid search)
|
| 232 |
+
if strategy is None or strategy == "grid":
|
| 233 |
+
strategy_list = None
|
| 234 |
+
elif isinstance(strategy, str):
|
| 235 |
+
strategy_list = [strategy]
|
| 236 |
+
else:
|
| 237 |
+
strategy_list = list(strategy)
|
| 238 |
+
|
| 239 |
+
# ── Grid search path (original behavior) ──
|
| 240 |
+
if strategy_list is None:
|
| 241 |
+
actual_variations = variations if variations is not None else DEFAULT_VARIATIONS
|
| 242 |
+
|
| 243 |
+
grid_strategy = GridSearchStrategy(variations=actual_variations)
|
| 244 |
+
candidates = await grid_strategy.generate(agent, budget=budget)
|
| 245 |
+
|
| 246 |
+
baseline_candidate = Candidate(agent=agent, mutations={}, rationale="baseline")
|
| 247 |
+
has_baseline = any(c.agent.name == agent.name and c.mutations == {} for c in candidates)
|
| 248 |
+
if not has_baseline:
|
| 249 |
+
candidates.insert(0, baseline_candidate)
|
| 250 |
+
|
| 251 |
+
optimizer = FlowOptimizer(parallel=parallel, use_llm_evaluator=use_llm_eval)
|
| 252 |
+
|
| 253 |
+
if quiet:
|
| 254 |
+
with _suppress_output():
|
| 255 |
+
opt_result = await optimizer.optimize(candidates, resolved_tasks)
|
| 256 |
+
else:
|
| 257 |
+
opt_result = await optimizer.optimize(candidates, resolved_tasks)
|
| 258 |
+
|
| 259 |
+
result = _opt_result_to_agent_result(opt_result, agent)
|
| 260 |
+
|
| 261 |
+
# ── Active strategy path ──
|
| 262 |
+
else:
|
| 263 |
+
current_agent = agent
|
| 264 |
+
last_opt_result: OptimizationResult | None = None
|
| 265 |
+
|
| 266 |
+
for strat_name in strategy_list:
|
| 267 |
+
strat_instance = _resolve_strategy(strat_name)
|
| 268 |
+
optimizer = FlowOptimizer(parallel=parallel, use_llm_evaluator=use_llm_eval)
|
| 269 |
+
|
| 270 |
+
if quiet:
|
| 271 |
+
with _suppress_output():
|
| 272 |
+
last_opt_result = await optimizer.optimize_with_strategy(
|
| 273 |
+
strategy=strat_instance,
|
| 274 |
+
base=current_agent,
|
| 275 |
+
tasks=resolved_tasks,
|
| 276 |
+
budget=budget,
|
| 277 |
+
)
|
| 278 |
+
else:
|
| 279 |
+
last_opt_result = await optimizer.optimize_with_strategy(
|
| 280 |
+
strategy=strat_instance,
|
| 281 |
+
base=current_agent,
|
| 282 |
+
tasks=resolved_tasks,
|
| 283 |
+
budget=budget,
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# Next stage starts from the best agent found
|
| 287 |
+
best = last_opt_result.get_best_candidate("score")
|
| 288 |
+
if best:
|
| 289 |
+
current_agent = best.candidate.agent
|
| 290 |
+
|
| 291 |
+
assert last_opt_result is not None
|
| 292 |
+
result = _opt_result_to_agent_result(last_opt_result, agent)
|
| 293 |
+
|
| 294 |
+
# Auto-persist if agent was deployed
|
| 295 |
+
if agent_id is not None:
|
| 296 |
+
try:
|
| 297 |
+
from flow.ui.services.persistence_adapter import PersistenceAdapter
|
| 298 |
+
|
| 299 |
+
adapter = PersistenceAdapter()
|
| 300 |
+
opt_to_persist = opt_result if strategy_list is None else last_opt_result
|
| 301 |
+
result.job_id = await adapter.persist_optimization(opt_to_persist, agent_id)
|
| 302 |
+
except ImportError:
|
| 303 |
+
pass # DB not available, skip persistence
|
| 304 |
+
|
| 305 |
+
return result
|
src/flow/experiments/evaluators/heuristic.py
CHANGED
|
@@ -73,7 +73,7 @@ class HeuristicEvaluator:
|
|
| 73 |
|
| 74 |
# Check if agent reported task complete
|
| 75 |
output_lower = run_result.output.lower()
|
| 76 |
-
if "complete" in output_lower or "
|
| 77 |
criteria_results.append(
|
| 78 |
CriterionResult(
|
| 79 |
name="task_completed",
|
|
|
|
| 73 |
|
| 74 |
# Check if agent reported task complete
|
| 75 |
output_lower = run_result.output.lower()
|
| 76 |
+
if "complete" in output_lower or "finished" in output_lower:
|
| 77 |
criteria_results.append(
|
| 78 |
CriterionResult(
|
| 79 |
name="task_completed",
|
src/flow/experiments/evaluators/llm.py
CHANGED
|
@@ -11,6 +11,21 @@ from ..types import CriterionResult, EvalResult, RunResult
|
|
| 11 |
|
| 12 |
logger = logging.getLogger(__name__)
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
class LLMEvaluator:
|
| 16 |
"""Evaluator that uses an LLM to assess agent output against criteria.
|
|
@@ -39,6 +54,7 @@ class LLMEvaluator:
|
|
| 39 |
model_name: str = "gpt-4o",
|
| 40 |
passing_threshold: float = 0.7,
|
| 41 |
temperature: float | None = None,
|
|
|
|
| 42 |
) -> None:
|
| 43 |
"""Initialize the LLM evaluator.
|
| 44 |
|
|
@@ -50,13 +66,56 @@ class LLMEvaluator:
|
|
| 50 |
temperature: Temperature for LLM calls. None means don't specify
|
| 51 |
(use model default). Some models like gpt-5.2-chat
|
| 52 |
only support temperature=1.0.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
"""
|
| 54 |
self.model_client = model_client
|
| 55 |
self.model_name = model_name
|
| 56 |
self.passing_threshold = passing_threshold
|
| 57 |
self.temperature = temperature
|
| 58 |
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
"""Build the evaluation prompt for the LLM."""
|
| 61 |
criteria_text = "\n".join(
|
| 62 |
f"- **{c.name}** (weight: {c.weight}): {c.instruction}"
|
|
@@ -66,6 +125,8 @@ class LLMEvaluator:
|
|
| 66 |
# Extract execution trace summary for research/multi-step tasks
|
| 67 |
trace_summary = self._get_trace_summary(run_result)
|
| 68 |
|
|
|
|
|
|
|
| 69 |
return f"""You are an expert evaluator assessing an AI agent's output.
|
| 70 |
|
| 71 |
## Task
|
|
@@ -73,15 +134,18 @@ The agent was given this task:
|
|
| 73 |
```
|
| 74 |
{run_result.task.prompt}
|
| 75 |
```
|
| 76 |
-
|
| 77 |
## Agent Output
|
| 78 |
```
|
| 79 |
-
{run_result.output
|
| 80 |
```
|
| 81 |
|
| 82 |
## Files Created
|
| 83 |
{json.dumps(run_result.files_created, indent=2) if run_result.files_created else "None"}
|
| 84 |
|
|
|
|
|
|
|
|
|
|
| 85 |
## Execution Trace
|
| 86 |
{trace_summary}
|
| 87 |
|
|
@@ -95,27 +159,61 @@ The agent was given this task:
|
|
| 95 |
Evaluate the agent's output against each criterion. Consider both the final output AND the execution
|
| 96 |
trace (tools used, steps taken) when assessing correctness.
|
| 97 |
|
| 98 |
-
For each criterion:
|
| 99 |
-
1.
|
| 100 |
-
2.
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
{{
|
| 108 |
-
"name": "criterion_name",
|
| 109 |
-
"score": 0.85,
|
| 110 |
-
"passed": true,
|
| 111 |
-
"reasoning": "Brief explanation"
|
| 112 |
-
}}
|
| 113 |
-
],
|
| 114 |
-
"overall_reasoning": "Summary of the overall evaluation"
|
| 115 |
-
}}
|
| 116 |
-
```
|
| 117 |
"""
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
def _get_trace_summary(self, run_result: RunResult) -> str:
|
| 120 |
"""Extract a summary of the execution trace for evaluation."""
|
| 121 |
if not run_result.trace:
|
|
@@ -137,11 +235,12 @@ Total tool calls: {metrics.tool_call_count}
|
|
| 137 |
{tool_summary}
|
| 138 |
Tokens used: {metrics.total_tokens} (input: {metrics.input_tokens}, output: {metrics.output_tokens})"""
|
| 139 |
|
| 140 |
-
async def evaluate(self, run_result: RunResult) -> EvalResult:
|
| 141 |
"""Evaluate the agent's output using an LLM.
|
| 142 |
|
| 143 |
Args:
|
| 144 |
run_result: The result from running an agent on a task
|
|
|
|
| 145 |
|
| 146 |
Returns:
|
| 147 |
EvalResult with LLM-generated scores and reasoning
|
|
@@ -158,45 +257,44 @@ Tokens used: {metrics.total_tokens} (input: {metrics.input_tokens}, output: {met
|
|
| 158 |
),
|
| 159 |
)
|
| 160 |
|
| 161 |
-
prompt = self._get_evaluation_prompt(run_result)
|
| 162 |
|
| 163 |
try:
|
| 164 |
-
# Build params
|
| 165 |
params: dict[str, Any] = {
|
| 166 |
"model": self.model_name,
|
| 167 |
"messages": [
|
| 168 |
{
|
| 169 |
"role": "system",
|
| 170 |
-
"content": "You are an expert evaluator.
|
| 171 |
},
|
| 172 |
{"role": "user", "content": prompt},
|
| 173 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
}
|
| 175 |
if self.temperature is not None:
|
| 176 |
params["temperature"] = self.temperature
|
| 177 |
|
| 178 |
response = await self.model_client.chat.completions.create(**params)
|
| 179 |
|
| 180 |
-
# Extract
|
| 181 |
-
response_text = response.choices[0].message.content or ""
|
| 182 |
-
|
| 183 |
-
# Parse JSON from response
|
| 184 |
-
json_start = response_text.find("{")
|
| 185 |
-
json_end = response_text.rfind("}") + 1
|
| 186 |
-
if json_start >= 0 and json_end > json_start:
|
| 187 |
-
eval_data = json.loads(response_text[json_start:json_end])
|
| 188 |
-
else:
|
| 189 |
-
raise ValueError("No JSON found in response")
|
| 190 |
|
| 191 |
# Build criterion results
|
| 192 |
criteria_results = []
|
| 193 |
total_weighted_score = 0.0
|
|
|
|
| 194 |
total_weight = 0.0
|
| 195 |
|
| 196 |
for cr_data in eval_data.get("criteria_results", []):
|
| 197 |
cr = CriterionResult(
|
| 198 |
name=cr_data.get("name", "unknown"),
|
| 199 |
score=float(cr_data.get("score", 0.0)),
|
|
|
|
| 200 |
passed=bool(cr_data.get("passed", False)),
|
| 201 |
reasoning=cr_data.get("reasoning", ""),
|
| 202 |
)
|
|
@@ -210,13 +308,16 @@ Tokens used: {metrics.total_tokens} (input: {metrics.input_tokens}, output: {met
|
|
| 210 |
break
|
| 211 |
|
| 212 |
total_weighted_score += cr.score * weight
|
|
|
|
| 213 |
total_weight += weight
|
| 214 |
|
| 215 |
-
# Calculate overall
|
| 216 |
overall_score = total_weighted_score / total_weight if total_weight > 0 else 0.0
|
|
|
|
| 217 |
|
| 218 |
return EvalResult(
|
| 219 |
score=overall_score,
|
|
|
|
| 220 |
passed=overall_score >= self.passing_threshold,
|
| 221 |
criteria_results=criteria_results,
|
| 222 |
reasoning=eval_data.get("overall_reasoning", ""),
|
|
|
|
| 11 |
|
| 12 |
logger = logging.getLogger(__name__)
|
| 13 |
|
| 14 |
+
# Presets for how agent output is formatted before sending to the LLM judge.
|
| 15 |
+
# Agent outputs can be very large (100K-600K+ chars for multi-step tasks).
|
| 16 |
+
# The final answer is almost always at the end, so "head_tail" (default)
|
| 17 |
+
# keeps both the initial approach and the final answer visible to the judge.
|
| 18 |
+
#
|
| 19 |
+
# Each preset specifies {"head": N, "tail": M} where N chars from the start
|
| 20 |
+
# and M chars from the end are kept. When truncation occurs, a marker like
|
| 21 |
+
# "... [150,000 chars truncated] ..." is inserted.
|
| 22 |
+
OUTPUT_FORMAT_PRESETS: dict[str, dict[str, int]] = {
|
| 23 |
+
"head_tail": {"head": 2000, "tail": 10000}, # Default: sees start + final answer
|
| 24 |
+
"head_only": {"head": 8000, "tail": 0}, # Legacy: first 8K only
|
| 25 |
+
"tail_only": {"head": 0, "tail": 12000}, # Only the final output
|
| 26 |
+
"full": {"head": 0, "tail": 0}, # No truncation (watch context limits)
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
|
| 30 |
class LLMEvaluator:
|
| 31 |
"""Evaluator that uses an LLM to assess agent output against criteria.
|
|
|
|
| 54 |
model_name: str = "gpt-4o",
|
| 55 |
passing_threshold: float = 0.7,
|
| 56 |
temperature: float | None = None,
|
| 57 |
+
output_format: str | dict[str, int] = "head_tail",
|
| 58 |
) -> None:
|
| 59 |
"""Initialize the LLM evaluator.
|
| 60 |
|
|
|
|
| 66 |
temperature: Temperature for LLM calls. None means don't specify
|
| 67 |
(use model default). Some models like gpt-5.2-chat
|
| 68 |
only support temperature=1.0.
|
| 69 |
+
output_format: How to format agent output for the judge. Either a
|
| 70 |
+
preset name ("head_tail", "head_only", "tail_only", "full")
|
| 71 |
+
or a dict with "head" and "tail" char counts.
|
| 72 |
+
See OUTPUT_FORMAT_PRESETS for details.
|
| 73 |
"""
|
| 74 |
self.model_client = model_client
|
| 75 |
self.model_name = model_name
|
| 76 |
self.passing_threshold = passing_threshold
|
| 77 |
self.temperature = temperature
|
| 78 |
|
| 79 |
+
# Resolve output format
|
| 80 |
+
if isinstance(output_format, str):
|
| 81 |
+
if output_format not in OUTPUT_FORMAT_PRESETS:
|
| 82 |
+
raise ValueError(
|
| 83 |
+
f"Unknown output_format '{output_format}'. "
|
| 84 |
+
f"Available: {list(OUTPUT_FORMAT_PRESETS.keys())}"
|
| 85 |
+
)
|
| 86 |
+
fmt = OUTPUT_FORMAT_PRESETS[output_format]
|
| 87 |
+
else:
|
| 88 |
+
fmt = output_format
|
| 89 |
+
self._output_head = fmt["head"]
|
| 90 |
+
self._output_tail = fmt["tail"]
|
| 91 |
+
|
| 92 |
+
def _format_output(self, output: str) -> str:
|
| 93 |
+
"""Format agent output for the evaluation prompt.
|
| 94 |
+
|
| 95 |
+
Uses head+tail truncation to ensure the judge sees both the initial
|
| 96 |
+
approach and the final answer. When output is truncated, a marker is
|
| 97 |
+
inserted showing how many characters were removed.
|
| 98 |
+
|
| 99 |
+
The strategy is configured via the output_format parameter on __init__.
|
| 100 |
+
"""
|
| 101 |
+
head = self._output_head
|
| 102 |
+
tail = self._output_tail
|
| 103 |
+
budget = head + tail
|
| 104 |
+
|
| 105 |
+
# No truncation if budget is 0 (full mode) or output fits
|
| 106 |
+
if budget == 0 or len(output) <= budget:
|
| 107 |
+
return output
|
| 108 |
+
|
| 109 |
+
parts: list[str] = []
|
| 110 |
+
if head > 0:
|
| 111 |
+
parts.append(output[:head])
|
| 112 |
+
truncated = len(output) - budget
|
| 113 |
+
parts.append(f"\n\n... [{truncated:,} chars truncated] ...\n\n")
|
| 114 |
+
if tail > 0:
|
| 115 |
+
parts.append(output[-tail:])
|
| 116 |
+
return "".join(parts)
|
| 117 |
+
|
| 118 |
+
def _get_evaluation_prompt(self, run_result: RunResult, instructions: str | None = None) -> str:
|
| 119 |
"""Build the evaluation prompt for the LLM."""
|
| 120 |
criteria_text = "\n".join(
|
| 121 |
f"- **{c.name}** (weight: {c.weight}): {c.instruction}"
|
|
|
|
| 125 |
# Extract execution trace summary for research/multi-step tasks
|
| 126 |
trace_summary = self._get_trace_summary(run_result)
|
| 127 |
|
| 128 |
+
instructions_section = f"\n## Agent Instructions\n```\n{instructions}\n```\n" if instructions else ""
|
| 129 |
+
|
| 130 |
return f"""You are an expert evaluator assessing an AI agent's output.
|
| 131 |
|
| 132 |
## Task
|
|
|
|
| 134 |
```
|
| 135 |
{run_result.task.prompt}
|
| 136 |
```
|
| 137 |
+
{instructions_section}
|
| 138 |
## Agent Output
|
| 139 |
```
|
| 140 |
+
{self._format_output(run_result.output)}
|
| 141 |
```
|
| 142 |
|
| 143 |
## Files Created
|
| 144 |
{json.dumps(run_result.files_created, indent=2) if run_result.files_created else "None"}
|
| 145 |
|
| 146 |
+
## Tool Results
|
| 147 |
+
{self._format_tool_results(run_result.tool_results)}
|
| 148 |
+
|
| 149 |
## Execution Trace
|
| 150 |
{trace_summary}
|
| 151 |
|
|
|
|
| 159 |
Evaluate the agent's output against each criterion. Consider both the final output AND the execution
|
| 160 |
trace (tools used, steps taken) when assessing correctness.
|
| 161 |
|
| 162 |
+
For each criterion, provide TWO scores:
|
| 163 |
+
1. **score** (0.0 or 1.0): Does the agent's final answer exactly match what's required? This is strict exact-match.
|
| 164 |
+
2. **reasoning_score** (0.0 to 1.0): Did the agent demonstrate correct reasoning/methodology? Give partial credit for:
|
| 165 |
+
- Correct approach but wrong format (e.g., "17000" when "17" was expected)
|
| 166 |
+
- Correct methodology but wrong final number
|
| 167 |
+
- Identifying the right sources/data but making a calculation error
|
| 168 |
+
- Partial completion of a multi-part task
|
| 169 |
+
3. **passed**: true if score >= 1.0 (exact match)
|
| 170 |
+
4. Provide brief reasoning explaining both scores
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
"""
|
| 172 |
|
| 173 |
+
def _get_eval_schema(self) -> dict[str, Any]:
|
| 174 |
+
"""Get JSON schema for structured evaluation output."""
|
| 175 |
+
return {
|
| 176 |
+
"name": "evaluation_result",
|
| 177 |
+
"strict": True,
|
| 178 |
+
"schema": {
|
| 179 |
+
"type": "object",
|
| 180 |
+
"properties": {
|
| 181 |
+
"criteria_results": {
|
| 182 |
+
"type": "array",
|
| 183 |
+
"items": {
|
| 184 |
+
"type": "object",
|
| 185 |
+
"properties": {
|
| 186 |
+
"name": {"type": "string"},
|
| 187 |
+
"score": {"type": "number"},
|
| 188 |
+
"reasoning_score": {"type": "number"},
|
| 189 |
+
"passed": {"type": "boolean"},
|
| 190 |
+
"reasoning": {"type": "string"},
|
| 191 |
+
},
|
| 192 |
+
"required": ["name", "score", "reasoning_score", "passed", "reasoning"],
|
| 193 |
+
"additionalProperties": False,
|
| 194 |
+
},
|
| 195 |
+
},
|
| 196 |
+
"overall_reasoning": {"type": "string"},
|
| 197 |
+
},
|
| 198 |
+
"required": ["criteria_results", "overall_reasoning"],
|
| 199 |
+
"additionalProperties": False,
|
| 200 |
+
},
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
def _format_tool_results(self, tool_results: list[dict[str, str]]) -> str:
|
| 204 |
+
"""Format tool results for the evaluation prompt."""
|
| 205 |
+
if not tool_results:
|
| 206 |
+
return "None"
|
| 207 |
+
lines = []
|
| 208 |
+
for tr in tool_results:
|
| 209 |
+
tool = tr.get("tool", "unknown")
|
| 210 |
+
output = tr.get("output", "")
|
| 211 |
+
# Truncate long tool outputs
|
| 212 |
+
if len(output) > 500:
|
| 213 |
+
output = output[:500] + f"... [{len(output) - 500} chars truncated]"
|
| 214 |
+
lines.append(f"**{tool}**:\n```\n{output}\n```")
|
| 215 |
+
return "\n".join(lines)
|
| 216 |
+
|
| 217 |
def _get_trace_summary(self, run_result: RunResult) -> str:
|
| 218 |
"""Extract a summary of the execution trace for evaluation."""
|
| 219 |
if not run_result.trace:
|
|
|
|
| 235 |
{tool_summary}
|
| 236 |
Tokens used: {metrics.total_tokens} (input: {metrics.input_tokens}, output: {metrics.output_tokens})"""
|
| 237 |
|
| 238 |
+
async def evaluate(self, run_result: RunResult, instructions: str | None = None) -> EvalResult:
|
| 239 |
"""Evaluate the agent's output using an LLM.
|
| 240 |
|
| 241 |
Args:
|
| 242 |
run_result: The result from running an agent on a task
|
| 243 |
+
instructions: Optional instructions used by the agent during the run
|
| 244 |
|
| 245 |
Returns:
|
| 246 |
EvalResult with LLM-generated scores and reasoning
|
|
|
|
| 257 |
),
|
| 258 |
)
|
| 259 |
|
| 260 |
+
prompt = self._get_evaluation_prompt(run_result, instructions=instructions)
|
| 261 |
|
| 262 |
try:
|
| 263 |
+
# Build params with structured output
|
| 264 |
params: dict[str, Any] = {
|
| 265 |
"model": self.model_name,
|
| 266 |
"messages": [
|
| 267 |
{
|
| 268 |
"role": "system",
|
| 269 |
+
"content": "You are an expert evaluator.",
|
| 270 |
},
|
| 271 |
{"role": "user", "content": prompt},
|
| 272 |
],
|
| 273 |
+
"response_format": {
|
| 274 |
+
"type": "json_schema",
|
| 275 |
+
"json_schema": self._get_eval_schema(),
|
| 276 |
+
},
|
| 277 |
}
|
| 278 |
if self.temperature is not None:
|
| 279 |
params["temperature"] = self.temperature
|
| 280 |
|
| 281 |
response = await self.model_client.chat.completions.create(**params)
|
| 282 |
|
| 283 |
+
# Extract and parse response - structured output guarantees valid JSON
|
| 284 |
+
response_text = response.choices[0].message.content or "{}"
|
| 285 |
+
eval_data = json.loads(response_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
# Build criterion results
|
| 288 |
criteria_results = []
|
| 289 |
total_weighted_score = 0.0
|
| 290 |
+
total_weighted_reasoning = 0.0
|
| 291 |
total_weight = 0.0
|
| 292 |
|
| 293 |
for cr_data in eval_data.get("criteria_results", []):
|
| 294 |
cr = CriterionResult(
|
| 295 |
name=cr_data.get("name", "unknown"),
|
| 296 |
score=float(cr_data.get("score", 0.0)),
|
| 297 |
+
reasoning_score=float(cr_data.get("reasoning_score", 0.0)),
|
| 298 |
passed=bool(cr_data.get("passed", False)),
|
| 299 |
reasoning=cr_data.get("reasoning", ""),
|
| 300 |
)
|
|
|
|
| 308 |
break
|
| 309 |
|
| 310 |
total_weighted_score += cr.score * weight
|
| 311 |
+
total_weighted_reasoning += cr.reasoning_score * weight
|
| 312 |
total_weight += weight
|
| 313 |
|
| 314 |
+
# Calculate overall scores
|
| 315 |
overall_score = total_weighted_score / total_weight if total_weight > 0 else 0.0
|
| 316 |
+
overall_reasoning_score = total_weighted_reasoning / total_weight if total_weight > 0 else 0.0
|
| 317 |
|
| 318 |
return EvalResult(
|
| 319 |
score=overall_score,
|
| 320 |
+
reasoning_score=overall_reasoning_score,
|
| 321 |
passed=overall_score >= self.passing_threshold,
|
| 322 |
criteria_results=criteria_results,
|
| 323 |
reasoning=eval_data.get("overall_reasoning", ""),
|
src/flow/experiments/expansion.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft. All rights reserved.
|
| 2 |
+
|
| 3 |
+
"""Variation expansion pipeline.
|
| 4 |
+
|
| 5 |
+
Expands experiment variations (literals + strategies) into concrete values,
|
| 6 |
+
then generates candidates via Cartesian product.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
from dataclasses import asdict
|
| 13 |
+
from itertools import product as itertools_product
|
| 14 |
+
from typing import TYPE_CHECKING, Any
|
| 15 |
+
|
| 16 |
+
from .models import (
|
| 17 |
+
Agent,
|
| 18 |
+
Candidate,
|
| 19 |
+
CompactionConfig,
|
| 20 |
+
ExperimentRunner,
|
| 21 |
+
LiteralVariation,
|
| 22 |
+
StrategyVariation,
|
| 23 |
+
VariationItem,
|
| 24 |
+
)
|
| 25 |
+
from .strategies import get_strategy
|
| 26 |
+
|
| 27 |
+
if TYPE_CHECKING:
|
| 28 |
+
from .types import Task
|
| 29 |
+
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
async def expand_variations(
|
| 34 |
+
variations: dict[str, list[VariationItem]],
|
| 35 |
+
base: Agent,
|
| 36 |
+
tasks: list[Task],
|
| 37 |
+
runner: ExperimentRunner | None = None,
|
| 38 |
+
) -> dict[str, list[Any]]:
|
| 39 |
+
"""Expand all variations to concrete values.
|
| 40 |
+
|
| 41 |
+
- LiteralVariation: value passes through directly
|
| 42 |
+
- StrategyVariation: strategy.generate() is called to produce values
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
variations: Parsed variations from Experiment
|
| 46 |
+
base: Base agent for strategies
|
| 47 |
+
tasks: Tasks for active strategies
|
| 48 |
+
runner: Optional ExperimentRunner for active strategies
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Dict mapping dimension names to lists of concrete values
|
| 52 |
+
"""
|
| 53 |
+
expanded: dict[str, list[Any]] = {}
|
| 54 |
+
|
| 55 |
+
for dimension, items in variations.items():
|
| 56 |
+
expanded[dimension] = []
|
| 57 |
+
logger.info(f"Expanding dimension '{dimension}' ({len(items)} items)")
|
| 58 |
+
|
| 59 |
+
for item in items:
|
| 60 |
+
if isinstance(item, LiteralVariation):
|
| 61 |
+
# Literal: add directly
|
| 62 |
+
expanded[dimension].append(item.value)
|
| 63 |
+
logger.debug(f" Literal: {_format_value(item.value)}")
|
| 64 |
+
|
| 65 |
+
elif isinstance(item, StrategyVariation):
|
| 66 |
+
# Strategy: invoke and collect results
|
| 67 |
+
logger.info(f" Running strategy '{item.strategy}' (max_candidates={item.max_candidates})")
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
strategy = get_strategy(item.strategy, item.config)
|
| 71 |
+
|
| 72 |
+
candidates = await strategy.generate(
|
| 73 |
+
base=base,
|
| 74 |
+
budget=item.max_candidates,
|
| 75 |
+
tasks=tasks,
|
| 76 |
+
runner=runner,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# Extract mutated values for this dimension
|
| 80 |
+
for cand in candidates:
|
| 81 |
+
if dimension in cand.mutations:
|
| 82 |
+
value = cand.mutations[dimension]
|
| 83 |
+
expanded[dimension].append(value)
|
| 84 |
+
logger.debug(f" Strategy produced: {_format_value(value)}")
|
| 85 |
+
else:
|
| 86 |
+
# Strategy didn't mutate this dimension, use base value
|
| 87 |
+
base_value = getattr(base, dimension, None)
|
| 88 |
+
if base_value is not None:
|
| 89 |
+
expanded[dimension].append(base_value)
|
| 90 |
+
logger.debug(f" Strategy kept base: {_format_value(base_value)}")
|
| 91 |
+
|
| 92 |
+
logger.info(f" Strategy '{item.strategy}' produced {len(candidates)} candidates")
|
| 93 |
+
|
| 94 |
+
except Exception as e:
|
| 95 |
+
logger.error(f" Strategy '{item.strategy}' failed: {e}")
|
| 96 |
+
raise
|
| 97 |
+
|
| 98 |
+
logger.info(f"Dimension '{dimension}': {len(expanded[dimension])} total values")
|
| 99 |
+
|
| 100 |
+
return expanded
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def generate_candidates(
|
| 104 |
+
base: Agent,
|
| 105 |
+
expanded: dict[str, list[Any]],
|
| 106 |
+
budget: int = 1000,
|
| 107 |
+
) -> list[Candidate]:
|
| 108 |
+
"""Generate candidates via Cartesian product of expanded variations.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
base: Base agent to mutate
|
| 112 |
+
expanded: Dict mapping dimension names to lists of concrete values
|
| 113 |
+
budget: Maximum candidates to generate (safety limit)
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
List of Candidate objects
|
| 117 |
+
"""
|
| 118 |
+
if not expanded:
|
| 119 |
+
return [Candidate(agent=base, mutations={}, rationale="baseline")]
|
| 120 |
+
|
| 121 |
+
dimensions = list(expanded.keys())
|
| 122 |
+
value_lists = [expanded[d] for d in dimensions]
|
| 123 |
+
|
| 124 |
+
# Check if any dimension is empty
|
| 125 |
+
for dim, values in zip(dimensions, value_lists, strict=True):
|
| 126 |
+
if not values:
|
| 127 |
+
logger.warning(f"Dimension '{dim}' has no values, using baseline")
|
| 128 |
+
return [Candidate(agent=base, mutations={}, rationale="baseline (empty variations)")]
|
| 129 |
+
|
| 130 |
+
candidates: list[Candidate] = []
|
| 131 |
+
|
| 132 |
+
for values in itertools_product(*value_lists):
|
| 133 |
+
if len(candidates) >= budget:
|
| 134 |
+
logger.warning(f"Reached budget limit ({budget}), stopping candidate generation")
|
| 135 |
+
break
|
| 136 |
+
|
| 137 |
+
mutations = dict(zip(dimensions, values, strict=True))
|
| 138 |
+
candidate = _create_candidate(base, mutations)
|
| 139 |
+
candidates.append(candidate)
|
| 140 |
+
|
| 141 |
+
logger.info(f"Generated {len(candidates)} candidates from {len(dimensions)} dimensions")
|
| 142 |
+
return candidates
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _create_candidate(base: Agent, mutations: dict[str, Any]) -> Candidate:
|
| 146 |
+
"""Create a candidate by applying mutations to base agent.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
base: Base agent
|
| 150 |
+
mutations: Dict of field name -> value
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
New Candidate with mutated agent
|
| 154 |
+
"""
|
| 155 |
+
# Build mutated agent dict
|
| 156 |
+
agent_dict = asdict(base)
|
| 157 |
+
|
| 158 |
+
for key, value in mutations.items():
|
| 159 |
+
if key == "compaction" and isinstance(value, CompactionConfig):
|
| 160 |
+
agent_dict["compaction"] = asdict(value)
|
| 161 |
+
elif key in agent_dict:
|
| 162 |
+
agent_dict[key] = value
|
| 163 |
+
|
| 164 |
+
# Reconstruct CompactionConfig from dict
|
| 165 |
+
comp_data = agent_dict.pop("compaction")
|
| 166 |
+
if isinstance(comp_data, dict):
|
| 167 |
+
compaction = CompactionConfig(**comp_data)
|
| 168 |
+
else:
|
| 169 |
+
compaction = comp_data
|
| 170 |
+
|
| 171 |
+
# Handle tools field - keep as-is (str, list, or dict)
|
| 172 |
+
tools = agent_dict.pop("tools", "standard")
|
| 173 |
+
|
| 174 |
+
mutated = Agent(
|
| 175 |
+
**{k: v for k, v in agent_dict.items() if k not in ("compaction", "tools")},
|
| 176 |
+
compaction=compaction,
|
| 177 |
+
tools=tools,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Build name from mutations
|
| 181 |
+
name_parts = _build_name_parts(mutations)
|
| 182 |
+
mutated.name = f"{base.name}_{'_'.join(name_parts)}" if name_parts else base.name
|
| 183 |
+
|
| 184 |
+
# Serialize mutations for storage
|
| 185 |
+
serializable_mutations = _serialize_mutations(mutations)
|
| 186 |
+
|
| 187 |
+
return Candidate(
|
| 188 |
+
agent=mutated,
|
| 189 |
+
mutations=serializable_mutations,
|
| 190 |
+
rationale=f"Variations: {', '.join(name_parts)}" if name_parts else "baseline",
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def _build_name_parts(mutations: dict[str, Any]) -> list[str]:
|
| 195 |
+
"""Build name parts from mutations for candidate naming."""
|
| 196 |
+
name_parts = []
|
| 197 |
+
|
| 198 |
+
for k, v in mutations.items():
|
| 199 |
+
if isinstance(v, CompactionConfig):
|
| 200 |
+
name_parts.append(f"{v.strategy}")
|
| 201 |
+
if v.strategy == "head_tail":
|
| 202 |
+
name_parts.append(f"h{v.head_size}_t{v.tail_size}")
|
| 203 |
+
elif k == "tools":
|
| 204 |
+
if isinstance(v, str):
|
| 205 |
+
name_parts.append(f"tools={v}")
|
| 206 |
+
elif isinstance(v, list):
|
| 207 |
+
name_parts.append(f"tools=[{len(v)}]")
|
| 208 |
+
else:
|
| 209 |
+
name_parts.append(f"tools=[{len(v)}]")
|
| 210 |
+
elif k == "llm_config" and isinstance(v, dict):
|
| 211 |
+
provider = v.get("provider", "unknown")
|
| 212 |
+
model = v.get("model", "")
|
| 213 |
+
name_parts.append(f"{provider}/{model}" if model else provider)
|
| 214 |
+
elif k == "instructions":
|
| 215 |
+
# Truncate instructions for name
|
| 216 |
+
preview = str(v)[:30].replace(" ", "_").replace("\n", "_")
|
| 217 |
+
name_parts.append(f"instr={preview}")
|
| 218 |
+
elif isinstance(v, bool):
|
| 219 |
+
name_parts.append(f"{k}={'on' if v else 'off'}")
|
| 220 |
+
else:
|
| 221 |
+
name_parts.append(f"{k}={v}")
|
| 222 |
+
|
| 223 |
+
return name_parts
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def _serialize_mutations(mutations: dict[str, Any]) -> dict[str, Any]:
|
| 227 |
+
"""Serialize mutations for storage (convert non-serializable types)."""
|
| 228 |
+
serializable = {}
|
| 229 |
+
|
| 230 |
+
for k, v in mutations.items():
|
| 231 |
+
if isinstance(v, CompactionConfig):
|
| 232 |
+
serializable[k] = asdict(v)
|
| 233 |
+
else:
|
| 234 |
+
serializable[k] = v
|
| 235 |
+
|
| 236 |
+
return serializable
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def _format_value(value: Any) -> str:
|
| 240 |
+
"""Format a value for logging."""
|
| 241 |
+
if isinstance(value, CompactionConfig):
|
| 242 |
+
return f"CompactionConfig({value.strategy})"
|
| 243 |
+
elif isinstance(value, str) and len(value) > 50:
|
| 244 |
+
return f'"{value[:50]}..."'
|
| 245 |
+
elif isinstance(value, dict):
|
| 246 |
+
return f"dict({len(value)} keys)"
|
| 247 |
+
elif isinstance(value, list):
|
| 248 |
+
return f"list({len(value)} items)"
|
| 249 |
+
else:
|
| 250 |
+
return repr(value)
|
src/flow/experiments/gaia_converter.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
from loguru import logger
|
| 6 |
+
|
| 7 |
+
from flow.experiments.types import EvalCriterion, Task
|
| 8 |
+
from flow.tools.text_inspector_qa import TextInspectorTool
|
| 9 |
+
from flow.tools.visual_inspector_qa import VisualInspectorTool
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _get_augmented_prompt_for_files(
|
| 13 |
+
local_path: str,
|
| 14 |
+
task: Task,
|
| 15 |
+
visual_inspector_tool: VisualInspectorTool | None,
|
| 16 |
+
text_inspector_tool: TextInspectorTool,
|
| 17 |
+
) -> str:
|
| 18 |
+
gaia_file = task.metadata.get("gaia_file")
|
| 19 |
+
if not gaia_file:
|
| 20 |
+
return ""
|
| 21 |
+
|
| 22 |
+
file_name = str(gaia_file)
|
| 23 |
+
full_file_path = str(Path(local_path) / file_name)
|
| 24 |
+
ext = Path(file_name).suffix.lower()
|
| 25 |
+
|
| 26 |
+
prompt_use_files = "\n\nTo answer the question above, you will have to use these attached files:"
|
| 27 |
+
|
| 28 |
+
if ext in [".pdf", ".xlsx"]:
|
| 29 |
+
image_path = file_name.split(".")[0] + ".png"
|
| 30 |
+
full_image_path = Path(local_path) / image_path
|
| 31 |
+
if full_image_path.exists():
|
| 32 |
+
prompt_use_files += f"\nAttached image: {full_image_path}"
|
| 33 |
+
else:
|
| 34 |
+
prompt_use_files += f"\nAttached file: {full_file_path}"
|
| 35 |
+
|
| 36 |
+
elif ext == ".zip":
|
| 37 |
+
import shutil
|
| 38 |
+
|
| 39 |
+
folder_name = full_file_path.replace(".zip", "")
|
| 40 |
+
os.makedirs(folder_name, exist_ok=True)
|
| 41 |
+
shutil.unpack_archive(full_file_path, folder_name)
|
| 42 |
+
|
| 43 |
+
# Convert the extracted files
|
| 44 |
+
prompt_use_files = (
|
| 45 |
+
"\n\nYou have been given a zip archive of supporting files. "
|
| 46 |
+
"We extracted it into a directory: find the extracted files at the following paths:\n"
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
for root, _, files in os.walk(folder_name):
|
| 50 |
+
for file in files:
|
| 51 |
+
file_path = os.path.join(root, file)
|
| 52 |
+
prompt_use_files += f"- {file_path}\n"
|
| 53 |
+
if Path(file).suffix.lower() in [".png", ".jpg", ".jpeg"] and visual_inspector_tool is not None:
|
| 54 |
+
prompt = f"""Write a caption of 5 sentences maximum for this image. Pay special attention to any details that might be useful for someone answering the following question:
|
| 55 |
+
{task.prompt}. But do not try to answer the question directly!
|
| 56 |
+
Do not add any information that is not present in the image.
|
| 57 |
+
""".strip()
|
| 58 |
+
prompt_use_files += (
|
| 59 |
+
"> Description of this image: "
|
| 60 |
+
+ visual_inspector_tool(image_path=file_path, question=prompt)
|
| 61 |
+
+ "\n\n"
|
| 62 |
+
)
|
| 63 |
+
else:
|
| 64 |
+
prompt = f"""Write a short caption (5 sentences maximum) for this file. Pay special attention to any details that might be useful for someone answering the following question:
|
| 65 |
+
{task.prompt}. But do not try to answer the question directly!
|
| 66 |
+
Do not add any information that is not present in the file.
|
| 67 |
+
""".strip()
|
| 68 |
+
prompt_use_files += (
|
| 69 |
+
"> Description of this file: "
|
| 70 |
+
+ text_inspector_tool.forward_initial_exam_mode(file_path=file_path, question=prompt)
|
| 71 |
+
+ "\n\n"
|
| 72 |
+
)
|
| 73 |
+
elif ext in [".png", ".jpg", ".jpeg"]:
|
| 74 |
+
prompt_use_files += f"\nAttached image: {full_file_path}"
|
| 75 |
+
elif ext in [".mp3", ".m4a", ".wav"]:
|
| 76 |
+
prompt_use_files += f"\nAttached audio: {full_file_path}"
|
| 77 |
+
else:
|
| 78 |
+
prompt_use_files += f"\nAttached file: {full_file_path}"
|
| 79 |
+
|
| 80 |
+
return prompt_use_files
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def extract_task(row: dict[str, Any]) -> dict[str, Any] | None:
|
| 84 |
+
"""Extract task fields from a row with flexible field names.
|
| 85 |
+
|
| 86 |
+
GAIA dataset has inconsistent field names across versions, so we try
|
| 87 |
+
multiple variants for each field.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
row: Raw row from parquet/jsonl
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
Normalized task dict, or None if task should be skipped
|
| 94 |
+
"""
|
| 95 |
+
# Question field variants
|
| 96 |
+
question = row.get("Question") or row.get("question") or row.get("query") or row.get("prompt")
|
| 97 |
+
|
| 98 |
+
# Answer field variants
|
| 99 |
+
answer = row.get("Final answer") or row.get("answer") or row.get("final_answer")
|
| 100 |
+
|
| 101 |
+
# Task ID field variants
|
| 102 |
+
task_id = str(row.get("task_id") or row.get("question_id") or row.get("id") or row.get("uuid"))
|
| 103 |
+
|
| 104 |
+
# Level field
|
| 105 |
+
level = row.get("Level") or row.get("level")
|
| 106 |
+
if isinstance(level, str) and level.isdigit():
|
| 107 |
+
level = int(level)
|
| 108 |
+
|
| 109 |
+
# File attachment
|
| 110 |
+
file_name = row.get("file_name") or row.get("filename")
|
| 111 |
+
|
| 112 |
+
# Skip tasks without question or valid answer (test set has "?" placeholders)
|
| 113 |
+
if not question or answer is None or str(answer).strip() in ["?", ""]:
|
| 114 |
+
return None
|
| 115 |
+
|
| 116 |
+
return {
|
| 117 |
+
"task_id": task_id,
|
| 118 |
+
"question": question,
|
| 119 |
+
"answer": str(answer),
|
| 120 |
+
"level": level,
|
| 121 |
+
"file_name": file_name,
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def convert_to_flow_task(gaia_task: dict[str, Any]) -> Task:
|
| 126 |
+
"""Convert a GAIA task to Flow task format.
|
| 127 |
+
|
| 128 |
+
Flow uses LLM-as-judge evaluation with criteria instructions. For GAIA,
|
| 129 |
+
we store the expected answer in both the criteria instruction and metadata
|
| 130 |
+
so that:
|
| 131 |
+
1. LLM-as-judge can evaluate based on the instruction
|
| 132 |
+
2. Custom evaluators can use metadata.gaia_answer for exact-match scoring
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
gaia_task: Normalized GAIA task dict
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
Flow-compatible task dict
|
| 139 |
+
"""
|
| 140 |
+
return Task(
|
| 141 |
+
name=gaia_task["task_id"],
|
| 142 |
+
prompt=gaia_task["question"],
|
| 143 |
+
criteria=[
|
| 144 |
+
EvalCriterion(
|
| 145 |
+
name="correct_answer",
|
| 146 |
+
instruction=f"The agent's final answer must match: {gaia_task['answer']}",
|
| 147 |
+
weight=1.0,
|
| 148 |
+
)
|
| 149 |
+
],
|
| 150 |
+
metadata={
|
| 151 |
+
"gaia_answer": gaia_task["answer"],
|
| 152 |
+
"gaia_level": gaia_task.get("level"),
|
| 153 |
+
"gaia_file": gaia_task.get("file_name"),
|
| 154 |
+
"source": "gaia-benchmark",
|
| 155 |
+
},
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def convert_gaia(example: dict[str, Any], index: int, dataset_metadata: dict[str, Any] | None = None) -> Task:
|
| 160 |
+
logger.debug(f"Processing task at index: {index}")
|
| 161 |
+
|
| 162 |
+
if dataset_metadata is None:
|
| 163 |
+
raise ValueError("dataset_metadata is required and cannot be None.")
|
| 164 |
+
|
| 165 |
+
# Validate required fields in dataset_metadata
|
| 166 |
+
config = dataset_metadata.get("config")
|
| 167 |
+
split = dataset_metadata.get("split")
|
| 168 |
+
local_path = dataset_metadata.get("local_path")
|
| 169 |
+
|
| 170 |
+
if config is None:
|
| 171 |
+
raise ValueError("dataset_metadata 'config' is required and cannot be None.")
|
| 172 |
+
|
| 173 |
+
if split is None:
|
| 174 |
+
raise ValueError("dataset_metadata 'split' is required and cannot be None.")
|
| 175 |
+
|
| 176 |
+
if local_path is None:
|
| 177 |
+
raise ValueError("dataset_metadata 'local_path' is required and cannot be None.")
|
| 178 |
+
|
| 179 |
+
# Derive GAIA year from the config when possible (e.g., "2023_level2" -> "2023"),
|
| 180 |
+
# falling back to "2023" to preserve existing behavior if parsing fails.
|
| 181 |
+
gaia_year = "2023"
|
| 182 |
+
if isinstance(config, str):
|
| 183 |
+
year_candidate = config.split("_", 1)[0]
|
| 184 |
+
if year_candidate.isdigit() and len(year_candidate) == 4:
|
| 185 |
+
gaia_year = year_candidate
|
| 186 |
+
|
| 187 |
+
resolved_local_path = str(Path(local_path) / gaia_year / split)
|
| 188 |
+
|
| 189 |
+
extracted_task = extract_task(example)
|
| 190 |
+
|
| 191 |
+
converted_task = convert_to_flow_task(extracted_task)
|
| 192 |
+
|
| 193 |
+
try:
|
| 194 |
+
visual_inspector_tool = VisualInspectorTool()
|
| 195 |
+
text_inspector_tool = TextInspectorTool()
|
| 196 |
+
except RuntimeError as exc:
|
| 197 |
+
logger.warning(
|
| 198 |
+
"Inspector tools could not be initialized (likely missing environment "
|
| 199 |
+
"variables). Skipping file-based prompt augmentation. Error: {}",
|
| 200 |
+
exc,
|
| 201 |
+
)
|
| 202 |
+
prompt_for_files = ""
|
| 203 |
+
else:
|
| 204 |
+
prompt_for_files = _get_augmented_prompt_for_files(
|
| 205 |
+
local_path=resolved_local_path,
|
| 206 |
+
task=converted_task,
|
| 207 |
+
visual_inspector_tool=visual_inspector_tool,
|
| 208 |
+
text_inspector_tool=text_inspector_tool,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
if len(prompt_for_files) > 0:
|
| 212 |
+
new_prompt = converted_task.prompt + prompt_for_files
|
| 213 |
+
converted_task.metadata["original_prompt"] = converted_task.prompt
|
| 214 |
+
converted_task.prompt = new_prompt
|
| 215 |
+
|
| 216 |
+
return converted_task
|
src/flow/experiments/hf_datasets.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Convert Hugging Face datasets to Flow task format.
|
| 2 |
+
|
| 3 |
+
This module provides utilities to convert HF datasets (like GSM8K, MATH, HumanEval)
|
| 4 |
+
into Flow's task format for use with GEPA optimization.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
# From CLI
|
| 8 |
+
python -m flow.cli.hf_import gsm8k --output tasks/gsm8k.jsonl --limit 100
|
| 9 |
+
|
| 10 |
+
# Programmatically
|
| 11 |
+
from flow.experiments.hf_datasets import import_hf_dataset
|
| 12 |
+
tasks = import_hf_dataset("openai/gsm8k", split="train", limit=50)
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import logging
|
| 19 |
+
import os
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import Any
|
| 22 |
+
|
| 23 |
+
from flow.experiments.types import EvalCriterion, Task
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Dataset-specific converters
|
| 29 |
+
# Each converter knows how to extract question/answer from a specific dataset
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def convert_gsm8k(example: dict[str, Any], index: int, dataset_metadata: dict[str, Any] | None = None) -> Task:
|
| 33 |
+
"""Convert GSM8K math problem to Flow task.
|
| 34 |
+
|
| 35 |
+
GSM8K format:
|
| 36 |
+
{
|
| 37 |
+
"question": "Natalia sold clips to 48 of her friends...",
|
| 38 |
+
"answer": "Natalia sold 48/2 = 24 clips in May. ... #### 72"
|
| 39 |
+
}
|
| 40 |
+
"""
|
| 41 |
+
question = example["question"]
|
| 42 |
+
answer = example["answer"]
|
| 43 |
+
|
| 44 |
+
# Extract final answer (after ####)
|
| 45 |
+
final_answer = None
|
| 46 |
+
if "####" in answer:
|
| 47 |
+
final_answer = answer.split("####")[-1].strip()
|
| 48 |
+
|
| 49 |
+
# Create task with evaluation criteria
|
| 50 |
+
criteria = [
|
| 51 |
+
EvalCriterion(
|
| 52 |
+
name="correctness",
|
| 53 |
+
instruction=f"The solution correctly answers: {question}. The correct answer is {final_answer}",
|
| 54 |
+
weight=1.0,
|
| 55 |
+
),
|
| 56 |
+
EvalCriterion(
|
| 57 |
+
name="reasoning",
|
| 58 |
+
instruction="The solution shows clear mathematical reasoning and step-by-step work",
|
| 59 |
+
weight=0.7,
|
| 60 |
+
),
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
task_metadata = {"dataset": "gsm8k", "index": index, "answer": answer, "final_answer": final_answer}
|
| 64 |
+
if dataset_metadata:
|
| 65 |
+
task_metadata.update(dataset_metadata)
|
| 66 |
+
|
| 67 |
+
return Task(
|
| 68 |
+
name=f"gsm8k_{index}",
|
| 69 |
+
prompt=f"Solve this math problem step by step:\n\n{question}",
|
| 70 |
+
criteria=criteria,
|
| 71 |
+
metadata=task_metadata,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def convert_math(example: dict[str, Any], index: int, dataset_metadata: dict[str, Any] | None = None) -> Task:
|
| 76 |
+
"""Convert MATH dataset problem to Flow task.
|
| 77 |
+
|
| 78 |
+
MATH format:
|
| 79 |
+
{
|
| 80 |
+
"problem": "What is 2+2?",
|
| 81 |
+
"solution": "The answer is 4",
|
| 82 |
+
"level": "Level 1",
|
| 83 |
+
"type": "Algebra"
|
| 84 |
+
}
|
| 85 |
+
"""
|
| 86 |
+
problem = example["problem"]
|
| 87 |
+
solution = example.get("solution", "")
|
| 88 |
+
level = example.get("level", "Unknown")
|
| 89 |
+
problem_type = example.get("type", "Unknown")
|
| 90 |
+
|
| 91 |
+
criteria = [
|
| 92 |
+
EvalCriterion(name="correctness", instruction=f"The solution correctly solves: {problem}", weight=1.0),
|
| 93 |
+
EvalCriterion(
|
| 94 |
+
name="mathematical_rigor",
|
| 95 |
+
instruction="The solution uses proper mathematical notation and reasoning",
|
| 96 |
+
weight=0.8,
|
| 97 |
+
),
|
| 98 |
+
]
|
| 99 |
+
|
| 100 |
+
task_metadata = {"dataset": "math", "index": index, "level": level, "type": problem_type, "solution": solution}
|
| 101 |
+
if dataset_metadata:
|
| 102 |
+
task_metadata.update(dataset_metadata)
|
| 103 |
+
|
| 104 |
+
return Task(
|
| 105 |
+
name=f"math_{problem_type.lower()}_{index}",
|
| 106 |
+
prompt=f"Solve this {level} {problem_type} problem:\n\n{problem}",
|
| 107 |
+
criteria=criteria,
|
| 108 |
+
metadata=task_metadata,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def convert_humaneval(example: dict[str, Any], index: int, dataset_metadata: dict[str, Any] | None = None) -> Task:
|
| 113 |
+
r"""Convert HumanEval coding problem to Flow task.
|
| 114 |
+
|
| 115 |
+
HumanEval format:
|
| 116 |
+
{
|
| 117 |
+
"task_id": "HumanEval/0",
|
| 118 |
+
"prompt": "def has_close_elements(numbers, threshold):\n ...",
|
| 119 |
+
"canonical_solution": " ...",
|
| 120 |
+
"test": "def check(...):\n ...",
|
| 121 |
+
"entry_point": "has_close_elements"
|
| 122 |
+
}
|
| 123 |
+
"""
|
| 124 |
+
task_id = example.get("task_id", f"task_{index}")
|
| 125 |
+
prompt = example["prompt"]
|
| 126 |
+
entry_point = example.get("entry_point", "")
|
| 127 |
+
test = example.get("test", "")
|
| 128 |
+
|
| 129 |
+
criteria = [
|
| 130 |
+
EvalCriterion(
|
| 131 |
+
name="correctness", instruction="The code implementation is correct and passes all test cases", weight=1.0
|
| 132 |
+
),
|
| 133 |
+
EvalCriterion(
|
| 134 |
+
name="code_quality",
|
| 135 |
+
instruction="The code is clean, well-documented, and follows best practices",
|
| 136 |
+
weight=0.6,
|
| 137 |
+
),
|
| 138 |
+
]
|
| 139 |
+
|
| 140 |
+
task_metadata = {"dataset": "humaneval", "task_id": task_id, "entry_point": entry_point, "test": test}
|
| 141 |
+
if dataset_metadata:
|
| 142 |
+
task_metadata.update(dataset_metadata)
|
| 143 |
+
|
| 144 |
+
return Task(
|
| 145 |
+
name=f"humaneval_{task_id.replace('/', '_')}",
|
| 146 |
+
prompt=f"Complete this Python function:\n\n{prompt}\n\nMake sure it passes the test cases.",
|
| 147 |
+
criteria=criteria,
|
| 148 |
+
metadata=task_metadata,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def convert_mbpp(example: dict[str, Any], index: int, dataset_metadata: dict[str, Any] | None = None) -> Task:
|
| 153 |
+
"""Convert MBPP coding problem to Flow task.
|
| 154 |
+
|
| 155 |
+
MBPP format:
|
| 156 |
+
{
|
| 157 |
+
"task_id": 1,
|
| 158 |
+
"text": "Write a function to find the minimum cost path...",
|
| 159 |
+
"code": "def min_cost(cost, m, n): ...",
|
| 160 |
+
"test_list": ["assert min_cost(...) == ..."]
|
| 161 |
+
}
|
| 162 |
+
"""
|
| 163 |
+
task_id = example.get("task_id", index)
|
| 164 |
+
text = example.get("text", "")
|
| 165 |
+
test_list = example.get("test_list", [])
|
| 166 |
+
|
| 167 |
+
criteria = [
|
| 168 |
+
EvalCriterion(name="correctness", instruction=f"The solution correctly implements: {text}", weight=1.0),
|
| 169 |
+
EvalCriterion(name="efficiency", instruction="The solution uses an efficient algorithm", weight=0.7),
|
| 170 |
+
]
|
| 171 |
+
|
| 172 |
+
task_metadata = {"dataset": "mbpp", "task_id": task_id, "test_list": test_list}
|
| 173 |
+
if dataset_metadata:
|
| 174 |
+
task_metadata.update(dataset_metadata)
|
| 175 |
+
|
| 176 |
+
return Task(
|
| 177 |
+
name=f"mbpp_{task_id}",
|
| 178 |
+
prompt=f"{text}\n\nImplement this in Python.",
|
| 179 |
+
criteria=criteria,
|
| 180 |
+
metadata=task_metadata,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# Registry of dataset converters
|
| 185 |
+
def _get_gaia_converter():
|
| 186 |
+
"""Lazy import for GAIA converter to avoid smolagents dependency at import time."""
|
| 187 |
+
from flow.experiments.gaia_converter import convert_gaia
|
| 188 |
+
return convert_gaia
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
DATASET_CONVERTERS = {
|
| 192 |
+
"openai/gsm8k": convert_gsm8k,
|
| 193 |
+
"gsm8k": convert_gsm8k,
|
| 194 |
+
"competition_math": convert_math,
|
| 195 |
+
"hendrycks/math": convert_math,
|
| 196 |
+
"humaneval": convert_humaneval,
|
| 197 |
+
"openai_humaneval": convert_humaneval,
|
| 198 |
+
"mbpp": convert_mbpp,
|
| 199 |
+
"google-research-datasets/mbpp": convert_mbpp,
|
| 200 |
+
"gaia-benchmark/GAIA": _get_gaia_converter, # Lazy loaded
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def import_hf_dataset(
|
| 205 |
+
dataset_name: str,
|
| 206 |
+
config: str | None = None,
|
| 207 |
+
split: str = "train",
|
| 208 |
+
limit: int | None = None,
|
| 209 |
+
converter_override: Any = None,
|
| 210 |
+
local_path: str | Path | None = None,
|
| 211 |
+
) -> list[Task]:
|
| 212 |
+
"""Import a Hugging Face dataset and convert to Flow tasks.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
dataset_name: HF dataset name (e.g., "openai/gsm8k")
|
| 216 |
+
config: Dataset configuration/subset (e.g., "main")
|
| 217 |
+
split: Dataset split to use (default: "train")
|
| 218 |
+
limit: Maximum number of examples to convert (default: all)
|
| 219 |
+
converter_override: Custom converter function (optional)
|
| 220 |
+
local_path: Path to download the dataset snapshot to using huggingface_hub.snapshot_download().
|
| 221 |
+
When provided, downloads the dataset to this path first, then loads from local files.
|
| 222 |
+
If the snapshot already exists at this path, it will be reused.
|
| 223 |
+
For private datasets, set the HF_TOKEN environment variable with your Hugging Face token.
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
List of Flow Task objects
|
| 227 |
+
|
| 228 |
+
Environment Variables:
|
| 229 |
+
HF_TOKEN: Hugging Face API token for accessing private/gated datasets.
|
| 230 |
+
Required when using local_path with private datasets.
|
| 231 |
+
|
| 232 |
+
Example:
|
| 233 |
+
>>> # Load from Hugging Face Hub (default behavior)
|
| 234 |
+
>>> tasks = import_hf_dataset("openai/gsm8k", config="main", split="train", limit=50)
|
| 235 |
+
>>> print(f"Loaded {len(tasks)} tasks")
|
| 236 |
+
|
| 237 |
+
>>> # Download to local path first, then load
|
| 238 |
+
>>> tasks = import_hf_dataset("openai/gsm8k", config="main", split="train", local_path="/data/gsm8k")
|
| 239 |
+
|
| 240 |
+
>>> # For private datasets, set HF_TOKEN env variable first
|
| 241 |
+
>>> # export HF_TOKEN="hf_..."
|
| 242 |
+
>>> tasks = import_hf_dataset("org/private-dataset", split="train", local_path="/data/private")
|
| 243 |
+
"""
|
| 244 |
+
try:
|
| 245 |
+
from datasets import load_dataset
|
| 246 |
+
except ImportError as e:
|
| 247 |
+
raise ImportError("Hugging Face datasets library is required. Install with: pip install datasets") from e
|
| 248 |
+
|
| 249 |
+
# Download to local path if specified, then load from there
|
| 250 |
+
if local_path is not None:
|
| 251 |
+
try:
|
| 252 |
+
from huggingface_hub import snapshot_download
|
| 253 |
+
except ImportError as e:
|
| 254 |
+
raise ImportError(
|
| 255 |
+
"huggingface_hub library is required for local_path support. Install with: pip install huggingface_hub"
|
| 256 |
+
) from e
|
| 257 |
+
|
| 258 |
+
local_path = Path(local_path)
|
| 259 |
+
hf_token = os.environ.get("HF_TOKEN")
|
| 260 |
+
logger.info(f"Downloading dataset {dataset_name} to local path: {local_path}")
|
| 261 |
+
snapshot_path = snapshot_download(
|
| 262 |
+
repo_id=dataset_name,
|
| 263 |
+
repo_type="dataset",
|
| 264 |
+
local_dir=str(local_path),
|
| 265 |
+
token=hf_token,
|
| 266 |
+
)
|
| 267 |
+
logger.info(f"Loading dataset from local snapshot: {snapshot_path} (split: {split})")
|
| 268 |
+
dataset = load_dataset(snapshot_path, config, split=split)
|
| 269 |
+
else:
|
| 270 |
+
logger.info(f"Loading dataset: {dataset_name} (config: {config}, split: {split})")
|
| 271 |
+
dataset = load_dataset(dataset_name, config, split=split)
|
| 272 |
+
|
| 273 |
+
# Apply limit
|
| 274 |
+
if limit:
|
| 275 |
+
dataset = dataset.select(range(min(limit, len(dataset))))
|
| 276 |
+
|
| 277 |
+
logger.info(f"Converting {len(dataset)} examples to Flow tasks...")
|
| 278 |
+
|
| 279 |
+
# Find converter
|
| 280 |
+
converter = converter_override
|
| 281 |
+
if converter is None:
|
| 282 |
+
# Try to find matching converter
|
| 283 |
+
for key, conv in DATASET_CONVERTERS.items():
|
| 284 |
+
if key in dataset_name:
|
| 285 |
+
# Handle lazy loaders (functions that return the actual converter)
|
| 286 |
+
if conv is _get_gaia_converter:
|
| 287 |
+
converter = conv()
|
| 288 |
+
else:
|
| 289 |
+
converter = conv
|
| 290 |
+
break
|
| 291 |
+
|
| 292 |
+
if converter is None:
|
| 293 |
+
raise ValueError(
|
| 294 |
+
f"No converter found for dataset '{dataset_name}'. "
|
| 295 |
+
f"Available: {list(DATASET_CONVERTERS.keys())}\n"
|
| 296 |
+
f"Use converter_override parameter to provide a custom converter."
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# Build dataset metadata to pass to converters
|
| 300 |
+
dataset_metadata: dict[str, Any] = {}
|
| 301 |
+
dataset_metadata["local_path"] = str(local_path) if local_path else None
|
| 302 |
+
dataset_metadata["config"] = config
|
| 303 |
+
dataset_metadata["split"] = split
|
| 304 |
+
|
| 305 |
+
# Convert examples
|
| 306 |
+
tasks = []
|
| 307 |
+
for i, example in enumerate(dataset):
|
| 308 |
+
try:
|
| 309 |
+
task = converter(example, i, dataset_metadata)
|
| 310 |
+
tasks.append(task)
|
| 311 |
+
except Exception as e:
|
| 312 |
+
logger.warning(f"Failed to convert example {i}: {e}", exc_info=True)
|
| 313 |
+
|
| 314 |
+
logger.info(f"Successfully converted {len(tasks)} tasks")
|
| 315 |
+
return tasks
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def save_tasks_to_jsonl(tasks: list[Task], output_path: Path) -> None:
|
| 319 |
+
"""Save tasks to JSONL file.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
tasks: List of Task objects
|
| 323 |
+
output_path: Path to output JSONL file
|
| 324 |
+
"""
|
| 325 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 326 |
+
|
| 327 |
+
with open(output_path, "w") as f:
|
| 328 |
+
for task in tasks:
|
| 329 |
+
# Convert to dict
|
| 330 |
+
task_dict = {
|
| 331 |
+
"name": task.name,
|
| 332 |
+
"prompt": task.prompt,
|
| 333 |
+
"criteria": [{"name": c.name, "instruction": c.instruction, "weight": c.weight} for c in task.criteria],
|
| 334 |
+
"metadata": task.metadata,
|
| 335 |
+
}
|
| 336 |
+
f.write(json.dumps(task_dict) + "\n")
|
| 337 |
+
|
| 338 |
+
logger.info(f"Saved {len(tasks)} tasks to {output_path}")
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def register_converter(dataset_name: str, converter_func: Any) -> None:
|
| 342 |
+
"""Register a custom converter for a dataset.
|
| 343 |
+
|
| 344 |
+
Args:
|
| 345 |
+
dataset_name: Dataset identifier
|
| 346 |
+
converter_func: Function that converts example dict to Task
|
| 347 |
+
|
| 348 |
+
Example:
|
| 349 |
+
>>> def my_converter(example, index):
|
| 350 |
+
... return Task(name=f"task_{index}", prompt=example["text"], ...)
|
| 351 |
+
>>> register_converter("my/dataset", my_converter)
|
| 352 |
+
"""
|
| 353 |
+
DATASET_CONVERTERS[dataset_name] = converter_func
|
| 354 |
+
logger.info(f"Registered converter for '{dataset_name}'")
|
src/flow/experiments/models.py
CHANGED
|
@@ -3,6 +3,8 @@
|
|
| 3 |
"""Core data models for the optimization framework.
|
| 4 |
|
| 5 |
Defines:
|
|
|
|
|
|
|
| 6 |
- CompactionConfig: Extensible compaction strategy configuration
|
| 7 |
- Agent: Framework-agnostic agent definition (what the customer brings)
|
| 8 |
- Candidate: A mutated agent variant produced by optimization
|
|
@@ -17,14 +19,17 @@ from __future__ import annotations
|
|
| 17 |
from dataclasses import asdict, dataclass, field
|
| 18 |
from itertools import product as itertools_product
|
| 19 |
from pathlib import Path
|
| 20 |
-
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
| 21 |
|
| 22 |
import yaml
|
| 23 |
|
| 24 |
if TYPE_CHECKING:
|
| 25 |
-
from collections.abc import Awaitable, Callable
|
|
|
|
|
|
|
| 26 |
|
| 27 |
from .evaluators.base import Evaluator
|
|
|
|
| 28 |
from .types import Task
|
| 29 |
|
| 30 |
|
|
@@ -34,6 +39,89 @@ if TYPE_CHECKING:
|
|
| 34 |
|
| 35 |
# Tool presets define common tool configurations.
|
| 36 |
# Each preset maps tool names to their configuration dicts.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
TOOL_PRESETS: dict[str, dict[str, dict[str, Any]]] = {
|
| 38 |
"full": {
|
| 39 |
"read_file": {},
|
|
@@ -43,9 +131,8 @@ TOOL_PRESETS: dict[str, dict[str, dict[str, Any]]] = {
|
|
| 43 |
"glob_files": {},
|
| 44 |
"ls": {},
|
| 45 |
"grep": {},
|
| 46 |
-
"bash": {"timeout":
|
| 47 |
"check_processes": {},
|
| 48 |
-
"python_repl": {},
|
| 49 |
"think": {},
|
| 50 |
"todo_write": {},
|
| 51 |
"todo_read": {},
|
|
@@ -65,20 +152,21 @@ TOOL_PRESETS: dict[str, dict[str, dict[str, Any]]] = {
|
|
| 65 |
"glob_files": {},
|
| 66 |
"ls": {},
|
| 67 |
"grep": {},
|
| 68 |
-
"bash": {"timeout":
|
| 69 |
"check_processes": {},
|
| 70 |
-
"python_repl": {},
|
| 71 |
"think": {},
|
| 72 |
"todo_write": {},
|
| 73 |
"todo_read": {},
|
| 74 |
"memory": {},
|
| 75 |
"skills": {},
|
|
|
|
|
|
|
| 76 |
},
|
| 77 |
"minimal": {
|
| 78 |
"read_file": {},
|
| 79 |
"write_file": {},
|
| 80 |
"edit_file": {},
|
| 81 |
-
"bash": {"timeout":
|
| 82 |
"think": {},
|
| 83 |
},
|
| 84 |
"readonly": {
|
|
@@ -91,16 +179,17 @@ TOOL_PRESETS: dict[str, dict[str, dict[str, Any]]] = {
|
|
| 91 |
}
|
| 92 |
|
| 93 |
|
| 94 |
-
def resolve_tools(tools: str | list[str] | dict[str, dict[str, Any]]) -> dict[str, dict[str, Any]]:
|
| 95 |
"""Normalize tool specification to dict form.
|
| 96 |
|
| 97 |
-
Accepts
|
|
|
|
| 98 |
- str: Preset name (e.g., "standard", "minimal", "full", "readonly")
|
| 99 |
- list[str]: List of tool names with default configs
|
| 100 |
- dict[str, dict]: Full specification with per-tool configs
|
| 101 |
|
| 102 |
Args:
|
| 103 |
-
tools: Tool specification in any supported format
|
| 104 |
|
| 105 |
Returns:
|
| 106 |
Dict mapping tool names to their configuration dicts
|
|
@@ -109,6 +198,9 @@ def resolve_tools(tools: str | list[str] | dict[str, dict[str, Any]]) -> dict[st
|
|
| 109 |
ValueError: If preset name is unknown
|
| 110 |
|
| 111 |
Example:
|
|
|
|
|
|
|
|
|
|
| 112 |
>>> resolve_tools("standard")
|
| 113 |
{"read_file": {}, "write_file": {}, ...}
|
| 114 |
|
|
@@ -118,6 +210,8 @@ def resolve_tools(tools: str | list[str] | dict[str, dict[str, Any]]) -> dict[st
|
|
| 118 |
>>> resolve_tools({"bash": {"timeout": 60}})
|
| 119 |
{"bash": {"timeout": 60}}
|
| 120 |
"""
|
|
|
|
|
|
|
| 121 |
if isinstance(tools, str):
|
| 122 |
if tools not in TOOL_PRESETS:
|
| 123 |
raise ValueError(f"Unknown tool preset: {tools}. Available: {list(TOOL_PRESETS.keys())}")
|
|
@@ -148,8 +242,8 @@ class CompactionConfig:
|
|
| 148 |
token_budget: Maximum tokens for context window (used by token-based strategies)
|
| 149 |
"""
|
| 150 |
|
| 151 |
-
strategy: str = "
|
| 152 |
-
params: dict[str, Any] = field(default_factory=
|
| 153 |
token_budget: int = 100_000
|
| 154 |
|
| 155 |
# =========================================================================
|
|
@@ -278,6 +372,10 @@ class CompactionConfig:
|
|
| 278 |
return self.params.get("head_ratio", 0.2)
|
| 279 |
|
| 280 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
@dataclass
|
| 282 |
class Agent:
|
| 283 |
"""Framework-agnostic agent definition.
|
|
@@ -289,10 +387,10 @@ class Agent:
|
|
| 289 |
|
| 290 |
Attributes:
|
| 291 |
name: Unique identifier for this agent
|
| 292 |
-
framework: Which harness to use ("maf", "
|
| 293 |
description: Human-readable description
|
| 294 |
instructions: System prompt / instructions (optional, uses framework default if None)
|
| 295 |
-
instructions_preset: Preset name for instructions (
|
| 296 |
llm_config: LLM configuration with provider and model info:
|
| 297 |
{"provider": "azure|openai|anthropic", "model": "gpt-4o"}
|
| 298 |
If None, auto-detects from environment variables.
|
|
@@ -304,13 +402,270 @@ class Agent:
|
|
| 304 |
"""
|
| 305 |
|
| 306 |
name: str
|
| 307 |
-
framework:
|
| 308 |
description: str = ""
|
| 309 |
instructions: str | None = None
|
| 310 |
-
instructions_preset: str | None = None # e.g., "
|
| 311 |
llm_config: dict[str, Any] | None = None # {"provider": "azure", "model": "gpt-4o"}
|
| 312 |
compaction: CompactionConfig = field(default_factory=CompactionConfig)
|
| 313 |
-
tools: str | list[str] | dict[str, dict[str, Any]] =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
|
| 315 |
|
| 316 |
@dataclass
|
|
@@ -325,11 +680,15 @@ class Candidate:
|
|
| 325 |
agent: The mutated agent configuration
|
| 326 |
mutations: Dict describing what was changed from the base
|
| 327 |
rationale: Human-readable explanation of why this candidate exists
|
|
|
|
|
|
|
|
|
|
| 328 |
"""
|
| 329 |
|
| 330 |
agent: Agent
|
| 331 |
mutations: dict[str, Any] = field(default_factory=dict)
|
| 332 |
rationale: str = ""
|
|
|
|
| 333 |
|
| 334 |
|
| 335 |
@dataclass
|
|
@@ -345,47 +704,93 @@ class ExperimentResult:
|
|
| 345 |
traces: dict[str, Any] = field(default_factory=dict)
|
| 346 |
|
| 347 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
@runtime_checkable
|
| 349 |
class CandidateStrategy(Protocol):
|
| 350 |
"""Protocol for generating candidate variants from a base agent.
|
| 351 |
|
| 352 |
Implementations can be:
|
| 353 |
-
-
|
| 354 |
-
-
|
| 355 |
-
|
| 356 |
|
| 357 |
-
All logic is internal to the strategy
|
| 358 |
and receives the final list of candidates.
|
| 359 |
|
| 360 |
Examples:
|
| 361 |
- GridSearchStrategy: Exhaustive grid over parameter combinations
|
| 362 |
-
- (Future)
|
| 363 |
- (Future) BayesianStrategy: Bayesian optimization over parameters
|
| 364 |
"""
|
| 365 |
|
| 366 |
-
def generate(
|
| 367 |
self,
|
| 368 |
base: Agent,
|
| 369 |
budget: int,
|
| 370 |
*,
|
| 371 |
tasks: list[Task] | None = None,
|
| 372 |
-
|
| 373 |
-
run_experiment: Callable[[Candidate, Task], Awaitable[ExperimentResult]] | None = None,
|
| 374 |
) -> list[Candidate]:
|
| 375 |
"""Generate candidate variants from a base agent.
|
| 376 |
|
| 377 |
Args:
|
| 378 |
base: The base agent to optimize
|
| 379 |
budget: Maximum number of candidates to return
|
| 380 |
-
tasks: Optional tasks for strategies that run internal experiments
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
|
|
|
| 384 |
|
| 385 |
Returns:
|
| 386 |
List of Candidate objects (at most `budget` items).
|
| 387 |
For iterative strategies, returns the final/best candidates after
|
| 388 |
-
internal optimization loops complete.
|
|
|
|
| 389 |
"""
|
| 390 |
...
|
| 391 |
|
|
@@ -419,23 +824,22 @@ class GridSearchStrategy:
|
|
| 419 |
"""
|
| 420 |
self.variations = variations
|
| 421 |
|
| 422 |
-
def generate(
|
| 423 |
self,
|
| 424 |
base: Agent,
|
| 425 |
budget: int,
|
| 426 |
*,
|
| 427 |
tasks: list[Task] | None = None,
|
| 428 |
-
|
| 429 |
-
run_experiment: Callable[[Candidate, Task], Awaitable[ExperimentResult]] | None = None,
|
| 430 |
) -> list[Candidate]:
|
| 431 |
"""Generate all grid combinations up to budget.
|
| 432 |
|
| 433 |
-
Note: tasks
|
| 434 |
-
|
| 435 |
-
|
| 436 |
"""
|
| 437 |
# Delete unused params to satisfy linters
|
| 438 |
-
del tasks,
|
| 439 |
|
| 440 |
if not self.variations:
|
| 441 |
return [Candidate(agent=base, mutations={}, rationale="baseline")]
|
|
@@ -690,6 +1094,38 @@ def _extract_metrics(
|
|
| 690 |
# =============================================================================
|
| 691 |
|
| 692 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 693 |
@dataclass
|
| 694 |
class Experiment:
|
| 695 |
"""Experiment configuration for optimization.
|
|
@@ -699,53 +1135,173 @@ class Experiment:
|
|
| 699 |
- Experiment YAML: How to test it (variations, tasks, evaluation settings)
|
| 700 |
|
| 701 |
Attributes:
|
| 702 |
-
base_agent: Path to base agent YAML file
|
| 703 |
suite: Built-in task suite name (e.g., "coding", "quick")
|
| 704 |
tasks: Path to custom tasks JSONL file (alternative to suite)
|
| 705 |
-
variations: Dict
|
| 706 |
parallel: Max concurrent experiments
|
| 707 |
-
budget: Maximum candidates to generate
|
| 708 |
use_llm_eval: Whether to use LLM-as-Judge evaluation
|
| 709 |
|
| 710 |
Example YAML:
|
| 711 |
```yaml
|
| 712 |
-
base_agent:
|
| 713 |
suite: coding
|
| 714 |
|
| 715 |
variations:
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
-
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
|
|
|
| 724 |
|
| 725 |
tools:
|
| 726 |
- minimal
|
| 727 |
- standard
|
| 728 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 729 |
|
| 730 |
-
parallel:
|
| 731 |
-
budget:
|
| 732 |
use_llm_eval: true
|
| 733 |
```
|
| 734 |
"""
|
| 735 |
|
| 736 |
-
base_agent: str
|
| 737 |
suite: str | None = None
|
| 738 |
tasks: str | None = None
|
| 739 |
-
variations: dict[str, list[
|
| 740 |
parallel: int = 4
|
| 741 |
budget: int = 100
|
| 742 |
use_llm_eval: bool = True
|
| 743 |
|
| 744 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 745 |
def load_experiment(path: Path) -> Experiment:
|
| 746 |
"""Load an Experiment from a YAML file.
|
| 747 |
|
| 748 |
-
|
| 749 |
|
| 750 |
Args:
|
| 751 |
path: Path to the experiment YAML file
|
|
@@ -762,38 +1318,34 @@ def load_experiment(path: Path) -> Experiment:
|
|
| 762 |
|
| 763 |
data = yaml.safe_load(path.read_text())
|
| 764 |
|
| 765 |
-
#
|
| 766 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 767 |
raw_variations = data.get("variations", {})
|
| 768 |
|
| 769 |
-
for
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
|
| 782 |
-
|
| 783 |
-
|
| 784 |
-
|
| 785 |
-
|
| 786 |
-
else:
|
| 787 |
-
raise ValueError(f"Unknown compaction shorthand: {v}")
|
| 788 |
-
else:
|
| 789 |
-
parsed_compactions.append(v)
|
| 790 |
-
variations["compaction"] = parsed_compactions
|
| 791 |
-
else:
|
| 792 |
-
# Other variations pass through as-is
|
| 793 |
-
variations[key] = values
|
| 794 |
|
| 795 |
return Experiment(
|
| 796 |
-
base_agent=data
|
| 797 |
suite=data.get("suite"),
|
| 798 |
tasks=data.get("tasks"),
|
| 799 |
variations=variations,
|
|
|
|
| 3 |
"""Core data models for the optimization framework.
|
| 4 |
|
| 5 |
Defines:
|
| 6 |
+
- COMPACTION_STRATEGIES: Registry of compaction strategies for schema API
|
| 7 |
+
- DEFAULT_TOKEN_BUDGET: Default token budget (200k) for modern models
|
| 8 |
- CompactionConfig: Extensible compaction strategy configuration
|
| 9 |
- Agent: Framework-agnostic agent definition (what the customer brings)
|
| 10 |
- Candidate: A mutated agent variant produced by optimization
|
|
|
|
| 19 |
from dataclasses import asdict, dataclass, field
|
| 20 |
from itertools import product as itertools_product
|
| 21 |
from pathlib import Path
|
| 22 |
+
from typing import TYPE_CHECKING, Any, Literal, Protocol, runtime_checkable
|
| 23 |
|
| 24 |
import yaml
|
| 25 |
|
| 26 |
if TYPE_CHECKING:
|
| 27 |
+
from collections.abc import AsyncIterator, Awaitable, Callable
|
| 28 |
+
|
| 29 |
+
from flow.harness.base import Event
|
| 30 |
|
| 31 |
from .evaluators.base import Evaluator
|
| 32 |
+
from .results import AgentOptimizationResult, EvaluationResult
|
| 33 |
from .types import Task
|
| 34 |
|
| 35 |
|
|
|
|
| 39 |
|
| 40 |
# Tool presets define common tool configurations.
|
| 41 |
# Each preset maps tool names to their configuration dicts.
|
| 42 |
+
# =============================================================================
|
| 43 |
+
# Compaction Strategy Configuration
|
| 44 |
+
# =============================================================================
|
| 45 |
+
|
| 46 |
+
# Default token budget for modern models (GPT-4o, Claude 3.5, etc.)
|
| 47 |
+
DEFAULT_TOKEN_BUDGET = 200_000
|
| 48 |
+
|
| 49 |
+
# Compaction strategies registry for schema API
|
| 50 |
+
# All strategies use token-based triggers (not message count) for safety
|
| 51 |
+
COMPACTION_STRATEGIES: dict[str, dict[str, Any]] = {
|
| 52 |
+
"head_tail": {
|
| 53 |
+
"label": "Head + Tail",
|
| 54 |
+
"description": "Keep head (system prompt, initial context) + tail (recent messages). Drops middle when over budget.",
|
| 55 |
+
"params": {
|
| 56 |
+
"head_ratio": {
|
| 57 |
+
"type": "number",
|
| 58 |
+
"default": 0.2,
|
| 59 |
+
"min": 0,
|
| 60 |
+
"max": 1,
|
| 61 |
+
"description": "Fraction of budget for head messages (0.2 = 20%)",
|
| 62 |
+
},
|
| 63 |
+
"token_budget": {
|
| 64 |
+
"type": "number",
|
| 65 |
+
"default": DEFAULT_TOKEN_BUDGET,
|
| 66 |
+
"min": 1000,
|
| 67 |
+
"description": "Max tokens before compaction triggers",
|
| 68 |
+
},
|
| 69 |
+
},
|
| 70 |
+
},
|
| 71 |
+
"sliding_window": {
|
| 72 |
+
"label": "Sliding Window",
|
| 73 |
+
"description": "Keep system message + most recent messages that fit within budget. Simple and effective.",
|
| 74 |
+
"params": {
|
| 75 |
+
"token_budget": {
|
| 76 |
+
"type": "number",
|
| 77 |
+
"default": DEFAULT_TOKEN_BUDGET,
|
| 78 |
+
"min": 1000,
|
| 79 |
+
"description": "Max tokens for context window",
|
| 80 |
+
},
|
| 81 |
+
},
|
| 82 |
+
},
|
| 83 |
+
"summarization": {
|
| 84 |
+
"label": "Summarization",
|
| 85 |
+
"description": "Summarize middle messages using LLM instead of dropping them. Preserves context but adds latency.",
|
| 86 |
+
"params": {
|
| 87 |
+
"head_messages": {
|
| 88 |
+
"type": "number",
|
| 89 |
+
"default": 2,
|
| 90 |
+
"min": 1,
|
| 91 |
+
"description": "Messages to preserve at head",
|
| 92 |
+
},
|
| 93 |
+
"tail_messages": {
|
| 94 |
+
"type": "number",
|
| 95 |
+
"default": 4,
|
| 96 |
+
"min": 1,
|
| 97 |
+
"description": "Messages to preserve at tail",
|
| 98 |
+
},
|
| 99 |
+
"summary_max_tokens": {
|
| 100 |
+
"type": "number",
|
| 101 |
+
"default": 1000,
|
| 102 |
+
"min": 100,
|
| 103 |
+
"description": "Max tokens for the summary",
|
| 104 |
+
},
|
| 105 |
+
"token_budget": {
|
| 106 |
+
"type": "number",
|
| 107 |
+
"default": DEFAULT_TOKEN_BUDGET,
|
| 108 |
+
"min": 1000,
|
| 109 |
+
"description": "Max tokens before compaction triggers",
|
| 110 |
+
},
|
| 111 |
+
},
|
| 112 |
+
},
|
| 113 |
+
"none": {
|
| 114 |
+
"label": "No Compaction",
|
| 115 |
+
"description": "Context grows unbounded. Only use for benchmarking or very short tasks.",
|
| 116 |
+
"params": {},
|
| 117 |
+
},
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# =============================================================================
|
| 122 |
+
# Tool Configuration
|
| 123 |
+
# =============================================================================
|
| 124 |
+
|
| 125 |
TOOL_PRESETS: dict[str, dict[str, dict[str, Any]]] = {
|
| 126 |
"full": {
|
| 127 |
"read_file": {},
|
|
|
|
| 131 |
"glob_files": {},
|
| 132 |
"ls": {},
|
| 133 |
"grep": {},
|
| 134 |
+
"bash": {"timeout": 300},
|
| 135 |
"check_processes": {},
|
|
|
|
| 136 |
"think": {},
|
| 137 |
"todo_write": {},
|
| 138 |
"todo_read": {},
|
|
|
|
| 152 |
"glob_files": {},
|
| 153 |
"ls": {},
|
| 154 |
"grep": {},
|
| 155 |
+
"bash": {"timeout": 300},
|
| 156 |
"check_processes": {},
|
|
|
|
| 157 |
"think": {},
|
| 158 |
"todo_write": {},
|
| 159 |
"todo_read": {},
|
| 160 |
"memory": {},
|
| 161 |
"skills": {},
|
| 162 |
+
"web_search": {},
|
| 163 |
+
"web_fetch": {},
|
| 164 |
},
|
| 165 |
"minimal": {
|
| 166 |
"read_file": {},
|
| 167 |
"write_file": {},
|
| 168 |
"edit_file": {},
|
| 169 |
+
"bash": {"timeout": 300},
|
| 170 |
"think": {},
|
| 171 |
},
|
| 172 |
"readonly": {
|
|
|
|
| 179 |
}
|
| 180 |
|
| 181 |
|
| 182 |
+
def resolve_tools(tools: str | list[str] | dict[str, dict[str, Any]] | None) -> dict[str, dict[str, Any]]:
|
| 183 |
"""Normalize tool specification to dict form.
|
| 184 |
|
| 185 |
+
Accepts four input formats:
|
| 186 |
+
- None: No tools (empty set)
|
| 187 |
- str: Preset name (e.g., "standard", "minimal", "full", "readonly")
|
| 188 |
- list[str]: List of tool names with default configs
|
| 189 |
- dict[str, dict]: Full specification with per-tool configs
|
| 190 |
|
| 191 |
Args:
|
| 192 |
+
tools: Tool specification in any supported format, or None for no tools
|
| 193 |
|
| 194 |
Returns:
|
| 195 |
Dict mapping tool names to their configuration dicts
|
|
|
|
| 198 |
ValueError: If preset name is unknown
|
| 199 |
|
| 200 |
Example:
|
| 201 |
+
>>> resolve_tools(None)
|
| 202 |
+
{}
|
| 203 |
+
|
| 204 |
>>> resolve_tools("standard")
|
| 205 |
{"read_file": {}, "write_file": {}, ...}
|
| 206 |
|
|
|
|
| 210 |
>>> resolve_tools({"bash": {"timeout": 60}})
|
| 211 |
{"bash": {"timeout": 60}}
|
| 212 |
"""
|
| 213 |
+
if tools is None:
|
| 214 |
+
return {}
|
| 215 |
if isinstance(tools, str):
|
| 216 |
if tools not in TOOL_PRESETS:
|
| 217 |
raise ValueError(f"Unknown tool preset: {tools}. Available: {list(TOOL_PRESETS.keys())}")
|
|
|
|
| 242 |
token_budget: Maximum tokens for context window (used by token-based strategies)
|
| 243 |
"""
|
| 244 |
|
| 245 |
+
strategy: str = "none"
|
| 246 |
+
params: dict[str, Any] = field(default_factory=dict)
|
| 247 |
token_budget: int = 100_000
|
| 248 |
|
| 249 |
# =========================================================================
|
|
|
|
| 372 |
return self.params.get("head_ratio", 0.2)
|
| 373 |
|
| 374 |
|
| 375 |
+
# Supported agent frameworks (harnesses)
|
| 376 |
+
Framework = Literal["maf", "miniagent", "langgraph"]
|
| 377 |
+
|
| 378 |
+
|
| 379 |
@dataclass
|
| 380 |
class Agent:
|
| 381 |
"""Framework-agnostic agent definition.
|
|
|
|
| 387 |
|
| 388 |
Attributes:
|
| 389 |
name: Unique identifier for this agent
|
| 390 |
+
framework: Which harness to use ("maf", "miniagent", "langgraph")
|
| 391 |
description: Human-readable description
|
| 392 |
instructions: System prompt / instructions (optional, uses framework default if None)
|
| 393 |
+
instructions_preset: Preset name for instructions (default: "general")
|
| 394 |
llm_config: LLM configuration with provider and model info:
|
| 395 |
{"provider": "azure|openai|anthropic", "model": "gpt-4o"}
|
| 396 |
If None, auto-detects from environment variables.
|
|
|
|
| 402 |
"""
|
| 403 |
|
| 404 |
name: str
|
| 405 |
+
framework: Framework = "miniagent"
|
| 406 |
description: str = ""
|
| 407 |
instructions: str | None = None
|
| 408 |
+
instructions_preset: str | None = None # e.g., "general"
|
| 409 |
llm_config: dict[str, Any] | None = None # {"provider": "azure", "model": "gpt-4o"}
|
| 410 |
compaction: CompactionConfig = field(default_factory=CompactionConfig)
|
| 411 |
+
tools: str | list[str] | dict[str, dict[str, Any]] | None = None
|
| 412 |
+
|
| 413 |
+
# Set by deploy() — when set, evaluate/optimize auto-persist to DB
|
| 414 |
+
_id: str | None = field(default=None, repr=False, compare=False)
|
| 415 |
+
|
| 416 |
+
@property
|
| 417 |
+
def id(self) -> str | None:
|
| 418 |
+
"""Agent ID in the database, set after deploy()."""
|
| 419 |
+
return self._id
|
| 420 |
+
|
| 421 |
+
@classmethod
|
| 422 |
+
def from_preset(cls, name: str) -> Agent:
|
| 423 |
+
"""Create an Agent from a named preset.
|
| 424 |
+
|
| 425 |
+
Args:
|
| 426 |
+
name: Preset name (e.g., "coding", "research", "document-analysis")
|
| 427 |
+
|
| 428 |
+
Returns:
|
| 429 |
+
A new Agent instance with the preset's configuration
|
| 430 |
+
|
| 431 |
+
Example:
|
| 432 |
+
agent = Agent.from_preset("coding")
|
| 433 |
+
result = await agent.evaluate(tasks="quick")
|
| 434 |
+
"""
|
| 435 |
+
from .presets import get_preset
|
| 436 |
+
|
| 437 |
+
preset = get_preset(name)
|
| 438 |
+
return cls(
|
| 439 |
+
name=preset.agent.name,
|
| 440 |
+
framework=preset.agent.framework,
|
| 441 |
+
description=preset.agent.description,
|
| 442 |
+
instructions=preset.agent.instructions,
|
| 443 |
+
instructions_preset=preset.agent.instructions_preset,
|
| 444 |
+
llm_config=preset.agent.llm_config,
|
| 445 |
+
compaction=preset.agent.compaction,
|
| 446 |
+
tools=preset.agent.tools,
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
async def run(self, task: str, workspace: Path | None = None) -> str:
|
| 450 |
+
"""Run the agent on a task and return the final output.
|
| 451 |
+
|
| 452 |
+
This is the simplest way to use an agent — give it a task, get a result.
|
| 453 |
+
|
| 454 |
+
Args:
|
| 455 |
+
task: The task/prompt to execute
|
| 456 |
+
workspace: Optional workspace directory (creates temp dir if None)
|
| 457 |
+
|
| 458 |
+
Returns:
|
| 459 |
+
The agent's final text output
|
| 460 |
+
|
| 461 |
+
Example:
|
| 462 |
+
agent = Agent(name="coding-agent", tools="standard")
|
| 463 |
+
output = await agent.run("Create hello.py that prints Hello World")
|
| 464 |
+
print(output)
|
| 465 |
+
"""
|
| 466 |
+
output_parts: list[str] = []
|
| 467 |
+
async for event in self.run_stream(task, workspace=workspace):
|
| 468 |
+
if event.type.value == "text_delta":
|
| 469 |
+
output_parts.append(event.content)
|
| 470 |
+
return "".join(output_parts)
|
| 471 |
+
|
| 472 |
+
async def run_stream(
|
| 473 |
+
self, task: str, *, workspace: Path | None = None
|
| 474 |
+
) -> AsyncIterator[Event]:
|
| 475 |
+
"""Run the agent on a task with streaming events.
|
| 476 |
+
|
| 477 |
+
Yields real-time events as the agent works — text chunks, tool calls,
|
| 478 |
+
tool results, and completion. Use this for live output in notebooks or CLIs.
|
| 479 |
+
|
| 480 |
+
Args:
|
| 481 |
+
task: The task/prompt to execute
|
| 482 |
+
workspace: Optional workspace directory (creates temp dir if None)
|
| 483 |
+
|
| 484 |
+
Yields:
|
| 485 |
+
Event objects (text_delta, tool_call_start, tool_result, done, etc.)
|
| 486 |
+
|
| 487 |
+
Example:
|
| 488 |
+
agent = Agent(name="coding-agent", tools="standard")
|
| 489 |
+
async for event in agent.run_stream("Create hello.py"):
|
| 490 |
+
if event.type.value == "text_delta":
|
| 491 |
+
print(event.content, end="", flush=True)
|
| 492 |
+
"""
|
| 493 |
+
import tempfile
|
| 494 |
+
|
| 495 |
+
# Lazy imports to avoid circular deps and keep Agent lightweight
|
| 496 |
+
from flow.harness import create_harness
|
| 497 |
+
from flow.harness.registry import ensure_harnesses_registered
|
| 498 |
+
|
| 499 |
+
ensure_harnesses_registered()
|
| 500 |
+
|
| 501 |
+
if workspace is None:
|
| 502 |
+
workspace = Path(tempfile.mkdtemp(prefix="flow_run_"))
|
| 503 |
+
|
| 504 |
+
harness = create_harness(self, workspace)
|
| 505 |
+
try:
|
| 506 |
+
async for event in harness.run_stream(task):
|
| 507 |
+
yield event
|
| 508 |
+
finally:
|
| 509 |
+
await harness.close()
|
| 510 |
+
|
| 511 |
+
async def deploy(self) -> str:
|
| 512 |
+
"""Register this agent in the Flow database.
|
| 513 |
+
|
| 514 |
+
Creates an AgentConfig row in the local SQLite DB (~/.flow/flow_ui.db).
|
| 515 |
+
No running server required — this is a pure DB write. After deploying,
|
| 516 |
+
all evaluate() and optimize() calls auto-persist results to the DB.
|
| 517 |
+
|
| 518 |
+
Run ``flow serve`` separately to browse results in the UI.
|
| 519 |
+
|
| 520 |
+
Returns:
|
| 521 |
+
The agent ID (UUID string)
|
| 522 |
+
|
| 523 |
+
Example:
|
| 524 |
+
agent = Agent(name="coding-agent", tools="standard")
|
| 525 |
+
agent_id = await agent.deploy()
|
| 526 |
+
# Results now auto-persist
|
| 527 |
+
result = await agent.evaluate(tasks="quick")
|
| 528 |
+
# Run `flow serve` to view at http://localhost:7860/agents/{agent_id}
|
| 529 |
+
"""
|
| 530 |
+
try:
|
| 531 |
+
from flow.ui.services.persistence_adapter import PersistenceAdapter
|
| 532 |
+
except ImportError as e:
|
| 533 |
+
raise ImportError(
|
| 534 |
+
"DB dependencies not available. Install flow with UI support "
|
| 535 |
+
"to use deploy(): pip install flow[ui] or uv sync"
|
| 536 |
+
) from e
|
| 537 |
+
|
| 538 |
+
adapter = PersistenceAdapter()
|
| 539 |
+
self._id = await adapter.deploy_agent(self)
|
| 540 |
+
return self._id
|
| 541 |
+
|
| 542 |
+
async def evaluate(
|
| 543 |
+
self,
|
| 544 |
+
tasks: str | list[Task] | Path = "quick",
|
| 545 |
+
*,
|
| 546 |
+
parallel: int = 4,
|
| 547 |
+
use_llm_eval: bool = True,
|
| 548 |
+
quiet: bool = False,
|
| 549 |
+
) -> EvaluationResult:
|
| 550 |
+
"""Evaluate this agent on a set of tasks.
|
| 551 |
+
|
| 552 |
+
If the agent has been deployed (via deploy()), results are
|
| 553 |
+
automatically persisted to the database.
|
| 554 |
+
|
| 555 |
+
Args:
|
| 556 |
+
tasks: Task specification - suite name (str like "quick", "coding"),
|
| 557 |
+
list of Task objects, or Path to JSONL file
|
| 558 |
+
parallel: Number of concurrent task executions
|
| 559 |
+
use_llm_eval: Whether to use LLM-as-Judge for scoring
|
| 560 |
+
quiet: Suppress verbose output
|
| 561 |
+
|
| 562 |
+
Returns:
|
| 563 |
+
EvaluationResult with score, tokens, pass_rate, etc.
|
| 564 |
+
|
| 565 |
+
Example:
|
| 566 |
+
agent = Agent(name="my-agent", tools="standard")
|
| 567 |
+
result = await agent.evaluate(tasks="quick")
|
| 568 |
+
print(f"Score: {result.score:.2f}, Pass rate: {result.pass_rate:.0%}")
|
| 569 |
+
"""
|
| 570 |
+
from .agent_api import _evaluate_agent_impl
|
| 571 |
+
|
| 572 |
+
return await _evaluate_agent_impl(
|
| 573 |
+
self, tasks, parallel, use_llm_eval, quiet, agent_id=self._id
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
async def optimize(
|
| 577 |
+
self,
|
| 578 |
+
tasks: str | list[Task] | Path = "quick",
|
| 579 |
+
*,
|
| 580 |
+
strategy: str | list[str] | None = None,
|
| 581 |
+
variations: dict[str, list[Any]] | None = None,
|
| 582 |
+
parallel: int = 4,
|
| 583 |
+
budget: int = 50,
|
| 584 |
+
use_llm_eval: bool = True,
|
| 585 |
+
quiet: bool = False,
|
| 586 |
+
) -> AgentOptimizationResult:
|
| 587 |
+
"""Optimize this agent's configuration.
|
| 588 |
+
|
| 589 |
+
Supports two modes:
|
| 590 |
+
- **Grid search** (default): Exhaustive search over parameter combinations
|
| 591 |
+
- **Active strategies**: Iterative evaluate-reflect-adjust optimization
|
| 592 |
+
|
| 593 |
+
If the agent has been deployed (via deploy()), results are
|
| 594 |
+
automatically persisted to the database.
|
| 595 |
+
|
| 596 |
+
Args:
|
| 597 |
+
tasks: Task specification - suite name (str), list of Tasks, or Path
|
| 598 |
+
strategy: Optimization strategy to use:
|
| 599 |
+
- None or "grid": Grid search over variations (default)
|
| 600 |
+
- "tools": Iteratively discover optimal tool configuration
|
| 601 |
+
- "instructions": Iteratively rewrite instructions from failures
|
| 602 |
+
- list: Run multiple strategies sequentially, e.g.
|
| 603 |
+
["instructions", "tools"] optimizes instructions first,
|
| 604 |
+
then tools starting from the improved agent
|
| 605 |
+
variations: Custom grid search variations (only used with grid strategy)
|
| 606 |
+
parallel: Number of concurrent experiments
|
| 607 |
+
budget: Maximum number of candidates to test
|
| 608 |
+
use_llm_eval: Whether to use LLM-as-Judge for scoring
|
| 609 |
+
quiet: Suppress verbose output
|
| 610 |
+
|
| 611 |
+
Returns:
|
| 612 |
+
AgentOptimizationResult with baseline, best, and improvement metrics
|
| 613 |
+
|
| 614 |
+
Example:
|
| 615 |
+
agent = Agent(name="my-agent", tools="standard")
|
| 616 |
+
|
| 617 |
+
# Grid search (default)
|
| 618 |
+
result = await agent.optimize(tasks="quick")
|
| 619 |
+
|
| 620 |
+
# Active: discover optimal tools
|
| 621 |
+
result = await agent.optimize(tasks="quick", strategy="tools")
|
| 622 |
+
|
| 623 |
+
# Active: improve instructions
|
| 624 |
+
result = await agent.optimize(tasks="quick", strategy="instructions")
|
| 625 |
+
|
| 626 |
+
# Pipeline: instructions first, then tools
|
| 627 |
+
result = await agent.optimize(
|
| 628 |
+
tasks="quick", strategy=["instructions", "tools"]
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
print(f"Best score: {result.best.score:.2f}")
|
| 632 |
+
optimized = result.best_agent
|
| 633 |
+
"""
|
| 634 |
+
from .agent_api import _optimize_agent_impl
|
| 635 |
+
|
| 636 |
+
return await _optimize_agent_impl(
|
| 637 |
+
self, tasks, variations, parallel, budget, use_llm_eval, quiet,
|
| 638 |
+
agent_id=self._id,
|
| 639 |
+
strategy=strategy,
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
@dataclass
|
| 644 |
+
class StrategyIteration:
|
| 645 |
+
"""One iteration of an active strategy's optimization loop.
|
| 646 |
+
|
| 647 |
+
Tracks what was tried, how it scored, and why the change was made.
|
| 648 |
+
Active strategies accumulate these to provide a full audit trail.
|
| 649 |
+
|
| 650 |
+
Attributes:
|
| 651 |
+
iteration: Iteration number (0 = baseline)
|
| 652 |
+
instructions_preview: First 200 chars of instructions used
|
| 653 |
+
full_instructions: Complete instructions text for this iteration
|
| 654 |
+
avg_score: Average score across tasks for this iteration
|
| 655 |
+
pass_rate: Fraction of tasks that passed
|
| 656 |
+
failures_count: Number of tasks that failed
|
| 657 |
+
change_description: What was changed (e.g., "Added bash timeout instructions")
|
| 658 |
+
change_rationale: Why the change was made (e.g., "3/5 tasks failed due to hanging bash commands")
|
| 659 |
+
"""
|
| 660 |
+
|
| 661 |
+
iteration: int
|
| 662 |
+
instructions_preview: str
|
| 663 |
+
avg_score: float
|
| 664 |
+
pass_rate: float
|
| 665 |
+
failures_count: int
|
| 666 |
+
change_description: str = ""
|
| 667 |
+
change_rationale: str = ""
|
| 668 |
+
full_instructions: str = ""
|
| 669 |
|
| 670 |
|
| 671 |
@dataclass
|
|
|
|
| 680 |
agent: The mutated agent configuration
|
| 681 |
mutations: Dict describing what was changed from the base
|
| 682 |
rationale: Human-readable explanation of why this candidate exists
|
| 683 |
+
optimization_history: Audit trail from active optimization strategies.
|
| 684 |
+
Each entry records one iteration of the optimization loop with
|
| 685 |
+
scores, failure counts, and descriptions of what changed and why.
|
| 686 |
"""
|
| 687 |
|
| 688 |
agent: Agent
|
| 689 |
mutations: dict[str, Any] = field(default_factory=dict)
|
| 690 |
rationale: str = ""
|
| 691 |
+
optimization_history: list[StrategyIteration] = field(default_factory=list)
|
| 692 |
|
| 693 |
|
| 694 |
@dataclass
|
|
|
|
| 704 |
traces: dict[str, Any] = field(default_factory=dict)
|
| 705 |
|
| 706 |
|
| 707 |
+
@runtime_checkable
|
| 708 |
+
class ExperimentRunner(Protocol):
|
| 709 |
+
"""Protocol for evaluating candidates against tasks.
|
| 710 |
+
|
| 711 |
+
This is the interface that active strategies use to test candidate
|
| 712 |
+
configurations. The FlowOptimizer implements this protocol, providing
|
| 713 |
+
strategies with access to the full execution pipeline (harness creation,
|
| 714 |
+
agent execution, trace collection, LLM evaluation, metrics extraction)
|
| 715 |
+
without exposing internal details.
|
| 716 |
+
|
| 717 |
+
Passive strategies (GridSearchStrategy, etc.) ignore this entirely.
|
| 718 |
+
Active strategies call evaluate() in a loop to iteratively refine
|
| 719 |
+
candidates based on real execution results.
|
| 720 |
+
|
| 721 |
+
The evaluate() method returns a CandidateSummary (from optimizer.py)
|
| 722 |
+
which contains:
|
| 723 |
+
- avg_score, pass_rate: Aggregate performance metrics
|
| 724 |
+
- task_results: list[TaskResult] — per-task details including:
|
| 725 |
+
- eval_reasoning: Why the evaluator scored it this way
|
| 726 |
+
- eval_score, eval_passed: Score and pass/fail status
|
| 727 |
+
- criteria_results: Per-criterion breakdown
|
| 728 |
+
- run_result.output: What the agent produced
|
| 729 |
+
- run_result.trace: Full OTel execution trace
|
| 730 |
+
- metrics: Token counts, tool usage, duration
|
| 731 |
+
"""
|
| 732 |
+
|
| 733 |
+
async def evaluate(
|
| 734 |
+
self,
|
| 735 |
+
candidate: Candidate,
|
| 736 |
+
tasks: list[Task],
|
| 737 |
+
) -> Any:
|
| 738 |
+
"""Evaluate a candidate on a set of tasks.
|
| 739 |
+
|
| 740 |
+
Args:
|
| 741 |
+
candidate: The candidate to evaluate
|
| 742 |
+
tasks: Tasks to run the candidate on
|
| 743 |
+
|
| 744 |
+
Returns:
|
| 745 |
+
CandidateSummary with aggregated scores and per-task details.
|
| 746 |
+
Typed as Any to avoid circular imports — the actual return type
|
| 747 |
+
is flow.experiments.optimizer.CandidateSummary.
|
| 748 |
+
"""
|
| 749 |
+
...
|
| 750 |
+
|
| 751 |
+
|
| 752 |
@runtime_checkable
|
| 753 |
class CandidateStrategy(Protocol):
|
| 754 |
"""Protocol for generating candidate variants from a base agent.
|
| 755 |
|
| 756 |
Implementations can be:
|
| 757 |
+
- Passive (single-shot): GridSearchStrategy ignores optional params
|
| 758 |
+
- Active (iterative): Uses runner to evaluate candidates, inspect failures,
|
| 759 |
+
and iteratively refine configurations based on real execution results
|
| 760 |
|
| 761 |
+
All logic is internal to the strategy — the caller just calls generate()
|
| 762 |
and receives the final list of candidates.
|
| 763 |
|
| 764 |
Examples:
|
| 765 |
- GridSearchStrategy: Exhaustive grid over parameter combinations
|
| 766 |
+
- (Future) InstructionOptimizer: Iteratively improves instructions from failures
|
| 767 |
- (Future) BayesianStrategy: Bayesian optimization over parameters
|
| 768 |
"""
|
| 769 |
|
| 770 |
+
async def generate(
|
| 771 |
self,
|
| 772 |
base: Agent,
|
| 773 |
budget: int,
|
| 774 |
*,
|
| 775 |
tasks: list[Task] | None = None,
|
| 776 |
+
runner: ExperimentRunner | None = None,
|
|
|
|
| 777 |
) -> list[Candidate]:
|
| 778 |
"""Generate candidate variants from a base agent.
|
| 779 |
|
| 780 |
Args:
|
| 781 |
base: The base agent to optimize
|
| 782 |
budget: Maximum number of candidates to return
|
| 783 |
+
tasks: Optional tasks for active strategies that run internal experiments
|
| 784 |
+
runner: Optional experiment runner for active strategies.
|
| 785 |
+
Active strategies call runner.evaluate(candidate, tasks)
|
| 786 |
+
to test candidates and use results to guide optimization.
|
| 787 |
+
Passive strategies ignore this parameter.
|
| 788 |
|
| 789 |
Returns:
|
| 790 |
List of Candidate objects (at most `budget` items).
|
| 791 |
For iterative strategies, returns the final/best candidates after
|
| 792 |
+
internal optimization loops complete. Candidates may include
|
| 793 |
+
optimization_history with per-iteration audit trail.
|
| 794 |
"""
|
| 795 |
...
|
| 796 |
|
|
|
|
| 824 |
"""
|
| 825 |
self.variations = variations
|
| 826 |
|
| 827 |
+
async def generate(
|
| 828 |
self,
|
| 829 |
base: Agent,
|
| 830 |
budget: int,
|
| 831 |
*,
|
| 832 |
tasks: list[Task] | None = None,
|
| 833 |
+
runner: ExperimentRunner | None = None,
|
|
|
|
| 834 |
) -> list[Candidate]:
|
| 835 |
"""Generate all grid combinations up to budget.
|
| 836 |
|
| 837 |
+
Note: tasks and runner are accepted for protocol compatibility but
|
| 838 |
+
ignored — GridSearchStrategy is a passive strategy that doesn't
|
| 839 |
+
run experiments internally.
|
| 840 |
"""
|
| 841 |
# Delete unused params to satisfy linters
|
| 842 |
+
del tasks, runner
|
| 843 |
|
| 844 |
if not self.variations:
|
| 845 |
return [Candidate(agent=base, mutations={}, rationale="baseline")]
|
|
|
|
| 1094 |
# =============================================================================
|
| 1095 |
|
| 1096 |
|
| 1097 |
+
@dataclass
|
| 1098 |
+
class LiteralVariation:
|
| 1099 |
+
"""A literal/static variation value.
|
| 1100 |
+
|
| 1101 |
+
Used for predefined values like tool presets, compaction configs, etc.
|
| 1102 |
+
"""
|
| 1103 |
+
|
| 1104 |
+
value: Any
|
| 1105 |
+
|
| 1106 |
+
|
| 1107 |
+
@dataclass
|
| 1108 |
+
class StrategyVariation:
|
| 1109 |
+
"""A strategy that generates variation values dynamically.
|
| 1110 |
+
|
| 1111 |
+
Used for active optimization strategies like GEPA (instructions)
|
| 1112 |
+
or agentic tool selection.
|
| 1113 |
+
|
| 1114 |
+
Attributes:
|
| 1115 |
+
strategy: Strategy name (e.g., "gepa", "agentic")
|
| 1116 |
+
max_candidates: Number of candidates this strategy will produce
|
| 1117 |
+
config: Strategy-specific configuration
|
| 1118 |
+
"""
|
| 1119 |
+
|
| 1120 |
+
strategy: str
|
| 1121 |
+
max_candidates: int = 1
|
| 1122 |
+
config: dict[str, Any] = field(default_factory=dict)
|
| 1123 |
+
|
| 1124 |
+
|
| 1125 |
+
# Union type for variation items
|
| 1126 |
+
VariationItem = LiteralVariation | StrategyVariation
|
| 1127 |
+
|
| 1128 |
+
|
| 1129 |
@dataclass
|
| 1130 |
class Experiment:
|
| 1131 |
"""Experiment configuration for optimization.
|
|
|
|
| 1135 |
- Experiment YAML: How to test it (variations, tasks, evaluation settings)
|
| 1136 |
|
| 1137 |
Attributes:
|
| 1138 |
+
base_agent: Path to base agent YAML file (required)
|
| 1139 |
suite: Built-in task suite name (e.g., "coding", "quick")
|
| 1140 |
tasks: Path to custom tasks JSONL file (alternative to suite)
|
| 1141 |
+
variations: Dict mapping dimension names to lists of VariationItems
|
| 1142 |
parallel: Max concurrent experiments
|
| 1143 |
+
budget: Maximum candidates to generate (safety limit)
|
| 1144 |
use_llm_eval: Whether to use LLM-as-Judge evaluation
|
| 1145 |
|
| 1146 |
Example YAML:
|
| 1147 |
```yaml
|
| 1148 |
+
base_agent: agents/coder.yaml
|
| 1149 |
suite: coding
|
| 1150 |
|
| 1151 |
variations:
|
| 1152 |
+
instructions:
|
| 1153 |
+
# Literal values
|
| 1154 |
+
- "You are a helpful coding assistant"
|
| 1155 |
+
- file: prompts/expert.md
|
| 1156 |
+
# Strategy (active optimization)
|
| 1157 |
+
- strategy: gepa
|
| 1158 |
+
max_candidates: 3
|
| 1159 |
+
config:
|
| 1160 |
+
reflection_lm: gpt-4o
|
| 1161 |
|
| 1162 |
tools:
|
| 1163 |
- minimal
|
| 1164 |
- standard
|
| 1165 |
+
- strategy: agentic
|
| 1166 |
+
max_candidates: 2
|
| 1167 |
+
|
| 1168 |
+
compaction:
|
| 1169 |
+
- strategy: none
|
| 1170 |
+
- strategy: sliding_window
|
| 1171 |
+
token_budget: 50000
|
| 1172 |
|
| 1173 |
+
parallel: 8
|
| 1174 |
+
budget: 100
|
| 1175 |
use_llm_eval: true
|
| 1176 |
```
|
| 1177 |
"""
|
| 1178 |
|
| 1179 |
+
base_agent: str
|
| 1180 |
suite: str | None = None
|
| 1181 |
tasks: str | None = None
|
| 1182 |
+
variations: dict[str, list[VariationItem]] = field(default_factory=dict)
|
| 1183 |
parallel: int = 4
|
| 1184 |
budget: int = 100
|
| 1185 |
use_llm_eval: bool = True
|
| 1186 |
|
| 1187 |
|
| 1188 |
+
def compute_max_experiments(variations: dict[str, list[VariationItem]]) -> int:
|
| 1189 |
+
"""Compute maximum number of experiments from variations.
|
| 1190 |
+
|
| 1191 |
+
Each dimension contributes its count (literals + strategy max_candidates),
|
| 1192 |
+
and total is the Cartesian product.
|
| 1193 |
+
|
| 1194 |
+
Args:
|
| 1195 |
+
variations: Parsed variations dict
|
| 1196 |
+
|
| 1197 |
+
Returns:
|
| 1198 |
+
Maximum number of experiments
|
| 1199 |
+
"""
|
| 1200 |
+
if not variations:
|
| 1201 |
+
return 1
|
| 1202 |
+
|
| 1203 |
+
import math
|
| 1204 |
+
|
| 1205 |
+
counts = []
|
| 1206 |
+
for items in variations.values():
|
| 1207 |
+
dim_count = 0
|
| 1208 |
+
for item in items:
|
| 1209 |
+
if isinstance(item, StrategyVariation):
|
| 1210 |
+
dim_count += item.max_candidates
|
| 1211 |
+
else:
|
| 1212 |
+
dim_count += 1
|
| 1213 |
+
counts.append(max(dim_count, 1))
|
| 1214 |
+
|
| 1215 |
+
return math.prod(counts)
|
| 1216 |
+
|
| 1217 |
+
|
| 1218 |
+
def _parse_literal_value(dimension: str, value: Any) -> Any:
|
| 1219 |
+
"""Parse a literal value for a specific dimension.
|
| 1220 |
+
|
| 1221 |
+
Handles special cases like compaction configs and file references.
|
| 1222 |
+
|
| 1223 |
+
Args:
|
| 1224 |
+
dimension: The dimension name (e.g., "compaction", "tools")
|
| 1225 |
+
value: The raw value from YAML
|
| 1226 |
+
|
| 1227 |
+
Returns:
|
| 1228 |
+
Parsed value appropriate for the dimension
|
| 1229 |
+
"""
|
| 1230 |
+
# Handle file references
|
| 1231 |
+
if isinstance(value, dict) and "file" in value:
|
| 1232 |
+
file_path = Path(value["file"])
|
| 1233 |
+
if file_path.exists():
|
| 1234 |
+
return file_path.read_text()
|
| 1235 |
+
# If relative path, caller should resolve it
|
| 1236 |
+
return value
|
| 1237 |
+
|
| 1238 |
+
# Handle compaction dimension
|
| 1239 |
+
if dimension == "compaction":
|
| 1240 |
+
if isinstance(value, dict):
|
| 1241 |
+
# Dict with strategy key is a CompactionConfig
|
| 1242 |
+
return CompactionConfig(**value)
|
| 1243 |
+
elif isinstance(value, str):
|
| 1244 |
+
# Shorthand: "none", "head_tail", etc.
|
| 1245 |
+
if value == "none":
|
| 1246 |
+
return CompactionConfig.none()
|
| 1247 |
+
elif value == "head_tail":
|
| 1248 |
+
return CompactionConfig.head_tail()
|
| 1249 |
+
elif value == "sliding_window":
|
| 1250 |
+
return CompactionConfig.sliding_window()
|
| 1251 |
+
elif value == "summarization":
|
| 1252 |
+
return CompactionConfig.summarization()
|
| 1253 |
+
else:
|
| 1254 |
+
raise ValueError(f"Unknown compaction shorthand: {value}")
|
| 1255 |
+
|
| 1256 |
+
# All other values pass through as-is
|
| 1257 |
+
return value
|
| 1258 |
+
|
| 1259 |
+
|
| 1260 |
+
# Known compaction strategy names (these are NOT optimization strategies)
|
| 1261 |
+
_COMPACTION_STRATEGY_NAMES = {"none", "head_tail", "sliding_window", "summarization", "last_n", "head_tail_tokens"}
|
| 1262 |
+
|
| 1263 |
+
|
| 1264 |
+
def _is_strategy_variation(item: Any) -> bool:
|
| 1265 |
+
"""Check if an item is a StrategyVariation (optimization strategy).
|
| 1266 |
+
|
| 1267 |
+
Distinguishes between:
|
| 1268 |
+
- StrategyVariation: {"strategy": "gepa", "max_candidates": 3, "config": {...}}
|
| 1269 |
+
- Compaction literal: {"strategy": "sliding_window", "token_budget": 50000}
|
| 1270 |
+
|
| 1271 |
+
The key difference:
|
| 1272 |
+
- Optimization strategies have max_candidates or config keys
|
| 1273 |
+
- Compaction configs have strategy names like "none", "sliding_window", etc.
|
| 1274 |
+
|
| 1275 |
+
Args:
|
| 1276 |
+
item: The raw item from YAML
|
| 1277 |
+
|
| 1278 |
+
Returns:
|
| 1279 |
+
True if this is a StrategyVariation, False if literal
|
| 1280 |
+
"""
|
| 1281 |
+
if not isinstance(item, dict):
|
| 1282 |
+
return False
|
| 1283 |
+
|
| 1284 |
+
if "strategy" not in item:
|
| 1285 |
+
return False
|
| 1286 |
+
|
| 1287 |
+
strategy_name = item["strategy"]
|
| 1288 |
+
|
| 1289 |
+
# If it has max_candidates or config, it's definitely an optimization strategy
|
| 1290 |
+
if "max_candidates" in item or "config" in item:
|
| 1291 |
+
return True
|
| 1292 |
+
|
| 1293 |
+
# If the strategy name is a known compaction strategy, it's a literal
|
| 1294 |
+
if strategy_name in _COMPACTION_STRATEGY_NAMES:
|
| 1295 |
+
return False
|
| 1296 |
+
|
| 1297 |
+
# Otherwise assume it's an optimization strategy (will fail at runtime if invalid)
|
| 1298 |
+
return True
|
| 1299 |
+
|
| 1300 |
+
|
| 1301 |
def load_experiment(path: Path) -> Experiment:
|
| 1302 |
"""Load an Experiment from a YAML file.
|
| 1303 |
|
| 1304 |
+
Parses variations into VariationItem objects (LiteralVariation or StrategyVariation).
|
| 1305 |
|
| 1306 |
Args:
|
| 1307 |
path: Path to the experiment YAML file
|
|
|
|
| 1318 |
|
| 1319 |
data = yaml.safe_load(path.read_text())
|
| 1320 |
|
| 1321 |
+
# Validate required fields
|
| 1322 |
+
if "base_agent" not in data:
|
| 1323 |
+
raise ValueError("Experiment YAML must specify 'base_agent'")
|
| 1324 |
+
|
| 1325 |
+
# Parse variations into VariationItem objects
|
| 1326 |
+
variations: dict[str, list[VariationItem]] = {}
|
| 1327 |
raw_variations = data.get("variations", {})
|
| 1328 |
|
| 1329 |
+
for dimension, items in raw_variations.items():
|
| 1330 |
+
parsed_items: list[VariationItem] = []
|
| 1331 |
+
|
| 1332 |
+
for item in items:
|
| 1333 |
+
if _is_strategy_variation(item):
|
| 1334 |
+
# This is a StrategyVariation (optimization strategy like "gepa")
|
| 1335 |
+
parsed_items.append(StrategyVariation(
|
| 1336 |
+
strategy=item["strategy"],
|
| 1337 |
+
max_candidates=item.get("max_candidates", 1),
|
| 1338 |
+
config=item.get("config", {}),
|
| 1339 |
+
))
|
| 1340 |
+
else:
|
| 1341 |
+
# This is a LiteralVariation
|
| 1342 |
+
parsed_value = _parse_literal_value(dimension, item)
|
| 1343 |
+
parsed_items.append(LiteralVariation(value=parsed_value))
|
| 1344 |
+
|
| 1345 |
+
variations[dimension] = parsed_items
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1346 |
|
| 1347 |
return Experiment(
|
| 1348 |
+
base_agent=data["base_agent"],
|
| 1349 |
suite=data.get("suite"),
|
| 1350 |
tasks=data.get("tasks"),
|
| 1351 |
variations=variations,
|
src/flow/experiments/optimizer.py
CHANGED
|
@@ -24,11 +24,13 @@ from .ablation import compute_pareto_frontier
|
|
| 24 |
from .evaluators import LLMEvaluator
|
| 25 |
from .metrics import TraceMetrics, extract_metrics
|
| 26 |
from .models import (
|
|
|
|
| 27 |
Candidate,
|
| 28 |
export_optimization_results,
|
| 29 |
)
|
| 30 |
from .runner import FlowExperimentRunner, setup_tracing
|
| 31 |
-
from .types import RunResult, Task
|
|
|
|
| 32 |
|
| 33 |
logger = logging.getLogger(__name__)
|
| 34 |
|
|
@@ -45,6 +47,7 @@ class TaskResult:
|
|
| 45 |
eval_passed: bool
|
| 46 |
eval_reasoning: str
|
| 47 |
criteria_results: list[dict[str, Any]] = field(default_factory=list) # Per-criterion scores
|
|
|
|
| 48 |
|
| 49 |
|
| 50 |
@dataclass
|
|
@@ -57,6 +60,7 @@ class CandidateSummary:
|
|
| 57 |
|
| 58 |
# Aggregated metrics
|
| 59 |
avg_score: float = 0.0
|
|
|
|
| 60 |
avg_tokens: float = 0.0
|
| 61 |
avg_duration: float = 0.0
|
| 62 |
pass_rate: float = 0.0
|
|
@@ -69,12 +73,17 @@ class CandidateSummary:
|
|
| 69 |
|
| 70 |
def to_dict(self) -> dict[str, Any]:
|
| 71 |
"""Convert to dictionary for serialization."""
|
|
|
|
|
|
|
|
|
|
| 72 |
return {
|
| 73 |
"name": self.name,
|
|
|
|
| 74 |
"agent": asdict(self.candidate.agent),
|
| 75 |
"mutations": self.candidate.mutations,
|
| 76 |
"rationale": self.candidate.rationale,
|
| 77 |
"avg_score": self.avg_score,
|
|
|
|
| 78 |
"avg_tokens": self.avg_tokens,
|
| 79 |
"avg_duration": self.avg_duration,
|
| 80 |
"pass_rate": self.pass_rate,
|
|
@@ -82,15 +91,22 @@ class CandidateSummary:
|
|
| 82 |
"task_count": self.task_count,
|
| 83 |
"pareto_rank": self.pareto_rank,
|
| 84 |
"is_pareto_optimal": self.is_pareto_optimal,
|
| 85 |
-
# Include per-task results with
|
| 86 |
"task_results": [
|
| 87 |
{
|
| 88 |
"task_name": tr.task_name,
|
|
|
|
|
|
|
| 89 |
"eval_score": tr.eval_score,
|
|
|
|
| 90 |
"eval_passed": tr.eval_passed,
|
| 91 |
"eval_reasoning": tr.eval_reasoning,
|
|
|
|
| 92 |
"tokens": tr.metrics.total_tokens,
|
| 93 |
"duration": tr.run_result.duration_seconds,
|
|
|
|
|
|
|
|
|
|
| 94 |
}
|
| 95 |
for tr in self.task_results
|
| 96 |
],
|
|
@@ -149,7 +165,7 @@ class FlowOptimizer:
|
|
| 149 |
})
|
| 150 |
optimizer = FlowOptimizer(parallel=4)
|
| 151 |
base = Agent(name="my_agent")
|
| 152 |
-
candidates = strategy.generate(base, budget=10)
|
| 153 |
result = await optimizer.optimize(candidates, tasks)
|
| 154 |
print(f"Best: {result.rank_by_score[0]}")
|
| 155 |
"""
|
|
@@ -164,11 +180,16 @@ class FlowOptimizer:
|
|
| 164 |
self.use_llm_evaluator = use_llm_evaluator
|
| 165 |
self.output_dir = output_dir or Path.home() / ".flow" / "optimizations"
|
| 166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
async def optimize(
|
| 168 |
self,
|
| 169 |
candidates: list[Candidate],
|
| 170 |
tasks: list[Task],
|
| 171 |
progress_callback: Callable[[int, int, str, str], None] | None = None,
|
|
|
|
| 172 |
) -> OptimizationResult:
|
| 173 |
"""Run optimization across all candidates and tasks.
|
| 174 |
|
|
@@ -176,13 +197,15 @@ class FlowOptimizer:
|
|
| 176 |
candidates: Candidates to test
|
| 177 |
tasks: Tasks to run each candidate on
|
| 178 |
progress_callback: Optional callback(completed, total, candidate_name, task_name)
|
|
|
|
| 179 |
|
| 180 |
Returns:
|
| 181 |
OptimizationResult with rankings and exported agents
|
| 182 |
"""
|
| 183 |
start_time = datetime.now()
|
| 184 |
timestamp = start_time.strftime("%Y%m%d_%H%M%S")
|
| 185 |
-
run_dir
|
|
|
|
| 186 |
run_dir.mkdir(parents=True, exist_ok=True)
|
| 187 |
|
| 188 |
setup_tracing("flow-optimizer")
|
|
@@ -202,6 +225,127 @@ class FlowOptimizer:
|
|
| 202 |
if self.use_llm_evaluator:
|
| 203 |
evaluator = self._create_evaluator()
|
| 204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
task_results = await self._run_parallel(
|
| 206 |
candidates, tasks, run_dir, evaluator, progress_callback
|
| 207 |
)
|
|
@@ -266,10 +410,11 @@ class FlowOptimizer:
|
|
| 266 |
|
| 267 |
async with lock:
|
| 268 |
completed += 1
|
| 269 |
-
status = "
|
| 270 |
print(
|
| 271 |
f" [{completed}/{total}] {candidate.agent.name}/{task.name}: "
|
| 272 |
f"{status} score={result.eval_score:.2f} "
|
|
|
|
| 273 |
f"tokens={result.metrics.total_tokens:,}"
|
| 274 |
)
|
| 275 |
if progress_callback:
|
|
@@ -289,6 +434,43 @@ class FlowOptimizer:
|
|
| 289 |
|
| 290 |
return valid_results
|
| 291 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
async def _run_single(
|
| 293 |
self,
|
| 294 |
candidate: Candidate,
|
|
@@ -298,9 +480,13 @@ class FlowOptimizer:
|
|
| 298 |
) -> TaskResult:
|
| 299 |
"""Run a single candidate-task experiment."""
|
| 300 |
# Import harness modules to register them, then use registry
|
| 301 |
-
import flow.harness.maf
|
|
|
|
|
|
|
| 302 |
try:
|
| 303 |
-
import flow.harness.miniagent
|
|
|
|
|
|
|
| 304 |
except ImportError:
|
| 305 |
pass # miniagent harness is optional
|
| 306 |
from flow.harness import create_harness
|
|
@@ -313,16 +499,24 @@ class FlowOptimizer:
|
|
| 313 |
metrics = extract_metrics(run_result.trace)
|
| 314 |
|
| 315 |
criteria_results: list[dict[str, Any]] = []
|
|
|
|
| 316 |
if evaluator:
|
| 317 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
eval_score = eval_result.score
|
| 319 |
eval_passed = eval_result.passed
|
| 320 |
eval_reasoning = eval_result.reasoning
|
|
|
|
| 321 |
# Convert criteria results to dicts for serialization
|
| 322 |
criteria_results = [
|
| 323 |
{
|
| 324 |
"name": cr.name,
|
| 325 |
"score": cr.score,
|
|
|
|
| 326 |
"passed": cr.passed,
|
| 327 |
"reasoning": cr.reasoning,
|
| 328 |
}
|
|
@@ -342,6 +536,7 @@ class FlowOptimizer:
|
|
| 342 |
eval_passed=eval_passed,
|
| 343 |
eval_reasoning=eval_reasoning,
|
| 344 |
criteria_results=criteria_results,
|
|
|
|
| 345 |
)
|
| 346 |
finally:
|
| 347 |
await harness.close()
|
|
@@ -370,6 +565,7 @@ class FlowOptimizer:
|
|
| 370 |
candidate=candidate,
|
| 371 |
task_results=results,
|
| 372 |
avg_score=sum(r.eval_score for r in results) / len(results),
|
|
|
|
| 373 |
avg_tokens=sum(r.metrics.total_tokens for r in results) / len(results),
|
| 374 |
avg_duration=sum(r.run_result.duration_seconds for r in results) / len(results),
|
| 375 |
pass_rate=sum(1 for r in results if r.eval_passed) / len(results),
|
|
@@ -425,7 +621,7 @@ class FlowOptimizer:
|
|
| 425 |
logger.info("Creating AsyncAzureOpenAI client for evaluator")
|
| 426 |
client = AsyncAzureOpenAI(
|
| 427 |
api_key=api_key,
|
| 428 |
-
api_version="2024-
|
| 429 |
azure_endpoint=endpoint,
|
| 430 |
)
|
| 431 |
|
|
@@ -480,13 +676,14 @@ class FlowOptimizer:
|
|
| 480 |
print(" OPTIMIZATION RESULTS")
|
| 481 |
print("=" * 70)
|
| 482 |
|
| 483 |
-
print(f"\n{'Candidate':<30} | {'Score':>8} | {'Tokens':>10} | {'Pareto':>8}")
|
| 484 |
-
print("-" *
|
| 485 |
|
| 486 |
for summary in sorted(result.summaries, key=lambda s: s.avg_score, reverse=True):
|
| 487 |
-
pareto = "
|
| 488 |
print(
|
| 489 |
f"{summary.name:<30} | {summary.avg_score:>8.2f} | "
|
|
|
|
| 490 |
f"{summary.avg_tokens:>10,.0f} | {pareto:>8}"
|
| 491 |
)
|
| 492 |
|
|
@@ -510,3 +707,64 @@ def load_tasks_from_jsonl(path: Path) -> list[Task]:
|
|
| 510 |
List of Task objects
|
| 511 |
"""
|
| 512 |
return _load_tasks_impl(path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
from .evaluators import LLMEvaluator
|
| 25 |
from .metrics import TraceMetrics, extract_metrics
|
| 26 |
from .models import (
|
| 27 |
+
Agent,
|
| 28 |
Candidate,
|
| 29 |
export_optimization_results,
|
| 30 |
)
|
| 31 |
from .runner import FlowExperimentRunner, setup_tracing
|
| 32 |
+
from .types import RunResult, Task
|
| 33 |
+
from .types import load_tasks_from_jsonl as _load_tasks_impl
|
| 34 |
|
| 35 |
logger = logging.getLogger(__name__)
|
| 36 |
|
|
|
|
| 47 |
eval_passed: bool
|
| 48 |
eval_reasoning: str
|
| 49 |
criteria_results: list[dict[str, Any]] = field(default_factory=list) # Per-criterion scores
|
| 50 |
+
eval_reasoning_score: float = 0.0 # Partial credit for correct methodology
|
| 51 |
|
| 52 |
|
| 53 |
@dataclass
|
|
|
|
| 60 |
|
| 61 |
# Aggregated metrics
|
| 62 |
avg_score: float = 0.0
|
| 63 |
+
avg_reasoning_score: float = 0.0
|
| 64 |
avg_tokens: float = 0.0
|
| 65 |
avg_duration: float = 0.0
|
| 66 |
pass_rate: float = 0.0
|
|
|
|
| 73 |
|
| 74 |
def to_dict(self) -> dict[str, Any]:
|
| 75 |
"""Convert to dictionary for serialization."""
|
| 76 |
+
# Extract candidate_id from mutations if available (set by GEPA adapter)
|
| 77 |
+
candidate_id = self.candidate.mutations.get("_candidate_id", None)
|
| 78 |
+
|
| 79 |
return {
|
| 80 |
"name": self.name,
|
| 81 |
+
"candidate_id": candidate_id,
|
| 82 |
"agent": asdict(self.candidate.agent),
|
| 83 |
"mutations": self.candidate.mutations,
|
| 84 |
"rationale": self.candidate.rationale,
|
| 85 |
"avg_score": self.avg_score,
|
| 86 |
+
"avg_reasoning_score": self.avg_reasoning_score,
|
| 87 |
"avg_tokens": self.avg_tokens,
|
| 88 |
"avg_duration": self.avg_duration,
|
| 89 |
"pass_rate": self.pass_rate,
|
|
|
|
| 91 |
"task_count": self.task_count,
|
| 92 |
"pareto_rank": self.pareto_rank,
|
| 93 |
"is_pareto_optimal": self.is_pareto_optimal,
|
| 94 |
+
# Include per-task results with full agent output and trace
|
| 95 |
"task_results": [
|
| 96 |
{
|
| 97 |
"task_name": tr.task_name,
|
| 98 |
+
"task_prompt": tr.run_result.task.prompt,
|
| 99 |
+
"agent_output": tr.run_result.output,
|
| 100 |
"eval_score": tr.eval_score,
|
| 101 |
+
"eval_reasoning_score": tr.eval_reasoning_score,
|
| 102 |
"eval_passed": tr.eval_passed,
|
| 103 |
"eval_reasoning": tr.eval_reasoning,
|
| 104 |
+
"criteria_results": tr.criteria_results,
|
| 105 |
"tokens": tr.metrics.total_tokens,
|
| 106 |
"duration": tr.run_result.duration_seconds,
|
| 107 |
+
"files_created": tr.run_result.files_created,
|
| 108 |
+
"tool_results": tr.run_result.tool_results,
|
| 109 |
+
"trace": tr.run_result.trace,
|
| 110 |
}
|
| 111 |
for tr in self.task_results
|
| 112 |
],
|
|
|
|
| 165 |
})
|
| 166 |
optimizer = FlowOptimizer(parallel=4)
|
| 167 |
base = Agent(name="my_agent")
|
| 168 |
+
candidates = await strategy.generate(base, budget=10)
|
| 169 |
result = await optimizer.optimize(candidates, tasks)
|
| 170 |
print(f"Best: {result.rank_by_score[0]}")
|
| 171 |
"""
|
|
|
|
| 180 |
self.use_llm_evaluator = use_llm_evaluator
|
| 181 |
self.output_dir = output_dir or Path.home() / ".flow" / "optimizations"
|
| 182 |
|
| 183 |
+
# Internal state set during optimize() for use by evaluate()
|
| 184 |
+
self._evaluator: LLMEvaluator | None = None
|
| 185 |
+
self._run_dir: Path | None = None
|
| 186 |
+
|
| 187 |
async def optimize(
|
| 188 |
self,
|
| 189 |
candidates: list[Candidate],
|
| 190 |
tasks: list[Task],
|
| 191 |
progress_callback: Callable[[int, int, str, str], None] | None = None,
|
| 192 |
+
run_dir: Path | None = None,
|
| 193 |
) -> OptimizationResult:
|
| 194 |
"""Run optimization across all candidates and tasks.
|
| 195 |
|
|
|
|
| 197 |
candidates: Candidates to test
|
| 198 |
tasks: Tasks to run each candidate on
|
| 199 |
progress_callback: Optional callback(completed, total, candidate_name, task_name)
|
| 200 |
+
run_dir: Optional fixed directory for this run. If None, creates timestamped subdir.
|
| 201 |
|
| 202 |
Returns:
|
| 203 |
OptimizationResult with rankings and exported agents
|
| 204 |
"""
|
| 205 |
start_time = datetime.now()
|
| 206 |
timestamp = start_time.strftime("%Y%m%d_%H%M%S")
|
| 207 |
+
if run_dir is None:
|
| 208 |
+
run_dir = self.output_dir / timestamp
|
| 209 |
run_dir.mkdir(parents=True, exist_ok=True)
|
| 210 |
|
| 211 |
setup_tracing("flow-optimizer")
|
|
|
|
| 225 |
if self.use_llm_evaluator:
|
| 226 |
evaluator = self._create_evaluator()
|
| 227 |
|
| 228 |
+
# Store for use by evaluate() (ExperimentRunner protocol)
|
| 229 |
+
self._evaluator = evaluator
|
| 230 |
+
self._run_dir = run_dir
|
| 231 |
+
|
| 232 |
+
task_results = await self._run_parallel(
|
| 233 |
+
candidates, tasks, run_dir, evaluator, progress_callback
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
summaries = self._aggregate_results(task_results, candidates)
|
| 237 |
+
pareto_names = self._compute_pareto(summaries)
|
| 238 |
+
|
| 239 |
+
rank_by_score = sorted(summaries, key=lambda s: s.avg_score, reverse=True)
|
| 240 |
+
rank_by_tokens = sorted(summaries, key=lambda s: s.avg_tokens)
|
| 241 |
+
rank_by_efficiency = sorted(
|
| 242 |
+
summaries,
|
| 243 |
+
key=lambda s: s.avg_score / max(s.avg_tokens, 1),
|
| 244 |
+
reverse=True,
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
summary_dicts = [s.to_dict() for s in summaries]
|
| 248 |
+
exported = export_optimization_results(
|
| 249 |
+
summary_dicts, pareto_names, run_dir, timestamp
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
end_time = datetime.now()
|
| 253 |
+
|
| 254 |
+
result = OptimizationResult(
|
| 255 |
+
timestamp=timestamp,
|
| 256 |
+
output_dir=run_dir,
|
| 257 |
+
summaries=summaries,
|
| 258 |
+
pareto_frontier=pareto_names,
|
| 259 |
+
exported_agents=exported,
|
| 260 |
+
rank_by_score=[s.name for s in rank_by_score],
|
| 261 |
+
rank_by_tokens=[s.name for s in rank_by_tokens],
|
| 262 |
+
rank_by_efficiency=[s.name for s in rank_by_efficiency],
|
| 263 |
+
total_experiments=len(task_results),
|
| 264 |
+
total_duration_seconds=(end_time - start_time).total_seconds(),
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
self._save_results(result, run_dir)
|
| 268 |
+
self._print_summary(result)
|
| 269 |
+
|
| 270 |
+
return result
|
| 271 |
+
|
| 272 |
+
async def optimize_with_strategy(
|
| 273 |
+
self,
|
| 274 |
+
strategy: Any, # CandidateStrategy
|
| 275 |
+
base: Agent,
|
| 276 |
+
tasks: list[Task],
|
| 277 |
+
budget: int = 50,
|
| 278 |
+
progress_callback: Callable[[int, int, str, str], None] | None = None,
|
| 279 |
+
run_dir: Path | None = None,
|
| 280 |
+
) -> OptimizationResult:
|
| 281 |
+
"""Run optimization using a CandidateStrategy.
|
| 282 |
+
|
| 283 |
+
This is the entry point for strategy-driven optimization. It:
|
| 284 |
+
1. Sets up infrastructure (evaluator, tracing, output dir)
|
| 285 |
+
2. Passes self as ExperimentRunner to the strategy
|
| 286 |
+
3. Runs the strategy's generate() to get candidates
|
| 287 |
+
4. Does a final evaluation of returned candidates
|
| 288 |
+
5. Performs Pareto analysis and exports results
|
| 289 |
+
|
| 290 |
+
For active strategies, the strategy will call self.evaluate()
|
| 291 |
+
during generate() to test candidates iteratively.
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
strategy: A CandidateStrategy implementation
|
| 295 |
+
base: Base agent to optimize
|
| 296 |
+
tasks: Tasks to evaluate candidates on
|
| 297 |
+
budget: Maximum candidates for the strategy to produce
|
| 298 |
+
progress_callback: Optional callback(completed, total, candidate, task)
|
| 299 |
+
run_dir: Optional fixed output directory
|
| 300 |
+
|
| 301 |
+
Returns:
|
| 302 |
+
OptimizationResult with rankings and exported agents
|
| 303 |
+
"""
|
| 304 |
+
start_time = datetime.now()
|
| 305 |
+
timestamp = start_time.strftime("%Y%m%d_%H%M%S")
|
| 306 |
+
if run_dir is None:
|
| 307 |
+
run_dir = self.output_dir / timestamp
|
| 308 |
+
run_dir.mkdir(parents=True, exist_ok=True)
|
| 309 |
+
|
| 310 |
+
setup_tracing("flow-optimizer")
|
| 311 |
+
|
| 312 |
+
# Set up evaluator and store state for evaluate()
|
| 313 |
+
evaluator = None
|
| 314 |
+
if self.use_llm_evaluator:
|
| 315 |
+
evaluator = self._create_evaluator()
|
| 316 |
+
self._evaluator = evaluator
|
| 317 |
+
self._run_dir = run_dir
|
| 318 |
+
|
| 319 |
+
print("=" * 70)
|
| 320 |
+
print(" FLOW OPTIMIZER (Strategy Mode)")
|
| 321 |
+
print("=" * 70)
|
| 322 |
+
print(f" Strategy: {type(strategy).__name__}")
|
| 323 |
+
print(f" Base Agent: {base.name}")
|
| 324 |
+
print(f" Tasks: {len(tasks)}")
|
| 325 |
+
print(f" Budget: {budget}")
|
| 326 |
+
print(f" Parallel: {self.parallel}")
|
| 327 |
+
print(f" Output: {run_dir}")
|
| 328 |
+
print("=" * 70)
|
| 329 |
+
|
| 330 |
+
# Pass self as runner — FlowOptimizer implements the ExperimentRunner
|
| 331 |
+
# protocol via the evaluate() method above
|
| 332 |
+
candidates = await strategy.generate(
|
| 333 |
+
base=base,
|
| 334 |
+
budget=budget,
|
| 335 |
+
tasks=tasks,
|
| 336 |
+
runner=self,
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
if not candidates:
|
| 340 |
+
logger.warning("Strategy produced no candidates")
|
| 341 |
+
candidates = [Candidate(agent=base, mutations={}, rationale="baseline (strategy produced none)")]
|
| 342 |
+
|
| 343 |
+
print(f"\nStrategy produced {len(candidates)} candidates. Running final evaluation...")
|
| 344 |
+
|
| 345 |
+
# Save config
|
| 346 |
+
self._save_config(candidates, tasks, run_dir)
|
| 347 |
+
|
| 348 |
+
# Final evaluation of all candidates across all tasks
|
| 349 |
task_results = await self._run_parallel(
|
| 350 |
candidates, tasks, run_dir, evaluator, progress_callback
|
| 351 |
)
|
|
|
|
| 410 |
|
| 411 |
async with lock:
|
| 412 |
completed += 1
|
| 413 |
+
status = "PASS" if result.eval_passed else "FAIL"
|
| 414 |
print(
|
| 415 |
f" [{completed}/{total}] {candidate.agent.name}/{task.name}: "
|
| 416 |
f"{status} score={result.eval_score:.2f} "
|
| 417 |
+
f"reasoning={result.eval_reasoning_score:.2f} "
|
| 418 |
f"tokens={result.metrics.total_tokens:,}"
|
| 419 |
)
|
| 420 |
if progress_callback:
|
|
|
|
| 434 |
|
| 435 |
return valid_results
|
| 436 |
|
| 437 |
+
async def evaluate(
|
| 438 |
+
self,
|
| 439 |
+
candidate: Candidate,
|
| 440 |
+
tasks: list[Task],
|
| 441 |
+
) -> CandidateSummary:
|
| 442 |
+
"""Evaluate a candidate on a set of tasks.
|
| 443 |
+
|
| 444 |
+
Implements the ExperimentRunner protocol. Active strategies call this
|
| 445 |
+
to test candidates during their optimization loop, reusing the full
|
| 446 |
+
execution pipeline (harness, tracing, LLM evaluation, metrics).
|
| 447 |
+
|
| 448 |
+
This method requires that optimize() has been called first (or that
|
| 449 |
+
_evaluator and _run_dir have been set up), since it reuses the
|
| 450 |
+
optimizer's evaluator and output directory.
|
| 451 |
+
|
| 452 |
+
Args:
|
| 453 |
+
candidate: The candidate to evaluate
|
| 454 |
+
tasks: Tasks to run the candidate on
|
| 455 |
+
|
| 456 |
+
Returns:
|
| 457 |
+
CandidateSummary with aggregated scores and per-task details
|
| 458 |
+
"""
|
| 459 |
+
if self._run_dir is None:
|
| 460 |
+
raise RuntimeError(
|
| 461 |
+
"evaluate() requires the optimizer to be initialized. "
|
| 462 |
+
"Call optimize() first, or use optimize_with_strategy() which handles setup."
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
task_results = await self._run_parallel(
|
| 466 |
+
[candidate], tasks, self._run_dir, self._evaluator, None
|
| 467 |
+
)
|
| 468 |
+
summaries = self._aggregate_results(task_results, [candidate])
|
| 469 |
+
if not summaries:
|
| 470 |
+
# Return empty summary if all experiments failed
|
| 471 |
+
return CandidateSummary(name=candidate.agent.name, candidate=candidate)
|
| 472 |
+
return summaries[0]
|
| 473 |
+
|
| 474 |
async def _run_single(
|
| 475 |
self,
|
| 476 |
candidate: Candidate,
|
|
|
|
| 480 |
) -> TaskResult:
|
| 481 |
"""Run a single candidate-task experiment."""
|
| 482 |
# Import harness modules to register them, then use registry
|
| 483 |
+
import flow.harness.maf as _maf
|
| 484 |
+
|
| 485 |
+
_ = _maf
|
| 486 |
try:
|
| 487 |
+
import flow.harness.miniagent as _miniagent
|
| 488 |
+
|
| 489 |
+
_ = _miniagent
|
| 490 |
except ImportError:
|
| 491 |
pass # miniagent harness is optional
|
| 492 |
from flow.harness import create_harness
|
|
|
|
| 499 |
metrics = extract_metrics(run_result.trace)
|
| 500 |
|
| 501 |
criteria_results: list[dict[str, Any]] = []
|
| 502 |
+
eval_reasoning_score = 0.0
|
| 503 |
if evaluator:
|
| 504 |
+
if isinstance(evaluator, LLMEvaluator):
|
| 505 |
+
eval_result = await evaluator.evaluate(
|
| 506 |
+
run_result, instructions=candidate.agent.instructions
|
| 507 |
+
)
|
| 508 |
+
else:
|
| 509 |
+
eval_result = await evaluator.evaluate(run_result)
|
| 510 |
eval_score = eval_result.score
|
| 511 |
eval_passed = eval_result.passed
|
| 512 |
eval_reasoning = eval_result.reasoning
|
| 513 |
+
eval_reasoning_score = eval_result.reasoning_score
|
| 514 |
# Convert criteria results to dicts for serialization
|
| 515 |
criteria_results = [
|
| 516 |
{
|
| 517 |
"name": cr.name,
|
| 518 |
"score": cr.score,
|
| 519 |
+
"reasoning_score": cr.reasoning_score,
|
| 520 |
"passed": cr.passed,
|
| 521 |
"reasoning": cr.reasoning,
|
| 522 |
}
|
|
|
|
| 536 |
eval_passed=eval_passed,
|
| 537 |
eval_reasoning=eval_reasoning,
|
| 538 |
criteria_results=criteria_results,
|
| 539 |
+
eval_reasoning_score=eval_reasoning_score,
|
| 540 |
)
|
| 541 |
finally:
|
| 542 |
await harness.close()
|
|
|
|
| 565 |
candidate=candidate,
|
| 566 |
task_results=results,
|
| 567 |
avg_score=sum(r.eval_score for r in results) / len(results),
|
| 568 |
+
avg_reasoning_score=sum(r.eval_reasoning_score for r in results) / len(results),
|
| 569 |
avg_tokens=sum(r.metrics.total_tokens for r in results) / len(results),
|
| 570 |
avg_duration=sum(r.run_result.duration_seconds for r in results) / len(results),
|
| 571 |
pass_rate=sum(1 for r in results if r.eval_passed) / len(results),
|
|
|
|
| 621 |
logger.info("Creating AsyncAzureOpenAI client for evaluator")
|
| 622 |
client = AsyncAzureOpenAI(
|
| 623 |
api_key=api_key,
|
| 624 |
+
api_version="2024-08-01-preview", # Required for json_schema response_format
|
| 625 |
azure_endpoint=endpoint,
|
| 626 |
)
|
| 627 |
|
|
|
|
| 676 |
print(" OPTIMIZATION RESULTS")
|
| 677 |
print("=" * 70)
|
| 678 |
|
| 679 |
+
print(f"\n{'Candidate':<30} | {'Score':>8} | {'Reason':>8} | {'Tokens':>10} | {'Pareto':>8}")
|
| 680 |
+
print("-" * 75)
|
| 681 |
|
| 682 |
for summary in sorted(result.summaries, key=lambda s: s.avg_score, reverse=True):
|
| 683 |
+
pareto = "*" if summary.is_pareto_optimal else ""
|
| 684 |
print(
|
| 685 |
f"{summary.name:<30} | {summary.avg_score:>8.2f} | "
|
| 686 |
+
f"{summary.avg_reasoning_score:>8.2f} | "
|
| 687 |
f"{summary.avg_tokens:>10,.0f} | {pareto:>8}"
|
| 688 |
)
|
| 689 |
|
|
|
|
| 707 |
List of Task objects
|
| 708 |
"""
|
| 709 |
return _load_tasks_impl(path)
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
async def evaluate_agent(
|
| 713 |
+
agent: Agent,
|
| 714 |
+
tasks: list[Task],
|
| 715 |
+
*,
|
| 716 |
+
parallel: int = 4,
|
| 717 |
+
use_llm_evaluator: bool = True,
|
| 718 |
+
output_dir: Path | None = None,
|
| 719 |
+
) -> CandidateSummary:
|
| 720 |
+
"""Evaluate a single agent on a set of tasks.
|
| 721 |
+
|
| 722 |
+
This is useful for:
|
| 723 |
+
- Getting baseline performance before optimization
|
| 724 |
+
- Testing a specific agent configuration
|
| 725 |
+
- Validating an exported/promoted agent
|
| 726 |
+
|
| 727 |
+
Example:
|
| 728 |
+
from flow.experiments import Agent, evaluate_agent, get_task_suite
|
| 729 |
+
|
| 730 |
+
agent = Agent(name="my-agent", instructions="You are helpful.")
|
| 731 |
+
tasks = get_task_suite("coding")
|
| 732 |
+
|
| 733 |
+
result = await evaluate_agent(agent, tasks)
|
| 734 |
+
print(f"Score: {result.avg_score:.2f}")
|
| 735 |
+
print(f"Pass rate: {result.pass_rate:.0%}")
|
| 736 |
+
print(f"Avg tokens: {result.avg_tokens:,.0f}")
|
| 737 |
+
|
| 738 |
+
Args:
|
| 739 |
+
agent: The agent to evaluate
|
| 740 |
+
tasks: List of tasks to run the agent on
|
| 741 |
+
parallel: Number of concurrent task executions (default: 4)
|
| 742 |
+
use_llm_evaluator: Whether to use LLM-as-Judge for scoring (default: True)
|
| 743 |
+
output_dir: Optional directory for results (default: ~/.flow/evaluations)
|
| 744 |
+
|
| 745 |
+
Returns:
|
| 746 |
+
CandidateSummary with aggregated metrics:
|
| 747 |
+
- avg_score: Mean evaluation score across tasks
|
| 748 |
+
- pass_rate: Fraction of tasks that passed
|
| 749 |
+
- avg_tokens: Mean token usage per task
|
| 750 |
+
- avg_duration: Mean execution time per task
|
| 751 |
+
- task_results: Per-task breakdown with scores and reasoning
|
| 752 |
+
"""
|
| 753 |
+
# Wrap agent in a candidate for the optimizer
|
| 754 |
+
candidate = Candidate(agent=agent, mutations={}, rationale="baseline evaluation")
|
| 755 |
+
|
| 756 |
+
# Use a separate output directory for evaluations
|
| 757 |
+
eval_output_dir = output_dir or Path.home() / ".flow" / "evaluations"
|
| 758 |
+
|
| 759 |
+
optimizer = FlowOptimizer(
|
| 760 |
+
parallel=parallel,
|
| 761 |
+
use_llm_evaluator=use_llm_evaluator,
|
| 762 |
+
output_dir=eval_output_dir,
|
| 763 |
+
)
|
| 764 |
+
|
| 765 |
+
result = await optimizer.optimize([candidate], tasks)
|
| 766 |
+
|
| 767 |
+
if not result.summaries:
|
| 768 |
+
raise RuntimeError("Evaluation produced no results")
|
| 769 |
+
|
| 770 |
+
return result.summaries[0]
|
src/flow/experiments/presets.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft. All rights reserved.
|
| 2 |
+
"""Agent presets — pre-configured agent bundles for common use cases.
|
| 3 |
+
|
| 4 |
+
Presets are the single source of truth for agent templates. They are:
|
| 5 |
+
- Defined here in Python
|
| 6 |
+
- Served to the UI via the /api/schema/agent endpoint
|
| 7 |
+
- Used in code via Agent.from_preset("coding")
|
| 8 |
+
|
| 9 |
+
Each preset bundles a full Agent configuration with metadata
|
| 10 |
+
(label, description, suggested datasets, tags) so users can
|
| 11 |
+
get started quickly without configuring every field.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
from dataclasses import dataclass, field
|
| 17 |
+
from typing import TYPE_CHECKING
|
| 18 |
+
|
| 19 |
+
from .models import Agent, CompactionConfig
|
| 20 |
+
|
| 21 |
+
if TYPE_CHECKING:
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class AgentPreset:
|
| 27 |
+
"""A pre-configured agent bundle for a specific use case.
|
| 28 |
+
|
| 29 |
+
Attributes:
|
| 30 |
+
name: Machine identifier (e.g., "coding", "research")
|
| 31 |
+
label: Human-readable name (e.g., "Coding Agent")
|
| 32 |
+
description: What this preset is optimized for
|
| 33 |
+
agent: Fully configured Agent instance
|
| 34 |
+
suggested_datasets: Task suite names to evaluate this preset
|
| 35 |
+
tags: Categorization tags for UI display
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
name: str
|
| 39 |
+
label: str
|
| 40 |
+
description: str
|
| 41 |
+
agent: Agent
|
| 42 |
+
suggested_datasets: list[str] = field(default_factory=list)
|
| 43 |
+
tags: list[str] = field(default_factory=list)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# =============================================================================
|
| 47 |
+
# Preset Registry
|
| 48 |
+
# =============================================================================
|
| 49 |
+
|
| 50 |
+
AGENT_PRESETS: dict[str, AgentPreset] = {
|
| 51 |
+
"coding": AgentPreset(
|
| 52 |
+
name="coding",
|
| 53 |
+
label="Coding Agent",
|
| 54 |
+
description="Writes, debugs, and refactors code. Reads and edits files, "
|
| 55 |
+
"runs shell commands, and tracks progress with todos.",
|
| 56 |
+
agent=Agent(
|
| 57 |
+
name="coding-agent",
|
| 58 |
+
framework="miniagent",
|
| 59 |
+
instructions_preset="general",
|
| 60 |
+
compaction=CompactionConfig.none(),
|
| 61 |
+
tools="standard",
|
| 62 |
+
),
|
| 63 |
+
suggested_datasets=["quick", "coding"],
|
| 64 |
+
tags=["code", "files", "debugging"],
|
| 65 |
+
),
|
| 66 |
+
"research": AgentPreset(
|
| 67 |
+
name="research",
|
| 68 |
+
label="Research Agent",
|
| 69 |
+
description="Answers factual questions using web search, fetches and reads "
|
| 70 |
+
"web pages, and executes code for calculations. Verifies claims from sources.",
|
| 71 |
+
agent=Agent(
|
| 72 |
+
name="research-agent",
|
| 73 |
+
framework="miniagent",
|
| 74 |
+
instructions_preset="general",
|
| 75 |
+
compaction=CompactionConfig.none(),
|
| 76 |
+
tools="standard",
|
| 77 |
+
),
|
| 78 |
+
suggested_datasets=["quick"],
|
| 79 |
+
tags=["web", "search", "facts"],
|
| 80 |
+
),
|
| 81 |
+
"document-analysis": AgentPreset(
|
| 82 |
+
name="document-analysis",
|
| 83 |
+
label="Document Analysis Agent",
|
| 84 |
+
description="Processes and analyzes documents including PDFs, Word docs, "
|
| 85 |
+
"spreadsheets, and presentations. Uses specialized skills for document formats.",
|
| 86 |
+
agent=Agent(
|
| 87 |
+
name="document-analysis-agent",
|
| 88 |
+
framework="miniagent",
|
| 89 |
+
instructions_preset="general",
|
| 90 |
+
compaction=CompactionConfig.none(),
|
| 91 |
+
tools="full",
|
| 92 |
+
),
|
| 93 |
+
suggested_datasets=["quick"],
|
| 94 |
+
tags=["documents", "analysis", "skills"],
|
| 95 |
+
),
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def get_preset(name: str) -> AgentPreset:
|
| 100 |
+
"""Get an agent preset by name.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
name: Preset identifier (e.g., "coding", "research", "document-analysis")
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
The AgentPreset
|
| 107 |
+
|
| 108 |
+
Raises:
|
| 109 |
+
ValueError: If preset name is not found
|
| 110 |
+
"""
|
| 111 |
+
if name not in AGENT_PRESETS:
|
| 112 |
+
available = ", ".join(AGENT_PRESETS.keys())
|
| 113 |
+
raise ValueError(f"Unknown preset: {name!r}. Available: {available}")
|
| 114 |
+
return AGENT_PRESETS[name]
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def get_all_presets() -> dict[str, AgentPreset]:
|
| 118 |
+
"""Get all available agent presets.
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
Dict mapping preset names to AgentPreset instances
|
| 122 |
+
"""
|
| 123 |
+
return dict(AGENT_PRESETS)
|
src/flow/experiments/results.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft. All rights reserved.
|
| 2 |
+
|
| 3 |
+
"""Simple result types for the Agent API.
|
| 4 |
+
|
| 5 |
+
These types provide a clean, user-friendly interface for
|
| 6 |
+
evaluation and optimization results.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from dataclasses import dataclass, field
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import TYPE_CHECKING, Any
|
| 14 |
+
|
| 15 |
+
if TYPE_CHECKING:
|
| 16 |
+
from .optimizer import CandidateSummary
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class EvaluationResult:
|
| 21 |
+
"""Result from evaluating an agent on tasks.
|
| 22 |
+
|
| 23 |
+
Attributes:
|
| 24 |
+
score: Average evaluation score (0.0 to 1.0)
|
| 25 |
+
tokens: Total tokens used across all tasks
|
| 26 |
+
pass_rate: Fraction of tasks that passed (0.0 to 1.0)
|
| 27 |
+
duration: Total duration in seconds
|
| 28 |
+
task_count: Number of tasks evaluated
|
| 29 |
+
|
| 30 |
+
Example:
|
| 31 |
+
result = await agent.evaluate(tasks="quick")
|
| 32 |
+
print(f"Score: {result.score:.2f}")
|
| 33 |
+
print(f"Pass rate: {result.pass_rate:.0%}")
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
score: float
|
| 37 |
+
tokens: int
|
| 38 |
+
pass_rate: float
|
| 39 |
+
duration: float
|
| 40 |
+
task_count: int
|
| 41 |
+
|
| 42 |
+
# Set when agent was deployed — links to the DB job
|
| 43 |
+
job_id: str | None = field(default=None, repr=False)
|
| 44 |
+
|
| 45 |
+
# Internal reference to full details (for advanced users)
|
| 46 |
+
_details: CandidateSummary | None = field(default=None, repr=False)
|
| 47 |
+
|
| 48 |
+
def __str__(self) -> str:
|
| 49 |
+
return f"score={self.score:.2f}, tokens={self.tokens:,}, pass_rate={self.pass_rate:.0%}"
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@dataclass
|
| 53 |
+
class ImprovementMetrics:
|
| 54 |
+
"""Metrics showing improvement from optimization.
|
| 55 |
+
|
| 56 |
+
Attributes:
|
| 57 |
+
score_delta: Improvement in score (best - baseline)
|
| 58 |
+
token_reduction_pct: Token reduction as percentage (positive = fewer tokens)
|
| 59 |
+
|
| 60 |
+
Example:
|
| 61 |
+
if result.improvement.token_reduction_pct > 20:
|
| 62 |
+
print("Significant token savings!")
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
score_delta: float
|
| 66 |
+
token_reduction_pct: float
|
| 67 |
+
|
| 68 |
+
def __str__(self) -> str:
|
| 69 |
+
score_str = f"{self.score_delta:+.2f}" if self.score_delta != 0 else "0"
|
| 70 |
+
# Token reduction: positive = saved tokens, so show as negative change
|
| 71 |
+
if self.token_reduction_pct > 0:
|
| 72 |
+
token_str = f"-{self.token_reduction_pct:.0f}%"
|
| 73 |
+
elif self.token_reduction_pct < 0:
|
| 74 |
+
token_str = f"+{-self.token_reduction_pct:.0f}%"
|
| 75 |
+
else:
|
| 76 |
+
token_str = "0%"
|
| 77 |
+
return f"score: {score_str}, tokens: {token_str}"
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@dataclass
|
| 81 |
+
class AgentOptimizationResult:
|
| 82 |
+
"""Result from optimizing an agent.
|
| 83 |
+
|
| 84 |
+
Attributes:
|
| 85 |
+
baseline: Performance of the original agent
|
| 86 |
+
best: Performance of the best found configuration
|
| 87 |
+
improvement: Metrics showing improvement over baseline
|
| 88 |
+
best_agent: The optimized agent configuration
|
| 89 |
+
candidates_tested: Number of candidates evaluated
|
| 90 |
+
pareto_frontier: Names of Pareto-optimal candidates
|
| 91 |
+
output_dir: Directory where detailed results are saved
|
| 92 |
+
|
| 93 |
+
Example:
|
| 94 |
+
result = await agent.optimize(tasks="quick")
|
| 95 |
+
print(f"Best score: {result.best.score:.2f}")
|
| 96 |
+
print(f"Token reduction: {result.improvement.token_reduction_pct:.0f}%")
|
| 97 |
+
|
| 98 |
+
# Use the optimized agent
|
| 99 |
+
optimized_agent = result.best_agent
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
baseline: EvaluationResult
|
| 103 |
+
best: EvaluationResult
|
| 104 |
+
improvement: ImprovementMetrics
|
| 105 |
+
best_agent: Any # Agent type (Any to avoid circular import)
|
| 106 |
+
candidates_tested: int
|
| 107 |
+
pareto_frontier: list[str]
|
| 108 |
+
output_dir: Path
|
| 109 |
+
|
| 110 |
+
# Set when agent was deployed — links to the DB job
|
| 111 |
+
job_id: str | None = field(default=None, repr=False)
|
| 112 |
+
|
| 113 |
+
def __str__(self) -> str:
|
| 114 |
+
return (
|
| 115 |
+
f"Optimization: {self.baseline} → {self.best}\n"
|
| 116 |
+
f"Improvement: {self.improvement}\n"
|
| 117 |
+
f"Candidates tested: {self.candidates_tested}"
|
| 118 |
+
)
|
src/flow/experiments/runner.py
CHANGED
|
@@ -17,6 +17,8 @@ from opentelemetry.sdk.trace import TracerProvider
|
|
| 17 |
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
| 18 |
from opentelemetry.semconv._incubating.attributes.service_attributes import SERVICE_NAME
|
| 19 |
|
|
|
|
|
|
|
| 20 |
from .trace_collector import FlowTraceCollector
|
| 21 |
from .types import RunResult, Task
|
| 22 |
|
|
@@ -25,23 +27,40 @@ if TYPE_CHECKING:
|
|
| 25 |
|
| 26 |
logger = logging.getLogger(__name__)
|
| 27 |
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
-
def setup_tracing(service_name: str = "flow-experiments") ->
|
| 30 |
-
"""Setup OpenTelemetry tracing with
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
Args:
|
| 36 |
service_name: Name for the tracing service
|
| 37 |
|
| 38 |
Returns:
|
| 39 |
-
The
|
| 40 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
resource = Resource.create({SERVICE_NAME: service_name})
|
| 42 |
provider = TracerProvider(resource=resource)
|
| 43 |
trace.set_tracer_provider(provider)
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
# Enable agent framework instrumentation if available
|
| 46 |
try:
|
| 47 |
from agent_framework.observability import enable_instrumentation
|
|
@@ -52,7 +71,16 @@ def setup_tracing(service_name: str = "flow-experiments") -> TracerProvider:
|
|
| 52 |
except Exception as e:
|
| 53 |
logger.debug(f"Could not enable Agent Framework instrumentation: {e}")
|
| 54 |
|
| 55 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
|
| 58 |
class FlowExperimentRunner:
|
|
@@ -60,7 +88,7 @@ class FlowExperimentRunner:
|
|
| 60 |
|
| 61 |
The runner handles:
|
| 62 |
- Setting up temporary workspaces
|
| 63 |
-
- Collecting execution traces via OpenTelemetry
|
| 64 |
- Measuring execution time
|
| 65 |
- Capturing files created
|
| 66 |
- Supporting streaming execution
|
|
@@ -97,18 +125,14 @@ class FlowExperimentRunner:
|
|
| 97 |
|
| 98 |
async def run(
|
| 99 |
self,
|
| 100 |
-
harness:
|
| 101 |
task: Task,
|
| 102 |
workspace: Path | None = None,
|
| 103 |
) -> RunResult:
|
| 104 |
"""Run a harness on a task and collect results.
|
| 105 |
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
2. Sets up trace collection
|
| 109 |
-
3. Executes the harness with streaming
|
| 110 |
-
4. Collects output and files created
|
| 111 |
-
5. Returns a RunResult with all data
|
| 112 |
|
| 113 |
Args:
|
| 114 |
harness: The harness to run (any BaseHarness implementation)
|
|
@@ -134,30 +158,30 @@ class FlowExperimentRunner:
|
|
| 134 |
# Track files before execution
|
| 135 |
files_before = set(self._list_files(workspace))
|
| 136 |
|
| 137 |
-
#
|
| 138 |
-
collector =
|
| 139 |
-
processor: SimpleSpanProcessor | None = None
|
| 140 |
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
processor = SimpleSpanProcessor(collector)
|
| 145 |
-
provider.add_span_processor(processor)
|
| 146 |
-
logger.debug("Trace collection enabled")
|
| 147 |
-
except Exception as e:
|
| 148 |
-
logger.debug(f"Could not set up trace collection: {e}")
|
| 149 |
|
| 150 |
# Execute the harness
|
| 151 |
start_time = time.time()
|
| 152 |
output_chunks: list[str] = []
|
|
|
|
| 153 |
error: str | None = None
|
| 154 |
|
| 155 |
try:
|
| 156 |
-
#
|
| 157 |
-
|
| 158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
-
try:
|
| 161 |
# Use streaming execution to capture all output
|
| 162 |
async for event in harness.run_stream(task.prompt):
|
| 163 |
# Collect text output
|
|
@@ -167,14 +191,13 @@ class FlowExperimentRunner:
|
|
| 167 |
if event.type in (EventType.TEXT_DELTA, EventType.TEXT_DONE):
|
| 168 |
output_chunks.append(event.content)
|
| 169 |
elif event.type == EventType.TOOL_RESULT:
|
| 170 |
-
|
| 171 |
-
|
|
|
|
|
|
|
| 172 |
elif event.type == EventType.ERROR:
|
| 173 |
-
# Capture error from harness
|
| 174 |
error = event.content
|
| 175 |
logger.error(f"Harness error: {error}")
|
| 176 |
-
finally:
|
| 177 |
-
os.chdir(original_cwd)
|
| 178 |
|
| 179 |
except Exception as e:
|
| 180 |
error = str(e)
|
|
@@ -183,22 +206,11 @@ class FlowExperimentRunner:
|
|
| 183 |
end_time = time.time()
|
| 184 |
duration_seconds = end_time - start_time
|
| 185 |
|
| 186 |
-
#
|
| 187 |
-
if
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
logger.debug(f"Error flushing processor: {e}")
|
| 192 |
-
|
| 193 |
-
# Get collected traces
|
| 194 |
-
trace_data = collector.get_traces()
|
| 195 |
-
|
| 196 |
-
# Clean up trace processor
|
| 197 |
-
if processor:
|
| 198 |
-
try:
|
| 199 |
-
processor.shutdown()
|
| 200 |
-
except Exception as e:
|
| 201 |
-
logger.debug(f"Error shutting down processor: {e}")
|
| 202 |
|
| 203 |
# Find files created
|
| 204 |
files_after = set(self._list_files(workspace))
|
|
@@ -223,6 +235,7 @@ class FlowExperimentRunner:
|
|
| 223 |
duration_seconds=duration_seconds,
|
| 224 |
workspace=workspace,
|
| 225 |
error=error,
|
|
|
|
| 226 |
)
|
| 227 |
|
| 228 |
def _list_files(self, directory: Path) -> list[str]:
|
|
|
|
| 17 |
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
| 18 |
from opentelemetry.semconv._incubating.attributes.service_attributes import SERVICE_NAME
|
| 19 |
|
| 20 |
+
from flow.tools.workspace import set_workspace
|
| 21 |
+
|
| 22 |
from .trace_collector import FlowTraceCollector
|
| 23 |
from .types import RunResult, Task
|
| 24 |
|
|
|
|
| 27 |
|
| 28 |
logger = logging.getLogger(__name__)
|
| 29 |
|
| 30 |
+
# Module-level shared collector — set up once via setup_tracing()
|
| 31 |
+
_shared_collector: FlowTraceCollector | None = None
|
| 32 |
+
|
| 33 |
|
| 34 |
+
def setup_tracing(service_name: str = "flow-experiments") -> FlowTraceCollector:
|
| 35 |
+
"""Setup OpenTelemetry tracing with a single shared collector.
|
| 36 |
|
| 37 |
+
Creates one TracerProvider + one SimpleSpanProcessor + one FlowTraceCollector.
|
| 38 |
+
Idempotent: if already set up, returns the existing collector. This avoids
|
| 39 |
+
the issue where ``trace.set_tracer_provider()`` silently ignores subsequent
|
| 40 |
+
calls (OTEL SDK only allows setting the provider once), which would cause
|
| 41 |
+
a new collector to be created but never receive any spans.
|
| 42 |
|
| 43 |
Args:
|
| 44 |
service_name: Name for the tracing service
|
| 45 |
|
| 46 |
Returns:
|
| 47 |
+
The shared FlowTraceCollector (also stored module-level)
|
| 48 |
"""
|
| 49 |
+
global _shared_collector
|
| 50 |
+
|
| 51 |
+
# Already set up — return existing collector
|
| 52 |
+
if _shared_collector is not None:
|
| 53 |
+
return _shared_collector
|
| 54 |
+
|
| 55 |
resource = Resource.create({SERVICE_NAME: service_name})
|
| 56 |
provider = TracerProvider(resource=resource)
|
| 57 |
trace.set_tracer_provider(provider)
|
| 58 |
|
| 59 |
+
# Create ONE shared collector and ONE processor
|
| 60 |
+
_shared_collector = FlowTraceCollector()
|
| 61 |
+
processor = SimpleSpanProcessor(_shared_collector)
|
| 62 |
+
provider.add_span_processor(processor)
|
| 63 |
+
|
| 64 |
# Enable agent framework instrumentation if available
|
| 65 |
try:
|
| 66 |
from agent_framework.observability import enable_instrumentation
|
|
|
|
| 71 |
except Exception as e:
|
| 72 |
logger.debug(f"Could not enable Agent Framework instrumentation: {e}")
|
| 73 |
|
| 74 |
+
return _shared_collector
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def get_shared_collector() -> FlowTraceCollector | None:
|
| 78 |
+
"""Get the shared trace collector (if setup_tracing was called).
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
The shared FlowTraceCollector or None
|
| 82 |
+
"""
|
| 83 |
+
return _shared_collector
|
| 84 |
|
| 85 |
|
| 86 |
class FlowExperimentRunner:
|
|
|
|
| 88 |
|
| 89 |
The runner handles:
|
| 90 |
- Setting up temporary workspaces
|
| 91 |
+
- Collecting execution traces via OpenTelemetry (isolated per task)
|
| 92 |
- Measuring execution time
|
| 93 |
- Capturing files created
|
| 94 |
- Supporting streaming execution
|
|
|
|
| 125 |
|
| 126 |
async def run(
|
| 127 |
self,
|
| 128 |
+
harness: BaseHarness,
|
| 129 |
task: Task,
|
| 130 |
workspace: Path | None = None,
|
| 131 |
) -> RunResult:
|
| 132 |
"""Run a harness on a task and collect results.
|
| 133 |
|
| 134 |
+
Uses a root span to obtain a trace_id, then retrieves only this
|
| 135 |
+
task's spans from the shared collector after execution.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
Args:
|
| 138 |
harness: The harness to run (any BaseHarness implementation)
|
|
|
|
| 158 |
# Track files before execution
|
| 159 |
files_before = set(self._list_files(workspace))
|
| 160 |
|
| 161 |
+
# Get the shared collector (set up by setup_tracing)
|
| 162 |
+
collector = _shared_collector
|
|
|
|
| 163 |
|
| 164 |
+
# Create a root span to get a unique trace_id for this task
|
| 165 |
+
tracer = trace.get_tracer("flow.experiments", "0.1.0")
|
| 166 |
+
task_trace_ids: set[str] = set()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
# Execute the harness
|
| 169 |
start_time = time.time()
|
| 170 |
output_chunks: list[str] = []
|
| 171 |
+
tool_results: list[dict[str, str]] = []
|
| 172 |
error: str | None = None
|
| 173 |
|
| 174 |
try:
|
| 175 |
+
# Set workspace via contextvar (safe for concurrent async tasks —
|
| 176 |
+
# each task gets its own contextvar copy, no process-global cwd mutation)
|
| 177 |
+
set_workspace(workspace)
|
| 178 |
+
|
| 179 |
+
# Create root span — all child spans inherit its trace_id
|
| 180 |
+
with tracer.start_as_current_span(f"task_{task.name}") as root_span:
|
| 181 |
+
trace_id = format(root_span.get_span_context().trace_id, "032x")
|
| 182 |
+
task_trace_ids.add(trace_id)
|
| 183 |
+
logger.debug(f"Task '{task.name}' trace_id: {trace_id}")
|
| 184 |
|
|
|
|
| 185 |
# Use streaming execution to capture all output
|
| 186 |
async for event in harness.run_stream(task.prompt):
|
| 187 |
# Collect text output
|
|
|
|
| 191 |
if event.type in (EventType.TEXT_DELTA, EventType.TEXT_DONE):
|
| 192 |
output_chunks.append(event.content)
|
| 193 |
elif event.type == EventType.TOOL_RESULT:
|
| 194 |
+
tool_results.append({
|
| 195 |
+
"tool": event.tool_name or "unknown",
|
| 196 |
+
"output": event.content,
|
| 197 |
+
})
|
| 198 |
elif event.type == EventType.ERROR:
|
|
|
|
| 199 |
error = event.content
|
| 200 |
logger.error(f"Harness error: {error}")
|
|
|
|
|
|
|
| 201 |
|
| 202 |
except Exception as e:
|
| 203 |
error = str(e)
|
|
|
|
| 206 |
end_time = time.time()
|
| 207 |
duration_seconds = end_time - start_time
|
| 208 |
|
| 209 |
+
# Retrieve only this task's traces from the shared collector
|
| 210 |
+
if collector is not None:
|
| 211 |
+
trace_data = collector.get_traces_for_task(task_trace_ids)
|
| 212 |
+
else:
|
| 213 |
+
trace_data = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
# Find files created
|
| 216 |
files_after = set(self._list_files(workspace))
|
|
|
|
| 235 |
duration_seconds=duration_seconds,
|
| 236 |
workspace=workspace,
|
| 237 |
error=error,
|
| 238 |
+
tool_results=tool_results,
|
| 239 |
)
|
| 240 |
|
| 241 |
def _list_files(self, directory: Path) -> list[str]:
|
src/flow/experiments/strategies/__init__.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft. All rights reserved.
|
| 2 |
+
|
| 3 |
+
"""Strategy registry for optimization.
|
| 4 |
+
|
| 5 |
+
Provides a registry of available strategies that can be used in experiment YAML
|
| 6 |
+
via the `strategy:` key in variations.
|
| 7 |
+
|
| 8 |
+
Example YAML:
|
| 9 |
+
variations:
|
| 10 |
+
instructions:
|
| 11 |
+
- "You are helpful" # Literal
|
| 12 |
+
- strategy: gepa # Strategy
|
| 13 |
+
max_candidates: 3
|
| 14 |
+
config:
|
| 15 |
+
reflection_lm: gpt-4o
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import logging
|
| 21 |
+
from typing import TYPE_CHECKING, Any
|
| 22 |
+
|
| 23 |
+
if TYPE_CHECKING:
|
| 24 |
+
from ..models import CandidateStrategy
|
| 25 |
+
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
# Strategy registry maps strategy names to factory functions
|
| 29 |
+
# Factory functions take config dict and return strategy instances
|
| 30 |
+
_STRATEGY_REGISTRY: dict[str, type] = {}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def register_strategy(name: str, strategy_class: type) -> None:
|
| 34 |
+
"""Register a strategy class.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
name: Strategy name used in YAML
|
| 38 |
+
strategy_class: Strategy class to instantiate
|
| 39 |
+
"""
|
| 40 |
+
_STRATEGY_REGISTRY[name] = strategy_class
|
| 41 |
+
logger.debug(f"Registered strategy: {name}")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def get_strategy(name: str, config: dict[str, Any]) -> CandidateStrategy:
|
| 45 |
+
"""Get a strategy instance by name.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
name: Strategy name from YAML
|
| 49 |
+
config: Strategy configuration dict
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
Instantiated strategy
|
| 53 |
+
|
| 54 |
+
Raises:
|
| 55 |
+
ValueError: If strategy name is unknown
|
| 56 |
+
"""
|
| 57 |
+
if name not in _STRATEGY_REGISTRY:
|
| 58 |
+
available = list(_STRATEGY_REGISTRY.keys())
|
| 59 |
+
raise ValueError(f"Unknown strategy: {name}. Available: {available}")
|
| 60 |
+
|
| 61 |
+
strategy_class = _STRATEGY_REGISTRY[name]
|
| 62 |
+
return strategy_class(config=config)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_registered_strategies() -> dict[str, type]:
|
| 66 |
+
"""Get all registered strategies.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
Dict mapping strategy names to their classes
|
| 70 |
+
"""
|
| 71 |
+
return dict(_STRATEGY_REGISTRY)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# =============================================================================
|
| 75 |
+
# Register built-in strategies
|
| 76 |
+
# =============================================================================
|
| 77 |
+
|
| 78 |
+
def _register_builtin_strategies() -> None:
|
| 79 |
+
"""Register built-in strategies."""
|
| 80 |
+
# GEPA strategy (optional - requires gepa package)
|
| 81 |
+
try:
|
| 82 |
+
from flow.optimizers.gepa_adapter import GepaStrategy
|
| 83 |
+
register_strategy("gepa", GepaStrategy)
|
| 84 |
+
except ImportError:
|
| 85 |
+
logger.debug("GEPA strategy not available (gepa package not installed)")
|
| 86 |
+
|
| 87 |
+
# LLM rewriter strategy (simple instruction variations)
|
| 88 |
+
try:
|
| 89 |
+
from .llm_rewriter import LLMRewriterStrategy
|
| 90 |
+
register_strategy("llm_rewriter", LLMRewriterStrategy)
|
| 91 |
+
except ImportError:
|
| 92 |
+
logger.debug("LLM rewriter strategy not available")
|
| 93 |
+
|
| 94 |
+
# Tool selector strategy (generates tool configurations)
|
| 95 |
+
try:
|
| 96 |
+
from .tool_selector import ToolSelectorStrategy
|
| 97 |
+
register_strategy("tool_selector", ToolSelectorStrategy)
|
| 98 |
+
except ImportError:
|
| 99 |
+
logger.debug("Tool selector strategy not available")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# Register on module import
|
| 103 |
+
_register_builtin_strategies()
|
src/flow/experiments/strategies/llm_rewriter.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft. All rights reserved.
|
| 2 |
+
|
| 3 |
+
"""LLM-based instruction rewriter strategy.
|
| 4 |
+
|
| 5 |
+
This strategy always requires a runner and tasks. It:
|
| 6 |
+
1. Evaluates the current instructions on all tasks
|
| 7 |
+
2. Reflects on failures to understand what went wrong
|
| 8 |
+
3. Rewrites instructions to address failures
|
| 9 |
+
4. Re-evaluates and repeats until convergence or budget exhausted
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import logging
|
| 15 |
+
import os
|
| 16 |
+
from dataclasses import dataclass, field
|
| 17 |
+
from typing import Any
|
| 18 |
+
|
| 19 |
+
from ..models import Agent, Candidate, ExperimentRunner, StrategyIteration
|
| 20 |
+
from ..types import Task
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class LLMRewriterStrategy:
|
| 27 |
+
"""Strategy that uses an LLM to iteratively improve agent instructions.
|
| 28 |
+
|
| 29 |
+
Runs an evaluate-reflect-rewrite loop. Each iteration evaluates
|
| 30 |
+
the current instructions on tasks via the runner, analyzes failures,
|
| 31 |
+
and rewrites the instructions to address them. Stops when:
|
| 32 |
+
- All tasks pass
|
| 33 |
+
- Score improvement drops below min_improvement
|
| 34 |
+
- max_iterations reached
|
| 35 |
+
|
| 36 |
+
Requires both a runner (to evaluate candidates) and tasks (to test on).
|
| 37 |
+
|
| 38 |
+
Config options:
|
| 39 |
+
model: LLM for rewriting (default: gpt-4o-mini)
|
| 40 |
+
max_iterations: Max optimization iterations (default: 5)
|
| 41 |
+
min_improvement: Min score gain to continue (default: 0.05)
|
| 42 |
+
|
| 43 |
+
Example YAML:
|
| 44 |
+
strategy:
|
| 45 |
+
type: llm_rewriter
|
| 46 |
+
config:
|
| 47 |
+
model: gpt-4o-mini
|
| 48 |
+
max_iterations: 5
|
| 49 |
+
min_improvement: 0.05
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
config: dict[str, Any] = field(default_factory=dict)
|
| 53 |
+
|
| 54 |
+
async def generate(
|
| 55 |
+
self,
|
| 56 |
+
base: Agent,
|
| 57 |
+
budget: int,
|
| 58 |
+
*,
|
| 59 |
+
tasks: list[Task] | None = None,
|
| 60 |
+
runner: ExperimentRunner | None = None,
|
| 61 |
+
) -> list[Candidate]:
|
| 62 |
+
"""Generate optimized instruction variants via evaluate-reflect-rewrite loop.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
base: Base agent with instructions to rewrite
|
| 66 |
+
budget: Max candidates to generate
|
| 67 |
+
tasks: Tasks to evaluate on (required)
|
| 68 |
+
runner: ExperimentRunner for evaluation (required)
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
List of candidates with optimized instructions
|
| 72 |
+
|
| 73 |
+
Raises:
|
| 74 |
+
ValueError: If tasks or runner not provided
|
| 75 |
+
"""
|
| 76 |
+
if runner is None:
|
| 77 |
+
raise ValueError(
|
| 78 |
+
"LLMRewriterStrategy requires a runner. "
|
| 79 |
+
"Use FlowOptimizer.optimize_with_strategy() to provide one."
|
| 80 |
+
)
|
| 81 |
+
if not tasks:
|
| 82 |
+
raise ValueError(
|
| 83 |
+
"LLMRewriterStrategy requires tasks to evaluate against."
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
base_instructions = base.instructions or "You are a helpful assistant."
|
| 87 |
+
return await self._generate_active(base, base_instructions, budget, tasks, runner)
|
| 88 |
+
|
| 89 |
+
async def _generate_active(
|
| 90 |
+
self,
|
| 91 |
+
base: Agent,
|
| 92 |
+
instructions: str,
|
| 93 |
+
budget: int,
|
| 94 |
+
tasks: list[Task],
|
| 95 |
+
runner: ExperimentRunner,
|
| 96 |
+
) -> list[Candidate]:
|
| 97 |
+
"""Run active optimization loop with real evaluation feedback."""
|
| 98 |
+
model = self.config.get("model", "gpt-4o-mini")
|
| 99 |
+
max_iterations = self.config.get("max_iterations", 5)
|
| 100 |
+
min_improvement = self.config.get("min_improvement", 0.05)
|
| 101 |
+
|
| 102 |
+
logger.info(
|
| 103 |
+
f"LLMRewriterStrategy: active mode (max_iterations={max_iterations}, "
|
| 104 |
+
f"min_improvement={min_improvement})"
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
current_instructions = instructions
|
| 108 |
+
best_instructions = instructions
|
| 109 |
+
best_score = 0.0
|
| 110 |
+
history: list[StrategyIteration] = []
|
| 111 |
+
|
| 112 |
+
for iteration in range(max_iterations):
|
| 113 |
+
# 1. Evaluate current instructions
|
| 114 |
+
agent = Agent(
|
| 115 |
+
name=f"{base.name}_rewrite_iter{iteration}",
|
| 116 |
+
framework=base.framework,
|
| 117 |
+
instructions=current_instructions,
|
| 118 |
+
llm_config=base.llm_config,
|
| 119 |
+
compaction=base.compaction,
|
| 120 |
+
tools=base.tools,
|
| 121 |
+
)
|
| 122 |
+
candidate = Candidate(
|
| 123 |
+
agent=agent,
|
| 124 |
+
mutations={"instructions": current_instructions},
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
summary = await runner.evaluate(candidate, tasks)
|
| 128 |
+
|
| 129 |
+
avg_score = getattr(summary, "avg_score", 0.0)
|
| 130 |
+
pass_rate = getattr(summary, "pass_rate", 0.0)
|
| 131 |
+
task_results = getattr(summary, "task_results", [])
|
| 132 |
+
failures = [tr for tr in task_results if not getattr(tr, "eval_passed", True)]
|
| 133 |
+
|
| 134 |
+
logger.info(
|
| 135 |
+
f" Iteration {iteration}: avg_score={avg_score:.3f}, "
|
| 136 |
+
f"pass_rate={pass_rate:.1%}, failures={len(failures)}"
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# Build per-task summary for rationale
|
| 140 |
+
task_lines: list[str] = []
|
| 141 |
+
for tr in task_results:
|
| 142 |
+
task_name = getattr(tr, "task_name", "unknown")
|
| 143 |
+
passed = getattr(tr, "eval_passed", True)
|
| 144 |
+
reasoning = getattr(tr, "eval_reasoning", "")
|
| 145 |
+
status = "PASS" if passed else "FAIL"
|
| 146 |
+
task_lines.append(f" [{status}] {task_name}: {reasoning[:150]}")
|
| 147 |
+
tasks_summary = "\n".join(task_lines)
|
| 148 |
+
|
| 149 |
+
# Record iteration
|
| 150 |
+
change_desc = "Baseline evaluation" if iteration == 0 else f"Rewrite iteration {iteration}"
|
| 151 |
+
change_rationale = f"Per-task results:\n{tasks_summary}"
|
| 152 |
+
if iteration > 0:
|
| 153 |
+
score_delta = avg_score - history[-1].avg_score
|
| 154 |
+
change_rationale = (
|
| 155 |
+
f"Score {'improved' if score_delta > 0 else 'declined'} by {abs(score_delta):.3f}. "
|
| 156 |
+
f"{len(failures)} failures remaining.\n{tasks_summary}"
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
history.append(
|
| 160 |
+
StrategyIteration(
|
| 161 |
+
iteration=iteration,
|
| 162 |
+
instructions_preview=current_instructions[:200],
|
| 163 |
+
full_instructions=current_instructions,
|
| 164 |
+
avg_score=avg_score,
|
| 165 |
+
pass_rate=pass_rate,
|
| 166 |
+
failures_count=len(failures),
|
| 167 |
+
change_description=change_desc,
|
| 168 |
+
change_rationale=change_rationale,
|
| 169 |
+
)
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# Track best
|
| 173 |
+
if avg_score > best_score:
|
| 174 |
+
best_score = avg_score
|
| 175 |
+
best_instructions = current_instructions
|
| 176 |
+
|
| 177 |
+
# 2. Check stopping conditions
|
| 178 |
+
if iteration > 0:
|
| 179 |
+
improvement = avg_score - history[-2].avg_score
|
| 180 |
+
if improvement < min_improvement and avg_score <= best_score:
|
| 181 |
+
logger.info(
|
| 182 |
+
f" Stopping: improvement ({improvement:.3f}) < "
|
| 183 |
+
f"min_improvement ({min_improvement})"
|
| 184 |
+
)
|
| 185 |
+
break
|
| 186 |
+
|
| 187 |
+
if not failures:
|
| 188 |
+
logger.info(" Stopping: all tasks passed")
|
| 189 |
+
break
|
| 190 |
+
|
| 191 |
+
if iteration == max_iterations - 1:
|
| 192 |
+
break # Don't rewrite on last iteration
|
| 193 |
+
|
| 194 |
+
# 3. Reflect on failures and rewrite
|
| 195 |
+
current_instructions = self._reflect_and_rewrite(
|
| 196 |
+
current_instructions, failures, avg_score, model
|
| 197 |
+
)
|
| 198 |
+
logger.info(f" Rewrote instructions ({len(current_instructions)} chars)")
|
| 199 |
+
|
| 200 |
+
# Build final candidate with optimization history
|
| 201 |
+
final_agent = Agent(
|
| 202 |
+
name=f"{base.name}_llm_rewriter_optimized",
|
| 203 |
+
framework=base.framework,
|
| 204 |
+
instructions=best_instructions,
|
| 205 |
+
llm_config=base.llm_config,
|
| 206 |
+
compaction=base.compaction,
|
| 207 |
+
tools=base.tools,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
score_progression = f"{history[0].avg_score:.2f} → {best_score:.2f}"
|
| 211 |
+
return [
|
| 212 |
+
Candidate(
|
| 213 |
+
agent=final_agent,
|
| 214 |
+
mutations={"instructions": best_instructions},
|
| 215 |
+
rationale=f"LLM rewriter active optimization: {len(history)} iterations, {score_progression}",
|
| 216 |
+
optimization_history=history,
|
| 217 |
+
)
|
| 218 |
+
]
|
| 219 |
+
|
| 220 |
+
def _reflect_and_rewrite(
|
| 221 |
+
self,
|
| 222 |
+
instructions: str,
|
| 223 |
+
failures: list[Any],
|
| 224 |
+
current_score: float,
|
| 225 |
+
model: str,
|
| 226 |
+
) -> str:
|
| 227 |
+
"""Analyze failures and rewrite instructions to address them."""
|
| 228 |
+
# Build failure analysis
|
| 229 |
+
failure_descriptions = []
|
| 230 |
+
for tr in failures[:5]: # Limit to 5 failures for context
|
| 231 |
+
task_name = getattr(tr, "task_name", "unknown")
|
| 232 |
+
reasoning = getattr(tr, "eval_reasoning", "No reasoning")
|
| 233 |
+
score = getattr(tr, "eval_score", 0.0)
|
| 234 |
+
failure_descriptions.append(
|
| 235 |
+
f"- Task '{task_name}' (score={score:.2f}): {reasoning[:200]}"
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
failures_text = "\n".join(failure_descriptions)
|
| 239 |
+
|
| 240 |
+
prompt = f"""You are a prompt engineer writing guidelines for a coding assistant.
|
| 241 |
+
|
| 242 |
+
The assistant's current guidelines scored {current_score:.2f} out of 1.0 on a benchmark.
|
| 243 |
+
|
| 244 |
+
Here are the tasks where performance was low:
|
| 245 |
+
{failures_text}
|
| 246 |
+
|
| 247 |
+
The current guidelines are:
|
| 248 |
+
---
|
| 249 |
+
{instructions}
|
| 250 |
+
---
|
| 251 |
+
|
| 252 |
+
Write a new, improved version of the guidelines. The new guidelines should:
|
| 253 |
+
1. Help the assistant succeed on a wide range of coding tasks — the failures
|
| 254 |
+
above are examples, but the guidelines must generalize beyond them
|
| 255 |
+
2. Include concrete strategies (e.g., always verify output, check edge cases,
|
| 256 |
+
create and run files when asked)
|
| 257 |
+
3. Be general-purpose: do NOT reference specific task names, specific answers,
|
| 258 |
+
or specific test cases from the failures above
|
| 259 |
+
4. Focus on transferable skills and habits (e.g., "verify output matches
|
| 260 |
+
requirements" not "check that fibonacci returns 55")
|
| 261 |
+
5. Be concise
|
| 262 |
+
|
| 263 |
+
Output ONLY the new guidelines text, nothing else."""
|
| 264 |
+
|
| 265 |
+
try:
|
| 266 |
+
return self._call_llm(prompt, model) or instructions
|
| 267 |
+
except Exception as e:
|
| 268 |
+
logger.warning(f"LLM rewrite failed: {e}")
|
| 269 |
+
# Primary prompt failed — the original instructions may have
|
| 270 |
+
# triggered a content filter (Azure, OpenAI, etc.) or caused
|
| 271 |
+
# another error. Try a fallback that omits them entirely.
|
| 272 |
+
logger.info("Retrying rewrite with fallback prompt (without original instructions)")
|
| 273 |
+
return self._fallback_rewrite(failures_text, current_score, model)
|
| 274 |
+
|
| 275 |
+
def _fallback_rewrite(
|
| 276 |
+
self,
|
| 277 |
+
failures_text: str,
|
| 278 |
+
current_score: float,
|
| 279 |
+
model: str,
|
| 280 |
+
) -> str:
|
| 281 |
+
"""Generate new instructions from scratch when the primary rewrite is blocked.
|
| 282 |
+
|
| 283 |
+
This avoids including the original instructions (which may trigger
|
| 284 |
+
content filters) and instead writes fresh guidelines based solely on
|
| 285 |
+
the task failure descriptions.
|
| 286 |
+
"""
|
| 287 |
+
prompt = f"""You are a prompt engineer. Write guidelines for a coding assistant.
|
| 288 |
+
|
| 289 |
+
The assistant scored {current_score:.2f} out of 1.0 on these tasks:
|
| 290 |
+
{failures_text}
|
| 291 |
+
|
| 292 |
+
Write concise guidelines that would help a coding assistant succeed on
|
| 293 |
+
a wide range of coding tasks. The failures above are examples — the
|
| 294 |
+
guidelines must generalize beyond them. The guidelines should:
|
| 295 |
+
1. Instruct the assistant to complete coding tasks by creating files and
|
| 296 |
+
running code
|
| 297 |
+
2. Include strategies for verifying output and handling edge cases
|
| 298 |
+
3. Be general-purpose: do NOT reference specific task names or answers
|
| 299 |
+
from the failures above
|
| 300 |
+
4. Focus on transferable habits and skills
|
| 301 |
+
|
| 302 |
+
Output ONLY the guidelines text, nothing else."""
|
| 303 |
+
|
| 304 |
+
try:
|
| 305 |
+
result = self._call_llm(prompt, model)
|
| 306 |
+
if result:
|
| 307 |
+
logger.info("Fallback rewrite succeeded")
|
| 308 |
+
return result
|
| 309 |
+
except Exception as e2:
|
| 310 |
+
logger.warning(f"Fallback rewrite also failed: {e2}")
|
| 311 |
+
|
| 312 |
+
# Last resort: return a sensible default
|
| 313 |
+
logger.info("Using default coding assistant guidelines")
|
| 314 |
+
return (
|
| 315 |
+
"You are a helpful coding assistant. When given a task:\n"
|
| 316 |
+
"1. Create the requested files with correct, working code\n"
|
| 317 |
+
"2. Run the code and verify the output is correct\n"
|
| 318 |
+
"3. Handle edge cases and validate results before finishing"
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def _get_client(self, model: str) -> tuple[Any, str]:
|
| 323 |
+
"""Get OpenAI client and model name."""
|
| 324 |
+
try:
|
| 325 |
+
from openai import AzureOpenAI, OpenAI
|
| 326 |
+
except ImportError as e:
|
| 327 |
+
raise ImportError("openai package required for LLMRewriterStrategy") from e
|
| 328 |
+
|
| 329 |
+
azure_key = os.environ.get("AZURE_OPENAI_API_KEY")
|
| 330 |
+
azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
|
| 331 |
+
|
| 332 |
+
if azure_key and azure_endpoint:
|
| 333 |
+
client = AzureOpenAI(
|
| 334 |
+
api_key=azure_key,
|
| 335 |
+
api_version="2024-08-01-preview",
|
| 336 |
+
azure_endpoint=azure_endpoint,
|
| 337 |
+
)
|
| 338 |
+
model_name = os.environ.get("AZURE_OPENAI_DEPLOYMENT", model)
|
| 339 |
+
else:
|
| 340 |
+
openai_key = os.environ.get("OPENAI_API_KEY")
|
| 341 |
+
if not openai_key:
|
| 342 |
+
raise ValueError("No OpenAI or Azure OpenAI credentials found")
|
| 343 |
+
client = OpenAI(api_key=openai_key)
|
| 344 |
+
model_name = model
|
| 345 |
+
|
| 346 |
+
return client, model_name
|
| 347 |
+
|
| 348 |
+
def _call_llm(self, prompt: str, model: str) -> str:
|
| 349 |
+
"""Call LLM with a prompt."""
|
| 350 |
+
client, model_name = self._get_client(model)
|
| 351 |
+
|
| 352 |
+
response = client.chat.completions.create(
|
| 353 |
+
model=model_name,
|
| 354 |
+
messages=[{"role": "user", "content": prompt}],
|
| 355 |
+
)
|
| 356 |
+
return response.choices[0].message.content or ""
|
| 357 |
+
|
src/flow/experiments/strategies/tool_selector.py
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft. All rights reserved.
|
| 2 |
+
|
| 3 |
+
"""Active tool selector strategy.
|
| 4 |
+
|
| 5 |
+
Uses the runner to evaluate tool configurations and iteratively adjust
|
| 6 |
+
the tool set based on actual execution failures. The strategy:
|
| 7 |
+
1. Evaluates the current tool set on all tasks
|
| 8 |
+
2. Analyzes failures and trace data to identify missing/unnecessary tools
|
| 9 |
+
3. Uses an LLM to recommend tool changes
|
| 10 |
+
4. Re-evaluates and repeats until convergence or budget exhausted
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
import os
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from typing import Any
|
| 19 |
+
|
| 20 |
+
from ..metrics import extract_metrics
|
| 21 |
+
from ..models import Agent, Candidate, ExperimentRunner, StrategyIteration, TOOL_PRESETS
|
| 22 |
+
from ..types import Task
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
# All tools the strategy can choose from
|
| 27 |
+
ALL_AVAILABLE_TOOLS: list[str] = sorted(
|
| 28 |
+
{tool for preset in TOOL_PRESETS.values() for tool in preset}
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class ToolSelectorStrategy:
|
| 34 |
+
"""Strategy that iteratively optimizes tool configurations via evaluation.
|
| 35 |
+
|
| 36 |
+
Runs an evaluate-analyze-adjust loop. Each iteration evaluates
|
| 37 |
+
the current tool set on tasks via the runner, analyzes which tools
|
| 38 |
+
were used/missing from traces, and uses an LLM to recommend changes.
|
| 39 |
+
|
| 40 |
+
Requires both a runner (to evaluate candidates) and tasks (to test on).
|
| 41 |
+
|
| 42 |
+
Config options:
|
| 43 |
+
model: LLM for tool recommendations (default: gpt-4o-mini)
|
| 44 |
+
max_iterations: Max optimization iterations (default: 3)
|
| 45 |
+
min_improvement: Min score gain to continue (default: 0.05)
|
| 46 |
+
available_tools: List of tool names to choose from (default: all known tools)
|
| 47 |
+
|
| 48 |
+
Example YAML:
|
| 49 |
+
strategy:
|
| 50 |
+
type: tool_selector
|
| 51 |
+
config:
|
| 52 |
+
model: gpt-4o-mini
|
| 53 |
+
max_iterations: 3
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
config: dict[str, Any] = field(default_factory=dict)
|
| 57 |
+
|
| 58 |
+
async def generate(
|
| 59 |
+
self,
|
| 60 |
+
base: Agent,
|
| 61 |
+
budget: int,
|
| 62 |
+
*,
|
| 63 |
+
tasks: list[Task] | None = None,
|
| 64 |
+
runner: ExperimentRunner | None = None,
|
| 65 |
+
) -> list[Candidate]:
|
| 66 |
+
"""Generate optimized tool configurations via evaluate-analyze-adjust loop.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
base: Base agent with initial tool configuration
|
| 70 |
+
budget: Max candidates to generate
|
| 71 |
+
tasks: Tasks to evaluate on (required)
|
| 72 |
+
runner: ExperimentRunner for evaluation (required)
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
List of candidates with optimized tool sets
|
| 76 |
+
|
| 77 |
+
Raises:
|
| 78 |
+
ValueError: If tasks or runner not provided
|
| 79 |
+
"""
|
| 80 |
+
if runner is None:
|
| 81 |
+
raise ValueError(
|
| 82 |
+
"ToolSelectorStrategy requires a runner. "
|
| 83 |
+
"Use FlowOptimizer.optimize_with_strategy() to provide one."
|
| 84 |
+
)
|
| 85 |
+
if not tasks:
|
| 86 |
+
raise ValueError(
|
| 87 |
+
"ToolSelectorStrategy requires tasks to evaluate against."
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# Resolve initial tools to a list
|
| 91 |
+
from ..models import resolve_tools
|
| 92 |
+
if base.tools is None or (isinstance(base.tools, list) and len(base.tools) == 0):
|
| 93 |
+
current_tools = []
|
| 94 |
+
else:
|
| 95 |
+
current_tools = sorted(resolve_tools(base.tools).keys())
|
| 96 |
+
|
| 97 |
+
return await self._generate_active(base, current_tools, budget, tasks, runner)
|
| 98 |
+
|
| 99 |
+
async def _generate_active(
|
| 100 |
+
self,
|
| 101 |
+
base: Agent,
|
| 102 |
+
tools: list[str],
|
| 103 |
+
budget: int,
|
| 104 |
+
tasks: list[Task],
|
| 105 |
+
runner: ExperimentRunner,
|
| 106 |
+
) -> list[Candidate]:
|
| 107 |
+
"""Run active optimization loop with real evaluation feedback."""
|
| 108 |
+
model = self.config.get("model", "gpt-4o-mini")
|
| 109 |
+
max_iterations = self.config.get("max_iterations", 3)
|
| 110 |
+
min_improvement = self.config.get("min_improvement", 0.05)
|
| 111 |
+
available_tools = self.config.get("available_tools", ALL_AVAILABLE_TOOLS)
|
| 112 |
+
|
| 113 |
+
logger.info(
|
| 114 |
+
f"ToolSelectorStrategy: active mode (max_iterations={max_iterations}, "
|
| 115 |
+
f"available_tools={len(available_tools)})"
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
current_tools = tools
|
| 119 |
+
best_tools = tools
|
| 120 |
+
best_score = 0.0
|
| 121 |
+
history: list[StrategyIteration] = []
|
| 122 |
+
# Track all unique tool configs tried, for returning as candidates
|
| 123 |
+
iteration_candidates: list[tuple[list[str], str]] = [] # (tools, name_suffix)
|
| 124 |
+
|
| 125 |
+
for iteration in range(max_iterations):
|
| 126 |
+
# 1. Evaluate current tool set
|
| 127 |
+
agent = Agent(
|
| 128 |
+
name=f"{base.name}_tools_iter{iteration}",
|
| 129 |
+
framework=base.framework,
|
| 130 |
+
instructions=base.instructions,
|
| 131 |
+
llm_config=base.llm_config,
|
| 132 |
+
compaction=base.compaction,
|
| 133 |
+
tools=current_tools,
|
| 134 |
+
)
|
| 135 |
+
candidate = Candidate(
|
| 136 |
+
agent=agent,
|
| 137 |
+
mutations={"tools": current_tools},
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
summary = await runner.evaluate(candidate, tasks)
|
| 141 |
+
|
| 142 |
+
avg_score = getattr(summary, "avg_score", 0.0)
|
| 143 |
+
pass_rate = getattr(summary, "pass_rate", 0.0)
|
| 144 |
+
task_results = getattr(summary, "task_results", [])
|
| 145 |
+
failures = [tr for tr in task_results if not getattr(tr, "eval_passed", True)]
|
| 146 |
+
|
| 147 |
+
# Collect tool usage from traces
|
| 148 |
+
tools_used: dict[str, int] = {}
|
| 149 |
+
for tr in task_results:
|
| 150 |
+
metrics = getattr(tr, "metrics", None)
|
| 151 |
+
if metrics and hasattr(metrics, "tool_calls_by_name"):
|
| 152 |
+
for name, count in metrics.tool_calls_by_name.items():
|
| 153 |
+
tools_used[name] = tools_used.get(name, 0) + count
|
| 154 |
+
|
| 155 |
+
logger.info(
|
| 156 |
+
f" Iteration {iteration}: avg_score={avg_score:.3f}, "
|
| 157 |
+
f"pass_rate={pass_rate:.1%}, failures={len(failures)}, "
|
| 158 |
+
f"tools={current_tools}, used={tools_used}"
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Build per-task summary for rationale
|
| 162 |
+
task_lines: list[str] = []
|
| 163 |
+
for tr in task_results:
|
| 164 |
+
task_name = getattr(tr, "task_name", "unknown")
|
| 165 |
+
passed = getattr(tr, "eval_passed", True)
|
| 166 |
+
reasoning = getattr(tr, "eval_reasoning", "")
|
| 167 |
+
task_metrics = getattr(tr, "metrics", None)
|
| 168 |
+
task_tools: dict[str, int] = {}
|
| 169 |
+
if task_metrics and hasattr(task_metrics, "tool_calls_by_name"):
|
| 170 |
+
task_tools = dict(task_metrics.tool_calls_by_name)
|
| 171 |
+
status = "PASS" if passed else "FAIL"
|
| 172 |
+
tools_info = f" (tools used: {task_tools})" if task_tools else ""
|
| 173 |
+
task_lines.append(f" [{status}] {task_name}{tools_info}: {reasoning[:150]}")
|
| 174 |
+
tasks_summary = "\n".join(task_lines)
|
| 175 |
+
|
| 176 |
+
# Record iteration
|
| 177 |
+
tools_desc = ", ".join(current_tools) or "(none)"
|
| 178 |
+
used_desc = ", ".join(f"{k}={v}" for k, v in sorted(tools_used.items())) or "(none)"
|
| 179 |
+
change_desc = "Baseline evaluation" if iteration == 0 else f"Tool adjustment iteration {iteration}"
|
| 180 |
+
change_rationale = f"Tools used: {used_desc}\n{tasks_summary}"
|
| 181 |
+
if iteration > 0:
|
| 182 |
+
score_delta = avg_score - history[-1].avg_score
|
| 183 |
+
added = set(current_tools) - set(best_tools if iteration == 1 else _prev_tools)
|
| 184 |
+
removed = set(_prev_tools) - set(current_tools) if iteration > 0 else set()
|
| 185 |
+
change_rationale = (
|
| 186 |
+
f"Score {'improved' if score_delta > 0 else 'declined'} by {abs(score_delta):.3f}. "
|
| 187 |
+
f"Added: {sorted(added) or 'none'}. Removed: {sorted(removed) or 'none'}. "
|
| 188 |
+
f"{len(failures)} failures remaining.\n"
|
| 189 |
+
f"Tools used: {used_desc}\n{tasks_summary}"
|
| 190 |
+
)
|
| 191 |
+
history.append(
|
| 192 |
+
StrategyIteration(
|
| 193 |
+
iteration=iteration,
|
| 194 |
+
instructions_preview=f"[{tools_desc}]"[:200],
|
| 195 |
+
full_instructions=f"[{tools_desc}]",
|
| 196 |
+
avg_score=avg_score,
|
| 197 |
+
pass_rate=pass_rate,
|
| 198 |
+
failures_count=len(failures),
|
| 199 |
+
change_description=change_desc,
|
| 200 |
+
change_rationale=change_rationale,
|
| 201 |
+
)
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# Track this iteration's config
|
| 205 |
+
label = "baseline" if iteration == 0 else f"iter{iteration}"
|
| 206 |
+
iteration_candidates.append((list(current_tools), label))
|
| 207 |
+
|
| 208 |
+
# Track best
|
| 209 |
+
if avg_score > best_score:
|
| 210 |
+
best_score = avg_score
|
| 211 |
+
best_tools = current_tools
|
| 212 |
+
|
| 213 |
+
# 2. Check stopping conditions
|
| 214 |
+
if iteration > 0:
|
| 215 |
+
improvement = avg_score - history[-2].avg_score
|
| 216 |
+
if improvement < min_improvement and avg_score <= best_score:
|
| 217 |
+
logger.info(
|
| 218 |
+
f" Stopping: improvement ({improvement:.3f}) < "
|
| 219 |
+
f"min_improvement ({min_improvement})"
|
| 220 |
+
)
|
| 221 |
+
break
|
| 222 |
+
|
| 223 |
+
if not failures:
|
| 224 |
+
logger.info(" Stopping: all tasks passed")
|
| 225 |
+
break
|
| 226 |
+
|
| 227 |
+
if iteration == max_iterations - 1:
|
| 228 |
+
break # Don't adjust on last iteration
|
| 229 |
+
|
| 230 |
+
# 3. Analyze failures and adjust tools
|
| 231 |
+
_prev_tools = current_tools
|
| 232 |
+
current_tools = self._analyze_and_adjust(
|
| 233 |
+
current_tools, task_results, tools_used, available_tools, model
|
| 234 |
+
)
|
| 235 |
+
logger.info(f" Adjusted tools: {current_tools}")
|
| 236 |
+
|
| 237 |
+
# Build candidates for all unique tool configs tried
|
| 238 |
+
# This gives the Pareto chart multiple data points to compare
|
| 239 |
+
candidates: list[Candidate] = []
|
| 240 |
+
seen_tool_sets: set[tuple[str, ...]] = set()
|
| 241 |
+
|
| 242 |
+
for iter_tools, label in iteration_candidates:
|
| 243 |
+
tool_key = tuple(sorted(iter_tools))
|
| 244 |
+
if tool_key in seen_tool_sets:
|
| 245 |
+
continue
|
| 246 |
+
seen_tool_sets.add(tool_key)
|
| 247 |
+
|
| 248 |
+
is_best = sorted(iter_tools) == sorted(best_tools)
|
| 249 |
+
suffix = "optimized" if is_best else label
|
| 250 |
+
agent = Agent(
|
| 251 |
+
name=f"{base.name}_tools_{suffix}",
|
| 252 |
+
framework=base.framework,
|
| 253 |
+
instructions=base.instructions,
|
| 254 |
+
llm_config=base.llm_config,
|
| 255 |
+
compaction=base.compaction,
|
| 256 |
+
tools=iter_tools,
|
| 257 |
+
)
|
| 258 |
+
tools_desc = ", ".join(iter_tools) or "(none)"
|
| 259 |
+
candidates.append(
|
| 260 |
+
Candidate(
|
| 261 |
+
agent=agent,
|
| 262 |
+
mutations={"tools": iter_tools},
|
| 263 |
+
rationale=f"Tools: [{tools_desc}]",
|
| 264 |
+
optimization_history=history if is_best else [],
|
| 265 |
+
)
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
# Ensure best is always included (may differ from any iteration if
|
| 269 |
+
# the best score was from an earlier iteration)
|
| 270 |
+
best_key = tuple(sorted(best_tools))
|
| 271 |
+
if best_key not in seen_tool_sets:
|
| 272 |
+
final_agent = Agent(
|
| 273 |
+
name=f"{base.name}_tools_optimized",
|
| 274 |
+
framework=base.framework,
|
| 275 |
+
instructions=base.instructions,
|
| 276 |
+
llm_config=base.llm_config,
|
| 277 |
+
compaction=base.compaction,
|
| 278 |
+
tools=best_tools,
|
| 279 |
+
)
|
| 280 |
+
tools_desc = ", ".join(best_tools)
|
| 281 |
+
candidates.append(
|
| 282 |
+
Candidate(
|
| 283 |
+
agent=final_agent,
|
| 284 |
+
mutations={"tools": best_tools},
|
| 285 |
+
rationale=f"Tools: [{tools_desc}]",
|
| 286 |
+
optimization_history=history,
|
| 287 |
+
)
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
return candidates
|
| 291 |
+
|
| 292 |
+
def _analyze_and_adjust(
|
| 293 |
+
self,
|
| 294 |
+
current_tools: list[str],
|
| 295 |
+
task_results: list[Any],
|
| 296 |
+
tools_used: dict[str, int],
|
| 297 |
+
available_tools: list[str],
|
| 298 |
+
model: str,
|
| 299 |
+
) -> list[str]:
|
| 300 |
+
"""Analyze failures and traces, then recommend tool changes."""
|
| 301 |
+
# Build analysis of what happened
|
| 302 |
+
failure_descriptions = []
|
| 303 |
+
for tr in task_results:
|
| 304 |
+
task_name = getattr(tr, "task_name", "unknown")
|
| 305 |
+
passed = getattr(tr, "eval_passed", True)
|
| 306 |
+
reasoning = getattr(tr, "eval_reasoning", "")
|
| 307 |
+
score = getattr(tr, "eval_score", 0.0)
|
| 308 |
+
|
| 309 |
+
# Get per-task tool usage
|
| 310 |
+
metrics = getattr(tr, "metrics", None)
|
| 311 |
+
task_tools = {}
|
| 312 |
+
if metrics and hasattr(metrics, "tool_calls_by_name"):
|
| 313 |
+
task_tools = dict(metrics.tool_calls_by_name)
|
| 314 |
+
|
| 315 |
+
status = "PASS" if passed else "FAIL"
|
| 316 |
+
failure_descriptions.append(
|
| 317 |
+
f"- [{status}] Task '{task_name}' (score={score:.2f}): "
|
| 318 |
+
f"tools_used={task_tools}. {reasoning[:200]}"
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
results_text = "\n".join(failure_descriptions)
|
| 322 |
+
not_in_current = sorted(set(available_tools) - set(current_tools))
|
| 323 |
+
|
| 324 |
+
prompt = f"""You are optimizing the tool configuration for a coding assistant.
|
| 325 |
+
|
| 326 |
+
Current tools: {current_tools}
|
| 327 |
+
Available tools NOT currently enabled: {not_in_current}
|
| 328 |
+
|
| 329 |
+
Task results with this tool set:
|
| 330 |
+
{results_text}
|
| 331 |
+
|
| 332 |
+
Tool usage across all tasks: {tools_used}
|
| 333 |
+
|
| 334 |
+
Based on the failures and tool usage patterns, recommend an updated tool list.
|
| 335 |
+
Consider:
|
| 336 |
+
- Tools that were needed but missing (e.g., agent tried to search but had no grep)
|
| 337 |
+
- Tools that were never used (candidates for removal to reduce complexity)
|
| 338 |
+
- Tools that could help with the failed tasks
|
| 339 |
+
|
| 340 |
+
Rules:
|
| 341 |
+
- Only select from the full available set: {available_tools}
|
| 342 |
+
- Always include at minimum: read_file, write_file, bash
|
| 343 |
+
- Do NOT add tools just because they exist — only add tools that would
|
| 344 |
+
address specific failure patterns seen above
|
| 345 |
+
|
| 346 |
+
Respond with ONLY a comma-separated list of tool names, nothing else.
|
| 347 |
+
Example: read_file, write_file, bash, grep, edit_file"""
|
| 348 |
+
|
| 349 |
+
try:
|
| 350 |
+
result = self._call_llm(prompt, model)
|
| 351 |
+
if result:
|
| 352 |
+
# Parse comma-separated tool names
|
| 353 |
+
parsed = [t.strip() for t in result.split(",") if t.strip()]
|
| 354 |
+
# Validate against available tools
|
| 355 |
+
valid = [t for t in parsed if t in available_tools]
|
| 356 |
+
if valid:
|
| 357 |
+
return sorted(valid)
|
| 358 |
+
logger.warning(f"No valid tools in LLM response: {parsed}")
|
| 359 |
+
except Exception as e:
|
| 360 |
+
logger.warning(f"LLM tool adjustment failed: {e}")
|
| 361 |
+
# Fallback: try adding commonly useful tools
|
| 362 |
+
return self._heuristic_adjust(current_tools, tools_used, available_tools)
|
| 363 |
+
|
| 364 |
+
return current_tools
|
| 365 |
+
|
| 366 |
+
def _heuristic_adjust(
|
| 367 |
+
self,
|
| 368 |
+
current_tools: list[str],
|
| 369 |
+
tools_used: dict[str, int],
|
| 370 |
+
available_tools: list[str],
|
| 371 |
+
) -> list[str]:
|
| 372 |
+
"""Fallback heuristic when LLM is unavailable."""
|
| 373 |
+
adjusted = set(current_tools)
|
| 374 |
+
|
| 375 |
+
# If bash was used heavily but grep/glob not available, add them
|
| 376 |
+
if "bash" in tools_used and tools_used["bash"] > 2:
|
| 377 |
+
for tool in ["grep", "glob_files", "ls"]:
|
| 378 |
+
if tool in available_tools:
|
| 379 |
+
adjusted.add(tool)
|
| 380 |
+
|
| 381 |
+
# If write_file was used but edit_file not available, add it
|
| 382 |
+
if "write_file" in tools_used and "edit_file" not in adjusted:
|
| 383 |
+
if "edit_file" in available_tools:
|
| 384 |
+
adjusted.add("edit_file")
|
| 385 |
+
|
| 386 |
+
# Add think if not present (helps with reasoning)
|
| 387 |
+
if "think" in available_tools:
|
| 388 |
+
adjusted.add("think")
|
| 389 |
+
|
| 390 |
+
return sorted(adjusted)
|
| 391 |
+
|
| 392 |
+
def _get_client(self, model: str) -> tuple[Any, str]:
|
| 393 |
+
"""Get OpenAI client and model name."""
|
| 394 |
+
try:
|
| 395 |
+
from openai import AzureOpenAI, OpenAI
|
| 396 |
+
except ImportError as e:
|
| 397 |
+
raise ImportError("openai package required for ToolSelectorStrategy") from e
|
| 398 |
+
|
| 399 |
+
azure_key = os.environ.get("AZURE_OPENAI_API_KEY")
|
| 400 |
+
azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
|
| 401 |
+
|
| 402 |
+
if azure_key and azure_endpoint:
|
| 403 |
+
client = AzureOpenAI(
|
| 404 |
+
api_key=azure_key,
|
| 405 |
+
api_version="2024-08-01-preview",
|
| 406 |
+
azure_endpoint=azure_endpoint,
|
| 407 |
+
)
|
| 408 |
+
model_name = os.environ.get("AZURE_OPENAI_DEPLOYMENT", model)
|
| 409 |
+
else:
|
| 410 |
+
openai_key = os.environ.get("OPENAI_API_KEY")
|
| 411 |
+
if not openai_key:
|
| 412 |
+
raise ValueError("No OpenAI or Azure OpenAI credentials found")
|
| 413 |
+
client = OpenAI(api_key=openai_key)
|
| 414 |
+
model_name = model
|
| 415 |
+
|
| 416 |
+
return client, model_name
|
| 417 |
+
|
| 418 |
+
def _call_llm(self, prompt: str, model: str) -> str:
|
| 419 |
+
"""Call LLM with a prompt."""
|
| 420 |
+
client, model_name = self._get_client(model)
|
| 421 |
+
|
| 422 |
+
response = client.chat.completions.create(
|
| 423 |
+
model=model_name,
|
| 424 |
+
messages=[{"role": "user", "content": prompt}],
|
| 425 |
+
)
|
| 426 |
+
return response.choices[0].message.content or ""
|
src/flow/experiments/trace_collector.py
CHANGED
|
@@ -1,8 +1,13 @@
|
|
| 1 |
# Copyright (c) Microsoft. All rights reserved.
|
| 2 |
|
| 3 |
-
"""OpenTelemetry trace collector for experiment analysis.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
import logging
|
|
|
|
| 6 |
from datetime import datetime
|
| 7 |
from typing import Any
|
| 8 |
|
|
@@ -12,24 +17,26 @@ logger = logging.getLogger(__name__)
|
|
| 12 |
|
| 13 |
|
| 14 |
class FlowTraceCollector(SpanExporter):
|
| 15 |
-
"""Collects OpenTelemetry spans
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
|
|
|
| 19 |
|
| 20 |
Example:
|
| 21 |
collector = FlowTraceCollector()
|
| 22 |
-
# Attach to TracerProvider via SimpleSpanProcessor
|
| 23 |
-
# Run
|
| 24 |
-
|
| 25 |
"""
|
| 26 |
|
| 27 |
def __init__(self) -> None:
|
| 28 |
"""Initialize the trace collector."""
|
| 29 |
-
self.
|
|
|
|
| 30 |
|
| 31 |
def export(self, spans: Any) -> SpanExportResult:
|
| 32 |
-
"""Collect spans
|
| 33 |
|
| 34 |
Args:
|
| 35 |
spans: Sequence of OpenTelemetry ReadableSpan objects
|
|
@@ -39,41 +46,46 @@ class FlowTraceCollector(SpanExporter):
|
|
| 39 |
"""
|
| 40 |
for span in spans:
|
| 41 |
try:
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
"timestamp": datetime.fromtimestamp(start_time).isoformat(),
|
| 50 |
-
"data": {
|
| 51 |
-
"operation_name": span.name,
|
| 52 |
-
"span_id": format(span.context.span_id, "016x"),
|
| 53 |
-
"trace_id": format(span.context.trace_id, "032x"),
|
| 54 |
-
"parent_span_id": (
|
| 55 |
-
format(span.parent.span_id, "016x") if span.parent else None
|
| 56 |
-
),
|
| 57 |
-
"duration_ms": duration_ms,
|
| 58 |
-
"attributes": dict(span.attributes) if span.attributes else {},
|
| 59 |
-
"status": str(span.status.status_code.name) if hasattr(span, "status") else "OK",
|
| 60 |
-
"events": [
|
| 61 |
-
{
|
| 62 |
-
"name": event.name,
|
| 63 |
-
"timestamp": datetime.fromtimestamp(
|
| 64 |
-
event.timestamp / 1_000_000_000
|
| 65 |
-
).isoformat(),
|
| 66 |
-
"attributes": dict(event.attributes) if event.attributes else {},
|
| 67 |
-
}
|
| 68 |
-
for event in (span.events or [])
|
| 69 |
-
],
|
| 70 |
-
},
|
| 71 |
-
})
|
| 72 |
except Exception as e:
|
| 73 |
logger.debug(f"Failed to collect span: {e}")
|
| 74 |
|
| 75 |
return SpanExportResult.SUCCESS
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
def force_flush(self, timeout_millis: int = 30000) -> bool:
|
| 78 |
"""Force flush spans (no-op for simple collection).
|
| 79 |
|
|
@@ -89,16 +101,47 @@ class FlowTraceCollector(SpanExporter):
|
|
| 89 |
"""Shutdown the exporter (no-op)."""
|
| 90 |
pass
|
| 91 |
|
| 92 |
-
def
|
| 93 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
Returns:
|
| 96 |
-
|
| 97 |
"""
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# Copyright (c) Microsoft. All rights reserved.
|
| 2 |
|
| 3 |
+
"""OpenTelemetry trace collector for experiment analysis.
|
| 4 |
+
|
| 5 |
+
Uses trace_id-based bucketing to isolate spans per task, even when
|
| 6 |
+
multiple tasks run concurrently on a shared TracerProvider.
|
| 7 |
+
"""
|
| 8 |
|
| 9 |
import logging
|
| 10 |
+
import threading
|
| 11 |
from datetime import datetime
|
| 12 |
from typing import Any
|
| 13 |
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
class FlowTraceCollector(SpanExporter):
|
| 20 |
+
"""Collects OpenTelemetry spans, bucketed by trace_id for isolation.
|
| 21 |
|
| 22 |
+
All spans from the global TracerProvider flow into this single collector.
|
| 23 |
+
Spans are stored in per-trace_id buckets so that each task can retrieve
|
| 24 |
+
only its own spans without cross-contamination.
|
| 25 |
|
| 26 |
Example:
|
| 27 |
collector = FlowTraceCollector()
|
| 28 |
+
# Attach ONCE to the global TracerProvider via SimpleSpanProcessor
|
| 29 |
+
# Run multiple tasks concurrently — each gets a unique trace_id
|
| 30 |
+
task_traces = collector.get_traces_for_task({"abc123"})
|
| 31 |
"""
|
| 32 |
|
| 33 |
def __init__(self) -> None:
|
| 34 |
"""Initialize the trace collector."""
|
| 35 |
+
self._spans_by_trace: dict[str, list[dict[str, Any]]] = {}
|
| 36 |
+
self._lock = threading.Lock()
|
| 37 |
|
| 38 |
def export(self, spans: Any) -> SpanExportResult:
|
| 39 |
+
"""Collect spans, bucketed by trace_id.
|
| 40 |
|
| 41 |
Args:
|
| 42 |
spans: Sequence of OpenTelemetry ReadableSpan objects
|
|
|
|
| 46 |
"""
|
| 47 |
for span in spans:
|
| 48 |
try:
|
| 49 |
+
trace_id = format(span.context.trace_id, "032x")
|
| 50 |
+
span_dict = self._convert_span(span)
|
| 51 |
+
|
| 52 |
+
with self._lock:
|
| 53 |
+
if trace_id not in self._spans_by_trace:
|
| 54 |
+
self._spans_by_trace[trace_id] = []
|
| 55 |
+
self._spans_by_trace[trace_id].append(span_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
except Exception as e:
|
| 57 |
logger.debug(f"Failed to collect span: {e}")
|
| 58 |
|
| 59 |
return SpanExportResult.SUCCESS
|
| 60 |
|
| 61 |
+
def get_traces_for_task(self, trace_ids: set[str]) -> list[dict[str, Any]]:
|
| 62 |
+
"""Get spans matching any of the given trace_ids, removing them.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
trace_ids: Set of trace_id hex strings to retrieve
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
List of span dicts belonging to those trace_ids
|
| 69 |
+
"""
|
| 70 |
+
result: list[dict[str, Any]] = []
|
| 71 |
+
with self._lock:
|
| 72 |
+
for tid in trace_ids:
|
| 73 |
+
result.extend(self._spans_by_trace.pop(tid, []))
|
| 74 |
+
return result
|
| 75 |
+
|
| 76 |
+
def get_traces(self) -> list[dict[str, Any]]:
|
| 77 |
+
"""Get and clear ALL collected traces (legacy API).
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
List of all collected trace spans, clearing internal state
|
| 81 |
+
"""
|
| 82 |
+
with self._lock:
|
| 83 |
+
all_spans: list[dict[str, Any]] = []
|
| 84 |
+
for spans in self._spans_by_trace.values():
|
| 85 |
+
all_spans.extend(spans)
|
| 86 |
+
self._spans_by_trace.clear()
|
| 87 |
+
return all_spans
|
| 88 |
+
|
| 89 |
def force_flush(self, timeout_millis: int = 30000) -> bool:
|
| 90 |
"""Force flush spans (no-op for simple collection).
|
| 91 |
|
|
|
|
| 101 |
"""Shutdown the exporter (no-op)."""
|
| 102 |
pass
|
| 103 |
|
| 104 |
+
def clear(self) -> None:
|
| 105 |
+
"""Clear collected traces without returning them."""
|
| 106 |
+
with self._lock:
|
| 107 |
+
self._spans_by_trace.clear()
|
| 108 |
+
|
| 109 |
+
@staticmethod
|
| 110 |
+
def _convert_span(span: Any) -> dict[str, Any]:
|
| 111 |
+
"""Convert an OTEL ReadableSpan to a dict.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
span: OpenTelemetry ReadableSpan
|
| 115 |
|
| 116 |
Returns:
|
| 117 |
+
Dictionary representation of the span
|
| 118 |
"""
|
| 119 |
+
start_time = span.start_time / 1_000_000_000
|
| 120 |
+
end_time = span.end_time / 1_000_000_000 if span.end_time else None
|
| 121 |
+
duration_ms = ((end_time - start_time) * 1000) if end_time else None
|
| 122 |
+
|
| 123 |
+
return {
|
| 124 |
+
"type": "trace_span",
|
| 125 |
+
"timestamp": datetime.fromtimestamp(start_time).isoformat(),
|
| 126 |
+
"data": {
|
| 127 |
+
"operation_name": span.name,
|
| 128 |
+
"span_id": format(span.context.span_id, "016x"),
|
| 129 |
+
"trace_id": format(span.context.trace_id, "032x"),
|
| 130 |
+
"parent_span_id": (
|
| 131 |
+
format(span.parent.span_id, "016x") if span.parent else None
|
| 132 |
+
),
|
| 133 |
+
"duration_ms": duration_ms,
|
| 134 |
+
"attributes": dict(span.attributes) if span.attributes else {},
|
| 135 |
+
"status": str(span.status.status_code.name) if hasattr(span, "status") else "OK",
|
| 136 |
+
"events": [
|
| 137 |
+
{
|
| 138 |
+
"name": event.name,
|
| 139 |
+
"timestamp": datetime.fromtimestamp(
|
| 140 |
+
event.timestamp / 1_000_000_000
|
| 141 |
+
).isoformat(),
|
| 142 |
+
"attributes": dict(event.attributes) if event.attributes else {},
|
| 143 |
+
}
|
| 144 |
+
for event in (span.events or [])
|
| 145 |
+
],
|
| 146 |
+
},
|
| 147 |
+
}
|
src/flow/experiments/types.py
CHANGED
|
@@ -61,6 +61,7 @@ class RunResult:
|
|
| 61 |
duration_seconds: float
|
| 62 |
workspace: Path
|
| 63 |
error: str | None = None
|
|
|
|
| 64 |
|
| 65 |
@property
|
| 66 |
def success(self) -> bool:
|
|
@@ -74,7 +75,8 @@ class CriterionResult:
|
|
| 74 |
|
| 75 |
Attributes:
|
| 76 |
name: Name of the criterion evaluated
|
| 77 |
-
score: Numeric score (0.0 to 1.0)
|
|
|
|
| 78 |
passed: Whether the criterion was met
|
| 79 |
reasoning: Explanation of the evaluation
|
| 80 |
"""
|
|
@@ -83,6 +85,7 @@ class CriterionResult:
|
|
| 83 |
score: float
|
| 84 |
passed: bool
|
| 85 |
reasoning: str
|
|
|
|
| 86 |
|
| 87 |
|
| 88 |
@dataclass
|
|
@@ -90,7 +93,8 @@ class EvalResult:
|
|
| 90 |
"""Result of evaluating an agent's output.
|
| 91 |
|
| 92 |
Attributes:
|
| 93 |
-
score: Overall weighted score (0.0 to 1.0)
|
|
|
|
| 94 |
passed: Whether the evaluation passed overall
|
| 95 |
criteria_results: Results for each individual criterion
|
| 96 |
reasoning: Overall evaluation reasoning/summary
|
|
@@ -100,6 +104,7 @@ class EvalResult:
|
|
| 100 |
passed: bool
|
| 101 |
criteria_results: list[CriterionResult]
|
| 102 |
reasoning: str
|
|
|
|
| 103 |
|
| 104 |
|
| 105 |
# =============================================================================
|
|
|
|
| 61 |
duration_seconds: float
|
| 62 |
workspace: Path
|
| 63 |
error: str | None = None
|
| 64 |
+
tool_results: list[dict[str, str]] = field(default_factory=list)
|
| 65 |
|
| 66 |
@property
|
| 67 |
def success(self) -> bool:
|
|
|
|
| 75 |
|
| 76 |
Attributes:
|
| 77 |
name: Name of the criterion evaluated
|
| 78 |
+
score: Numeric score (0.0 to 1.0) — exact match score
|
| 79 |
+
reasoning_score: Partial credit for correct reasoning/methodology (0.0 to 1.0)
|
| 80 |
passed: Whether the criterion was met
|
| 81 |
reasoning: Explanation of the evaluation
|
| 82 |
"""
|
|
|
|
| 85 |
score: float
|
| 86 |
passed: bool
|
| 87 |
reasoning: str
|
| 88 |
+
reasoning_score: float = 0.0
|
| 89 |
|
| 90 |
|
| 91 |
@dataclass
|
|
|
|
| 93 |
"""Result of evaluating an agent's output.
|
| 94 |
|
| 95 |
Attributes:
|
| 96 |
+
score: Overall weighted exact-match score (0.0 to 1.0)
|
| 97 |
+
reasoning_score: Overall weighted reasoning/methodology score (0.0 to 1.0)
|
| 98 |
passed: Whether the evaluation passed overall
|
| 99 |
criteria_results: Results for each individual criterion
|
| 100 |
reasoning: Overall evaluation reasoning/summary
|
|
|
|
| 104 |
passed: bool
|
| 105 |
criteria_results: list[CriterionResult]
|
| 106 |
reasoning: str
|
| 107 |
+
reasoning_score: float = 0.0
|
| 108 |
|
| 109 |
|
| 110 |
# =============================================================================
|
src/flow/harness/__init__.py
CHANGED
|
@@ -16,6 +16,10 @@ Usage:
|
|
| 16 |
harness = create_harness(agent, workspace=Path("/tmp"))
|
| 17 |
"""
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
from flow.harness.base import BaseHarness, Event, EventType
|
| 20 |
from flow.harness.registry import (
|
| 21 |
available_frameworks,
|
|
@@ -24,10 +28,7 @@ from flow.harness.registry import (
|
|
| 24 |
register,
|
| 25 |
)
|
| 26 |
|
| 27 |
-
|
| 28 |
-
# Each harness module calls register() on import
|
| 29 |
-
from flow.harness import maf as _maf # noqa: F401
|
| 30 |
-
from flow.harness import miniagent as _miniagent # noqa: F401
|
| 31 |
|
| 32 |
__all__ = [
|
| 33 |
"BaseHarness",
|
|
|
|
| 16 |
harness = create_harness(agent, workspace=Path("/tmp"))
|
| 17 |
"""
|
| 18 |
|
| 19 |
+
# Auto-register harnesses by importing them
|
| 20 |
+
# Each harness module calls register() on import
|
| 21 |
+
from flow.harness import maf as _maf
|
| 22 |
+
from flow.harness import miniagent as _miniagent
|
| 23 |
from flow.harness.base import BaseHarness, Event, EventType
|
| 24 |
from flow.harness.registry import (
|
| 25 |
available_frameworks,
|
|
|
|
| 28 |
register,
|
| 29 |
)
|
| 30 |
|
| 31 |
+
_ = (_maf, _miniagent) # Suppress unused import warnings
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
__all__ = [
|
| 34 |
"BaseHarness",
|
src/flow/harness/base.py
CHANGED
|
@@ -10,7 +10,7 @@ from abc import ABC, abstractmethod
|
|
| 10 |
from collections.abc import AsyncIterator
|
| 11 |
from dataclasses import dataclass, field
|
| 12 |
from enum import Enum
|
| 13 |
-
from typing import TYPE_CHECKING
|
| 14 |
|
| 15 |
if TYPE_CHECKING:
|
| 16 |
from pathlib import Path
|
|
@@ -62,18 +62,30 @@ class BaseHarness(ABC):
|
|
| 62 |
|
| 63 |
Implementations:
|
| 64 |
- MAFHarness (flow.harness.maf): Microsoft Agent Framework
|
| 65 |
-
- (
|
| 66 |
-
- (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
"""
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
@classmethod
|
| 70 |
@abstractmethod
|
| 71 |
def from_agent(
|
| 72 |
cls,
|
| 73 |
-
agent:
|
| 74 |
-
workspace:
|
| 75 |
-
llm_config:
|
| 76 |
-
) ->
|
| 77 |
"""Create a harness from an Agent definition.
|
| 78 |
|
| 79 |
Args:
|
|
|
|
| 10 |
from collections.abc import AsyncIterator
|
| 11 |
from dataclasses import dataclass, field
|
| 12 |
from enum import Enum
|
| 13 |
+
from typing import TYPE_CHECKING, ClassVar
|
| 14 |
|
| 15 |
if TYPE_CHECKING:
|
| 16 |
from pathlib import Path
|
|
|
|
| 62 |
|
| 63 |
Implementations:
|
| 64 |
- MAFHarness (flow.harness.maf): Microsoft Agent Framework
|
| 65 |
+
- LangGraphHarness (flow.harness.langgraph): LangGraph
|
| 66 |
+
- MiniAgentHarness (flow.harness.miniagent): MiniAgent
|
| 67 |
+
|
| 68 |
+
Class Attributes:
|
| 69 |
+
framework_name: Unique identifier for this framework (e.g., "maf", "langgraph")
|
| 70 |
+
framework_label: Human-readable label (e.g., "Microsoft Agent Framework")
|
| 71 |
+
framework_description: Short description of the framework
|
| 72 |
+
supported_compaction_strategies: List of compaction strategy names this framework supports
|
| 73 |
"""
|
| 74 |
|
| 75 |
+
# Framework metadata - subclasses should override these
|
| 76 |
+
framework_name: ClassVar[str] = ""
|
| 77 |
+
framework_label: ClassVar[str] = ""
|
| 78 |
+
framework_description: ClassVar[str] = ""
|
| 79 |
+
supported_compaction_strategies: ClassVar[list[str]] = []
|
| 80 |
+
|
| 81 |
@classmethod
|
| 82 |
@abstractmethod
|
| 83 |
def from_agent(
|
| 84 |
cls,
|
| 85 |
+
agent: Agent,
|
| 86 |
+
workspace: Path,
|
| 87 |
+
llm_config: LLMClientConfig | None = None,
|
| 88 |
+
) -> BaseHarness:
|
| 89 |
"""Create a harness from an Agent definition.
|
| 90 |
|
| 91 |
Args:
|
src/flow/harness/compaction/__init__.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft. All rights reserved.
|
| 2 |
+
"""Shared token-aware compaction strategies for all frameworks.
|
| 3 |
+
|
| 4 |
+
This module provides unified compaction strategies that work across
|
| 5 |
+
MAF, MiniAgent, and LangGraph frameworks. All strategies are token-based
|
| 6 |
+
to ensure safety against large messages.
|
| 7 |
+
|
| 8 |
+
Strategies:
|
| 9 |
+
- HeadTailStrategy: Keep head (system prompt) + tail (recent), drop middle
|
| 10 |
+
- SlidingWindowStrategy: Keep system + most recent messages within budget
|
| 11 |
+
- SummarizationStrategy: Summarize middle messages using LLM
|
| 12 |
+
- NoCompactionStrategy: Baseline (no management)
|
| 13 |
+
|
| 14 |
+
Usage:
|
| 15 |
+
from flow.harness.compaction import HeadTailStrategy, count_tokens
|
| 16 |
+
|
| 17 |
+
strategy = HeadTailStrategy(head_ratio=0.2, token_budget=200_000)
|
| 18 |
+
compacted = strategy.compact(messages)
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from flow.harness.compaction.strategies import (
|
| 22 |
+
CompactionStrategy,
|
| 23 |
+
HeadTailStrategy,
|
| 24 |
+
NoCompactionStrategy,
|
| 25 |
+
SlidingWindowStrategy,
|
| 26 |
+
SummarizationStrategy,
|
| 27 |
+
)
|
| 28 |
+
from flow.harness.compaction.tokenizer import count_tokens, get_encoder
|
| 29 |
+
|
| 30 |
+
__all__ = [
|
| 31 |
+
"CompactionStrategy",
|
| 32 |
+
"HeadTailStrategy",
|
| 33 |
+
"NoCompactionStrategy",
|
| 34 |
+
"SlidingWindowStrategy",
|
| 35 |
+
"SummarizationStrategy",
|
| 36 |
+
"count_tokens",
|
| 37 |
+
"get_encoder",
|
| 38 |
+
]
|
src/flow/harness/compaction/strategies.py
ADDED
|
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft. All rights reserved.
|
| 2 |
+
"""Token-aware compaction strategies for context management.
|
| 3 |
+
|
| 4 |
+
All strategies use token counting (not message counting) to ensure
|
| 5 |
+
safety against large messages that could blow past LLM context limits.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from dataclasses import dataclass, field
|
| 11 |
+
from typing import Any, Protocol
|
| 12 |
+
|
| 13 |
+
from flow.harness.compaction.tokenizer import (
|
| 14 |
+
count_message_tokens,
|
| 15 |
+
count_messages_tokens,
|
| 16 |
+
get_encoder,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
# Default token budget (safe for GPT-4o, Claude 3.5, etc.)
|
| 20 |
+
DEFAULT_TOKEN_BUDGET = 200_000
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CompactionStrategy(Protocol):
|
| 24 |
+
"""Protocol for compaction strategies.
|
| 25 |
+
|
| 26 |
+
All strategies must implement compact() which takes messages and
|
| 27 |
+
returns a (possibly compacted) list of messages.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def compact(
|
| 31 |
+
self,
|
| 32 |
+
messages: list[dict[str, Any]],
|
| 33 |
+
token_budget: int | None = None,
|
| 34 |
+
) -> list[dict[str, Any]]:
|
| 35 |
+
"""Compact messages to fit within token budget.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
messages: List of chat message dicts
|
| 39 |
+
token_budget: Max tokens (uses strategy default if None)
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
Compacted message list
|
| 43 |
+
"""
|
| 44 |
+
...
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dataclass
|
| 48 |
+
class NoCompactionStrategy:
|
| 49 |
+
"""Baseline: no compaction, context grows unbounded.
|
| 50 |
+
|
| 51 |
+
Use this for benchmarking to see how context grows without management.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def compact(
|
| 55 |
+
self,
|
| 56 |
+
messages: list[dict[str, Any]],
|
| 57 |
+
token_budget: int | None = None,
|
| 58 |
+
) -> list[dict[str, Any]]:
|
| 59 |
+
"""Return messages unchanged."""
|
| 60 |
+
return messages
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dataclass
|
| 64 |
+
class HeadTailStrategy:
|
| 65 |
+
"""Token-aware head+tail compaction.
|
| 66 |
+
|
| 67 |
+
Preserves:
|
| 68 |
+
- Head: System prompt, initial user message (critical context)
|
| 69 |
+
- Tail: Recent tool calls and results (working memory)
|
| 70 |
+
|
| 71 |
+
Drops middle messages when over budget, respecting atomic groups
|
| 72 |
+
(tool calls and their results must stay together).
|
| 73 |
+
|
| 74 |
+
This is the recommended strategy for most use cases.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
head_ratio: float = 0.2 # 20% for head by default
|
| 78 |
+
token_budget: int = DEFAULT_TOKEN_BUDGET
|
| 79 |
+
model: str = "gpt-4o"
|
| 80 |
+
|
| 81 |
+
# Statistics
|
| 82 |
+
compaction_count: int = field(default=0, repr=False)
|
| 83 |
+
total_tokens_saved: int = field(default=0, repr=False)
|
| 84 |
+
|
| 85 |
+
def _find_atomic_groups(
|
| 86 |
+
self, messages: list[dict[str, Any]]
|
| 87 |
+
) -> list[tuple[int, ...]]:
|
| 88 |
+
"""Group tool_call messages with their results.
|
| 89 |
+
|
| 90 |
+
OpenAI requires every tool_call to have a corresponding result.
|
| 91 |
+
This ensures we never split a tool call from its results.
|
| 92 |
+
|
| 93 |
+
Returns list of tuples, where each tuple contains indices that
|
| 94 |
+
must stay together.
|
| 95 |
+
"""
|
| 96 |
+
groups: list[tuple[int, ...]] = []
|
| 97 |
+
i = 0
|
| 98 |
+
|
| 99 |
+
while i < len(messages):
|
| 100 |
+
msg = messages[i]
|
| 101 |
+
|
| 102 |
+
if msg.get("tool_calls"):
|
| 103 |
+
# This message has tool calls - find all results
|
| 104 |
+
call_ids = {tc.get("id") for tc in msg["tool_calls"] if tc.get("id")}
|
| 105 |
+
group_indices = [i]
|
| 106 |
+
|
| 107 |
+
# Look ahead for results
|
| 108 |
+
j = i + 1
|
| 109 |
+
while j < len(messages) and call_ids:
|
| 110 |
+
if messages[j].get("role") == "tool":
|
| 111 |
+
tool_call_id = messages[j].get("tool_call_id")
|
| 112 |
+
if tool_call_id in call_ids:
|
| 113 |
+
group_indices.append(j)
|
| 114 |
+
call_ids.remove(tool_call_id)
|
| 115 |
+
j += 1
|
| 116 |
+
|
| 117 |
+
groups.append(tuple(group_indices))
|
| 118 |
+
i = max(group_indices) + 1 if group_indices else i + 1
|
| 119 |
+
else:
|
| 120 |
+
groups.append((i,))
|
| 121 |
+
i += 1
|
| 122 |
+
|
| 123 |
+
return groups
|
| 124 |
+
|
| 125 |
+
def compact(
|
| 126 |
+
self,
|
| 127 |
+
messages: list[dict[str, Any]],
|
| 128 |
+
token_budget: int | None = None,
|
| 129 |
+
) -> list[dict[str, Any]]:
|
| 130 |
+
"""Compact if over budget."""
|
| 131 |
+
if not messages:
|
| 132 |
+
return messages
|
| 133 |
+
|
| 134 |
+
budget = token_budget or self.token_budget
|
| 135 |
+
encoder = get_encoder(self.model)
|
| 136 |
+
current_tokens = count_messages_tokens(messages, self.model)
|
| 137 |
+
|
| 138 |
+
if current_tokens <= budget:
|
| 139 |
+
return messages
|
| 140 |
+
|
| 141 |
+
# COMPACTION NEEDED
|
| 142 |
+
self.compaction_count += 1
|
| 143 |
+
|
| 144 |
+
groups = self._find_atomic_groups(messages)
|
| 145 |
+
head_budget = int(budget * self.head_ratio)
|
| 146 |
+
tail_budget = budget - head_budget
|
| 147 |
+
|
| 148 |
+
# Fill head from start
|
| 149 |
+
head_groups: list[tuple[int, ...]] = []
|
| 150 |
+
head_tokens = 0
|
| 151 |
+
|
| 152 |
+
for group in groups:
|
| 153 |
+
group_msgs = [messages[i] for i in group]
|
| 154 |
+
group_tokens = sum(
|
| 155 |
+
count_message_tokens(m, self.model, encoder) for m in group_msgs
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
if head_tokens + group_tokens <= head_budget:
|
| 159 |
+
head_groups.append(group)
|
| 160 |
+
head_tokens += group_tokens
|
| 161 |
+
else:
|
| 162 |
+
break
|
| 163 |
+
|
| 164 |
+
# Fill tail from end (skip head groups)
|
| 165 |
+
remaining_groups = groups[len(head_groups) :]
|
| 166 |
+
tail_groups: list[tuple[int, ...]] = []
|
| 167 |
+
tail_tokens = 0
|
| 168 |
+
|
| 169 |
+
for group in reversed(remaining_groups):
|
| 170 |
+
group_msgs = [messages[i] for i in group]
|
| 171 |
+
group_tokens = sum(
|
| 172 |
+
count_message_tokens(m, self.model, encoder) for m in group_msgs
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
if tail_tokens + group_tokens <= tail_budget:
|
| 176 |
+
tail_groups.insert(0, group)
|
| 177 |
+
tail_tokens += group_tokens
|
| 178 |
+
else:
|
| 179 |
+
break
|
| 180 |
+
|
| 181 |
+
# Build compacted list
|
| 182 |
+
kept_indices: set[int] = set()
|
| 183 |
+
for group in head_groups + tail_groups:
|
| 184 |
+
kept_indices.update(group)
|
| 185 |
+
|
| 186 |
+
compacted = [messages[i] for i in sorted(kept_indices)]
|
| 187 |
+
|
| 188 |
+
# Track savings
|
| 189 |
+
compacted_tokens = count_messages_tokens(compacted, self.model)
|
| 190 |
+
self.total_tokens_saved += current_tokens - compacted_tokens
|
| 191 |
+
|
| 192 |
+
return compacted
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
@dataclass
|
| 196 |
+
class SlidingWindowStrategy:
|
| 197 |
+
"""Keep only recent messages within budget.
|
| 198 |
+
|
| 199 |
+
Always preserves the system message (if present) plus the most
|
| 200 |
+
recent messages that fit in the budget. Respects atomic groups
|
| 201 |
+
(tool calls and their results must stay together).
|
| 202 |
+
|
| 203 |
+
Simpler than HeadTailStrategy but may lose important early context.
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
token_budget: int = DEFAULT_TOKEN_BUDGET
|
| 207 |
+
model: str = "gpt-4o"
|
| 208 |
+
|
| 209 |
+
def _find_atomic_groups(
|
| 210 |
+
self, messages: list[dict[str, Any]]
|
| 211 |
+
) -> list[tuple[int, ...]]:
|
| 212 |
+
"""Group tool_call messages with their results."""
|
| 213 |
+
groups: list[tuple[int, ...]] = []
|
| 214 |
+
i = 0
|
| 215 |
+
|
| 216 |
+
while i < len(messages):
|
| 217 |
+
msg = messages[i]
|
| 218 |
+
|
| 219 |
+
if msg.get("tool_calls"):
|
| 220 |
+
call_ids = {tc.get("id") for tc in msg["tool_calls"] if tc.get("id")}
|
| 221 |
+
group_indices = [i]
|
| 222 |
+
|
| 223 |
+
j = i + 1
|
| 224 |
+
while j < len(messages) and call_ids:
|
| 225 |
+
if messages[j].get("role") == "tool":
|
| 226 |
+
tool_call_id = messages[j].get("tool_call_id")
|
| 227 |
+
if tool_call_id in call_ids:
|
| 228 |
+
group_indices.append(j)
|
| 229 |
+
call_ids.remove(tool_call_id)
|
| 230 |
+
j += 1
|
| 231 |
+
|
| 232 |
+
groups.append(tuple(group_indices))
|
| 233 |
+
i = max(group_indices) + 1 if group_indices else i + 1
|
| 234 |
+
else:
|
| 235 |
+
groups.append((i,))
|
| 236 |
+
i += 1
|
| 237 |
+
|
| 238 |
+
return groups
|
| 239 |
+
|
| 240 |
+
def compact(
|
| 241 |
+
self,
|
| 242 |
+
messages: list[dict[str, Any]],
|
| 243 |
+
token_budget: int | None = None,
|
| 244 |
+
) -> list[dict[str, Any]]:
|
| 245 |
+
"""Keep system message + most recent messages within budget."""
|
| 246 |
+
if not messages:
|
| 247 |
+
return messages
|
| 248 |
+
|
| 249 |
+
budget = token_budget or self.token_budget
|
| 250 |
+
encoder = get_encoder(self.model)
|
| 251 |
+
|
| 252 |
+
# Always keep system messages at the start
|
| 253 |
+
system_msgs: list[dict[str, Any]] = []
|
| 254 |
+
non_system_start = 0
|
| 255 |
+
|
| 256 |
+
for i, msg in enumerate(messages):
|
| 257 |
+
if msg.get("role") == "system":
|
| 258 |
+
system_msgs.append(msg)
|
| 259 |
+
non_system_start = i + 1
|
| 260 |
+
else:
|
| 261 |
+
break
|
| 262 |
+
|
| 263 |
+
other_msgs = messages[non_system_start:]
|
| 264 |
+
|
| 265 |
+
system_tokens = sum(
|
| 266 |
+
count_message_tokens(m, self.model, encoder) for m in system_msgs
|
| 267 |
+
)
|
| 268 |
+
remaining_budget = budget - system_tokens
|
| 269 |
+
|
| 270 |
+
if remaining_budget <= 0:
|
| 271 |
+
return system_msgs
|
| 272 |
+
|
| 273 |
+
# Check if we need to compact
|
| 274 |
+
other_tokens = sum(
|
| 275 |
+
count_message_tokens(m, self.model, encoder) for m in other_msgs
|
| 276 |
+
)
|
| 277 |
+
if other_tokens <= remaining_budget:
|
| 278 |
+
return messages # No compaction needed
|
| 279 |
+
|
| 280 |
+
# Find atomic groups in other messages
|
| 281 |
+
groups = self._find_atomic_groups(other_msgs)
|
| 282 |
+
|
| 283 |
+
# Fill from end, respecting atomic groups
|
| 284 |
+
kept_groups: list[tuple[int, ...]] = []
|
| 285 |
+
kept_tokens = 0
|
| 286 |
+
|
| 287 |
+
for group in reversed(groups):
|
| 288 |
+
group_msgs = [other_msgs[i] for i in group]
|
| 289 |
+
group_tokens = sum(
|
| 290 |
+
count_message_tokens(m, self.model, encoder) for m in group_msgs
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
if kept_tokens + group_tokens <= remaining_budget:
|
| 294 |
+
kept_groups.insert(0, group)
|
| 295 |
+
kept_tokens += group_tokens
|
| 296 |
+
else:
|
| 297 |
+
break
|
| 298 |
+
|
| 299 |
+
# Build result from kept groups
|
| 300 |
+
kept_indices: set[int] = set()
|
| 301 |
+
for group in kept_groups:
|
| 302 |
+
kept_indices.update(group)
|
| 303 |
+
|
| 304 |
+
result = [other_msgs[i] for i in sorted(kept_indices)]
|
| 305 |
+
|
| 306 |
+
return system_msgs + result
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
@dataclass
|
| 310 |
+
class SummarizationStrategy:
|
| 311 |
+
"""Summarize old messages instead of dropping them.
|
| 312 |
+
|
| 313 |
+
When over budget, this strategy:
|
| 314 |
+
1. Keeps: System message + initial user message (head)
|
| 315 |
+
2. Keeps: Most recent messages (tail)
|
| 316 |
+
3. Summarizes: Everything in between into a single "context so far" message
|
| 317 |
+
|
| 318 |
+
This preserves critical state (files read, findings, progress) that would
|
| 319 |
+
otherwise be lost with simple truncation strategies.
|
| 320 |
+
|
| 321 |
+
Note: Requires an async summarization function to be provided.
|
| 322 |
+
"""
|
| 323 |
+
|
| 324 |
+
head_messages: int = 2 # Keep first N messages
|
| 325 |
+
tail_messages: int = 4 # Keep last N messages
|
| 326 |
+
summary_max_tokens: int = 1000
|
| 327 |
+
token_budget: int = DEFAULT_TOKEN_BUDGET
|
| 328 |
+
model: str = "gpt-4o"
|
| 329 |
+
|
| 330 |
+
# Async function to generate summaries (must be set before use)
|
| 331 |
+
summarize_fn: Any = field(default=None, repr=False)
|
| 332 |
+
|
| 333 |
+
# Statistics
|
| 334 |
+
compaction_count: int = field(default=0, repr=False)
|
| 335 |
+
total_tokens_saved: int = field(default=0, repr=False)
|
| 336 |
+
|
| 337 |
+
def _find_safe_split_points(
|
| 338 |
+
self, messages: list[dict[str, Any]]
|
| 339 |
+
) -> tuple[int, int]:
|
| 340 |
+
"""Find safe points to split messages without breaking tool call/result pairs.
|
| 341 |
+
|
| 342 |
+
Returns (head_end, tail_start) indices where it's safe to summarize between.
|
| 343 |
+
"""
|
| 344 |
+
groups: list[tuple[int, int]] = [] # (start, end) indices
|
| 345 |
+
i = 0
|
| 346 |
+
|
| 347 |
+
while i < len(messages):
|
| 348 |
+
msg = messages[i]
|
| 349 |
+
if msg.get("tool_calls"):
|
| 350 |
+
call_ids = {tc.get("id") for tc in msg["tool_calls"] if tc.get("id")}
|
| 351 |
+
end = i
|
| 352 |
+
j = i + 1
|
| 353 |
+
while j < len(messages) and call_ids:
|
| 354 |
+
if messages[j].get("role") == "tool":
|
| 355 |
+
tool_call_id = messages[j].get("tool_call_id")
|
| 356 |
+
if tool_call_id in call_ids:
|
| 357 |
+
call_ids.discard(tool_call_id)
|
| 358 |
+
end = j
|
| 359 |
+
j += 1
|
| 360 |
+
groups.append((i, end + 1))
|
| 361 |
+
i = end + 1
|
| 362 |
+
else:
|
| 363 |
+
groups.append((i, i + 1))
|
| 364 |
+
i += 1
|
| 365 |
+
|
| 366 |
+
# Find safe head end (after self.head_messages worth of groups)
|
| 367 |
+
head_end = 0
|
| 368 |
+
for idx, (_start, end) in enumerate(groups):
|
| 369 |
+
if idx < self.head_messages:
|
| 370 |
+
head_end = end
|
| 371 |
+
else:
|
| 372 |
+
break
|
| 373 |
+
|
| 374 |
+
# Find safe tail start (before last self.tail_messages groups)
|
| 375 |
+
tail_start = len(messages)
|
| 376 |
+
tail_groups = min(self.tail_messages, len(groups))
|
| 377 |
+
if tail_groups > 0 and len(groups) > tail_groups:
|
| 378 |
+
tail_start = groups[-tail_groups][0]
|
| 379 |
+
|
| 380 |
+
# Ensure we don't overlap
|
| 381 |
+
if head_end >= tail_start:
|
| 382 |
+
return len(messages), len(messages)
|
| 383 |
+
|
| 384 |
+
return head_end, tail_start
|
| 385 |
+
|
| 386 |
+
def _extract_key_info(self, messages: list[dict[str, Any]]) -> str:
|
| 387 |
+
"""Extract key info without LLM (fallback)."""
|
| 388 |
+
files_read: set[str] = set()
|
| 389 |
+
key_findings: list[str] = []
|
| 390 |
+
|
| 391 |
+
for msg in messages:
|
| 392 |
+
if msg.get("role") == "tool" and msg.get("name") == "read_file":
|
| 393 |
+
files_read.add(msg.get("name") or "file")
|
| 394 |
+
if msg.get("role") == "assistant" and msg.get("content"):
|
| 395 |
+
content = msg["content"]
|
| 396 |
+
if isinstance(content, str) and len(content) < 200:
|
| 397 |
+
key_findings.append(content)
|
| 398 |
+
|
| 399 |
+
parts: list[str] = []
|
| 400 |
+
if files_read:
|
| 401 |
+
parts.append(f"Files accessed: {', '.join(files_read)}")
|
| 402 |
+
if key_findings:
|
| 403 |
+
parts.append(f"Key points: {'; '.join(key_findings[:5])}")
|
| 404 |
+
|
| 405 |
+
return "\n".join(parts) if parts else "Previous context was processed."
|
| 406 |
+
|
| 407 |
+
def compact(
|
| 408 |
+
self,
|
| 409 |
+
messages: list[dict[str, Any]],
|
| 410 |
+
token_budget: int | None = None,
|
| 411 |
+
) -> list[dict[str, Any]]:
|
| 412 |
+
"""Summarize middle messages if over budget.
|
| 413 |
+
|
| 414 |
+
Note: This is synchronous and uses simple extraction.
|
| 415 |
+
For LLM-based summarization, use compact_async().
|
| 416 |
+
"""
|
| 417 |
+
if not messages:
|
| 418 |
+
return messages
|
| 419 |
+
|
| 420 |
+
budget = token_budget or self.token_budget
|
| 421 |
+
current_tokens = count_messages_tokens(messages, self.model)
|
| 422 |
+
|
| 423 |
+
if current_tokens <= budget:
|
| 424 |
+
return messages
|
| 425 |
+
|
| 426 |
+
self.compaction_count += 1
|
| 427 |
+
|
| 428 |
+
head_end, tail_start = self._find_safe_split_points(messages)
|
| 429 |
+
|
| 430 |
+
head = messages[:head_end]
|
| 431 |
+
tail = messages[tail_start:]
|
| 432 |
+
middle = messages[head_end:tail_start]
|
| 433 |
+
|
| 434 |
+
if not middle:
|
| 435 |
+
return messages
|
| 436 |
+
|
| 437 |
+
# Extract key info without LLM
|
| 438 |
+
summary_text = self._extract_key_info(middle)
|
| 439 |
+
|
| 440 |
+
summary_message = {
|
| 441 |
+
"role": "user",
|
| 442 |
+
"content": f"[CONTEXT SUMMARY - Previous {len(middle)} messages compressed]\n\n{summary_text}\n\n[END SUMMARY - Continue from here]",
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
compacted = head + [summary_message] + tail
|
| 446 |
+
|
| 447 |
+
compacted_tokens = count_messages_tokens(compacted, self.model)
|
| 448 |
+
self.total_tokens_saved += current_tokens - compacted_tokens
|
| 449 |
+
|
| 450 |
+
return compacted
|
| 451 |
+
|
| 452 |
+
async def compact_async(
|
| 453 |
+
self,
|
| 454 |
+
messages: list[dict[str, Any]],
|
| 455 |
+
token_budget: int | None = None,
|
| 456 |
+
) -> list[dict[str, Any]]:
|
| 457 |
+
"""Async version that can use LLM for summarization."""
|
| 458 |
+
if not messages:
|
| 459 |
+
return messages
|
| 460 |
+
|
| 461 |
+
budget = token_budget or self.token_budget
|
| 462 |
+
current_tokens = count_messages_tokens(messages, self.model)
|
| 463 |
+
|
| 464 |
+
if current_tokens <= budget:
|
| 465 |
+
return messages
|
| 466 |
+
|
| 467 |
+
self.compaction_count += 1
|
| 468 |
+
|
| 469 |
+
head_end, tail_start = self._find_safe_split_points(messages)
|
| 470 |
+
|
| 471 |
+
head = messages[:head_end]
|
| 472 |
+
tail = messages[tail_start:]
|
| 473 |
+
middle = messages[head_end:tail_start]
|
| 474 |
+
|
| 475 |
+
if not middle:
|
| 476 |
+
return messages
|
| 477 |
+
|
| 478 |
+
# Generate summary
|
| 479 |
+
if self.summarize_fn:
|
| 480 |
+
try:
|
| 481 |
+
summary_text = await self.summarize_fn(middle, self.summary_max_tokens)
|
| 482 |
+
except Exception:
|
| 483 |
+
summary_text = self._extract_key_info(middle)
|
| 484 |
+
else:
|
| 485 |
+
summary_text = self._extract_key_info(middle)
|
| 486 |
+
|
| 487 |
+
summary_message = {
|
| 488 |
+
"role": "user",
|
| 489 |
+
"content": f"""[CONTEXT CHECKPOINT - Your previous work has been summarized below]
|
| 490 |
+
|
| 491 |
+
{summary_text}
|
| 492 |
+
|
| 493 |
+
---
|
| 494 |
+
IMPORTANT: Continue from where you left off. Do not repeat work already done.""",
|
| 495 |
+
}
|
| 496 |
+
|
| 497 |
+
compacted = head + [summary_message] + tail
|
| 498 |
+
|
| 499 |
+
compacted_tokens = count_messages_tokens(compacted, self.model)
|
| 500 |
+
self.total_tokens_saved += current_tokens - compacted_tokens
|
| 501 |
+
|
| 502 |
+
return compacted
|
src/flow/harness/compaction/tokenizer.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft. All rights reserved.
|
| 2 |
+
"""Shared tiktoken wrapper for consistent token counting across frameworks."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
import tiktoken
|
| 9 |
+
|
| 10 |
+
# Cache encoders to avoid repeated initialization
|
| 11 |
+
_ENCODER_CACHE: dict[str, tiktoken.Encoding] = {}
|
| 12 |
+
|
| 13 |
+
# Default encoding for unknown models
|
| 14 |
+
DEFAULT_ENCODING = "cl100k_base"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_encoder(model: str = "gpt-4o") -> tiktoken.Encoding:
|
| 18 |
+
"""Get tiktoken encoder for a model.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
model: Model name (e.g., "gpt-4o", "gpt-4", "gpt-3.5-turbo")
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
tiktoken Encoding instance
|
| 25 |
+
"""
|
| 26 |
+
if model in _ENCODER_CACHE:
|
| 27 |
+
return _ENCODER_CACHE[model]
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
encoder = tiktoken.encoding_for_model(model)
|
| 31 |
+
except KeyError:
|
| 32 |
+
# Fallback for unknown models (Claude, etc.)
|
| 33 |
+
encoder = tiktoken.get_encoding(DEFAULT_ENCODING)
|
| 34 |
+
|
| 35 |
+
_ENCODER_CACHE[model] = encoder
|
| 36 |
+
return encoder
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def count_tokens(
|
| 40 |
+
text: str,
|
| 41 |
+
model: str = "gpt-4o",
|
| 42 |
+
encoder: tiktoken.Encoding | None = None,
|
| 43 |
+
) -> int:
|
| 44 |
+
"""Count tokens in a text string.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
text: The text to count tokens for
|
| 48 |
+
model: Model name for encoding selection
|
| 49 |
+
encoder: Optional pre-fetched encoder (for performance)
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
Number of tokens
|
| 53 |
+
"""
|
| 54 |
+
if encoder is None:
|
| 55 |
+
encoder = get_encoder(model)
|
| 56 |
+
return len(encoder.encode(text))
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def count_message_tokens(
|
| 60 |
+
message: dict[str, Any],
|
| 61 |
+
model: str = "gpt-4o",
|
| 62 |
+
encoder: tiktoken.Encoding | None = None,
|
| 63 |
+
) -> int:
|
| 64 |
+
"""Count tokens in a chat message dict.
|
| 65 |
+
|
| 66 |
+
Handles:
|
| 67 |
+
- role overhead (~4 tokens per message)
|
| 68 |
+
- content text
|
| 69 |
+
- tool_calls (name + arguments)
|
| 70 |
+
- tool results
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
message: Chat message dict with role, content, etc.
|
| 74 |
+
model: Model name for encoding selection
|
| 75 |
+
encoder: Optional pre-fetched encoder
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
Approximate token count for the message
|
| 79 |
+
"""
|
| 80 |
+
if encoder is None:
|
| 81 |
+
encoder = get_encoder(model)
|
| 82 |
+
|
| 83 |
+
total = 4 # Role overhead (approximate)
|
| 84 |
+
|
| 85 |
+
# Content
|
| 86 |
+
content = message.get("content")
|
| 87 |
+
if content:
|
| 88 |
+
if isinstance(content, str):
|
| 89 |
+
total += len(encoder.encode(content))
|
| 90 |
+
elif isinstance(content, list):
|
| 91 |
+
# Handle structured content (text blocks, etc.)
|
| 92 |
+
for item in content:
|
| 93 |
+
if isinstance(item, dict) and "text" in item:
|
| 94 |
+
total += len(encoder.encode(item["text"]))
|
| 95 |
+
|
| 96 |
+
# Tool calls
|
| 97 |
+
tool_calls = message.get("tool_calls")
|
| 98 |
+
if tool_calls:
|
| 99 |
+
for tc in tool_calls:
|
| 100 |
+
total += 4 # Tool call overhead
|
| 101 |
+
if isinstance(tc, dict):
|
| 102 |
+
name = tc.get("name") or tc.get("function", {}).get("name", "")
|
| 103 |
+
args = tc.get("arguments") or tc.get("function", {}).get("arguments", "")
|
| 104 |
+
else:
|
| 105 |
+
# Handle object-style tool calls
|
| 106 |
+
name = getattr(tc, "name", "")
|
| 107 |
+
args = getattr(tc, "arguments", "")
|
| 108 |
+
|
| 109 |
+
if name:
|
| 110 |
+
total += len(encoder.encode(name))
|
| 111 |
+
if args:
|
| 112 |
+
total += len(encoder.encode(args))
|
| 113 |
+
|
| 114 |
+
return total
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def count_messages_tokens(
|
| 118 |
+
messages: list[dict[str, Any]],
|
| 119 |
+
model: str = "gpt-4o",
|
| 120 |
+
) -> int:
|
| 121 |
+
"""Count total tokens across all messages.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
messages: List of chat message dicts
|
| 125 |
+
model: Model name for encoding selection
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
Total token count
|
| 129 |
+
"""
|
| 130 |
+
encoder = get_encoder(model)
|
| 131 |
+
return sum(count_message_tokens(m, model, encoder) for m in messages)
|
src/flow/harness/langgraph/__init__.py
CHANGED
|
@@ -19,7 +19,11 @@ Usage:
|
|
| 19 |
print(event.type, event.content)
|
| 20 |
"""
|
| 21 |
|
| 22 |
-
from flow.harness.langgraph.compaction import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
from flow.harness.langgraph.harness import LangGraphHarness
|
| 24 |
from flow.harness.langgraph.otel_callback import OTelCallbackHandler
|
| 25 |
from flow.harness.langgraph.wrappers import build_langgraph_tools, wrap_for_langgraph
|
|
@@ -33,5 +37,7 @@ __all__ = [
|
|
| 33 |
"OTelCallbackHandler",
|
| 34 |
"build_langgraph_tools",
|
| 35 |
"create_compaction_hook",
|
|
|
|
|
|
|
| 36 |
"wrap_for_langgraph",
|
| 37 |
]
|
|
|
|
| 19 |
print(event.type, event.content)
|
| 20 |
"""
|
| 21 |
|
| 22 |
+
from flow.harness.langgraph.compaction import (
|
| 23 |
+
create_compaction_hook,
|
| 24 |
+
create_head_tail_hook,
|
| 25 |
+
create_sliding_window_hook,
|
| 26 |
+
)
|
| 27 |
from flow.harness.langgraph.harness import LangGraphHarness
|
| 28 |
from flow.harness.langgraph.otel_callback import OTelCallbackHandler
|
| 29 |
from flow.harness.langgraph.wrappers import build_langgraph_tools, wrap_for_langgraph
|
|
|
|
| 37 |
"OTelCallbackHandler",
|
| 38 |
"build_langgraph_tools",
|
| 39 |
"create_compaction_hook",
|
| 40 |
+
"create_head_tail_hook",
|
| 41 |
+
"create_sliding_window_hook",
|
| 42 |
"wrap_for_langgraph",
|
| 43 |
]
|
src/flow/harness/langgraph/compaction.py
CHANGED
|
@@ -1,51 +1,219 @@
|
|
| 1 |
-
|
|
|
|
| 2 |
|
| 3 |
-
Provides
|
| 4 |
-
|
| 5 |
"""
|
| 6 |
|
| 7 |
from __future__ import annotations
|
| 8 |
|
| 9 |
from typing import Any
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
Args:
|
| 21 |
-
|
| 22 |
-
|
|
|
|
| 23 |
|
| 24 |
Returns:
|
| 25 |
A function that can be used as a pre_model_hook in create_react_agent
|
| 26 |
|
| 27 |
Example:
|
| 28 |
-
hook = create_compaction_hook(
|
| 29 |
graph = create_react_agent(
|
| 30 |
model=model,
|
| 31 |
tools=tools,
|
| 32 |
pre_model_hook=hook,
|
| 33 |
)
|
| 34 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
def compact_messages(state: dict[str, Any]) -> dict[str, Any]:
|
| 37 |
-
"""Compact messages
|
| 38 |
messages = state.get("messages", [])
|
| 39 |
-
total = len(messages)
|
| 40 |
|
| 41 |
-
|
| 42 |
-
if total <= head_size + tail_size:
|
| 43 |
return {"llm_input_messages": messages}
|
| 44 |
|
| 45 |
-
#
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
return {"llm_input_messages":
|
| 50 |
|
| 51 |
return compact_messages
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft. All rights reserved.
|
| 2 |
+
"""Token-aware message compaction for LangGraph.
|
| 3 |
|
| 4 |
+
Provides pre-model hooks that implement token-based message compaction,
|
| 5 |
+
ensuring safety against large messages that could exceed LLM context limits.
|
| 6 |
"""
|
| 7 |
|
| 8 |
from __future__ import annotations
|
| 9 |
|
| 10 |
from typing import Any
|
| 11 |
|
| 12 |
+
from flow.harness.compaction import (
|
| 13 |
+
HeadTailStrategy,
|
| 14 |
+
SlidingWindowStrategy,
|
| 15 |
+
)
|
| 16 |
|
| 17 |
+
__all__ = [
|
| 18 |
+
"create_compaction_hook",
|
| 19 |
+
"create_head_tail_hook",
|
| 20 |
+
"create_sliding_window_hook",
|
| 21 |
+
]
|
| 22 |
|
| 23 |
+
# Default token budget (safe for GPT-4o, Claude 3.5, etc.)
|
| 24 |
+
DEFAULT_TOKEN_BUDGET = 200_000
|
| 25 |
|
| 26 |
+
|
| 27 |
+
def _langchain_msg_to_dict(msg: Any) -> dict[str, Any]:
|
| 28 |
+
"""Convert LangChain message to dict format for compaction strategies."""
|
| 29 |
+
if isinstance(msg, dict):
|
| 30 |
+
return msg
|
| 31 |
+
|
| 32 |
+
# Handle LangChain message types
|
| 33 |
+
result: dict[str, Any] = {}
|
| 34 |
+
|
| 35 |
+
# Get role from type
|
| 36 |
+
msg_type = getattr(msg, "type", None)
|
| 37 |
+
if msg_type == "human":
|
| 38 |
+
result["role"] = "user"
|
| 39 |
+
elif msg_type == "ai":
|
| 40 |
+
result["role"] = "assistant"
|
| 41 |
+
elif msg_type == "system":
|
| 42 |
+
result["role"] = "system"
|
| 43 |
+
elif msg_type == "tool":
|
| 44 |
+
result["role"] = "tool"
|
| 45 |
+
result["tool_call_id"] = getattr(msg, "tool_call_id", None)
|
| 46 |
+
else:
|
| 47 |
+
result["role"] = msg_type or "user"
|
| 48 |
+
|
| 49 |
+
# Get content
|
| 50 |
+
content = getattr(msg, "content", "")
|
| 51 |
+
result["content"] = content
|
| 52 |
+
|
| 53 |
+
# Get tool calls (for AIMessage)
|
| 54 |
+
tool_calls = getattr(msg, "tool_calls", None)
|
| 55 |
+
if tool_calls:
|
| 56 |
+
result["tool_calls"] = [
|
| 57 |
+
{
|
| 58 |
+
"id": tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None),
|
| 59 |
+
"function": {
|
| 60 |
+
"name": tc.get("name") if isinstance(tc, dict) else getattr(tc, "name", ""),
|
| 61 |
+
"arguments": str(tc.get("args", {})) if isinstance(tc, dict) else str(getattr(tc, "args", {})),
|
| 62 |
+
},
|
| 63 |
+
}
|
| 64 |
+
for tc in tool_calls
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
return result
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _dict_to_langchain_msg(msg_dict: dict[str, Any], original_msg: Any) -> Any:
|
| 71 |
+
"""Preserve original LangChain message (we don't convert back)."""
|
| 72 |
+
# For compaction, we return the original message objects
|
| 73 |
+
# The strategy just tells us which indices to keep
|
| 74 |
+
return original_msg
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def create_compaction_hook(
|
| 78 |
+
head_ratio: float = 0.2,
|
| 79 |
+
token_budget: int = DEFAULT_TOKEN_BUDGET,
|
| 80 |
+
model: str = "gpt-4o",
|
| 81 |
+
):
|
| 82 |
+
"""Create a pre-model hook for token-aware head+tail compaction.
|
| 83 |
+
|
| 84 |
+
This hook compacts messages by keeping head messages (system prompt,
|
| 85 |
+
initial context) and tail messages (recent work), dropping the middle
|
| 86 |
+
when token count exceeds the budget.
|
| 87 |
|
| 88 |
Args:
|
| 89 |
+
head_ratio: Fraction of budget for head messages (0.2 = 20%)
|
| 90 |
+
token_budget: Max tokens before compaction triggers
|
| 91 |
+
model: Model name for tokenizer selection
|
| 92 |
|
| 93 |
Returns:
|
| 94 |
A function that can be used as a pre_model_hook in create_react_agent
|
| 95 |
|
| 96 |
Example:
|
| 97 |
+
hook = create_compaction_hook(head_ratio=0.2, token_budget=200000)
|
| 98 |
graph = create_react_agent(
|
| 99 |
model=model,
|
| 100 |
tools=tools,
|
| 101 |
pre_model_hook=hook,
|
| 102 |
)
|
| 103 |
"""
|
| 104 |
+
strategy = HeadTailStrategy(
|
| 105 |
+
head_ratio=head_ratio,
|
| 106 |
+
token_budget=token_budget,
|
| 107 |
+
model=model,
|
| 108 |
+
)
|
| 109 |
|
| 110 |
def compact_messages(state: dict[str, Any]) -> dict[str, Any]:
|
| 111 |
+
"""Compact messages using token-aware head+tail strategy."""
|
| 112 |
messages = state.get("messages", [])
|
|
|
|
| 113 |
|
| 114 |
+
if not messages:
|
|
|
|
| 115 |
return {"llm_input_messages": messages}
|
| 116 |
|
| 117 |
+
# Convert to dict format for strategy
|
| 118 |
+
msg_dicts = [_langchain_msg_to_dict(m) for m in messages]
|
| 119 |
+
|
| 120 |
+
# Apply compaction
|
| 121 |
+
compacted_dicts = strategy.compact(msg_dicts)
|
| 122 |
+
|
| 123 |
+
# Map back to original message objects
|
| 124 |
+
# We need to find which original messages correspond to kept dicts
|
| 125 |
+
compacted_messages = []
|
| 126 |
+
# Note: dict_to_idx was used for ID-based matching but content matching is more reliable
|
| 127 |
+
|
| 128 |
+
# Build index set of kept messages
|
| 129 |
+
kept_indices = set()
|
| 130 |
+
for cd in compacted_dicts:
|
| 131 |
+
for i, md in enumerate(msg_dicts):
|
| 132 |
+
# Compare by content since we can't rely on identity
|
| 133 |
+
if (
|
| 134 |
+
md.get("role") == cd.get("role")
|
| 135 |
+
and md.get("content") == cd.get("content")
|
| 136 |
+
and i not in kept_indices
|
| 137 |
+
):
|
| 138 |
+
kept_indices.add(i)
|
| 139 |
+
break
|
| 140 |
+
|
| 141 |
+
compacted_messages = [messages[i] for i in sorted(kept_indices)]
|
| 142 |
+
|
| 143 |
+
return {"llm_input_messages": compacted_messages}
|
| 144 |
+
|
| 145 |
+
return compact_messages
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def create_head_tail_hook(
|
| 149 |
+
head_ratio: float = 0.2,
|
| 150 |
+
token_budget: int = DEFAULT_TOKEN_BUDGET,
|
| 151 |
+
model: str = "gpt-4o",
|
| 152 |
+
):
|
| 153 |
+
"""Alias for create_compaction_hook with head+tail strategy."""
|
| 154 |
+
return create_compaction_hook(
|
| 155 |
+
head_ratio=head_ratio,
|
| 156 |
+
token_budget=token_budget,
|
| 157 |
+
model=model,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def create_sliding_window_hook(
|
| 162 |
+
token_budget: int = DEFAULT_TOKEN_BUDGET,
|
| 163 |
+
model: str = "gpt-4o",
|
| 164 |
+
):
|
| 165 |
+
"""Create a pre-model hook for token-aware sliding window compaction.
|
| 166 |
+
|
| 167 |
+
This hook keeps the system message plus the most recent messages
|
| 168 |
+
that fit within the token budget.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
token_budget: Max tokens for context window
|
| 172 |
+
model: Model name for tokenizer selection
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
A function that can be used as a pre_model_hook in create_react_agent
|
| 176 |
+
|
| 177 |
+
Example:
|
| 178 |
+
hook = create_sliding_window_hook(token_budget=100000)
|
| 179 |
+
graph = create_react_agent(
|
| 180 |
+
model=model,
|
| 181 |
+
tools=tools,
|
| 182 |
+
pre_model_hook=hook,
|
| 183 |
+
)
|
| 184 |
+
"""
|
| 185 |
+
strategy = SlidingWindowStrategy(
|
| 186 |
+
token_budget=token_budget,
|
| 187 |
+
model=model,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
def compact_messages(state: dict[str, Any]) -> dict[str, Any]:
|
| 191 |
+
"""Compact messages using sliding window strategy."""
|
| 192 |
+
messages = state.get("messages", [])
|
| 193 |
+
|
| 194 |
+
if not messages:
|
| 195 |
+
return {"llm_input_messages": messages}
|
| 196 |
+
|
| 197 |
+
# Convert to dict format for strategy
|
| 198 |
+
msg_dicts = [_langchain_msg_to_dict(m) for m in messages]
|
| 199 |
+
|
| 200 |
+
# Apply compaction
|
| 201 |
+
compacted_dicts = strategy.compact(msg_dicts)
|
| 202 |
+
|
| 203 |
+
# Map back to original message objects
|
| 204 |
+
kept_indices = set()
|
| 205 |
+
for cd in compacted_dicts:
|
| 206 |
+
for i, md in enumerate(msg_dicts):
|
| 207 |
+
if (
|
| 208 |
+
md.get("role") == cd.get("role")
|
| 209 |
+
and md.get("content") == cd.get("content")
|
| 210 |
+
and i not in kept_indices
|
| 211 |
+
):
|
| 212 |
+
kept_indices.add(i)
|
| 213 |
+
break
|
| 214 |
+
|
| 215 |
+
compacted_messages = [messages[i] for i in sorted(kept_indices)]
|
| 216 |
|
| 217 |
+
return {"llm_input_messages": compacted_messages}
|
| 218 |
|
| 219 |
return compact_messages
|
src/flow/harness/langgraph/harness.py
CHANGED
|
@@ -10,7 +10,7 @@ import logging
|
|
| 10 |
import uuid
|
| 11 |
from collections.abc import AsyncIterator
|
| 12 |
from pathlib import Path
|
| 13 |
-
from typing import TYPE_CHECKING, Any
|
| 14 |
|
| 15 |
from opentelemetry import trace
|
| 16 |
|
|
@@ -50,6 +50,12 @@ class LangGraphHarness(BaseHarness):
|
|
| 50 |
print(event.type, event.content)
|
| 51 |
"""
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
@classmethod
|
| 54 |
def from_agent(cls, agent: Agent, workspace: Path) -> LangGraphHarness:
|
| 55 |
"""Create a LangGraph harness from an Agent spec.
|
|
@@ -61,11 +67,12 @@ class LangGraphHarness(BaseHarness):
|
|
| 61 |
Returns:
|
| 62 |
Configured LangGraphHarness instance
|
| 63 |
"""
|
|
|
|
|
|
|
|
|
|
| 64 |
from flow.experiments.models import resolve_tools
|
| 65 |
from flow.harness.langgraph.compaction import create_compaction_hook
|
| 66 |
from flow.harness.langgraph.wrappers import build_langgraph_tools
|
| 67 |
-
from langgraph.checkpoint.memory import InMemorySaver
|
| 68 |
-
from langgraph.prebuilt import create_react_agent
|
| 69 |
|
| 70 |
# Build tools (skip sub_agent - MAF-specific)
|
| 71 |
tools_spec = resolve_tools(agent.tools)
|
|
@@ -234,7 +241,7 @@ class LangGraphHarness(BaseHarness):
|
|
| 234 |
mode, data = chunk
|
| 235 |
|
| 236 |
if mode == "messages":
|
| 237 |
-
msg_chunk,
|
| 238 |
|
| 239 |
# Text content
|
| 240 |
if hasattr(msg_chunk, "content") and msg_chunk.content:
|
|
|
|
| 10 |
import uuid
|
| 11 |
from collections.abc import AsyncIterator
|
| 12 |
from pathlib import Path
|
| 13 |
+
from typing import TYPE_CHECKING, Any, ClassVar
|
| 14 |
|
| 15 |
from opentelemetry import trace
|
| 16 |
|
|
|
|
| 50 |
print(event.type, event.content)
|
| 51 |
"""
|
| 52 |
|
| 53 |
+
# Framework metadata
|
| 54 |
+
framework_name: ClassVar[str] = "langgraph"
|
| 55 |
+
framework_label: ClassVar[str] = "LangGraph"
|
| 56 |
+
framework_description: ClassVar[str] = "Graph-based workflows with state management"
|
| 57 |
+
supported_compaction_strategies: ClassVar[list[str]] = ["head_tail", "sliding_window", "none"]
|
| 58 |
+
|
| 59 |
@classmethod
|
| 60 |
def from_agent(cls, agent: Agent, workspace: Path) -> LangGraphHarness:
|
| 61 |
"""Create a LangGraph harness from an Agent spec.
|
|
|
|
| 67 |
Returns:
|
| 68 |
Configured LangGraphHarness instance
|
| 69 |
"""
|
| 70 |
+
from langgraph.checkpoint.memory import InMemorySaver
|
| 71 |
+
from langgraph.prebuilt import create_react_agent
|
| 72 |
+
|
| 73 |
from flow.experiments.models import resolve_tools
|
| 74 |
from flow.harness.langgraph.compaction import create_compaction_hook
|
| 75 |
from flow.harness.langgraph.wrappers import build_langgraph_tools
|
|
|
|
|
|
|
| 76 |
|
| 77 |
# Build tools (skip sub_agent - MAF-specific)
|
| 78 |
tools_spec = resolve_tools(agent.tools)
|
|
|
|
| 241 |
mode, data = chunk
|
| 242 |
|
| 243 |
if mode == "messages":
|
| 244 |
+
msg_chunk, _metadata = data
|
| 245 |
|
| 246 |
# Text content
|
| 247 |
if hasattr(msg_chunk, "content") and msg_chunk.content:
|
src/flow/harness/maf/__init__.py
CHANGED
|
@@ -12,7 +12,7 @@ from flow.harness.registry import register
|
|
| 12 |
register("maf", MAFHarness)
|
| 13 |
|
| 14 |
__all__ = [
|
| 15 |
-
"create_agent",
|
| 16 |
"HeadTailCompactingChatMessageStore",
|
| 17 |
"MAFHarness",
|
|
|
|
| 18 |
]
|
|
|
|
| 12 |
register("maf", MAFHarness)
|
| 13 |
|
| 14 |
__all__ = [
|
|
|
|
| 15 |
"HeadTailCompactingChatMessageStore",
|
| 16 |
"MAFHarness",
|
| 17 |
+
"create_agent",
|
| 18 |
]
|
src/flow/harness/maf/agent.py
CHANGED
|
@@ -148,15 +148,19 @@ def create_agent(
|
|
| 148 |
# Create message store factory if compaction is enabled
|
| 149 |
message_store_factory = None
|
| 150 |
if enable_compaction:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
def create_compacting_store() -> HeadTailCompactingChatMessageStore:
|
| 152 |
return HeadTailCompactingChatMessageStore(
|
| 153 |
-
|
| 154 |
-
tail_size=compaction_tail_size,
|
| 155 |
)
|
| 156 |
|
| 157 |
message_store_factory = create_compacting_store
|
| 158 |
logger.debug(
|
| 159 |
-
f"Message compaction enabled: head={compaction_head_size}, tail={compaction_tail_size}"
|
| 160 |
)
|
| 161 |
|
| 162 |
# Determine if memory is enabled for instructions
|
|
|
|
| 148 |
# Create message store factory if compaction is enabled
|
| 149 |
message_store_factory = None
|
| 150 |
if enable_compaction:
|
| 151 |
+
# Convert head/tail message counts to head_ratio for token-based compaction
|
| 152 |
+
# head_ratio = head_size / (head_size + tail_size)
|
| 153 |
+
total_size = compaction_head_size + compaction_tail_size
|
| 154 |
+
head_ratio = compaction_head_size / total_size if total_size > 0 else 0.2
|
| 155 |
+
|
| 156 |
def create_compacting_store() -> HeadTailCompactingChatMessageStore:
|
| 157 |
return HeadTailCompactingChatMessageStore(
|
| 158 |
+
head_ratio=head_ratio,
|
|
|
|
| 159 |
)
|
| 160 |
|
| 161 |
message_store_factory = create_compacting_store
|
| 162 |
logger.debug(
|
| 163 |
+
f"Message compaction enabled: head={compaction_head_size}, tail={compaction_tail_size}, head_ratio={head_ratio:.2f}"
|
| 164 |
)
|
| 165 |
|
| 166 |
# Determine if memory is enabled for instructions
|
src/flow/harness/maf/harness.py
CHANGED
|
@@ -9,7 +9,7 @@ import logging
|
|
| 9 |
import uuid
|
| 10 |
from collections.abc import AsyncIterator
|
| 11 |
from pathlib import Path
|
| 12 |
-
from typing import TYPE_CHECKING, Any
|
| 13 |
|
| 14 |
from flow.harness.base import BaseHarness, Event, EventType
|
| 15 |
|
|
@@ -67,13 +67,19 @@ class MAFHarness(BaseHarness):
|
|
| 67 |
>>> harness = MAFHarness.from_agent(agent, workspace=Path("/tmp"))
|
| 68 |
"""
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
@classmethod
|
| 71 |
def from_agent(
|
| 72 |
cls,
|
| 73 |
-
agent:
|
| 74 |
workspace: Path,
|
| 75 |
-
llm_config:
|
| 76 |
-
) ->
|
| 77 |
"""Create a MAFHarness from an Agent definition.
|
| 78 |
|
| 79 |
Args:
|
|
@@ -126,7 +132,7 @@ class MAFHarness(BaseHarness):
|
|
| 126 |
|
| 127 |
def __init__(
|
| 128 |
self,
|
| 129 |
-
agent:
|
| 130 |
**create_agent_kwargs: Any,
|
| 131 |
) -> None:
|
| 132 |
"""Initialize the harness.
|
|
|
|
| 9 |
import uuid
|
| 10 |
from collections.abc import AsyncIterator
|
| 11 |
from pathlib import Path
|
| 12 |
+
from typing import TYPE_CHECKING, Any, ClassVar
|
| 13 |
|
| 14 |
from flow.harness.base import BaseHarness, Event, EventType
|
| 15 |
|
|
|
|
| 67 |
>>> harness = MAFHarness.from_agent(agent, workspace=Path("/tmp"))
|
| 68 |
"""
|
| 69 |
|
| 70 |
+
# Framework metadata
|
| 71 |
+
framework_name: ClassVar[str] = "maf"
|
| 72 |
+
framework_label: ClassVar[str] = "Microsoft Agent Framework"
|
| 73 |
+
framework_description: ClassVar[str] = "Default agent implementation with ChatAgent"
|
| 74 |
+
supported_compaction_strategies: ClassVar[list[str]] = ["head_tail", "sliding_window", "none"]
|
| 75 |
+
|
| 76 |
@classmethod
|
| 77 |
def from_agent(
|
| 78 |
cls,
|
| 79 |
+
agent: Agent,
|
| 80 |
workspace: Path,
|
| 81 |
+
llm_config: LLMClientConfig | None = None,
|
| 82 |
+
) -> MAFHarness:
|
| 83 |
"""Create a MAFHarness from an Agent definition.
|
| 84 |
|
| 85 |
Args:
|
|
|
|
| 132 |
|
| 133 |
def __init__(
|
| 134 |
self,
|
| 135 |
+
agent: ChatAgent | None = None,
|
| 136 |
**create_agent_kwargs: Any,
|
| 137 |
) -> None:
|
| 138 |
"""Initialize the harness.
|
src/flow/harness/maf/message_store.py
CHANGED
|
@@ -1,21 +1,82 @@
|
|
| 1 |
-
|
|
|
|
| 2 |
|
| 3 |
-
Provides ChatMessageStoreProtocol implementations
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
from collections.abc import MutableMapping, Sequence
|
| 7 |
from typing import TYPE_CHECKING, Any
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
if TYPE_CHECKING:
|
| 10 |
from agent_framework import ChatMessage
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
|
|
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
| 19 |
|
| 20 |
IMPORTANT: This store preserves full ChatMessage objects including:
|
| 21 |
- FunctionCallContent (tool calls)
|
|
@@ -24,44 +85,44 @@ class HeadTailCompactingChatMessageStore:
|
|
| 24 |
|
| 25 |
This is critical because OpenAI's API requires tool results to immediately
|
| 26 |
follow their corresponding tool calls.
|
| 27 |
-
|
| 28 |
-
The compaction strategy:
|
| 29 |
-
- Keeps the first N messages (task context, initial instructions)
|
| 30 |
-
- Keeps the last M messages (recent work, current state)
|
| 31 |
-
- Drops middle messages to prevent context overflow
|
| 32 |
"""
|
| 33 |
|
| 34 |
def __init__(
|
| 35 |
self,
|
| 36 |
messages: Sequence["ChatMessage"] | None = None,
|
| 37 |
-
|
| 38 |
-
|
|
|
|
| 39 |
) -> None:
|
| 40 |
-
"""Initialize the
|
| 41 |
|
| 42 |
Args:
|
| 43 |
messages: Initial messages to store
|
| 44 |
-
|
| 45 |
-
|
|
|
|
| 46 |
"""
|
| 47 |
-
if
|
| 48 |
-
raise ValueError("
|
| 49 |
-
if
|
| 50 |
-
raise ValueError("
|
| 51 |
-
|
| 52 |
-
self._messages: list[
|
| 53 |
-
self.
|
| 54 |
-
self.
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
@property
|
| 57 |
-
def
|
| 58 |
-
"""
|
| 59 |
-
return self.
|
| 60 |
|
| 61 |
@property
|
| 62 |
-
def
|
| 63 |
-
"""
|
| 64 |
-
return self.
|
| 65 |
|
| 66 |
@property
|
| 67 |
def total_messages(self) -> int:
|
|
@@ -69,16 +130,126 @@ class HeadTailCompactingChatMessageStore:
|
|
| 69 |
return len(self._messages)
|
| 70 |
|
| 71 |
@property
|
| 72 |
-
def
|
| 73 |
-
"""Number of
|
| 74 |
-
|
| 75 |
-
max_kept = self._head_size + self._tail_size
|
| 76 |
-
return min(total, max_kept)
|
| 77 |
|
| 78 |
@property
|
| 79 |
-
def
|
| 80 |
-
"""
|
| 81 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
async def add_messages(self, messages: Sequence["ChatMessage"]) -> None:
|
| 84 |
"""Add messages to the store.
|
|
@@ -91,38 +262,32 @@ class HeadTailCompactingChatMessageStore:
|
|
| 91 |
self._messages.extend(messages)
|
| 92 |
|
| 93 |
async def list_messages(self) -> list["ChatMessage"]:
|
| 94 |
-
"""Get messages with
|
| 95 |
|
| 96 |
-
|
| 97 |
-
|
| 98 |
|
| 99 |
Returns:
|
| 100 |
List of ChatMessage objects after compaction
|
| 101 |
"""
|
| 102 |
-
|
| 103 |
-
max_kept = self._head_size + self._tail_size
|
| 104 |
-
|
| 105 |
-
# No compaction needed
|
| 106 |
-
if total <= max_kept:
|
| 107 |
-
return list(self._messages)
|
| 108 |
-
|
| 109 |
-
# Return head + tail
|
| 110 |
-
head = self._messages[: self._head_size]
|
| 111 |
-
tail = self._messages[-self._tail_size :] if self._tail_size > 0 else []
|
| 112 |
-
|
| 113 |
-
return head + tail
|
| 114 |
|
| 115 |
@classmethod
|
| 116 |
async def deserialize(
|
| 117 |
cls,
|
| 118 |
serialized_store_state: MutableMapping[str, Any],
|
| 119 |
**kwargs: Any,
|
| 120 |
-
) -> "
|
| 121 |
"""Create store from serialized state."""
|
| 122 |
from agent_framework import ChatMessage
|
| 123 |
|
| 124 |
-
|
| 125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
messages_data = serialized_store_state.get("messages", [])
|
| 128 |
messages = [
|
|
@@ -130,7 +295,12 @@ class HeadTailCompactingChatMessageStore:
|
|
| 130 |
for m in messages_data
|
| 131 |
]
|
| 132 |
|
| 133 |
-
return cls(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
async def update_from_state(
|
| 136 |
self,
|
|
@@ -149,10 +319,12 @@ class HeadTailCompactingChatMessageStore:
|
|
| 149 |
for m in messages_data
|
| 150 |
]
|
| 151 |
|
| 152 |
-
if "
|
| 153 |
-
self.
|
| 154 |
-
if "
|
| 155 |
-
self.
|
|
|
|
|
|
|
| 156 |
|
| 157 |
async def serialize(self, **kwargs: Any) -> dict[str, Any]:
|
| 158 |
"""Serialize the store state.
|
|
@@ -161,17 +333,23 @@ class HeadTailCompactingChatMessageStore:
|
|
| 161 |
"""
|
| 162 |
return {
|
| 163 |
"messages": [m.to_dict() for m in self._messages],
|
| 164 |
-
"
|
| 165 |
-
"
|
|
|
|
| 166 |
}
|
| 167 |
|
| 168 |
@property
|
| 169 |
-
def stats(self) -> dict[str,
|
| 170 |
"""Get compaction statistics."""
|
| 171 |
return {
|
| 172 |
"total_messages": self.total_messages,
|
| 173 |
-
"
|
| 174 |
-
"
|
| 175 |
-
"
|
| 176 |
-
"
|
|
|
|
| 177 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft. All rights reserved.
|
| 2 |
+
"""Token-aware message store for Microsoft Agent Framework.
|
| 3 |
|
| 4 |
+
Provides ChatMessageStoreProtocol implementations with token-based compaction
|
| 5 |
+
to ensure safety against large messages that could exceed LLM context limits.
|
| 6 |
"""
|
| 7 |
|
| 8 |
from collections.abc import MutableMapping, Sequence
|
| 9 |
from typing import TYPE_CHECKING, Any
|
| 10 |
|
| 11 |
+
from flow.harness.compaction import HeadTailStrategy
|
| 12 |
+
from flow.harness.compaction.tokenizer import get_encoder
|
| 13 |
+
|
| 14 |
+
_ = (HeadTailStrategy, get_encoder) # Used for external access via this module
|
| 15 |
+
|
| 16 |
if TYPE_CHECKING:
|
| 17 |
from agent_framework import ChatMessage
|
| 18 |
|
| 19 |
+
# Default token budget (safe for GPT-4o, Claude 3.5, etc.)
|
| 20 |
+
DEFAULT_TOKEN_BUDGET = 200_000
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _chat_message_to_dict(msg: "ChatMessage") -> dict[str, Any]:
|
| 24 |
+
"""Convert ChatMessage to dict format for token counting."""
|
| 25 |
+
return msg.to_dict()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _count_chat_message_tokens(msg: "ChatMessage", model: str = "gpt-4o") -> int:
|
| 29 |
+
"""Count tokens in a ChatMessage."""
|
| 30 |
+
msg_dict = _chat_message_to_dict(msg)
|
| 31 |
+
encoder = get_encoder(model)
|
| 32 |
+
|
| 33 |
+
tokens = 0
|
| 34 |
+
# Count role
|
| 35 |
+
if "role" in msg_dict:
|
| 36 |
+
tokens += len(encoder.encode(str(msg_dict["role"])))
|
| 37 |
+
|
| 38 |
+
# Count content
|
| 39 |
+
content = msg_dict.get("content")
|
| 40 |
+
if isinstance(content, str):
|
| 41 |
+
tokens += len(encoder.encode(content))
|
| 42 |
+
elif isinstance(content, list):
|
| 43 |
+
for item in content:
|
| 44 |
+
if isinstance(item, dict):
|
| 45 |
+
if "text" in item:
|
| 46 |
+
tokens += len(encoder.encode(str(item["text"])))
|
| 47 |
+
elif "content" in item:
|
| 48 |
+
tokens += len(encoder.encode(str(item["content"])))
|
| 49 |
+
elif isinstance(item, str):
|
| 50 |
+
tokens += len(encoder.encode(item))
|
| 51 |
+
|
| 52 |
+
# Count tool calls
|
| 53 |
+
tool_calls = msg_dict.get("tool_calls", [])
|
| 54 |
+
for tc in tool_calls:
|
| 55 |
+
if isinstance(tc, dict):
|
| 56 |
+
if "function" in tc:
|
| 57 |
+
func = tc["function"]
|
| 58 |
+
if isinstance(func, dict):
|
| 59 |
+
tokens += len(encoder.encode(func.get("name", "")))
|
| 60 |
+
tokens += len(encoder.encode(func.get("arguments", "")))
|
| 61 |
+
|
| 62 |
+
# Base overhead per message
|
| 63 |
+
tokens += 4
|
| 64 |
+
|
| 65 |
+
return tokens
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class TokenAwareChatMessageStore:
|
| 69 |
+
"""A token-aware message store for Agent Framework ChatMessage.
|
| 70 |
|
| 71 |
+
This store implements ChatMessageStoreProtocol and uses token counting
|
| 72 |
+
to trigger compaction, ensuring safety against large messages that could
|
| 73 |
+
exceed LLM context limits.
|
| 74 |
|
| 75 |
+
The compaction strategy:
|
| 76 |
+
- Keeps head messages (system prompt, initial context) based on head_ratio
|
| 77 |
+
- Keeps tail messages (recent work, current state)
|
| 78 |
+
- Drops middle messages when token count exceeds budget
|
| 79 |
+
- Respects atomic groups (tool calls + results must stay together)
|
| 80 |
|
| 81 |
IMPORTANT: This store preserves full ChatMessage objects including:
|
| 82 |
- FunctionCallContent (tool calls)
|
|
|
|
| 85 |
|
| 86 |
This is critical because OpenAI's API requires tool results to immediately
|
| 87 |
follow their corresponding tool calls.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
"""
|
| 89 |
|
| 90 |
def __init__(
|
| 91 |
self,
|
| 92 |
messages: Sequence["ChatMessage"] | None = None,
|
| 93 |
+
head_ratio: float = 0.2,
|
| 94 |
+
token_budget: int = DEFAULT_TOKEN_BUDGET,
|
| 95 |
+
model: str = "gpt-4o",
|
| 96 |
) -> None:
|
| 97 |
+
"""Initialize the token-aware store.
|
| 98 |
|
| 99 |
Args:
|
| 100 |
messages: Initial messages to store
|
| 101 |
+
head_ratio: Fraction of budget for head messages (0.2 = 20%)
|
| 102 |
+
token_budget: Max tokens before compaction triggers
|
| 103 |
+
model: Model name for tokenizer selection
|
| 104 |
"""
|
| 105 |
+
if head_ratio < 0 or head_ratio > 1:
|
| 106 |
+
raise ValueError("head_ratio must be between 0 and 1")
|
| 107 |
+
if token_budget < 1000:
|
| 108 |
+
raise ValueError("token_budget must be at least 1000")
|
| 109 |
+
|
| 110 |
+
self._messages: list[ChatMessage] = list(messages) if messages else []
|
| 111 |
+
self._head_ratio = head_ratio
|
| 112 |
+
self._token_budget = token_budget
|
| 113 |
+
self._model = model
|
| 114 |
+
self._compaction_count = 0
|
| 115 |
+
self._total_tokens_saved = 0
|
| 116 |
|
| 117 |
@property
|
| 118 |
+
def head_ratio(self) -> float:
|
| 119 |
+
"""Fraction of budget for head messages."""
|
| 120 |
+
return self._head_ratio
|
| 121 |
|
| 122 |
@property
|
| 123 |
+
def token_budget(self) -> int:
|
| 124 |
+
"""Max tokens before compaction triggers."""
|
| 125 |
+
return self._token_budget
|
| 126 |
|
| 127 |
@property
|
| 128 |
def total_messages(self) -> int:
|
|
|
|
| 130 |
return len(self._messages)
|
| 131 |
|
| 132 |
@property
|
| 133 |
+
def compaction_count(self) -> int:
|
| 134 |
+
"""Number of times compaction has been triggered."""
|
| 135 |
+
return self._compaction_count
|
|
|
|
|
|
|
| 136 |
|
| 137 |
@property
|
| 138 |
+
def total_tokens_saved(self) -> int:
|
| 139 |
+
"""Total tokens saved through compaction."""
|
| 140 |
+
return self._total_tokens_saved
|
| 141 |
+
|
| 142 |
+
def _count_tokens(self) -> int:
|
| 143 |
+
"""Count total tokens in all messages."""
|
| 144 |
+
return sum(_count_chat_message_tokens(m, self._model) for m in self._messages)
|
| 145 |
+
|
| 146 |
+
def _find_atomic_groups(
|
| 147 |
+
self, messages: list["ChatMessage"]
|
| 148 |
+
) -> list[tuple[int, ...]]:
|
| 149 |
+
"""Group tool_call messages with their results.
|
| 150 |
+
|
| 151 |
+
OpenAI requires every tool_call to have a corresponding result.
|
| 152 |
+
This ensures we never split a tool call from its results.
|
| 153 |
+
"""
|
| 154 |
+
groups: list[tuple[int, ...]] = []
|
| 155 |
+
i = 0
|
| 156 |
+
|
| 157 |
+
while i < len(messages):
|
| 158 |
+
msg = messages[i]
|
| 159 |
+
msg_dict = _chat_message_to_dict(msg)
|
| 160 |
+
|
| 161 |
+
if msg_dict.get("tool_calls"):
|
| 162 |
+
# This message has tool calls - find all results
|
| 163 |
+
call_ids = {
|
| 164 |
+
tc.get("id") for tc in msg_dict["tool_calls"] if tc.get("id")
|
| 165 |
+
}
|
| 166 |
+
group_indices = [i]
|
| 167 |
+
|
| 168 |
+
# Look ahead for results
|
| 169 |
+
j = i + 1
|
| 170 |
+
while j < len(messages) and call_ids:
|
| 171 |
+
next_dict = _chat_message_to_dict(messages[j])
|
| 172 |
+
if next_dict.get("role") == "tool":
|
| 173 |
+
tool_call_id = next_dict.get("tool_call_id")
|
| 174 |
+
if tool_call_id in call_ids:
|
| 175 |
+
group_indices.append(j)
|
| 176 |
+
call_ids.remove(tool_call_id)
|
| 177 |
+
j += 1
|
| 178 |
+
|
| 179 |
+
groups.append(tuple(group_indices))
|
| 180 |
+
i = max(group_indices) + 1 if group_indices else i + 1
|
| 181 |
+
else:
|
| 182 |
+
groups.append((i,))
|
| 183 |
+
i += 1
|
| 184 |
+
|
| 185 |
+
return groups
|
| 186 |
+
|
| 187 |
+
def _compact_messages(
|
| 188 |
+
self, messages: list["ChatMessage"]
|
| 189 |
+
) -> list["ChatMessage"]:
|
| 190 |
+
"""Apply head+tail compaction to messages."""
|
| 191 |
+
if not messages:
|
| 192 |
+
return messages
|
| 193 |
+
|
| 194 |
+
current_tokens = sum(
|
| 195 |
+
_count_chat_message_tokens(m, self._model) for m in messages
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
if current_tokens <= self._token_budget:
|
| 199 |
+
return messages
|
| 200 |
+
|
| 201 |
+
# COMPACTION NEEDED
|
| 202 |
+
self._compaction_count += 1
|
| 203 |
+
|
| 204 |
+
groups = self._find_atomic_groups(messages)
|
| 205 |
+
head_budget = int(self._token_budget * self._head_ratio)
|
| 206 |
+
tail_budget = self._token_budget - head_budget
|
| 207 |
+
|
| 208 |
+
# Fill head from start
|
| 209 |
+
head_groups: list[tuple[int, ...]] = []
|
| 210 |
+
head_tokens = 0
|
| 211 |
+
|
| 212 |
+
for group in groups:
|
| 213 |
+
group_tokens = sum(
|
| 214 |
+
_count_chat_message_tokens(messages[i], self._model) for i in group
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
if head_tokens + group_tokens <= head_budget:
|
| 218 |
+
head_groups.append(group)
|
| 219 |
+
head_tokens += group_tokens
|
| 220 |
+
else:
|
| 221 |
+
break
|
| 222 |
+
|
| 223 |
+
# Fill tail from end (skip head groups)
|
| 224 |
+
remaining_groups = groups[len(head_groups):]
|
| 225 |
+
tail_groups: list[tuple[int, ...]] = []
|
| 226 |
+
tail_tokens = 0
|
| 227 |
+
|
| 228 |
+
for group in reversed(remaining_groups):
|
| 229 |
+
group_tokens = sum(
|
| 230 |
+
_count_chat_message_tokens(messages[i], self._model) for i in group
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
if tail_tokens + group_tokens <= tail_budget:
|
| 234 |
+
tail_groups.insert(0, group)
|
| 235 |
+
tail_tokens += group_tokens
|
| 236 |
+
else:
|
| 237 |
+
break
|
| 238 |
+
|
| 239 |
+
# Build compacted list
|
| 240 |
+
kept_indices: set[int] = set()
|
| 241 |
+
for group in head_groups + tail_groups:
|
| 242 |
+
kept_indices.update(group)
|
| 243 |
+
|
| 244 |
+
compacted = [messages[i] for i in sorted(kept_indices)]
|
| 245 |
+
|
| 246 |
+
# Track savings
|
| 247 |
+
compacted_tokens = sum(
|
| 248 |
+
_count_chat_message_tokens(m, self._model) for m in compacted
|
| 249 |
+
)
|
| 250 |
+
self._total_tokens_saved += current_tokens - compacted_tokens
|
| 251 |
+
|
| 252 |
+
return compacted
|
| 253 |
|
| 254 |
async def add_messages(self, messages: Sequence["ChatMessage"]) -> None:
|
| 255 |
"""Add messages to the store.
|
|
|
|
| 262 |
self._messages.extend(messages)
|
| 263 |
|
| 264 |
async def list_messages(self) -> list["ChatMessage"]:
|
| 265 |
+
"""Get messages with token-aware compaction applied.
|
| 266 |
|
| 267 |
+
Applies head+tail compaction if total tokens exceed budget.
|
| 268 |
+
Respects atomic groups (tool calls stay with their results).
|
| 269 |
|
| 270 |
Returns:
|
| 271 |
List of ChatMessage objects after compaction
|
| 272 |
"""
|
| 273 |
+
return self._compact_messages(self._messages)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
@classmethod
|
| 276 |
async def deserialize(
|
| 277 |
cls,
|
| 278 |
serialized_store_state: MutableMapping[str, Any],
|
| 279 |
**kwargs: Any,
|
| 280 |
+
) -> "TokenAwareChatMessageStore":
|
| 281 |
"""Create store from serialized state."""
|
| 282 |
from agent_framework import ChatMessage
|
| 283 |
|
| 284 |
+
head_ratio = kwargs.get(
|
| 285 |
+
"head_ratio", serialized_store_state.get("head_ratio", 0.2)
|
| 286 |
+
)
|
| 287 |
+
token_budget = kwargs.get(
|
| 288 |
+
"token_budget", serialized_store_state.get("token_budget", DEFAULT_TOKEN_BUDGET)
|
| 289 |
+
)
|
| 290 |
+
model = kwargs.get("model", serialized_store_state.get("model", "gpt-4o"))
|
| 291 |
|
| 292 |
messages_data = serialized_store_state.get("messages", [])
|
| 293 |
messages = [
|
|
|
|
| 295 |
for m in messages_data
|
| 296 |
]
|
| 297 |
|
| 298 |
+
return cls(
|
| 299 |
+
messages=messages,
|
| 300 |
+
head_ratio=head_ratio,
|
| 301 |
+
token_budget=token_budget,
|
| 302 |
+
model=model,
|
| 303 |
+
)
|
| 304 |
|
| 305 |
async def update_from_state(
|
| 306 |
self,
|
|
|
|
| 319 |
for m in messages_data
|
| 320 |
]
|
| 321 |
|
| 322 |
+
if "head_ratio" in serialized_store_state:
|
| 323 |
+
self._head_ratio = serialized_store_state["head_ratio"]
|
| 324 |
+
if "token_budget" in serialized_store_state:
|
| 325 |
+
self._token_budget = serialized_store_state["token_budget"]
|
| 326 |
+
if "model" in serialized_store_state:
|
| 327 |
+
self._model = serialized_store_state["model"]
|
| 328 |
|
| 329 |
async def serialize(self, **kwargs: Any) -> dict[str, Any]:
|
| 330 |
"""Serialize the store state.
|
|
|
|
| 333 |
"""
|
| 334 |
return {
|
| 335 |
"messages": [m.to_dict() for m in self._messages],
|
| 336 |
+
"head_ratio": self._head_ratio,
|
| 337 |
+
"token_budget": self._token_budget,
|
| 338 |
+
"model": self._model,
|
| 339 |
}
|
| 340 |
|
| 341 |
@property
|
| 342 |
+
def stats(self) -> dict[str, Any]:
|
| 343 |
"""Get compaction statistics."""
|
| 344 |
return {
|
| 345 |
"total_messages": self.total_messages,
|
| 346 |
+
"current_tokens": self._count_tokens(),
|
| 347 |
+
"token_budget": self._token_budget,
|
| 348 |
+
"head_ratio": self._head_ratio,
|
| 349 |
+
"compaction_count": self._compaction_count,
|
| 350 |
+
"total_tokens_saved": self._total_tokens_saved,
|
| 351 |
}
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
# Backwards compatibility alias
|
| 355 |
+
HeadTailCompactingChatMessageStore = TokenAwareChatMessageStore
|
src/flow/harness/maf/tools/__init__.py
CHANGED
|
@@ -6,7 +6,7 @@ the to_maf_tool adapter.
|
|
| 6 |
|
| 7 |
Available tools:
|
| 8 |
- read_file, write_file, edit_file, multi_edit, glob_files, grep, ls
|
| 9 |
-
- bash, check_processes
|
| 10 |
- think, todo_write, todo_read
|
| 11 |
- memory, skills, task
|
| 12 |
- web_search, web_fetch
|
|
@@ -19,28 +19,49 @@ from pathlib import Path
|
|
| 19 |
from typing import Any
|
| 20 |
|
| 21 |
from flow.tools import (
|
| 22 |
-
#
|
| 23 |
-
|
|
|
|
| 24 |
# Execution
|
| 25 |
-
bash,
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
# Memory
|
| 29 |
-
memory,
|
| 30 |
-
|
| 31 |
-
web_search, web_fetch,
|
| 32 |
# Notebooks
|
| 33 |
-
notebook_edit,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
# Skills
|
| 35 |
-
skills,
|
| 36 |
# Sub-agent
|
| 37 |
-
task,
|
| 38 |
-
#
|
| 39 |
-
|
|
|
|
|
|
|
| 40 |
# Adapters
|
| 41 |
to_maf_tool,
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
)
|
| 45 |
|
| 46 |
__all__ = [
|
|
@@ -93,7 +114,6 @@ def build_tools(
|
|
| 93 |
# Execution
|
| 94 |
"bash": bash,
|
| 95 |
"check_processes": check_processes,
|
| 96 |
-
"python_repl": python_repl,
|
| 97 |
# Planning
|
| 98 |
"think": think,
|
| 99 |
"todo_write": todo_write,
|
|
@@ -110,6 +130,11 @@ def build_tools(
|
|
| 110 |
"skills": skills,
|
| 111 |
# Task/sub-agent (default instance)
|
| 112 |
"task": task,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
}
|
| 114 |
|
| 115 |
tools: list[Callable[..., Coroutine[Any, Any, str]]] = []
|
|
@@ -128,10 +153,20 @@ def build_tools(
|
|
| 128 |
tools.append(to_maf_tool(custom_task))
|
| 129 |
elif name == "skills" and config.get("additional_paths"):
|
| 130 |
# Skills with custom paths
|
| 131 |
-
custom_skills = create_skills_tool(
|
| 132 |
-
project_path=Path(config["additional_paths"][0])
|
| 133 |
-
)
|
| 134 |
tools.append(to_maf_tool(custom_skills))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
else:
|
| 136 |
logger.warning(f"Unknown tool name: {name}. Skipping.")
|
| 137 |
|
|
|
|
| 6 |
|
| 7 |
Available tools:
|
| 8 |
- read_file, write_file, edit_file, multi_edit, glob_files, grep, ls
|
| 9 |
+
- bash, check_processes
|
| 10 |
- think, todo_write, todo_read
|
| 11 |
- memory, skills, task
|
| 12 |
- web_search, web_fetch
|
|
|
|
| 19 |
from typing import Any
|
| 20 |
|
| 21 |
from flow.tools import (
|
| 22 |
+
# Base
|
| 23 |
+
Tool,
|
| 24 |
+
Workspace,
|
| 25 |
# Execution
|
| 26 |
+
bash,
|
| 27 |
+
check_processes,
|
| 28 |
+
create_skills_tool,
|
| 29 |
+
# Browsing
|
| 30 |
+
create_smol_web_search_tool,
|
| 31 |
+
create_task_tool,
|
| 32 |
+
create_visit_webpage_tool,
|
| 33 |
+
edit_file,
|
| 34 |
+
glob_files,
|
| 35 |
+
grep,
|
| 36 |
+
ls,
|
| 37 |
# Memory
|
| 38 |
+
memory,
|
| 39 |
+
multi_edit,
|
|
|
|
| 40 |
# Notebooks
|
| 41 |
+
notebook_edit,
|
| 42 |
+
notebook_read,
|
| 43 |
+
# Coding
|
| 44 |
+
read_file,
|
| 45 |
+
# Workspace management
|
| 46 |
+
set_workspace,
|
| 47 |
# Skills
|
| 48 |
+
skills,
|
| 49 |
# Sub-agent
|
| 50 |
+
task,
|
| 51 |
+
# File inspection
|
| 52 |
+
text_inspector,
|
| 53 |
+
# Planning
|
| 54 |
+
think,
|
| 55 |
# Adapters
|
| 56 |
to_maf_tool,
|
| 57 |
+
todo_read,
|
| 58 |
+
todo_write,
|
| 59 |
+
visual_inspector,
|
| 60 |
+
web_fetch,
|
| 61 |
+
# Web
|
| 62 |
+
web_search,
|
| 63 |
+
wikipedia_search,
|
| 64 |
+
write_file,
|
| 65 |
)
|
| 66 |
|
| 67 |
__all__ = [
|
|
|
|
| 114 |
# Execution
|
| 115 |
"bash": bash,
|
| 116 |
"check_processes": check_processes,
|
|
|
|
| 117 |
# Planning
|
| 118 |
"think": think,
|
| 119 |
"todo_write": todo_write,
|
|
|
|
| 130 |
"skills": skills,
|
| 131 |
# Task/sub-agent (default instance)
|
| 132 |
"task": task,
|
| 133 |
+
# Wikipedia search
|
| 134 |
+
"wikipedia_search": wikipedia_search,
|
| 135 |
+
# File inspection tools
|
| 136 |
+
"text_inspector": text_inspector,
|
| 137 |
+
"visual_inspector": visual_inspector,
|
| 138 |
}
|
| 139 |
|
| 140 |
tools: list[Callable[..., Coroutine[Any, Any, str]]] = []
|
|
|
|
| 153 |
tools.append(to_maf_tool(custom_task))
|
| 154 |
elif name == "skills" and config.get("additional_paths"):
|
| 155 |
# Skills with custom paths
|
| 156 |
+
custom_skills = create_skills_tool(project_path=Path(config["additional_paths"][0]))
|
|
|
|
|
|
|
| 157 |
tools.append(to_maf_tool(custom_skills))
|
| 158 |
+
# Web search tool
|
| 159 |
+
elif name == "smol_web_search":
|
| 160 |
+
wst_max_results = config.get("wst_max_results", 10)
|
| 161 |
+
wst_engine = config.get("wst_engine", "duckduckgo")
|
| 162 |
+
custom_smol_web_search = create_smol_web_search_tool(max_results=wst_max_results, engine=wst_engine)
|
| 163 |
+
tools.append(to_maf_tool(custom_smol_web_search))
|
| 164 |
+
|
| 165 |
+
elif name == "visit_webpage":
|
| 166 |
+
vwp_max_output_length = config.get("vwp_max_output_length", 40000)
|
| 167 |
+
custom_visit_webpage = create_visit_webpage_tool(max_output_length=vwp_max_output_length)
|
| 168 |
+
tools.append(to_maf_tool(custom_visit_webpage))
|
| 169 |
+
|
| 170 |
else:
|
| 171 |
logger.warning(f"Unknown tool name: {name}. Skipping.")
|
| 172 |
|
src/flow/harness/maf/wrappers.py
CHANGED
|
@@ -14,8 +14,8 @@ from collections.abc import Callable, Coroutine
|
|
| 14 |
from pathlib import Path
|
| 15 |
from typing import Any
|
| 16 |
|
| 17 |
-
from flow.tools import Tool, to_maf_tool
|
| 18 |
from flow.harness.maf.tools import build_tools as build_maf_tools_impl
|
|
|
|
| 19 |
|
| 20 |
logger = logging.getLogger(__name__)
|
| 21 |
|
|
|
|
| 14 |
from pathlib import Path
|
| 15 |
from typing import Any
|
| 16 |
|
|
|
|
| 17 |
from flow.harness.maf.tools import build_tools as build_maf_tools_impl
|
| 18 |
+
from flow.tools import Tool, to_maf_tool
|
| 19 |
|
| 20 |
logger = logging.getLogger(__name__)
|
| 21 |
|
src/flow/harness/miniagent/__init__.py
CHANGED
|
@@ -51,38 +51,38 @@ MiniAgent's tool loop:
|
|
| 51 |
messages.extend(results) # Next iteration will compact again
|
| 52 |
"""
|
| 53 |
|
| 54 |
-
|
| 55 |
-
from .
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
| 57 |
from .context import (
|
| 58 |
ContextStrategy,
|
| 59 |
-
NoCompactionStrategy,
|
| 60 |
HeadTailStrategy,
|
|
|
|
| 61 |
SlidingWindowStrategy,
|
| 62 |
SummarizationStrategy,
|
| 63 |
)
|
| 64 |
-
from .
|
| 65 |
from .hooks import (
|
| 66 |
-
|
|
|
|
| 67 |
HookEvent,
|
| 68 |
-
|
| 69 |
-
|
|
|
|
| 70 |
PostToolUseEvent,
|
| 71 |
PostToolUseResult,
|
| 72 |
-
PreModelCallEvent,
|
| 73 |
-
PostModelCallEvent,
|
| 74 |
PreCompactEvent,
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
)
|
| 79 |
-
from .instructions import
|
|
|
|
|
|
|
| 80 |
from .workspace import Workspace, get_workspace, set_workspace
|
| 81 |
-
from . import tools
|
| 82 |
-
|
| 83 |
-
# Register with Flow's harness system
|
| 84 |
-
from flow.harness.registry import register
|
| 85 |
-
from .harness import MiniAgentHarness
|
| 86 |
|
| 87 |
register("miniagent", MiniAgentHarness)
|
| 88 |
|
|
|
|
| 51 |
messages.extend(results) # Next iteration will compact again
|
| 52 |
"""
|
| 53 |
|
| 54 |
+
# Register with Flow's harness system
|
| 55 |
+
from flow.harness.registry import register
|
| 56 |
+
|
| 57 |
+
from . import tools
|
| 58 |
+
from .agent import AgentResponse, AgentThread, ChatAgent, StreamEvent, StreamEventType, UsageStats
|
| 59 |
+
from .client import ChatClient, ChatCompletionResult, ClientConfig
|
| 60 |
from .context import (
|
| 61 |
ContextStrategy,
|
|
|
|
| 62 |
HeadTailStrategy,
|
| 63 |
+
NoCompactionStrategy,
|
| 64 |
SlidingWindowStrategy,
|
| 65 |
SummarizationStrategy,
|
| 66 |
)
|
| 67 |
+
from .harness import MiniAgentHarness
|
| 68 |
from .hooks import (
|
| 69 |
+
AgentEndEvent,
|
| 70 |
+
AgentStartEvent,
|
| 71 |
HookEvent,
|
| 72 |
+
Hooks,
|
| 73 |
+
PostCompactEvent,
|
| 74 |
+
PostModelCallEvent,
|
| 75 |
PostToolUseEvent,
|
| 76 |
PostToolUseResult,
|
|
|
|
|
|
|
| 77 |
PreCompactEvent,
|
| 78 |
+
PreModelCallEvent,
|
| 79 |
+
PreToolUseEvent,
|
| 80 |
+
PreToolUseResult,
|
| 81 |
)
|
| 82 |
+
from .instructions import INSTRUCTIONS, get_instructions
|
| 83 |
+
from .messages import ChatMessage, ToolCall, ToolResult
|
| 84 |
+
from .tool import Tool, tool
|
| 85 |
from .workspace import Workspace, get_workspace, set_workspace
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
register("miniagent", MiniAgentHarness)
|
| 88 |
|
src/flow/harness/miniagent/agent.py
CHANGED
|
@@ -5,28 +5,29 @@ The key difference: context strategy is called BEFORE each LLM call in the
|
|
| 5 |
tool loop, and the compacted list continues to the next iteration.
|
| 6 |
"""
|
| 7 |
|
|
|
|
|
|
|
| 8 |
from dataclasses import dataclass, field
|
| 9 |
-
from typing import Any, AsyncGenerator
|
| 10 |
from enum import Enum
|
| 11 |
-
import
|
| 12 |
|
| 13 |
-
from .messages import ChatMessage, ToolCall
|
| 14 |
-
from .tool import Tool
|
| 15 |
from .client import ChatClient, ChatCompletionResult
|
| 16 |
from .context import ContextStrategy, NoCompactionStrategy
|
| 17 |
from .hooks import (
|
|
|
|
|
|
|
| 18 |
Hooks,
|
| 19 |
-
|
| 20 |
-
|
| 21 |
PostToolUseEvent,
|
| 22 |
PostToolUseResult,
|
| 23 |
-
PreModelCallEvent,
|
| 24 |
-
PostModelCallEvent,
|
| 25 |
PreCompactEvent,
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
)
|
|
|
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
class StreamEventType(str, Enum):
|
|
@@ -452,7 +453,7 @@ class ChatAgent:
|
|
| 452 |
try:
|
| 453 |
return await tool.invoke(**arguments)
|
| 454 |
except Exception as e:
|
| 455 |
-
return f"Error executing {name}: {
|
| 456 |
|
| 457 |
# === Hook emission methods ===
|
| 458 |
|
|
|
|
| 5 |
tool loop, and the compacted list continues to the next iteration.
|
| 6 |
"""
|
| 7 |
|
| 8 |
+
import json
|
| 9 |
+
from collections.abc import AsyncGenerator
|
| 10 |
from dataclasses import dataclass, field
|
|
|
|
| 11 |
from enum import Enum
|
| 12 |
+
from typing import Any
|
| 13 |
|
|
|
|
|
|
|
| 14 |
from .client import ChatClient, ChatCompletionResult
|
| 15 |
from .context import ContextStrategy, NoCompactionStrategy
|
| 16 |
from .hooks import (
|
| 17 |
+
AgentEndEvent,
|
| 18 |
+
AgentStartEvent,
|
| 19 |
Hooks,
|
| 20 |
+
PostCompactEvent,
|
| 21 |
+
PostModelCallEvent,
|
| 22 |
PostToolUseEvent,
|
| 23 |
PostToolUseResult,
|
|
|
|
|
|
|
| 24 |
PreCompactEvent,
|
| 25 |
+
PreModelCallEvent,
|
| 26 |
+
PreToolUseEvent,
|
| 27 |
+
PreToolUseResult,
|
| 28 |
)
|
| 29 |
+
from .messages import ChatMessage, ToolCall
|
| 30 |
+
from .tool import Tool
|
| 31 |
|
| 32 |
|
| 33 |
class StreamEventType(str, Enum):
|
|
|
|
| 453 |
try:
|
| 454 |
return await tool.invoke(**arguments)
|
| 455 |
except Exception as e:
|
| 456 |
+
return f"Error executing {name}: {e!s}"
|
| 457 |
|
| 458 |
# === Hook emission methods ===
|
| 459 |
|
src/flow/harness/miniagent/client.py
CHANGED
|
@@ -4,9 +4,9 @@ Provides a unified interface for both OpenAI and Azure OpenAI APIs.
|
|
| 4 |
Auto-detects configuration from environment variables.
|
| 5 |
"""
|
| 6 |
|
|
|
|
| 7 |
from dataclasses import dataclass
|
| 8 |
from typing import Any
|
| 9 |
-
import os
|
| 10 |
|
| 11 |
# Load .env file if present (override=True to prefer .env over shell env)
|
| 12 |
try:
|
|
@@ -88,7 +88,7 @@ class ChatClient:
|
|
| 88 |
def _create_client(self):
|
| 89 |
"""Create the appropriate async client."""
|
| 90 |
try:
|
| 91 |
-
from openai import
|
| 92 |
except ImportError:
|
| 93 |
raise ImportError(
|
| 94 |
"openai package is required. Install with: pip install openai"
|
|
|
|
| 4 |
Auto-detects configuration from environment variables.
|
| 5 |
"""
|
| 6 |
|
| 7 |
+
import os
|
| 8 |
from dataclasses import dataclass
|
| 9 |
from typing import Any
|
|
|
|
| 10 |
|
| 11 |
# Load .env file if present (override=True to prefer .env over shell env)
|
| 12 |
try:
|
|
|
|
| 88 |
def _create_client(self):
|
| 89 |
"""Create the appropriate async client."""
|
| 90 |
try:
|
| 91 |
+
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
| 92 |
except ImportError:
|
| 93 |
raise ImportError(
|
| 94 |
"openai package is required. Install with: pip install openai"
|
src/flow/harness/miniagent/context.py
CHANGED
|
@@ -5,8 +5,10 @@ Strategies are called BEFORE each LLM call, and the returned (potentially
|
|
| 5 |
compacted) list continues to the next iteration.
|
| 6 |
"""
|
| 7 |
|
|
|
|
| 8 |
from dataclasses import dataclass, field
|
| 9 |
-
from typing import
|
|
|
|
| 10 |
import tiktoken
|
| 11 |
|
| 12 |
from .messages import ChatMessage
|
|
@@ -471,13 +473,12 @@ SUMMARY:"""
|
|
| 471 |
if tc.name in ("read_file", "Read"):
|
| 472 |
# Try to extract path from arguments
|
| 473 |
try:
|
| 474 |
-
import json
|
| 475 |
args = json.loads(tc.arguments)
|
| 476 |
path = args.get("path") or args.get("file_path") or args.get("filename")
|
| 477 |
if path:
|
| 478 |
files.append(path)
|
| 479 |
-
except:
|
| 480 |
-
pass
|
| 481 |
return list(dict.fromkeys(files)) # Remove duplicates, preserve order
|
| 482 |
|
| 483 |
def _extract_key_info(self, messages: list[ChatMessage]) -> str:
|
|
|
|
| 5 |
compacted) list continues to the next iteration.
|
| 6 |
"""
|
| 7 |
|
| 8 |
+
import json
|
| 9 |
from dataclasses import dataclass, field
|
| 10 |
+
from typing import Any, Protocol
|
| 11 |
+
|
| 12 |
import tiktoken
|
| 13 |
|
| 14 |
from .messages import ChatMessage
|
|
|
|
| 473 |
if tc.name in ("read_file", "Read"):
|
| 474 |
# Try to extract path from arguments
|
| 475 |
try:
|
|
|
|
| 476 |
args = json.loads(tc.arguments)
|
| 477 |
path = args.get("path") or args.get("file_path") or args.get("filename")
|
| 478 |
if path:
|
| 479 |
files.append(path)
|
| 480 |
+
except (json.JSONDecodeError, KeyError, TypeError):
|
| 481 |
+
pass # Skip malformed tool calls
|
| 482 |
return list(dict.fromkeys(files)) # Remove duplicates, preserve order
|
| 483 |
|
| 484 |
def _extract_key_info(self, messages: list[ChatMessage]) -> str:
|
src/flow/harness/miniagent/harness.py
CHANGED
|
@@ -10,7 +10,7 @@ import logging
|
|
| 10 |
import uuid
|
| 11 |
from collections.abc import AsyncIterator
|
| 12 |
from pathlib import Path
|
| 13 |
-
from typing import TYPE_CHECKING, Any
|
| 14 |
|
| 15 |
from flow.harness.base import BaseHarness, Event, EventType
|
| 16 |
|
|
@@ -18,19 +18,21 @@ if TYPE_CHECKING:
|
|
| 18 |
from flow.experiments.models import Agent
|
| 19 |
from flow.llm import LLMClientConfig
|
| 20 |
|
| 21 |
-
from .
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
from .context import (
|
| 23 |
ContextStrategy,
|
| 24 |
-
NoCompactionStrategy,
|
| 25 |
HeadTailStrategy,
|
|
|
|
| 26 |
SlidingWindowStrategy,
|
| 27 |
SummarizationStrategy,
|
| 28 |
)
|
| 29 |
-
from .client import ChatClient
|
| 30 |
-
from .otel import enable_instrumentation
|
| 31 |
from .instructions import get_instructions
|
| 32 |
-
|
| 33 |
-
from flow.tools import Tool
|
| 34 |
|
| 35 |
logger = logging.getLogger(__name__)
|
| 36 |
|
|
@@ -61,13 +63,21 @@ class MiniAgentHarness(BaseHarness):
|
|
| 61 |
... print(event)
|
| 62 |
"""
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
@classmethod
|
| 65 |
def from_agent(
|
| 66 |
cls,
|
| 67 |
-
agent:
|
| 68 |
workspace: Path,
|
| 69 |
-
llm_config:
|
| 70 |
-
) ->
|
| 71 |
"""Create a MiniAgentHarness from an Agent definition.
|
| 72 |
|
| 73 |
Args:
|
|
@@ -105,13 +115,13 @@ class MiniAgentHarness(BaseHarness):
|
|
| 105 |
from .otel import create_otel_hooks
|
| 106 |
otel_hooks = create_otel_hooks(model=config.model)
|
| 107 |
|
| 108 |
-
# Resolve instructions: explicit > preset > default "
|
| 109 |
if agent.instructions:
|
| 110 |
instructions = agent.instructions
|
| 111 |
elif agent.instructions_preset:
|
| 112 |
instructions = get_instructions(agent.instructions_preset)
|
| 113 |
else:
|
| 114 |
-
instructions = get_instructions("
|
| 115 |
|
| 116 |
chat_agent = ChatAgent(
|
| 117 |
client=chat_client,
|
|
@@ -126,8 +136,8 @@ class MiniAgentHarness(BaseHarness):
|
|
| 126 |
|
| 127 |
@classmethod
|
| 128 |
def _create_client_config_from_llm_config(
|
| 129 |
-
cls, llm_config:
|
| 130 |
-
) ->
|
| 131 |
"""Create MiniAgent ClientConfig from Flow LLMClientConfig.
|
| 132 |
|
| 133 |
Args:
|
|
@@ -137,6 +147,7 @@ class MiniAgentHarness(BaseHarness):
|
|
| 137 |
MiniAgent ClientConfig
|
| 138 |
"""
|
| 139 |
from flow.llm import LLMProvider
|
|
|
|
| 140 |
from .client import ClientConfig
|
| 141 |
|
| 142 |
match llm_config.provider:
|
|
@@ -177,7 +188,7 @@ class MiniAgentHarness(BaseHarness):
|
|
| 177 |
@classmethod
|
| 178 |
def _create_client_config_from_dict(
|
| 179 |
cls, llm_config: dict[str, Any]
|
| 180 |
-
) ->
|
| 181 |
"""Create ClientConfig from agent's llm_config dict.
|
| 182 |
|
| 183 |
Supports a simple format for YAML configuration:
|
|
@@ -197,6 +208,7 @@ class MiniAgentHarness(BaseHarness):
|
|
| 197 |
ValueError: If required fields or env vars are missing
|
| 198 |
"""
|
| 199 |
import os
|
|
|
|
| 200 |
from .client import ClientConfig
|
| 201 |
|
| 202 |
provider = llm_config.get("provider", "").lower()
|
|
@@ -273,7 +285,7 @@ class MiniAgentHarness(BaseHarness):
|
|
| 273 |
)
|
| 274 |
|
| 275 |
@classmethod
|
| 276 |
-
def _create_context_strategy(cls, agent:
|
| 277 |
"""Map Flow's CompactionConfig to MiniAgent's ContextStrategy."""
|
| 278 |
config = agent.compaction
|
| 279 |
|
|
@@ -328,24 +340,37 @@ class MiniAgentHarness(BaseHarness):
|
|
| 328 |
"""
|
| 329 |
# Import shared tools
|
| 330 |
from flow.tools import (
|
| 331 |
-
|
| 332 |
-
read_file, write_file, edit_file, multi_edit, glob_files, grep, ls,
|
| 333 |
# Execution
|
| 334 |
-
bash,
|
| 335 |
-
|
| 336 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
# Memory
|
| 338 |
-
memory,
|
| 339 |
-
|
| 340 |
-
web_search, web_fetch,
|
| 341 |
# Notebooks
|
| 342 |
-
notebook_edit,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
# Skills
|
| 344 |
-
skills,
|
| 345 |
# Sub-agent
|
| 346 |
-
task,
|
| 347 |
-
#
|
| 348 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
)
|
| 350 |
|
| 351 |
# Set workspace for tools that need it (memory, todos, etc.)
|
|
@@ -364,7 +389,6 @@ class MiniAgentHarness(BaseHarness):
|
|
| 364 |
# Execution
|
| 365 |
"bash": bash,
|
| 366 |
"check_processes": check_processes,
|
| 367 |
-
"python_repl": python_repl,
|
| 368 |
# Planning
|
| 369 |
"think": think,
|
| 370 |
"todo_write": todo_write,
|
|
|
|
| 10 |
import uuid
|
| 11 |
from collections.abc import AsyncIterator
|
| 12 |
from pathlib import Path
|
| 13 |
+
from typing import TYPE_CHECKING, Any, ClassVar
|
| 14 |
|
| 15 |
from flow.harness.base import BaseHarness, Event, EventType
|
| 16 |
|
|
|
|
| 18 |
from flow.experiments.models import Agent
|
| 19 |
from flow.llm import LLMClientConfig
|
| 20 |
|
| 21 |
+
from .client import ClientConfig
|
| 22 |
+
|
| 23 |
+
from flow.tools import Tool
|
| 24 |
+
|
| 25 |
+
from .agent import AgentThread, ChatAgent, StreamEvent, StreamEventType
|
| 26 |
+
from .client import ChatClient
|
| 27 |
from .context import (
|
| 28 |
ContextStrategy,
|
|
|
|
| 29 |
HeadTailStrategy,
|
| 30 |
+
NoCompactionStrategy,
|
| 31 |
SlidingWindowStrategy,
|
| 32 |
SummarizationStrategy,
|
| 33 |
)
|
|
|
|
|
|
|
| 34 |
from .instructions import get_instructions
|
| 35 |
+
from .otel import enable_instrumentation
|
|
|
|
| 36 |
|
| 37 |
logger = logging.getLogger(__name__)
|
| 38 |
|
|
|
|
| 63 |
... print(event)
|
| 64 |
"""
|
| 65 |
|
| 66 |
+
# Framework metadata
|
| 67 |
+
framework_name: ClassVar[str] = "miniagent"
|
| 68 |
+
framework_label: ClassVar[str] = "MiniAgent"
|
| 69 |
+
framework_description: ClassVar[str] = "Token-aware context management with advanced compaction"
|
| 70 |
+
supported_compaction_strategies: ClassVar[list[str]] = [
|
| 71 |
+
"head_tail", "sliding_window", "summarization", "none"
|
| 72 |
+
]
|
| 73 |
+
|
| 74 |
@classmethod
|
| 75 |
def from_agent(
|
| 76 |
cls,
|
| 77 |
+
agent: Agent,
|
| 78 |
workspace: Path,
|
| 79 |
+
llm_config: LLMClientConfig | None = None,
|
| 80 |
+
) -> MiniAgentHarness:
|
| 81 |
"""Create a MiniAgentHarness from an Agent definition.
|
| 82 |
|
| 83 |
Args:
|
|
|
|
| 115 |
from .otel import create_otel_hooks
|
| 116 |
otel_hooks = create_otel_hooks(model=config.model)
|
| 117 |
|
| 118 |
+
# Resolve instructions: explicit > preset > default "general"
|
| 119 |
if agent.instructions:
|
| 120 |
instructions = agent.instructions
|
| 121 |
elif agent.instructions_preset:
|
| 122 |
instructions = get_instructions(agent.instructions_preset)
|
| 123 |
else:
|
| 124 |
+
instructions = get_instructions("general")
|
| 125 |
|
| 126 |
chat_agent = ChatAgent(
|
| 127 |
client=chat_client,
|
|
|
|
| 136 |
|
| 137 |
@classmethod
|
| 138 |
def _create_client_config_from_llm_config(
|
| 139 |
+
cls, llm_config: LLMClientConfig
|
| 140 |
+
) -> ClientConfig:
|
| 141 |
"""Create MiniAgent ClientConfig from Flow LLMClientConfig.
|
| 142 |
|
| 143 |
Args:
|
|
|
|
| 147 |
MiniAgent ClientConfig
|
| 148 |
"""
|
| 149 |
from flow.llm import LLMProvider
|
| 150 |
+
|
| 151 |
from .client import ClientConfig
|
| 152 |
|
| 153 |
match llm_config.provider:
|
|
|
|
| 188 |
@classmethod
|
| 189 |
def _create_client_config_from_dict(
|
| 190 |
cls, llm_config: dict[str, Any]
|
| 191 |
+
) -> ClientConfig:
|
| 192 |
"""Create ClientConfig from agent's llm_config dict.
|
| 193 |
|
| 194 |
Supports a simple format for YAML configuration:
|
|
|
|
| 208 |
ValueError: If required fields or env vars are missing
|
| 209 |
"""
|
| 210 |
import os
|
| 211 |
+
|
| 212 |
from .client import ClientConfig
|
| 213 |
|
| 214 |
provider = llm_config.get("provider", "").lower()
|
|
|
|
| 285 |
)
|
| 286 |
|
| 287 |
@classmethod
|
| 288 |
+
def _create_context_strategy(cls, agent: Agent) -> ContextStrategy:
|
| 289 |
"""Map Flow's CompactionConfig to MiniAgent's ContextStrategy."""
|
| 290 |
config = agent.compaction
|
| 291 |
|
|
|
|
| 340 |
"""
|
| 341 |
# Import shared tools
|
| 342 |
from flow.tools import (
|
| 343 |
+
Workspace,
|
|
|
|
| 344 |
# Execution
|
| 345 |
+
bash,
|
| 346 |
+
check_processes,
|
| 347 |
+
create_task_tool,
|
| 348 |
+
edit_file,
|
| 349 |
+
glob_files,
|
| 350 |
+
grep,
|
| 351 |
+
ls,
|
| 352 |
# Memory
|
| 353 |
+
memory,
|
| 354 |
+
multi_edit,
|
|
|
|
| 355 |
# Notebooks
|
| 356 |
+
notebook_edit,
|
| 357 |
+
notebook_read,
|
| 358 |
+
# Coding
|
| 359 |
+
read_file,
|
| 360 |
+
# Workspace management
|
| 361 |
+
set_workspace,
|
| 362 |
# Skills
|
| 363 |
+
skills,
|
| 364 |
# Sub-agent
|
| 365 |
+
task,
|
| 366 |
+
# Planning
|
| 367 |
+
think,
|
| 368 |
+
todo_read,
|
| 369 |
+
todo_write,
|
| 370 |
+
web_fetch,
|
| 371 |
+
# Web
|
| 372 |
+
web_search,
|
| 373 |
+
write_file,
|
| 374 |
)
|
| 375 |
|
| 376 |
# Set workspace for tools that need it (memory, todos, etc.)
|
|
|
|
| 389 |
# Execution
|
| 390 |
"bash": bash,
|
| 391 |
"check_processes": check_processes,
|
|
|
|
| 392 |
# Planning
|
| 393 |
"think": think,
|
| 394 |
"todo_write": todo_write,
|
src/flow/harness/miniagent/hooks.py
CHANGED
|
@@ -6,9 +6,10 @@ Inspired by Claude Agent SDK's hooks system. Hooks allow applications to:
|
|
| 6 |
- Control: Block tool calls, stop execution
|
| 7 |
"""
|
| 8 |
|
|
|
|
| 9 |
from dataclasses import dataclass, field
|
| 10 |
-
from typing import Any, Callable, Awaitable, Literal
|
| 11 |
from enum import Enum
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
class HookEvent(str, Enum):
|
|
|
|
| 6 |
- Control: Block tool calls, stop execution
|
| 7 |
"""
|
| 8 |
|
| 9 |
+
from collections.abc import Awaitable, Callable
|
| 10 |
from dataclasses import dataclass, field
|
|
|
|
| 11 |
from enum import Enum
|
| 12 |
+
from typing import Any, Literal
|
| 13 |
|
| 14 |
|
| 15 |
class HookEvent(str, Enum):
|
src/flow/harness/miniagent/instructions.py
CHANGED
|
@@ -89,7 +89,7 @@ Never assume libraries exist. Check package.json, requirements.txt, or equivalen
|
|
| 89 |
# Preset-specific instructions
|
| 90 |
# =============================================================================
|
| 91 |
|
| 92 |
-
|
| 93 |
|
| 94 |
## Response Style
|
| 95 |
|
|
@@ -111,7 +111,11 @@ CODING_AGENT_INSTRUCTIONS = f"""You are an expert coding assistant. You help use
|
|
| 111 |
- **ls**: List directory contents.
|
| 112 |
|
| 113 |
### Execution
|
| 114 |
-
- **bash**: Execute shell commands. Use for git, running tests, installing packages.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
### Planning
|
| 117 |
- **think**: Reason through complex problems before acting.
|
|
@@ -120,68 +124,25 @@ CODING_AGENT_INSTRUCTIONS = f"""You are an expert coding assistant. You help use
|
|
| 120 |
|
| 121 |
### Delegation (if available)
|
| 122 |
- **task**: Delegate complex sub-tasks to a specialist agent with isolated context.
|
| 123 |
-
{EFFICIENCY_INSTRUCTIONS}
|
| 124 |
-
{BEST_PRACTICES_INSTRUCTIONS}
|
| 125 |
-
"""
|
| 126 |
-
|
| 127 |
-
RESEARCH_AGENT_INSTRUCTIONS = f"""You are a research assistant. You help users find information, synthesize knowledge, and answer questions.
|
| 128 |
|
| 129 |
-
##
|
| 130 |
-
|
| 131 |
-
- Be thorough in research, concise in presentation.
|
| 132 |
-
- Cite sources with URLs when reporting findings.
|
| 133 |
-
- Synthesize information - don't just list results.
|
| 134 |
-
{TASK_COMPLETION_INSTRUCTIONS}
|
| 135 |
-
## Tools
|
| 136 |
|
| 137 |
-
###
|
| 138 |
-
-
|
| 139 |
-
- **web_fetch**: Fetch and read web page contents.
|
| 140 |
-
|
| 141 |
-
### Planning
|
| 142 |
-
- **think**: Work through complex questions step by step.
|
| 143 |
-
- **todo_write**: Track research progress on multi-part questions.
|
| 144 |
|
| 145 |
-
##
|
|
|
|
| 146 |
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
3. Synthesize findings into a coherent answer
|
| 150 |
-
4. If initial searches don't answer the question, refine and search again
|
| 151 |
|
| 152 |
-
##
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
4. **Keep going**: If first searches don't work, try different queries.
|
| 158 |
-
5. **Acknowledge uncertainty**: If information is unclear, say so.
|
| 159 |
-
"""
|
| 160 |
-
|
| 161 |
-
EXPLORE_AGENT_INSTRUCTIONS = f"""You are a codebase exploration specialist. Your job is to quickly find and understand code.
|
| 162 |
-
|
| 163 |
-
## Response Style
|
| 164 |
-
|
| 165 |
-
- Be concise. Your response goes to another agent, so be self-contained.
|
| 166 |
-
- Include file paths and line numbers in findings.
|
| 167 |
-
- Summarize what you found, don't dump raw content.
|
| 168 |
-
{TASK_COMPLETION_INSTRUCTIONS}
|
| 169 |
-
## Tools
|
| 170 |
-
|
| 171 |
-
- **read_file**: Read file contents (read fully, don't chunk).
|
| 172 |
-
- **glob_files**: Find files by pattern.
|
| 173 |
-
- **grep**: Search file contents with regex.
|
| 174 |
-
- **ls**: List directory contents.
|
| 175 |
-
- **think**: Reason about what you're finding.
|
| 176 |
-
- **todo_write**: Track exploration progress for complex searches.
|
| 177 |
{EFFICIENCY_INSTRUCTIONS}
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
1. **Start broad, then narrow**: Use glob/grep to find candidates, then batch-read.
|
| 181 |
-
2. **Be efficient**: Don't read files you don't need.
|
| 182 |
-
3. **Report clearly**: Include file paths and line numbers.
|
| 183 |
-
4. **Keep searching**: If first attempt doesn't find what's needed, try different patterns.
|
| 184 |
-
5. **Summarize**: Be self-contained for the calling agent.
|
| 185 |
"""
|
| 186 |
|
| 187 |
# =============================================================================
|
|
@@ -189,19 +150,17 @@ EXPLORE_AGENT_INSTRUCTIONS = f"""You are a codebase exploration specialist. Your
|
|
| 189 |
# =============================================================================
|
| 190 |
|
| 191 |
INSTRUCTIONS = {
|
| 192 |
-
"
|
| 193 |
-
"research": RESEARCH_AGENT_INSTRUCTIONS,
|
| 194 |
-
"explore": EXPLORE_AGENT_INSTRUCTIONS,
|
| 195 |
}
|
| 196 |
|
| 197 |
|
| 198 |
-
def get_instructions(preset: str = "
|
| 199 |
"""Get system instructions by preset name.
|
| 200 |
|
| 201 |
Args:
|
| 202 |
-
preset:
|
| 203 |
|
| 204 |
Returns:
|
| 205 |
System instruction string
|
| 206 |
"""
|
| 207 |
-
return INSTRUCTIONS.get(preset,
|
|
|
|
| 89 |
# Preset-specific instructions
|
| 90 |
# =============================================================================
|
| 91 |
|
| 92 |
+
GENERAL_AGENT_INSTRUCTIONS = f"""You are a helpful general-purpose agent. You solve tasks by combining reasoning, code execution, file operations, and web research as needed.
|
| 93 |
|
| 94 |
## Response Style
|
| 95 |
|
|
|
|
| 111 |
- **ls**: List directory contents.
|
| 112 |
|
| 113 |
### Execution
|
| 114 |
+
- **bash**: Execute shell commands. Use for git, running tests, installing packages, and running Python code (e.g., `python -c "print(2+2)"` or `python script.py`).
|
| 115 |
+
|
| 116 |
+
### Web Research
|
| 117 |
+
- **web_search**: Search the web for current information, facts, or data.
|
| 118 |
+
- **web_fetch**: Fetch and read web page contents.
|
| 119 |
|
| 120 |
### Planning
|
| 121 |
- **think**: Reason through complex problems before acting.
|
|
|
|
| 124 |
|
| 125 |
### Delegation (if available)
|
| 126 |
- **task**: Delegate complex sub-tasks to a specialist agent with isolated context.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
+
## Problem-Solving Strategy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
+
### For calculations and math problems
|
| 131 |
+
Write and execute code rather than computing in your head. Use `bash` with `python -c "..."` or write a script with `write_file` then run it with `bash` — this avoids arithmetic errors.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
+
### For questions requiring specific facts or current information
|
| 134 |
+
Use `web_search` to find authoritative sources, then `web_fetch` to read them. Do NOT guess or rely on memory for factual claims like dates, numbers, names, or statistics.
|
| 135 |
|
| 136 |
+
### For complex tasks (data processing, file analysis, media)
|
| 137 |
+
Write code to solve them. Install required libraries with `bash` (e.g., `pip install ...`). Break the problem into steps and verify each step works before moving on.
|
|
|
|
|
|
|
| 138 |
|
| 139 |
+
### Anti-hallucination
|
| 140 |
+
- NEVER guess factual answers. If you don't know, search or compute.
|
| 141 |
+
- When a task asks for a specific number or name, verify it from a source.
|
| 142 |
+
- If web search fails, try different queries before giving up.
|
| 143 |
+
- State your confidence level when reporting facts.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
{EFFICIENCY_INSTRUCTIONS}
|
| 145 |
+
{BEST_PRACTICES_INSTRUCTIONS}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
"""
|
| 147 |
|
| 148 |
# =============================================================================
|
|
|
|
| 150 |
# =============================================================================
|
| 151 |
|
| 152 |
INSTRUCTIONS = {
|
| 153 |
+
"general": GENERAL_AGENT_INSTRUCTIONS,
|
|
|
|
|
|
|
| 154 |
}
|
| 155 |
|
| 156 |
|
| 157 |
+
def get_instructions(preset: str = "general") -> str:
|
| 158 |
"""Get system instructions by preset name.
|
| 159 |
|
| 160 |
Args:
|
| 161 |
+
preset: Preset name (default: 'general')
|
| 162 |
|
| 163 |
Returns:
|
| 164 |
System instruction string
|
| 165 |
"""
|
| 166 |
+
return INSTRUCTIONS.get(preset, GENERAL_AGENT_INSTRUCTIONS)
|
src/flow/harness/miniagent/otel.py
CHANGED
|
@@ -16,12 +16,12 @@ from opentelemetry import trace
|
|
| 16 |
if TYPE_CHECKING:
|
| 17 |
from .hooks import (
|
| 18 |
Hooks,
|
| 19 |
-
PreModelCallEvent,
|
| 20 |
PostModelCallEvent,
|
| 21 |
-
PreToolUseEvent,
|
| 22 |
-
PreToolUseResult,
|
| 23 |
PostToolUseEvent,
|
| 24 |
PostToolUseResult,
|
|
|
|
|
|
|
|
|
|
| 25 |
)
|
| 26 |
|
| 27 |
__all__ = ["GenAIAttr", "create_otel_hooks", "enable_instrumentation"]
|
|
@@ -157,7 +157,7 @@ class OTelHooks:
|
|
| 157 |
self._llm_spans: dict[int, trace.Span] = {} # iteration -> span
|
| 158 |
self._tool_spans: dict[str, trace.Span] = {} # call_id -> span
|
| 159 |
|
| 160 |
-
async def on_pre_model_call(self, event:
|
| 161 |
"""Start an LLM span before model call.
|
| 162 |
|
| 163 |
Args:
|
|
@@ -166,7 +166,7 @@ class OTelHooks:
|
|
| 166 |
span = start_llm_span(model=self.model)
|
| 167 |
self._llm_spans[event.iteration] = span
|
| 168 |
|
| 169 |
-
async def on_post_model_call(self, event:
|
| 170 |
"""End the LLM span after model call.
|
| 171 |
|
| 172 |
Args:
|
|
@@ -178,7 +178,7 @@ class OTelHooks:
|
|
| 178 |
output_tokens = event.usage.get("output_tokens", 0)
|
| 179 |
end_llm_span(span, input_tokens, output_tokens)
|
| 180 |
|
| 181 |
-
async def on_pre_tool_use(self, event:
|
| 182 |
"""Start a tool span before tool execution.
|
| 183 |
|
| 184 |
Args:
|
|
@@ -191,7 +191,7 @@ class OTelHooks:
|
|
| 191 |
self._tool_spans[event.tool_call_id] = span
|
| 192 |
return None # Don't block
|
| 193 |
|
| 194 |
-
async def on_post_tool_use(self, event:
|
| 195 |
"""End the tool span after tool execution.
|
| 196 |
|
| 197 |
Args:
|
|
@@ -230,7 +230,7 @@ def enable_instrumentation() -> None:
|
|
| 230 |
_instrumentation_enabled = True
|
| 231 |
|
| 232 |
|
| 233 |
-
def create_otel_hooks(model: str = "gpt-4o") ->
|
| 234 |
"""Create a Hooks instance with OTEL instrumentation.
|
| 235 |
|
| 236 |
This is the main entry point for adding OTEL tracing to a MiniAgent.
|
|
|
|
| 16 |
if TYPE_CHECKING:
|
| 17 |
from .hooks import (
|
| 18 |
Hooks,
|
|
|
|
| 19 |
PostModelCallEvent,
|
|
|
|
|
|
|
| 20 |
PostToolUseEvent,
|
| 21 |
PostToolUseResult,
|
| 22 |
+
PreModelCallEvent,
|
| 23 |
+
PreToolUseEvent,
|
| 24 |
+
PreToolUseResult,
|
| 25 |
)
|
| 26 |
|
| 27 |
__all__ = ["GenAIAttr", "create_otel_hooks", "enable_instrumentation"]
|
|
|
|
| 157 |
self._llm_spans: dict[int, trace.Span] = {} # iteration -> span
|
| 158 |
self._tool_spans: dict[str, trace.Span] = {} # call_id -> span
|
| 159 |
|
| 160 |
+
async def on_pre_model_call(self, event: PreModelCallEvent) -> None:
|
| 161 |
"""Start an LLM span before model call.
|
| 162 |
|
| 163 |
Args:
|
|
|
|
| 166 |
span = start_llm_span(model=self.model)
|
| 167 |
self._llm_spans[event.iteration] = span
|
| 168 |
|
| 169 |
+
async def on_post_model_call(self, event: PostModelCallEvent) -> None:
|
| 170 |
"""End the LLM span after model call.
|
| 171 |
|
| 172 |
Args:
|
|
|
|
| 178 |
output_tokens = event.usage.get("output_tokens", 0)
|
| 179 |
end_llm_span(span, input_tokens, output_tokens)
|
| 180 |
|
| 181 |
+
async def on_pre_tool_use(self, event: PreToolUseEvent) -> PreToolUseResult | None:
|
| 182 |
"""Start a tool span before tool execution.
|
| 183 |
|
| 184 |
Args:
|
|
|
|
| 191 |
self._tool_spans[event.tool_call_id] = span
|
| 192 |
return None # Don't block
|
| 193 |
|
| 194 |
+
async def on_post_tool_use(self, event: PostToolUseEvent) -> PostToolUseResult | None:
|
| 195 |
"""End the tool span after tool execution.
|
| 196 |
|
| 197 |
Args:
|
|
|
|
| 230 |
_instrumentation_enabled = True
|
| 231 |
|
| 232 |
|
| 233 |
+
def create_otel_hooks(model: str = "gpt-4o") -> Hooks:
|
| 234 |
"""Create a Hooks instance with OTEL instrumentation.
|
| 235 |
|
| 236 |
This is the main entry point for adding OTEL tracing to a MiniAgent.
|
src/flow/harness/miniagent/tool.py
CHANGED
|
@@ -3,9 +3,10 @@
|
|
| 3 |
Provides a simple way to define tools that can be called by the LLM.
|
| 4 |
"""
|
| 5 |
|
| 6 |
-
from dataclasses import dataclass
|
| 7 |
-
from typing import Any, Callable, Literal, get_type_hints, get_origin, get_args, Annotated
|
| 8 |
import inspect
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
@dataclass
|
|
@@ -43,7 +44,7 @@ class Tool:
|
|
| 43 |
result = await result
|
| 44 |
return str(result) if not isinstance(result, str) else result
|
| 45 |
except Exception as e:
|
| 46 |
-
return f"Error executing {self.name}: {
|
| 47 |
|
| 48 |
|
| 49 |
def _python_type_to_json_schema(py_type: Any) -> dict[str, Any]:
|
|
@@ -53,13 +54,13 @@ def _python_type_to_json_schema(py_type: Any) -> dict[str, Any]:
|
|
| 53 |
return {"type": "null"}
|
| 54 |
|
| 55 |
# Handle basic types
|
| 56 |
-
if py_type
|
| 57 |
return {"type": "string"}
|
| 58 |
-
if py_type
|
| 59 |
return {"type": "integer"}
|
| 60 |
-
if py_type
|
| 61 |
return {"type": "number"}
|
| 62 |
-
if py_type
|
| 63 |
return {"type": "boolean"}
|
| 64 |
|
| 65 |
# Handle dict without type args
|
|
|
|
| 3 |
Provides a simple way to define tools that can be called by the LLM.
|
| 4 |
"""
|
| 5 |
|
|
|
|
|
|
|
| 6 |
import inspect
|
| 7 |
+
from collections.abc import Callable
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Annotated, Any, Literal, get_args, get_origin, get_type_hints
|
| 10 |
|
| 11 |
|
| 12 |
@dataclass
|
|
|
|
| 44 |
result = await result
|
| 45 |
return str(result) if not isinstance(result, str) else result
|
| 46 |
except Exception as e:
|
| 47 |
+
return f"Error executing {self.name}: {e!s}"
|
| 48 |
|
| 49 |
|
| 50 |
def _python_type_to_json_schema(py_type: Any) -> dict[str, Any]:
|
|
|
|
| 54 |
return {"type": "null"}
|
| 55 |
|
| 56 |
# Handle basic types
|
| 57 |
+
if py_type is str:
|
| 58 |
return {"type": "string"}
|
| 59 |
+
if py_type is int:
|
| 60 |
return {"type": "integer"}
|
| 61 |
+
if py_type is float:
|
| 62 |
return {"type": "number"}
|
| 63 |
+
if py_type is bool:
|
| 64 |
return {"type": "boolean"}
|
| 65 |
|
| 66 |
# Handle dict without type args
|
src/flow/harness/miniagent/tools/__init__.py
CHANGED
|
@@ -33,51 +33,51 @@ Example:
|
|
| 33 |
from flow.tools import (
|
| 34 |
# Base
|
| 35 |
Tool,
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
edit_file,
|
| 40 |
-
multi_edit,
|
| 41 |
glob_files,
|
| 42 |
grep,
|
| 43 |
ls,
|
|
|
|
|
|
|
|
|
|
| 44 |
# Notebook operations
|
| 45 |
notebook_edit,
|
| 46 |
notebook_read,
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
# Planning and reasoning
|
| 52 |
think,
|
| 53 |
-
todo_write,
|
| 54 |
todo_read,
|
|
|
|
|
|
|
| 55 |
# Web operations
|
| 56 |
web_search,
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
create_memory_tool,
|
| 61 |
-
# Skills
|
| 62 |
-
skills,
|
| 63 |
-
create_skills_tool,
|
| 64 |
-
# Sub-agent
|
| 65 |
-
task,
|
| 66 |
-
create_task_tool,
|
| 67 |
-
# Presets
|
| 68 |
-
coding_tools,
|
| 69 |
-
planning_tools,
|
| 70 |
web_tools as research_tools,
|
| 71 |
-
notebook_tools,
|
| 72 |
-
all_tools,
|
| 73 |
)
|
| 74 |
|
| 75 |
-
# Compatibility: reset_todos from planning module
|
| 76 |
-
from flow.tools.planning import reset_todos, get_todos
|
| 77 |
-
|
| 78 |
# Compatibility: reset_memory from memory module
|
| 79 |
from flow.tools.memory import reset_memory
|
| 80 |
|
|
|
|
|
|
|
| 81 |
|
| 82 |
__all__ = [
|
| 83 |
# Base
|
|
@@ -102,7 +102,6 @@ __all__ = [
|
|
| 102 |
# Execution
|
| 103 |
"bash",
|
| 104 |
"check_processes",
|
| 105 |
-
"python_repl",
|
| 106 |
# Planning
|
| 107 |
"think",
|
| 108 |
"todo_write",
|
|
|
|
| 33 |
from flow.tools import (
|
| 34 |
# Base
|
| 35 |
Tool,
|
| 36 |
+
all_tools,
|
| 37 |
+
# Execution
|
| 38 |
+
bash,
|
| 39 |
+
check_processes,
|
| 40 |
+
# Presets
|
| 41 |
+
coding_tools,
|
| 42 |
+
create_memory_tool,
|
| 43 |
+
create_skills_tool,
|
| 44 |
+
create_task_tool,
|
| 45 |
edit_file,
|
|
|
|
| 46 |
glob_files,
|
| 47 |
grep,
|
| 48 |
ls,
|
| 49 |
+
# Memory
|
| 50 |
+
memory,
|
| 51 |
+
multi_edit,
|
| 52 |
# Notebook operations
|
| 53 |
notebook_edit,
|
| 54 |
notebook_read,
|
| 55 |
+
notebook_tools,
|
| 56 |
+
planning_tools,
|
| 57 |
+
# File operations
|
| 58 |
+
read_file,
|
| 59 |
+
# Skills
|
| 60 |
+
skills,
|
| 61 |
+
# Sub-agent
|
| 62 |
+
task,
|
| 63 |
# Planning and reasoning
|
| 64 |
think,
|
|
|
|
| 65 |
todo_read,
|
| 66 |
+
todo_write,
|
| 67 |
+
web_fetch,
|
| 68 |
# Web operations
|
| 69 |
web_search,
|
| 70 |
+
write_file,
|
| 71 |
+
)
|
| 72 |
+
from flow.tools import (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
web_tools as research_tools,
|
|
|
|
|
|
|
| 74 |
)
|
| 75 |
|
|
|
|
|
|
|
|
|
|
| 76 |
# Compatibility: reset_memory from memory module
|
| 77 |
from flow.tools.memory import reset_memory
|
| 78 |
|
| 79 |
+
# Compatibility: reset_todos from planning module
|
| 80 |
+
from flow.tools.planning import get_todos, reset_todos
|
| 81 |
|
| 82 |
__all__ = [
|
| 83 |
# Base
|
|
|
|
| 102 |
# Execution
|
| 103 |
"bash",
|
| 104 |
"check_processes",
|
|
|
|
| 105 |
# Planning
|
| 106 |
"think",
|
| 107 |
"todo_write",
|
src/flow/harness/miniagent/workspace.py
CHANGED
|
@@ -31,6 +31,7 @@ Usage:
|
|
| 31 |
ws.memory_dir # /path/to/project/.miniagent/memory
|
| 32 |
"""
|
| 33 |
|
|
|
|
| 34 |
import json
|
| 35 |
from pathlib import Path
|
| 36 |
from typing import Any
|
|
@@ -97,7 +98,7 @@ class Workspace:
|
|
| 97 |
try:
|
| 98 |
with open(self.todos_file) as f:
|
| 99 |
return json.load(f) # type: ignore[no-any-return]
|
| 100 |
-
except (json.JSONDecodeError
|
| 101 |
return []
|
| 102 |
|
| 103 |
def save_todos(self, todos: list[dict[str, Any]]) -> None:
|
|
@@ -118,7 +119,7 @@ class Workspace:
|
|
| 118 |
try:
|
| 119 |
with open(filepath) as f:
|
| 120 |
memories.append(json.load(f))
|
| 121 |
-
except (json.JSONDecodeError
|
| 122 |
continue
|
| 123 |
return memories
|
| 124 |
|
|
@@ -130,7 +131,7 @@ class Workspace:
|
|
| 130 |
try:
|
| 131 |
with open(filepath) as f:
|
| 132 |
return json.load(f) # type: ignore[no-any-return]
|
| 133 |
-
except (json.JSONDecodeError
|
| 134 |
return None
|
| 135 |
|
| 136 |
def save_memory(self, memory_id: str, data: dict[str, Any]) -> None:
|
|
@@ -157,7 +158,7 @@ class Workspace:
|
|
| 157 |
try:
|
| 158 |
with open(self.config_file) as f:
|
| 159 |
return json.load(f)
|
| 160 |
-
except (json.JSONDecodeError
|
| 161 |
return {}
|
| 162 |
|
| 163 |
def save_config(self, config: dict[str, Any]) -> None:
|
|
@@ -170,29 +171,31 @@ class Workspace:
|
|
| 170 |
return f"Workspace({self._root})"
|
| 171 |
|
| 172 |
|
| 173 |
-
#
|
| 174 |
-
|
|
|
|
|
|
|
| 175 |
|
| 176 |
|
| 177 |
def get_workspace() -> Workspace:
|
| 178 |
-
"""Get the
|
| 179 |
-
|
| 180 |
-
if
|
| 181 |
-
|
| 182 |
-
|
|
|
|
| 183 |
|
| 184 |
|
| 185 |
def set_workspace(workspace: Workspace | str | Path) -> Workspace:
|
| 186 |
-
"""Set the
|
| 187 |
-
global _default_workspace
|
| 188 |
if isinstance(workspace, Workspace):
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
|
|
|
| 193 |
|
| 194 |
|
| 195 |
def reset_workspace() -> None:
|
| 196 |
-
"""Reset
|
| 197 |
-
|
| 198 |
-
_default_workspace = None
|
|
|
|
| 31 |
ws.memory_dir # /path/to/project/.miniagent/memory
|
| 32 |
"""
|
| 33 |
|
| 34 |
+
import contextvars
|
| 35 |
import json
|
| 36 |
from pathlib import Path
|
| 37 |
from typing import Any
|
|
|
|
| 98 |
try:
|
| 99 |
with open(self.todos_file) as f:
|
| 100 |
return json.load(f) # type: ignore[no-any-return]
|
| 101 |
+
except (OSError, json.JSONDecodeError):
|
| 102 |
return []
|
| 103 |
|
| 104 |
def save_todos(self, todos: list[dict[str, Any]]) -> None:
|
|
|
|
| 119 |
try:
|
| 120 |
with open(filepath) as f:
|
| 121 |
memories.append(json.load(f))
|
| 122 |
+
except (OSError, json.JSONDecodeError):
|
| 123 |
continue
|
| 124 |
return memories
|
| 125 |
|
|
|
|
| 131 |
try:
|
| 132 |
with open(filepath) as f:
|
| 133 |
return json.load(f) # type: ignore[no-any-return]
|
| 134 |
+
except (OSError, json.JSONDecodeError):
|
| 135 |
return None
|
| 136 |
|
| 137 |
def save_memory(self, memory_id: str, data: dict[str, Any]) -> None:
|
|
|
|
| 158 |
try:
|
| 159 |
with open(self.config_file) as f:
|
| 160 |
return json.load(f)
|
| 161 |
+
except (OSError, json.JSONDecodeError):
|
| 162 |
return {}
|
| 163 |
|
| 164 |
def save_config(self, config: dict[str, Any]) -> None:
|
|
|
|
| 171 |
return f"Workspace({self._root})"
|
| 172 |
|
| 173 |
|
| 174 |
+
# Per-task workspace via contextvars (safe for concurrent async tasks).
|
| 175 |
+
_workspace_var: contextvars.ContextVar[Workspace | None] = contextvars.ContextVar(
|
| 176 |
+
"miniagent_workspace", default=None
|
| 177 |
+
)
|
| 178 |
|
| 179 |
|
| 180 |
def get_workspace() -> Workspace:
|
| 181 |
+
"""Get the current workspace (creates from cwd if not set)."""
|
| 182 |
+
ws = _workspace_var.get()
|
| 183 |
+
if ws is None:
|
| 184 |
+
ws = Workspace()
|
| 185 |
+
_workspace_var.set(ws)
|
| 186 |
+
return ws
|
| 187 |
|
| 188 |
|
| 189 |
def set_workspace(workspace: Workspace | str | Path) -> Workspace:
|
| 190 |
+
"""Set the workspace for the current async task."""
|
|
|
|
| 191 |
if isinstance(workspace, Workspace):
|
| 192 |
+
_workspace_var.set(workspace)
|
| 193 |
+
return workspace
|
| 194 |
+
ws = Workspace(workspace)
|
| 195 |
+
_workspace_var.set(ws)
|
| 196 |
+
return ws
|
| 197 |
|
| 198 |
|
| 199 |
def reset_workspace() -> None:
|
| 200 |
+
"""Reset workspace for the current context (for testing)."""
|
| 201 |
+
_workspace_var.set(None)
|
|
|
src/flow/harness/registry.py
CHANGED
|
@@ -14,10 +14,10 @@ if TYPE_CHECKING:
|
|
| 14 |
from flow.harness.base import BaseHarness
|
| 15 |
from flow.llm import LLMClientConfig
|
| 16 |
|
| 17 |
-
_HARNESSES: dict[str, type[
|
| 18 |
|
| 19 |
|
| 20 |
-
def register(name: str, harness_class: type[
|
| 21 |
"""Register a harness class for a framework.
|
| 22 |
|
| 23 |
Args:
|
|
@@ -27,7 +27,7 @@ def register(name: str, harness_class: type["BaseHarness"]) -> None:
|
|
| 27 |
_HARNESSES[name] = harness_class
|
| 28 |
|
| 29 |
|
| 30 |
-
def get_harness_class(name: str) -> type[
|
| 31 |
"""Get harness class by framework name.
|
| 32 |
|
| 33 |
Args:
|
|
@@ -46,10 +46,10 @@ def get_harness_class(name: str) -> type["BaseHarness"]:
|
|
| 46 |
|
| 47 |
|
| 48 |
def create_harness(
|
| 49 |
-
agent:
|
| 50 |
workspace: Path,
|
| 51 |
-
llm_config:
|
| 52 |
-
) ->
|
| 53 |
"""Create a harness from an Agent spec.
|
| 54 |
|
| 55 |
This is the main entry point for creating harnesses. It looks up
|
|
@@ -71,6 +71,29 @@ def create_harness(
|
|
| 71 |
return harness_class.from_agent(agent, workspace, llm_config=llm_config)
|
| 72 |
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
def available_frameworks() -> list[str]:
|
| 75 |
"""Get list of available framework names.
|
| 76 |
|
|
@@ -78,3 +101,17 @@ def available_frameworks() -> list[str]:
|
|
| 78 |
List of registered framework names
|
| 79 |
"""
|
| 80 |
return list(_HARNESSES.keys())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
from flow.harness.base import BaseHarness
|
| 15 |
from flow.llm import LLMClientConfig
|
| 16 |
|
| 17 |
+
_HARNESSES: dict[str, type[BaseHarness]] = {}
|
| 18 |
|
| 19 |
|
| 20 |
+
def register(name: str, harness_class: type[BaseHarness]) -> None:
|
| 21 |
"""Register a harness class for a framework.
|
| 22 |
|
| 23 |
Args:
|
|
|
|
| 27 |
_HARNESSES[name] = harness_class
|
| 28 |
|
| 29 |
|
| 30 |
+
def get_harness_class(name: str) -> type[BaseHarness]:
|
| 31 |
"""Get harness class by framework name.
|
| 32 |
|
| 33 |
Args:
|
|
|
|
| 46 |
|
| 47 |
|
| 48 |
def create_harness(
|
| 49 |
+
agent: Agent,
|
| 50 |
workspace: Path,
|
| 51 |
+
llm_config: LLMClientConfig | None = None,
|
| 52 |
+
) -> BaseHarness:
|
| 53 |
"""Create a harness from an Agent spec.
|
| 54 |
|
| 55 |
This is the main entry point for creating harnesses. It looks up
|
|
|
|
| 71 |
return harness_class.from_agent(agent, workspace, llm_config=llm_config)
|
| 72 |
|
| 73 |
|
| 74 |
+
def ensure_harnesses_registered() -> None:
|
| 75 |
+
"""Ensure all built-in harnesses are registered.
|
| 76 |
+
|
| 77 |
+
Safe to call multiple times — only imports once.
|
| 78 |
+
"""
|
| 79 |
+
if _HARNESSES:
|
| 80 |
+
return
|
| 81 |
+
|
| 82 |
+
# Import harness modules to trigger their self-registration
|
| 83 |
+
import flow.harness.maf as _maf
|
| 84 |
+
import flow.harness.miniagent as _miniagent
|
| 85 |
+
|
| 86 |
+
_ = (_maf, _miniagent)
|
| 87 |
+
|
| 88 |
+
# LangGraph is optional
|
| 89 |
+
try:
|
| 90 |
+
import flow.harness.langgraph as _lg
|
| 91 |
+
|
| 92 |
+
_ = _lg # type: ignore[assignment]
|
| 93 |
+
except ImportError:
|
| 94 |
+
pass
|
| 95 |
+
|
| 96 |
+
|
| 97 |
def available_frameworks() -> list[str]:
|
| 98 |
"""Get list of available framework names.
|
| 99 |
|
|
|
|
| 101 |
List of registered framework names
|
| 102 |
"""
|
| 103 |
return list(_HARNESSES.keys())
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def get_registered_harnesses() -> dict[str, type[BaseHarness]]:
|
| 107 |
+
"""Get all registered harness classes.
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
Dict mapping framework names to harness classes.
|
| 111 |
+
Each harness class has metadata attributes:
|
| 112 |
+
- framework_name: Unique identifier
|
| 113 |
+
- framework_label: Human-readable name
|
| 114 |
+
- framework_description: Short description
|
| 115 |
+
- supported_compaction_strategies: List of supported strategy names
|
| 116 |
+
"""
|
| 117 |
+
return dict(_HARNESSES)
|