File size: 4,796 Bytes
e34b94f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
from datasets import DatasetDict, load_dataset
from typing import Dict, List
import random
 
from larm.data.builders.base_builder import BaseDatasetBuilder
from larm.data.interactions.singleturn_interaction import SingleTurnInteractionManager
from larm.common.registry import registry
from larm.data.envs.gpqa_env import GPQAEnv

@registry.register_builder("gpqa")
class GPQABuilder(BaseDatasetBuilder):   # Env

    DATASET_CONFIG_DICT = {
        "default": "configs/datasets/gpqa/default.yaml",
    }
    CACHE_PATH = None
    
    def _build_datasets(self) -> DatasetDict:
        # download data
        raw_train_dataset = load_dataset("Idavidrein/gpqa", "gpqa_main")["train"]
        raw_test_dataset = load_dataset("Idavidrein/gpqa", "gpqa_diamond")["train"]
        val_size = int(len(raw_train_dataset) * self.config.valid_ratio)
        split = raw_train_dataset.train_test_split(test_size=val_size, shuffle=True)
        raw_train_dataset, raw_valid_dataset = split["train"], split["test"]
        
        # save to json
        cache_root = registry.get_path("cache_root")
        save_dir = os.path.join(cache_root, "gpqa")
        raw_train_dataset.to_json(os.path.join(save_dir, "train.json"))
        raw_valid_dataset.to_json(os.path.join(save_dir, "valid.json"))
        raw_test_dataset.to_json(os.path.join(save_dir, "test.json"))
        
        # preprocess
        train_dataset = raw_train_dataset.map(self._preprocess).select_columns(self._keep_keys())
        valid_dataset = raw_valid_dataset.map(self._preprocess).select_columns(self._keep_keys())
        test_dataset = raw_test_dataset.map(self._preprocess).select_columns(self._keep_keys())

        # build dataset
        dataset_dict = DatasetDict()
        dataset_dict["train"] = train_dataset
        dataset_dict["valid"] = valid_dataset
        dataset_dict["test"] = test_dataset

        return dataset_dict

    def _build_sft_datasets(self) -> DatasetDict:
        return self._build_datasets()


    def _build_rl_datasets(self) -> DatasetDict:
        return self._build_datasets()
    
    @classmethod
    def _preprocess(cls, example: Dict):
        
        def build_answer_map(candidates: List[str]) -> Dict[str, Dict[str, object]]:

            indices = list(range(len(candidates)))
            random.shuffle(indices)

            orders = [chr(ord("A") + i) for i in range(len(candidates))]

            answer_map = {}
            for idx, candidate_idx in enumerate(indices):
                answer = candidates[candidate_idx]
                answer_map[answer] = {
                    "order": orders[idx],
                    "is_correct": (candidate_idx == 0) 
                }

            return answer_map
        
        def build_question(question, answer_map: Dict) -> str:
            result = question.strip() + "\n\nPlease choose one of the following options:\n"

            sorted_items = sorted(answer_map.items(), key=lambda x: x[1]["order"])

            for answer, meta in sorted_items:
                result += f"{meta['order']}. {answer}\n"

            return result

        def build_answer(rationale: str, answer_map: Dict) -> str:
            correct_answer = None
            for key, value in answer_map.items():
                if value.get("is_correct") is True:
                    correct_answer = value.get("order")
            assert correct_answer is not None
            return rationale + f"\n\nTherefore, the final answer is \\boxed{{{correct_answer}}}"

        question = example["Question"].strip()
        explanation = example["Explanation"].strip()
        correct_answer = example["Correct Answer"].strip()
        incorrect_answer1 = example["Incorrect Answer 1"].strip() 
        incorrect_answer2 = example["Incorrect Answer 2"].strip()
        incorrect_answer3 = example["Incorrect Answer 3"].strip()
        
        answers_map = build_answer_map([correct_answer, incorrect_answer1, incorrect_answer2, incorrect_answer3])
        question = build_question(question, answers_map)
        answer = build_answer(explanation, answers_map)

        format_template = r"""Solve the problem with proper reasoning, and make sure to put the FINAL CHOICE inside \boxed{}."""
        prompt_template = "Question: {prompt}\n"   
        processed_prompt = format_template + prompt_template.format(prompt=question)

        text_output = {
            "prompt": processed_prompt,
            "completion": answer,
            "solution": answer    
        }
        return text_output
    
    @classmethod
    def _keep_keys(cls):
        return ["prompt", "completion", "solution"]
    
    def get_env_cls(self):
        return GPQAEnv
        
    def get_generation_manager_cls(self):
        return SingleTurnInteractionManager