Spaces:
Runtime error
Runtime error
| """ | |
| Script for joining dataset of documents/reference summaries with generated summaries (likely from generate.py). | |
| Usage with custom datasets in JSONL format: | |
| python join.py --data_path <path to data in jsonl format> --generation_paths <paths to generated predictions> --output_path <path to output file> | |
| Optionally specify --model_names to override default model names. | |
| """ | |
| # !/usr/bin/env python | |
| # coding: utf-8 | |
| import argparse | |
| import json | |
| import os | |
| from pathlib import Path | |
| import torch | |
| from tqdm import tqdm | |
| BATCH_SIZE = 8 | |
| class JSONDataset(torch.utils.data.Dataset): | |
| def __init__(self, data_path): | |
| super(JSONDataset, self).__init__() | |
| with open(data_path) as fd: | |
| self.data = [json.loads(line) for line in fd] | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| return self.data[idx] | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--data_path', type=str) | |
| parser.add_argument('--generation_paths', type=str, nargs="+", required=True) | |
| parser.add_argument('--output_path', type=str, required=True) | |
| parser.add_argument('--model_names', type=str, nargs="+") | |
| args = parser.parse_args() | |
| if args.model_names and len(args.generation_paths) != len(args.model_names): | |
| raise ValueError('Length of args.generation_paths must equal length of args.model_names') | |
| if args.model_names: | |
| model_names = args.model_names | |
| else: | |
| model_names = [Path(p).name.split(".")[0] for p in args.generation_paths] | |
| args.dataset = os.path.splitext(os.path.basename(args.data_path))[0] | |
| args.split = 'user' | |
| # Load data | |
| dataset = JSONDataset(args.data_path) | |
| # Join files and write out single jsonl dataset | |
| generation_files = [open(fname) for fname in args.generation_paths] | |
| with open(args.output_path, 'w') as outp: | |
| for row in tqdm(zip(dataset, *generation_files)): | |
| # Process each original data record in parallel with generation(s) of the model(s) | |
| result = {} | |
| data = row[0] | |
| generations = row[1:] | |
| result['summary:reference'] = data['summary:reference'] | |
| result['document'] = data['document'] | |
| for model_name, gen in zip(model_names, generations): | |
| result[f'summary:{model_name}'] = gen | |
| outp.write( | |
| json.dumps(result) + '\n' | |
| ) | |
| for file in generation_files: | |
| file.close() | |