|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
import warnings |
|
|
import torch |
|
|
import numpy as np |
|
|
|
|
|
from data import data_utils |
|
|
from data.ofa_dataset import OFADataset |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) |
|
|
|
|
|
|
|
|
def collate(samples, pad_idx, eos_idx): |
|
|
if len(samples) == 0: |
|
|
return {} |
|
|
|
|
|
def merge(key): |
|
|
return data_utils.collate_tokens( |
|
|
[s[key] for s in samples], |
|
|
pad_idx, |
|
|
eos_idx=eos_idx, |
|
|
) |
|
|
|
|
|
src_tokens = merge("source") |
|
|
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples]) |
|
|
|
|
|
prev_output_tokens = None |
|
|
target = None |
|
|
if samples[0].get("target", None) is not None: |
|
|
target = merge("target") |
|
|
tgt_lengths = torch.LongTensor( |
|
|
[s["target"].ne(pad_idx).long().sum() for s in samples] |
|
|
) |
|
|
ntokens = tgt_lengths.sum().item() |
|
|
|
|
|
if samples[0].get("prev_output_tokens", None) is not None: |
|
|
prev_output_tokens = merge("prev_output_tokens") |
|
|
else: |
|
|
ntokens = src_lengths.sum().item() |
|
|
|
|
|
target_strs = np.array([s["target_str"] for s in samples]) |
|
|
|
|
|
batch = { |
|
|
"nsentences": len(samples), |
|
|
"ntokens": ntokens, |
|
|
"net_input": { |
|
|
"src_tokens": src_tokens, |
|
|
"src_lengths": src_lengths, |
|
|
"prev_output_tokens": prev_output_tokens |
|
|
}, |
|
|
"target": target, |
|
|
"target_strs": target_strs |
|
|
} |
|
|
|
|
|
return batch |
|
|
|
|
|
|
|
|
class SummaryDataset(OFADataset): |
|
|
def __init__( |
|
|
self, |
|
|
split, |
|
|
dataset, |
|
|
bpe, |
|
|
src_dict, |
|
|
tgt_dict=None, |
|
|
code_dict_size=8192, |
|
|
num_bins=1000, |
|
|
max_src_length=512, |
|
|
max_tgt_length=128, |
|
|
noise_ratio=0.0 |
|
|
): |
|
|
super().__init__(split, dataset, bpe, src_dict, tgt_dict) |
|
|
self.max_src_length = max_src_length |
|
|
self.max_tgt_length = max_tgt_length |
|
|
self.code_dict_size = code_dict_size |
|
|
self.num_bins = num_bins |
|
|
self.noise_ratio = noise_ratio |
|
|
|
|
|
if type(bpe).__name__ == 'GPT2BPE': |
|
|
self.prompt = ' what is the summary of article " {} "?' |
|
|
elif type(bpe).__name__ == 'BertBPE': |
|
|
self.prompt = "{} 请用一个句子简单总结上文:" |
|
|
|
|
|
def __getitem__(self, index): |
|
|
source, target = self.dataset[index] |
|
|
target_str = target.lower() |
|
|
|
|
|
source = self.pre_caption(source, max_words=self.max_src_length) |
|
|
target = self.pre_caption(target, max_words=self.max_tgt_length) |
|
|
source = source.replace('<unk>', 'unk') |
|
|
target = target.replace('<unk>', 'unk') |
|
|
|
|
|
src_item = self.encode_text( |
|
|
self.prompt.format(source), |
|
|
length=self.max_src_length |
|
|
) |
|
|
tgt_item = self.encode_text('{}'.format(target)) |
|
|
noise_tgt_item = self.add_noise_to_tgt(tgt_item.clone(), self.noise_ratio) |
|
|
|
|
|
src_item = torch.cat([self.bos_item, src_item, self.eos_item]) |
|
|
target_item = torch.cat([tgt_item, self.eos_item]) |
|
|
prev_output_item = torch.cat([self.bos_item, noise_tgt_item]) |
|
|
|
|
|
example = { |
|
|
"source": src_item, |
|
|
"target": target_item, |
|
|
"prev_output_tokens": prev_output_item, |
|
|
"target_str": target_str |
|
|
} |
|
|
return example |
|
|
|
|
|
def add_noise_to_tgt(self, target, p): |
|
|
noise_indices = torch.FloatTensor(target.size(0)).uniform_() < p |
|
|
target[noise_indices] = torch.randint( |
|
|
4, len(self.src_dict) - self.code_dict_size - self.num_bins, size=(noise_indices.sum(),) |
|
|
) |
|
|
return target |
|
|
|
|
|
def collater(self, samples, pad_to_length=None): |
|
|
"""Merge a list of samples to form a mini-batch. |
|
|
Args: |
|
|
samples (List[dict]): samples to collate |
|
|
Returns: |
|
|
dict: a mini-batch containing the data of the task |
|
|
""" |
|
|
return collate(samples, pad_idx=self.pad, eos_idx=self.eos) |