|
|
import os |
|
|
import json |
|
|
import random |
|
|
from typing import Any, Dict, Callable, List |
|
|
from .benchmark import Benchmark |
|
|
from .measures import exact_match_score, f1_score, acc_score |
|
|
from ..core.logging import logger |
|
|
from ..core.module_utils import load_json |
|
|
from datasets import load_dataset |
|
|
|
|
|
|
|
|
WORFBENCH_FILES_MAP = { |
|
|
"train": "worfbench_train.json", |
|
|
"test": "worfbench_test.json" |
|
|
} |
|
|
VALID_WORFBENCH_FILES = list(WORFBENCH_FILES_MAP.values()) |
|
|
|
|
|
def evaluate_workflow_sequence(prediction: List[Any], ground_truth: List[Any]) -> float: |
|
|
"""Evaluate F1 score for sequence workflow.""" |
|
|
from .measures import f1_score |
|
|
return f1_score(prediction=prediction, ground_truth=ground_truth) |
|
|
|
|
|
def evaluate_workflow_graph(prediction: Dict[str, Any], ground_truth: Dict[str, Any]) -> float: |
|
|
"""Evaluate F1 score for graph workflow.""" |
|
|
pred_nodes = set(prediction.get("nodes", [])) |
|
|
true_nodes = set(ground_truth.get("nodes", [])) |
|
|
pred_edges = set(tuple(edge) for edge in prediction.get("edges", [])) |
|
|
true_edges = set(tuple(edge) for edge in ground_truth.get("edges", [])) |
|
|
|
|
|
node_precision = len(pred_nodes & true_nodes) / len(pred_nodes) if pred_nodes else 0 |
|
|
node_recall = len(pred_nodes & true_nodes) / len(true_nodes) if true_nodes else 0 |
|
|
edge_precision = len(pred_edges & true_edges) / len(pred_edges) if pred_edges else 0 |
|
|
edge_recall = len(pred_edges & true_edges) / len(true_edges) if true_edges else 0 |
|
|
|
|
|
node_f1 = 2 * (node_precision * node_recall) / (node_precision + node_recall) if (node_precision + node_recall) > 0 else 0 |
|
|
edge_f1 = 2 * (edge_precision * edge_recall) / (edge_precision + edge_recall) if (edge_precision + edge_recall) > 0 else 0 |
|
|
|
|
|
return (node_f1 + edge_f1) / 2 |
|
|
|
|
|
def download_worfbench_data(dataset: str, save_folder: str) -> None: |
|
|
""" |
|
|
Download WorfBench dataset from Hugging Face. |
|
|
|
|
|
Args: |
|
|
dataset (str): Dataset name ("worfbench"). |
|
|
save_folder (str): Directory to save data. |
|
|
""" |
|
|
datasets_map = { |
|
|
"train": {"repo_id": "zjunlp/WorFBench_train", "filename": "worfbench_train.json", "split": "train"}, |
|
|
"test": {"repo_id": "zjunlp/WorFBench_test", "filename": "worfbench_test.json", "split": "test"} |
|
|
} |
|
|
|
|
|
os.makedirs(save_folder, exist_ok=True) |
|
|
for split, info in datasets_map.items(): |
|
|
repo_id = info["repo_id"] |
|
|
filename = info["filename"] |
|
|
dataset_split = info["split"] |
|
|
save_path = os.path.join(save_folder, filename) |
|
|
|
|
|
if not os.path.exists(save_path): |
|
|
logger.info(f"Downloading {split} split of {dataset} from {repo_id}...") |
|
|
try: |
|
|
|
|
|
ds = load_dataset(repo_id, split=dataset_split) |
|
|
|
|
|
data = [item for item in ds] |
|
|
with open(save_path, 'w', encoding='utf-8') as f: |
|
|
json.dump(data, f, ensure_ascii=False, indent=2) |
|
|
logger.info(f"Successfully downloaded and saved {filename} to {save_path}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to download or save {filename}: {e}") |
|
|
raise |
|
|
else: |
|
|
logger.info(f"File {save_path} already exists, skipping download.") |
|
|
|
|
|
class WorfBench(Benchmark): |
|
|
""" |
|
|
WorfBench evaluation class for assessing LLM agents on complex workflow generation tasks. |
|
|
Assumed data structure: |
|
|
{ |
|
|
"id": str, |
|
|
"task": str, |
|
|
"context": list of dicts (e.g., [{"title": str, "content": list of str}]), |
|
|
"expected_output": str or dict (sequence or graph), |
|
|
"type": str, |
|
|
"level": str |
|
|
} |
|
|
""" |
|
|
def __init__(self, path: str = None, mode: str = "test", **kwargs): |
|
|
path = os.path.expanduser(path or "~/.worfbench/data") |
|
|
super().__init__(name=type(self).__name__, path=path, mode=mode, **kwargs) |
|
|
|
|
|
def _load_data_from_file(self, file_name: str) -> Dict: |
|
|
if file_name is None: |
|
|
return None |
|
|
file_path = os.path.join(self.path, file_name) |
|
|
if not os.path.exists(file_path): |
|
|
download_worfbench_data(dataset="worfbench", save_folder=self.path) |
|
|
if not os.path.exists(file_path): |
|
|
logger.error(f"File {file_path} still does not exist after download attempt!") |
|
|
return None |
|
|
logger.info(f"Loading WorfBench data from {file_path} ...") |
|
|
data = load_json(path=file_path, type="json") |
|
|
if data is None: |
|
|
logger.error(f"Failed to load data from {file_path}") |
|
|
return None |
|
|
return data |
|
|
|
|
|
def _load_data(self) -> None: |
|
|
if self.mode in ["train", "dev"]: |
|
|
self._train_data = self._load_data_from_file(file_name=WORFBENCH_FILES_MAP["train"]) |
|
|
if self.mode == "dev": |
|
|
if self._train_data: |
|
|
random.seed(42) |
|
|
keys = list(self._train_data.keys()) |
|
|
n_dev = len(self._train_data[keys[0]]) // 10 or 1 |
|
|
indices = list(range(len(self._train_data[keys[0]]))) |
|
|
random.shuffle(indices) |
|
|
self._train_data = {k: [v[i] for i in indices[:n_dev]] for k, v in self._train_data.items()} |
|
|
if self.mode == "test": |
|
|
self._test_data = self._load_data_from_file(file_name=WORFBENCH_FILES_MAP["test"]) |
|
|
|
|
|
def _get_label(self, example: Dict) -> Any: |
|
|
return example.get("expected_output", "") |
|
|
|
|
|
def _get_id(self, example: Dict) -> Any: |
|
|
return example.get("id", "") |
|
|
|
|
|
def evaluate(self, prediction: Any, label: Any) -> Dict: |
|
|
if isinstance(prediction, list) and isinstance(label, list): |
|
|
f1 = evaluate_workflow_sequence(prediction, label) |
|
|
elif isinstance(prediction, dict) and isinstance(label, dict): |
|
|
f1 = evaluate_workflow_graph(prediction, label) |
|
|
else: |
|
|
f1 = f1_score(prediction=str(prediction), ground_truth=str(label)) |
|
|
em = exact_match_score(prediction=prediction, ground_truth=label) |
|
|
acc = acc_score(prediction=prediction, ground_truths=[label]) |
|
|
return {"em": em, "f1": f1, "acc": acc} |
|
|
|
|
|
async def async_evaluate(self, graph: Callable, example: Dict) -> float: |
|
|
task = example.get("task", "") |
|
|
context = "\n".join( |
|
|
f"{ctx.get('title', '')}: {' '.join(ctx.get('content', []))}" |
|
|
for ctx in example.get("context", []) |
|
|
if isinstance(ctx, dict) |
|
|
) |
|
|
inputs = f"Task: {task}\nContext: {context}\nGenerate workflow:\nAnswer:" |
|
|
try: |
|
|
generated_workflow = await graph(inputs) |
|
|
except Exception as e: |
|
|
logger.error(f"Error generating workflow: {e}") |
|
|
generated_workflow = "" |
|
|
label = self._get_label(example) |
|
|
metrics = self.evaluate(prediction=generated_workflow, label=label) |
|
|
return metrics["f1"] |