Spaces:
Runtime error
Runtime error
| """ | |
| Script for decoding summarization models available through Huggingface Transformers. | |
| To use with one of the 6 standard models: | |
| python generation.py --model <model abbreviation> --data_path <path to data in jsonl format> | |
| where model abbreviation is one of: bart-xsum, bart-cnndm, pegasus-xsum, pegasus-cnndm, pegasus-newsroom, | |
| pegasus-multinews: | |
| To use with arbitrary model: | |
| python generation.py --model_name_or_path <Huggingface model name or local path> --data_path <path to data in jsonl format> | |
| """ | |
| # !/usr/bin/env python | |
| # coding: utf-8 | |
| import argparse | |
| import json | |
| import os | |
| import torch | |
| from tqdm import tqdm | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| BATCH_SIZE = 8 | |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| BART_CNNDM_CHECKPOINT = 'facebook/bart-large-cnn' | |
| BART_XSUM_CHECKPOINT = 'facebook/bart-large-xsum' | |
| PEGASUS_CNNDM_CHECKPOINT = 'google/pegasus-cnn_dailymail' | |
| PEGASUS_XSUM_CHECKPOINT = 'google/pegasus-xsum' | |
| PEGASUS_NEWSROOM_CHECKPOINT = 'google/pegasus-newsroom' | |
| PEGASUS_MULTINEWS_CHECKPOINT = 'google/pegasus-multi_news' | |
| MODEL_CHECKPOINTS = { | |
| 'bart-xsum': BART_XSUM_CHECKPOINT, | |
| 'bart-cnndm': BART_CNNDM_CHECKPOINT, | |
| 'pegasus-xsum': PEGASUS_XSUM_CHECKPOINT, | |
| 'pegasus-cnndm': PEGASUS_CNNDM_CHECKPOINT, | |
| 'pegasus-newsroom': PEGASUS_NEWSROOM_CHECKPOINT, | |
| 'pegasus-multinews': PEGASUS_MULTINEWS_CHECKPOINT | |
| } | |
| class JSONDataset(torch.utils.data.Dataset): | |
| def __init__(self, data_path): | |
| super(JSONDataset, self).__init__() | |
| with open(data_path) as fd: | |
| self.data = [json.loads(line) for line in fd] | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| return self.data[idx] | |
| def postprocess_data(decoded): | |
| """ | |
| Remove generation artifacts and postprocess outputs | |
| :param decoded: model outputs | |
| """ | |
| return [x.replace('<n>', ' ') for x in decoded] | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description='Process some integers.') | |
| parser.add_argument('--model', type=str) | |
| parser.add_argument('--model_name_or_path', type=str) | |
| parser.add_argument('--data_path', type=str) | |
| args = parser.parse_args() | |
| if not (args.model or args.model_name_or_path): | |
| raise ValueError('Model is required') | |
| if args.model and args.model_name_or_path: | |
| raise ValueError('Specify model or model_name_or_path but not both') | |
| # Load models & data | |
| if args.model: | |
| model_name_or_path = MODEL_CHECKPOINTS[args.model] | |
| file_model_name = args.model | |
| else: | |
| model_name_or_path = args.model_name_or_path | |
| file_model_name = model_name_or_path.replace("/", "-") | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path).to(DEVICE) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) | |
| dataset = JSONDataset(args.data_path) | |
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE) | |
| # Write out dataset | |
| file_dataset_name = os.path.splitext(os.path.basename(args.data_path))[0] | |
| filename = f'{file_model_name}.{file_dataset_name}.predictions' | |
| fd_out = open(filename, 'w') | |
| model.eval() | |
| with torch.no_grad(): | |
| for raw_data in tqdm(dataloader): | |
| batch = tokenizer(raw_data["document"], return_tensors="pt", truncation=True, padding="longest").to(DEVICE) | |
| summaries = model.generate(input_ids=batch.input_ids, attention_mask=batch.attention_mask) | |
| decoded = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
| for example in postprocess_data(decoded): | |
| fd_out.write(example + '\n') | |