| """
|
| Agent runner that takes an AgentSpec from LLM-Agent-Factory and runs it
|
| as a single agent via the DiMAS framework (GraphBuilder + MACPRunner).
|
|
|
| The agent answers benchmark questions using its persona/description.
|
| Tracks token usage and latency.
|
| """
|
|
|
| import time
|
| from collections.abc import Callable
|
| from dataclasses import dataclass
|
| from typing import TypeVar
|
|
|
| from openai import OpenAI
|
|
|
| from experiments.benchmark_data import BenchmarkSample
|
|
|
|
|
|
|
| T = TypeVar("T")
|
|
|
|
|
| _RETRYABLE_EXCEPTIONS: tuple[type[BaseException], ...] = (
|
| ConnectionError,
|
| TimeoutError,
|
| OSError,
|
| )
|
|
|
| try:
|
| from openai import APIConnectionError, APITimeoutError, InternalServerError, RateLimitError
|
|
|
| _RETRYABLE_EXCEPTIONS = (
|
| *_RETRYABLE_EXCEPTIONS,
|
| APIConnectionError,
|
| APITimeoutError,
|
| InternalServerError,
|
| RateLimitError,
|
| )
|
| except ImportError:
|
| pass
|
|
|
|
|
| try:
|
| import httpx
|
|
|
| _RETRYABLE_EXCEPTIONS = (*_RETRYABLE_EXCEPTIONS, httpx.ConnectError, httpx.ReadTimeout, httpx.RemoteProtocolError)
|
| except ImportError:
|
| pass
|
|
|
|
|
| def retry_on_network_error[T](
|
| fn: Callable[..., T],
|
| *args,
|
| max_retries: int = 60,
|
| initial_wait: float = 5.0,
|
| max_wait: float = 60.0,
|
| backoff_factor: float = 1.5,
|
| **kwargs,
|
| ) -> T:
|
| """
|
| Call *fn* and retry on transient network / server errors.
|
|
|
| Waits with exponential backoff between retries. Prints a message
|
| so the user knows the runner is waiting for connectivity.
|
|
|
| Args:
|
| fn: Callable to execute.
|
| max_retries: How many times to retry (default 60 ≈ up to ~30 min).
|
| initial_wait: Seconds to wait after the first failure.
|
| max_wait: Cap on the wait between retries.
|
| backoff_factor: Multiplier for the wait after each failure.
|
|
|
| """
|
| wait = initial_wait
|
| for attempt in range(1, max_retries + 1):
|
| try:
|
| return fn(*args, **kwargs)
|
| except _RETRYABLE_EXCEPTIONS:
|
| if attempt == max_retries:
|
| raise
|
| time.sleep(wait)
|
| wait = min(wait * backoff_factor, max_wait)
|
|
|
| return fn(*args, **kwargs)
|
|
|
|
|
| @dataclass
|
| class AgentAnswer:
|
| """Result of an agent answering a benchmark question."""
|
|
|
| sample_id: str
|
| predicted_answer: str
|
| correct_answer: str
|
| is_correct: bool
|
|
|
| prompt_tokens: int = 0
|
| completion_tokens: int = 0
|
| total_tokens: int = 0
|
|
|
| gen_prompt_tokens: int = 0
|
| gen_completion_tokens: int = 0
|
| gen_total_tokens: int = 0
|
|
|
| retrieval_time: float = 0.0
|
| execution_time: float = 0.0
|
| latency_seconds: float = 0.0
|
|
|
| agent_id: str = ""
|
| agent_display_name: str = ""
|
| error: str | None = None
|
|
|
|
|
| def _extract_answer(response_text: str, sample: BenchmarkSample) -> str:
|
| """Extract the answer from model response."""
|
| text = response_text.strip()
|
|
|
|
|
| if sample.dataset_name in ("mmlu", "bigbench") and sample.choices:
|
| chr(64 + len(sample.choices)) if sample.choices else "D"
|
| valid_letters = [chr(65 + i) for i in range(len(sample.choices))] if sample.choices else ["A", "B", "C", "D"]
|
|
|
|
|
| for line in text.split("\n"):
|
| line = line.strip().upper()
|
| if line in valid_letters:
|
| return line
|
|
|
| for letter in valid_letters:
|
| if line.startswith((f"{letter}.", f"({letter})")) or line == f"**{letter}**":
|
| return letter
|
|
|
|
|
| if text and text[0].upper() in valid_letters:
|
| return text[0].upper()
|
|
|
|
|
| for ch in text.upper():
|
| if ch in valid_letters:
|
| return ch
|
|
|
|
|
|
|
| lower = text.lower()
|
| for prefix in ["the answer is ", "answer: ", "answer is "]:
|
| if prefix in lower:
|
| idx = lower.index(prefix) + len(prefix)
|
| return text[idx:].strip().rstrip(".")
|
|
|
| return text
|
|
|
|
|
| def _check_correct(predicted: str, correct: str, dataset_name: str) -> bool:
|
| """Check if the predicted answer matches the correct answer."""
|
| pred = predicted.strip().lower()
|
| corr = correct.strip().lower()
|
|
|
| if dataset_name == "mmlu":
|
|
|
| return pred[:1] == corr[:1]
|
|
|
| if dataset_name == "bigbench":
|
|
|
| if len(corr) == 1 and corr.isalpha():
|
| return pred[:1] == corr[:1]
|
|
|
| if pred == corr:
|
| return True
|
| return bool(corr in pred or pred in corr)
|
|
|
|
|
| if pred == corr:
|
| return True
|
| if corr in pred or pred in corr:
|
| return True
|
|
|
| if corr in ("true", "false") and pred in ("true", "false"):
|
| return pred == corr
|
|
|
| if corr.startswith("(") and corr.endswith(")"):
|
| inner = corr[1:-1]
|
| return pred == inner or pred.startswith(inner)
|
|
|
| return False
|
|
|
|
|
| def _build_dimas_answer_prompt(agent_spec: dict, question: str) -> str:
|
| """
|
| Build the instruction that DiMAS will use as the task query.
|
|
|
| DiMAS sends the task query + agent persona/description to the LLM.
|
| We embed the benchmark question into the task query.
|
| """
|
| return (
|
| "Answer the following question.\n"
|
| "For multiple-choice questions respond with ONLY the letter (A, B, C, or D).\n"
|
| "For free-form questions give a concise answer.\n"
|
| "Do NOT explain your reasoning.\n\n"
|
| f"{question}"
|
| )
|
|
|
|
|
| def run_agent_on_sample(
|
| agent_spec: dict,
|
| sample: BenchmarkSample,
|
| client: OpenAI,
|
| model: str = "gpt-oss",
|
| temperature: float = 0.1,
|
| max_tokens: int = 256,
|
| timeout: int = 120,
|
| ) -> AgentAnswer:
|
| """
|
| Run a single agent on a single benchmark sample via DiMAS framework.
|
|
|
| Creates a single-agent RoleGraph using GraphBuilder, then executes it
|
| with MACPRunner. The agent's persona/description come from the agent_spec
|
| generated by retrieval/RAG/AutoGen.
|
|
|
| Args:
|
| agent_spec: Agent specification dict (from retrieval, RAG, or AutoGen).
|
| sample: The benchmark question.
|
| client: OpenAI client configured for the API.
|
| model: Model name.
|
| temperature: Sampling temperature.
|
| max_tokens: Max tokens for response.
|
| timeout: Request timeout.
|
|
|
| Returns:
|
| AgentAnswer with results and metrics.
|
|
|
| """
|
| from rustworkx_framework.builder.graph_builder import BuilderConfig, GraphBuilder
|
| from rustworkx_framework.execution.runner import MACPRunner, RunnerConfig
|
|
|
| agent_id = agent_spec.get("agent_id", "agent")
|
| display_name = agent_spec.get("display_name", "AI Assistant")
|
| persona = agent_spec.get("persona", "")
|
| description = agent_spec.get("description", "")
|
| tools = agent_spec.get("tools", [])
|
|
|
|
|
| task_query = _build_dimas_answer_prompt(agent_spec, sample.question)
|
|
|
|
|
| exec_tokens = {"prompt": 0, "completion": 0, "total": 0}
|
|
|
| def llm_caller(prompt: str) -> str:
|
| """LLM caller for DiMAS MACPRunner that tracks token usage."""
|
|
|
| def _call():
|
| return client.chat.completions.create(
|
| model=model,
|
| messages=[{"role": "user", "content": prompt}],
|
| temperature=temperature,
|
| max_tokens=max_tokens,
|
| )
|
|
|
| response = retry_on_network_error(_call)
|
|
|
| usage = response.usage
|
| if usage:
|
| exec_tokens["prompt"] += usage.prompt_tokens
|
| exec_tokens["completion"] += usage.completion_tokens
|
| exec_tokens["total"] += usage.total_tokens
|
|
|
| content = response.choices[0].message.content or ""
|
| if not content:
|
| msg = response.choices[0].message
|
| if hasattr(msg, "reasoning_content") and msg.reasoning_content:
|
| content = msg.reasoning_content
|
| return content
|
|
|
| t0 = time.perf_counter()
|
| try:
|
|
|
| builder_config = BuilderConfig(
|
| include_task_node=True,
|
| validate=True,
|
| )
|
| builder = GraphBuilder(builder_config)
|
|
|
|
|
| builder.add_task(
|
| task_id="__task__",
|
| query=task_query,
|
| description="Benchmark question to answer",
|
| )
|
|
|
|
|
| builder.add_agent(
|
| agent_id=agent_id,
|
| display_name=display_name,
|
| persona=persona,
|
| description=description,
|
| llm_backbone=model,
|
| base_url=str(client.base_url),
|
| api_key=client.api_key,
|
| temperature=temperature,
|
| max_tokens=max_tokens,
|
| tools=tools if isinstance(tools, list) else [],
|
| )
|
|
|
|
|
| builder.connect_task_to_agents(agent_ids=[agent_id], bidirectional=False)
|
|
|
|
|
| graph = builder.build()
|
|
|
|
|
| runner = MACPRunner(
|
| llm_caller=llm_caller,
|
| config=RunnerConfig(
|
| timeout=float(timeout),
|
| adaptive=False,
|
| update_states=True,
|
| broadcast_task_to_all=True,
|
| ),
|
| )
|
|
|
| result = runner.run_round(graph, final_agent_id=agent_id)
|
|
|
| execution_time = time.perf_counter() - t0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| dimas_total_tokens = result.total_tokens
|
| dimas_total_time = result.total_time
|
|
|
|
|
| final_prompt = exec_tokens["prompt"]
|
| final_completion = exec_tokens["completion"]
|
| final_total = exec_tokens["total"]
|
| if final_total == 0 and dimas_total_tokens > 0:
|
|
|
| final_total = dimas_total_tokens
|
|
|
|
|
| if dimas_total_time > 0:
|
| execution_time = min(execution_time, dimas_total_time + 0.01)
|
|
|
|
|
| response_text = result.final_answer or ""
|
| if not response_text and result.messages:
|
|
|
| response_text = result.messages.get(agent_id, "")
|
|
|
| predicted = _extract_answer(response_text, sample)
|
| is_correct = _check_correct(predicted, sample.correct_answer, sample.dataset_name)
|
|
|
| return AgentAnswer(
|
| sample_id=sample.sample_id,
|
| predicted_answer=predicted,
|
| correct_answer=sample.correct_answer,
|
| is_correct=is_correct,
|
| prompt_tokens=final_prompt,
|
| completion_tokens=final_completion,
|
| total_tokens=final_total,
|
| execution_time=execution_time,
|
| latency_seconds=execution_time,
|
| agent_id=agent_id,
|
| agent_display_name=display_name,
|
| )
|
|
|
| except Exception as e:
|
| execution_time = time.perf_counter() - t0
|
| return AgentAnswer(
|
| sample_id=sample.sample_id,
|
| predicted_answer="",
|
| correct_answer=sample.correct_answer,
|
| is_correct=False,
|
| prompt_tokens=exec_tokens["prompt"],
|
| completion_tokens=exec_tokens["completion"],
|
| total_tokens=exec_tokens["total"],
|
| execution_time=execution_time,
|
| latency_seconds=execution_time,
|
| agent_id=agent_id,
|
| agent_display_name=display_name,
|
| error=str(e),
|
| )
|
|
|