| | import json |
| |
|
| | from datasets import Dataset, DatasetDict |
| |
|
| | from opencompass.registry import LOAD_DATASET |
| |
|
| | from .base import BaseDataset |
| |
|
| |
|
| | @LOAD_DATASET.register_module() |
| | class dropDataset(BaseDataset): |
| |
|
| | @staticmethod |
| | def get_answers(validated_answers): |
| | answers = [] |
| | for answer_item in validated_answers: |
| | if answer_item['number']: |
| | answers.append(answer_item['number']) |
| | elif any(answer_item['date'][i] for i in ['day', 'month', 'year']): |
| | d = [answer_item['date'][i] for i in ['day', 'month', 'year']] |
| | answers.append(' '.join(d).strip()) |
| | else: |
| | for span in answer_item['spans']: |
| | answers.append(span) |
| | answers = list(set(answers)) |
| | return answers |
| |
|
| | @staticmethod |
| | def load(path, only_number=True): |
| | with open(path, 'r', encoding='utf-8') as f: |
| | lines = json.load(f) |
| | dataset_list = [] |
| | for line in lines.values(): |
| | for qa_pair in line['qa_pairs']: |
| | validated_answers = qa_pair['validated_answers'] |
| | if only_number and not any(i['number'] |
| | for i in validated_answers): |
| | continue |
| | item = { |
| | 'prompt': line['passage'], |
| | 'question': qa_pair['question'], |
| | 'answers': dropDataset.get_answers(validated_answers), |
| | } |
| | dataset_list.append(item) |
| |
|
| | dataset_list = Dataset.from_list(dataset_list) |
| | return DatasetDict({'validation': dataset_list}) |
| |
|