Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import os.path as osp | |
| from src.benchmarks.qa_datasets import AmazonSTaRKDataset, PrimeKGSTaRKDataset, MAGSTaRKDataset, STaRKDataset | |
| def get_qa_dataset(name, root='data/', human_generated_eval=False): | |
| qa_root = osp.join(root, name) | |
| if name == 'amazon': | |
| split_dir = osp.join(qa_root, 'split') | |
| stark_qa_dir = osp.join(qa_root, 'stark_qa') | |
| dataset = AmazonSTaRKDataset(stark_qa_dir, split_dir, | |
| human_generated_eval=human_generated_eval) | |
| elif name == 'primekg': | |
| split_dir = osp.join(qa_root, 'split') | |
| stark_qa_dir = osp.join(qa_root, 'stark_qa') | |
| dataset = PrimeKGSTaRKDataset(stark_qa_dir, split_dir, | |
| human_generated_eval=human_generated_eval) | |
| elif name == 'mag': | |
| split_dir = osp.join(qa_root, 'split') | |
| stark_qa_dir = osp.join(qa_root, 'stark_qa') | |
| dataset = MAGSTaRKDataset(stark_qa_dir, split_dir, | |
| human_generated_eval=human_generated_eval) | |
| else: | |
| try: | |
| print('loading dataset from external data') | |
| split_dir = osp.join(qa_root, 'split') | |
| stark_qa_dir = osp.join(qa_root, 'stark_qa') | |
| dataset = STaRKDataset(stark_qa_dir, split_dir) | |
| except Exception as e: | |
| print('Please check dataset name, path, or format\n') | |
| raise e | |
| return dataset | |