| | import json |
| |
|
| | from datasets import Dataset |
| |
|
| | from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS |
| |
|
| | from .base import BaseDataset |
| |
|
| |
|
| | @LOAD_DATASET.register_module() |
| | class DRCDDataset(BaseDataset): |
| |
|
| | @staticmethod |
| | def load(path: str): |
| | with open(path, 'r', encoding='utf-8') as f: |
| | data = json.load(f) |
| | |
| | rows = [] |
| | for index, paragraphs in enumerate(data['data']): |
| | for paragraph in paragraphs['paragraphs']: |
| |
|
| | context = paragraph['context'] |
| |
|
| | for question in paragraph['qas']: |
| | answers = question['answers'] |
| | unique_answers = list(set([a['text'] for a in answers])) |
| | rows.append({ |
| | 'context': context, |
| | 'question': question['question'], |
| | 'answers': unique_answers |
| | }) |
| |
|
| | |
| | dataset = Dataset.from_dict({ |
| | 'context': [row['context'] for row in rows], |
| | 'question': [row['question'] for row in rows], |
| | 'answers': [row['answers'] for row in rows] |
| | }) |
| |
|
| | return dataset |
| |
|
| |
|
| | @TEXT_POSTPROCESSORS.register_module('drcd') |
| | def drcd_postprocess(text: str) -> str: |
| | if '答案是' in text: |
| | text = text.split('答案是')[1] |
| | return text |
| |
|