| | import csv |
| | import json |
| | import os.path as osp |
| |
|
| | from datasets import Dataset, DatasetDict |
| |
|
| | from opencompass.registry import LOAD_DATASET |
| |
|
| | from .base import BaseDataset |
| |
|
| |
|
| | @LOAD_DATASET.register_module() |
| | class CEvalDataset(BaseDataset): |
| |
|
| | @staticmethod |
| | def load(path: str, name: str): |
| | dataset = {} |
| | for split in ['dev', 'val', 'test']: |
| | filename = osp.join(path, split, f'{name}_{split}.csv') |
| | with open(filename, encoding='utf-8') as f: |
| | reader = csv.reader(f) |
| | header = next(reader) |
| | for row in reader: |
| | item = dict(zip(header, row)) |
| | item.setdefault('explanation', '') |
| | item.setdefault('answer', '') |
| | dataset.setdefault(split, []).append(item) |
| | dataset = {i: Dataset.from_list(dataset[i]) for i in dataset} |
| | return DatasetDict(dataset) |
| |
|
| |
|
| | class CEvalDatasetClean(BaseDataset): |
| |
|
| | |
| | |
| | @staticmethod |
| | def load_contamination_annotations(path, split='val'): |
| | import requests |
| |
|
| | assert split == 'val', 'Now we only have annotations for val set' |
| | annotation_cache_path = osp.join( |
| | path, split, 'ceval_contamination_annotations.json') |
| | if osp.exists(annotation_cache_path): |
| | with open(annotation_cache_path, 'r') as f: |
| | annotations = json.load(f) |
| | return annotations |
| | link_of_annotations = 'https://github.com/liyucheng09/Contamination_Detector/releases/download/v0.1.1rc/ceval_annotations.json' |
| | annotations = json.loads(requests.get(link_of_annotations).text) |
| | with open(annotation_cache_path, 'w') as f: |
| | json.dump(annotations, f) |
| | return annotations |
| |
|
| | @staticmethod |
| | def load(path: str, name: str): |
| | dataset = {} |
| | for split in ['dev', 'val', 'test']: |
| | if split == 'val': |
| | annotations = CEvalDatasetClean.load_contamination_annotations( |
| | path, split) |
| | filename = osp.join(path, split, f'{name}_{split}.csv') |
| | with open(filename, encoding='utf-8') as f: |
| | reader = csv.reader(f) |
| | header = next(reader) |
| | for row_index, row in enumerate(reader): |
| | item = dict(zip(header, row)) |
| | item.setdefault('explanation', '') |
| | item.setdefault('answer', '') |
| | if split == 'val': |
| | row_id = f'{name}-{row_index}' |
| | if row_id in annotations: |
| | item['is_clean'] = annotations[row_id][0] |
| | else: |
| | item['is_clean'] = 'not labeled' |
| | dataset.setdefault(split, []).append(item) |
| | dataset = {i: Dataset.from_list(dataset[i]) for i in dataset} |
| | return DatasetDict(dataset) |
| |
|