File size: 6,376 Bytes
5374a2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os 
from typing import Any, Callable
from .benchmark import Benchmark
from .measures import exact_match_score, f1_score, acc_score
from ..core.logging import logger
from ..core.module_utils import load_json
from ..utils.utils import download_file
from ..utils.aflow_utils.data_utils import AFLOW_DATASET_FILES_MAP, download_aflow_benchmark_data
from copy import deepcopy

HOTPOTQA_FILES_MAP = {"train": "hotpot_train_v1.1.json", "dev": "hotpot_dev_distractor_v1.json", "test": None}
VALIDE_RAW_HOTPOTQA_FILES = [file for file in list(HOTPOTQA_FILES_MAP.values()) if file is not None]

def download_raw_hotpotqa_data(name: str, save_folder: str):

    assert name in VALIDE_RAW_HOTPOTQA_FILES, f"'{name}' is an invalid hotpotqa file name. Available file names: {VALIDE_RAW_HOTPOTQA_FILES}"
    url = f"http://curtis.ml.cmu.edu/datasets/hotpot/{name}"
    typ = "train" if "train" in name else "dev"
    logger.info(f"Downloading HotPotQA {typ} data from: {url}")
    download_file(url=url, save_file=os.path.join(save_folder, name))


class PertQA(Benchmark):

    """Benchmark class for evaluating multi-hop question answering on HotPotQA dataset.
    
    Each HotPotQA example has the following structure:
    {
        "_id": str, 
        "question": str, 
        "answer": str, 
        "context": [["context_title", ["context_sentence", "another_sentence"]]],
        "supporting_facts": [["supporting_title", supporting_sentence_index]],
        "type": str,
        "level": str
    }
    
    The benchmark evaluates answers using exact match, F1 score, and accuracy metrics.
    """

    def __init__(self, path: str = None, mode: str = "all", pertdata='adamson', **kwargs):
        path = os.path.expanduser(path or "~/.evoagentx/data/hotpotqa")
        self.pertdata = pertdata
        super().__init__(name=type(self).__name__, path=path, mode=mode,pertdata=pertdata, **kwargs)

    def _load_data_from_file(self, file_name: str):
        if file_name is None:
            return None
        file_path = os.path.join(self.path, file_name)
        if not os.path.exists(file_path):
            download_raw_hotpotqa_data(name=file_name, save_folder=self.path)
        logger.info(f"loading HotPotQA data from {file_path} ...")
        return load_json(path=file_path, type="json")

    def _load_data(self):
        if self.pertdata == 'adamson':
            if self.mode == "train" or self.mode == "all":
                self._train_data = self._load_data_from_file("../../examples/pertqa/adamson_update_train.json")
            if self.mode == "dev" or self.mode == "all":
                self._dev_data = self._load_data_from_file("../../examples/pertqa/adamson_update_train.json")
            if self.mode == "test" or self.mode == "all":
                self._test_data = self._load_data_from_file("../../examples/pertqa/adamson_update_test.json")
                
        if self.pertdata == 'norman':
            if self.mode == "train" or self.mode == "all":
                self._train_data = self._load_data_from_file("../../examples/pertqa/norman_update_train.json")
            if self.mode == "dev" or self.mode == "all":
                self._dev_data = self._load_data_from_file("../../examples/pertqa/norman_update_train.json")
            if self.mode == "test" or self.mode == "all":
                self._test_data = self._load_data_from_file("../../examples/pertqa/norman_update_test.json")
        if self.pertdata == 'reploge':
            if self.mode == "train" or self.mode == "all":
                self._train_data = self._load_data_from_file("../../examples/pertqa/reploge_update_train.json")
            if self.mode == "dev" or self.mode == "all":
                self._dev_data = self._load_data_from_file("../../examples/pertqa/reploge_update_train.json")
            if self.mode == "test" or self.mode == "all":
                self._test_data = self._load_data_from_file("../../examples/pertqa/reploge_update_test.json") 
        self._dev_data_full = deepcopy(self._train_data)
    def _get_label(self, example: Any) -> Any:
        return example["answer"]
    
    def _get_id(self, example: Any) -> Any:
        return example["_id"]
    
    def evaluate(self, prediction: Any, label: Any) -> dict:
        em = exact_match_score(prediction=prediction, ground_truth=label)
        f1 = f1_score(prediction=prediction, ground_truth=label)
        acc = acc_score(prediction=prediction, ground_truths=[label])
        return {"f1": f1, "em": em, "acc": acc}
    async def async_evaluate(self, graph: Callable, example: Any) -> float:

        # generate solution 
        prompt = example["question_new"]
        inputs = f"Question: {prompt}\n\nAnswer:"
        solution = await graph(inputs)
        label = self._get_label(example)
        metrics = await super().async_evaluate(prediction=solution, label=label)
        return metrics["acc"]


class AFlowPertQA(PertQA):

    """
    AFlow-specific implementation of HotPotQA benchmark.
    """

    def _load_data_from_file(self, file_name: str):
        if file_name is None:
            return None
        file_path = os.path.join(self.path, file_name)
        if not os.path.exists(file_path):
            download_aflow_benchmark_data(dataset="hotpotqa", save_folder=self.path)
        logger.info(f"loading data from {file_path} ...")
        return load_json(path=file_path, type="jsonl")

    def _load_data(self):
        if self.pertdata == 'adamson':
            if self.mode == "train" or self.mode == "all":
                self._train_data = self._load_data_from_file("../../examples/pertqa/adamson_update_train.json")
            if self.mode == "dev" or self.mode == "all":
                self._dev_data = self._load_data_from_file("../../examples/pertqa/adamson_update_train.json")
            if self.mode == "test" or self.mode == "all":
                self._test_data = self._load_data_from_file("../../examples/pertqa/adamson_update_test.json")
    
    async def async_evaluate(self, graph: Callable, example: Any) -> float:

        # generate solution 
        prompt = example["question_new"]
        inputs = f"Question: {prompt}\n\nAnswer:"
        solution = await graph(inputs)
        label = self._get_label(example)
        metrics = await super().async_evaluate(prediction=solution, label=label)
        return metrics["acc"]