maotao / fairseq /tasks /fairseq_task.py
julse's picture
Upload 551 files
be611b4 verified
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import warnings
import torch
from fairseq import metrics, search, tokenizer, utils
from fairseq.data import data_utils, FairseqDataset, iterators, Dictionary
logger = logging.getLogger(__name__)
class FairseqTask(object):
"""
Tasks store dictionaries and provide helpers for loading/iterating over
Datasets, initializing the Model/Criterion and calculating the loss.
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
pass
@staticmethod
def logging_outputs_can_be_summed(criterion) -> bool:
"""
Whether the logging outputs returned by `train_step` and `valid_step` can
be summed across workers prior to calling `aggregate_logging_outputs`.
Setting this to True will improves distributed training speed.
"""
return criterion.logging_outputs_can_be_summed()
def __init__(self, args):
self.args = args
self.datasets = {}
self.dataset_to_epoch_iter = {}
@classmethod
def load_dictionary(cls, filename):
"""Load the dictionary from the filename
Args:
filename (str): the filename
"""
return Dictionary.load(filename)
@classmethod
def build_dictionary(
cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8
):
"""Build the dictionary
Args:
filenames (list): list of filenames
workers (int): number of concurrent workers
threshold (int): defines the minimum word count
nwords (int): defines the total number of words in the final dictionary,
including special symbols
padding_factor (int): can be used to pad the dictionary size to be a
multiple of 8, which is important on some hardware (e.g., Nvidia
Tensor Cores).
"""
d = Dictionary()
for filename in filenames:
Dictionary.add_file_to_dictionary(
filename, d, tokenizer.tokenize_line, workers
)
d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor)
return d
@classmethod
def setup_task(cls, args, **kwargs):
"""Setup the task (e.g., load dictionaries).
Args:
args (argparse.Namespace): parsed command-line arguments
"""
return cls(args, **kwargs)
def has_sharded_data(self, split):
return (os.pathsep in getattr(self.args, 'data', ''))
def load_dataset(self, split, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
raise NotImplementedError
def dataset(self, split):
"""
Return a loaded dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
Returns:
a :class:`~fairseq.data.FairseqDataset` corresponding to *split*
"""
from fairseq.data import FairseqDataset
if split not in self.datasets:
raise KeyError("Dataset not loaded: " + split)
if not isinstance(self.datasets[split], FairseqDataset):
raise TypeError("Datasets are expected to be of type FairseqDataset")
return self.datasets[split]
def filter_indices_by_size(self,
indices,
dataset,
max_positions,
ignore_invalid_inputs):
"""
Filter examples that are too large
Args:
indices (np.array): original array of sample indices
dataset (~fairseq.data.FairseqDataset): dataset to batch
max_positions (optional): max sentence length supported by the
model (default: None).
ignore_invalid_inputs (bool, optional): don't raise Exception for
sentences that are too long (default: False).
Returns:
np.array: array of filtered sample indices
"""
indices, ignored = dataset.filter_indices_by_size(indices, max_positions)
if len(ignored) > 0:
if not ignore_invalid_inputs:
raise Exception((
'Size of sample #{} is invalid (={}) since max_positions={}, '
'skip this example with --skip-invalid-size-inputs-valid-test'
).format(ignored[0], dataset.size(ignored[0]), max_positions))
logger.warning((
'{} samples have invalid sizes and will be skipped, '
'max_positions={}, first few sample ids={}'
).format(len(ignored), max_positions, ignored[:10]))
return indices
def get_batch_iterator(
self,
dataset,
max_tokens=None,
max_sentences=None,
max_positions=None,
ignore_invalid_inputs=False,
required_batch_size_multiple=1,
seed=1,
num_shards=1,
shard_id=0,
num_workers=0,
epoch=1
):
"""
Get an iterator that yields batches of data from the given dataset.
Args:
dataset (~fairseq.data.FairseqDataset): dataset to batch
max_tokens (int, optional): max number of tokens in each batch
(default: None).
max_sentences (int, optional): max number of sentences in each
batch (default: None).
max_positions (optional): max sentence length supported by the
model (default: None).
ignore_invalid_inputs (bool, optional): don't raise Exception for
sentences that are too long (default: False).
required_batch_size_multiple (int, optional): require batch size to
be a multiple of N (default: 1).
seed (int, optional): seed for random number generator for
reproducibility (default: 1).
num_shards (int, optional): shard the data iterator into N
shards (default: 1).
shard_id (int, optional): which shard of the data iterator to
return (default: 0).
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process
(default: 0).
epoch (int, optional): the epoch to start the iterator from
(default: 1).
Returns:
~fairseq.iterators.EpochBatchIterator: a batched iterator over the
given dataset split
"""
# For default fairseq task, return same iterator across epochs
# as datasets are not dynamic, can be overridden in task specific
# setting.
if dataset in self.dataset_to_epoch_iter:
return self.dataset_to_epoch_iter[dataset]
assert isinstance(dataset, FairseqDataset)
# initialize the dataset with the correct starting epoch
dataset.set_epoch(epoch)
# get indices ordered by example size
with data_utils.numpy_seed(seed):
indices = dataset.ordered_indices()
# filter examples that are too large
if max_positions is not None:
indices = self.filter_indices_by_size(indices,
dataset,
max_positions,
ignore_invalid_inputs)
# create mini-batches with given size constraints
batch_sampler = dataset.batch_by_size(
indices,
max_tokens=max_tokens,
max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
)
# return a reusable, sharded iterator
epoch_iter = iterators.EpochBatchIterator(
dataset=dataset,
collate_fn=dataset.collater,
batch_sampler=batch_sampler,
seed=seed,
num_shards=num_shards,
shard_id=shard_id,
num_workers=num_workers,
epoch=epoch,
buffer_size=getattr(self.args, 'data_buffer_size', 0)
)
self.dataset_to_epoch_iter[dataset] = epoch_iter
return epoch_iter
def build_model(self, args):
"""
Build the :class:`~fairseq.models.BaseFairseqModel` instance for this
task.
Args:
args (argparse.Namespace): parsed command-line arguments
Returns:
a :class:`~fairseq.models.BaseFairseqModel` instance
"""
from fairseq import models, quantization_utils
model = models.build_model(args, self)
if getattr(args, 'tpu', False):
model.prepare_for_tpu_()
model = quantization_utils.quantize_model_scalar(model, args)
return model
def build_criterion(self, args):
"""
Build the :class:`~fairseq.criterions.FairseqCriterion` instance for
this task.
Args:
args (argparse.Namespace): parsed command-line arguments
Returns:
a :class:`~fairseq.criterions.FairseqCriterion` instance
"""
from fairseq import criterions
return criterions.build_criterion(args, self)
def build_generator(
self, models, args,
seq_gen_cls=None, extra_gen_cls_kwargs=None
):
if getattr(args, "score_reference", False):
from fairseq.sequence_scorer import SequenceScorer
return SequenceScorer(
self.target_dictionary,
compute_alignment=getattr(args, "print_alignment", False),
)
from fairseq.sequence_generator import (
SequenceGenerator,
SequenceGeneratorWithAlignment,
)
# Choose search strategy. Defaults to Beam Search.
sampling = getattr(args, "sampling", False)
sampling_topk = getattr(args, "sampling_topk", -1)
sampling_topp = getattr(args, "sampling_topp", -1.0)
diverse_beam_groups = getattr(args, "diverse_beam_groups", -1)
diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5)
match_source_len = getattr(args, "match_source_len", False)
diversity_rate = getattr(args, "diversity_rate", -1)
if (
sum(
int(cond)
for cond in [
sampling,
diverse_beam_groups > 0,
match_source_len,
diversity_rate > 0,
]
)
> 1
):
raise ValueError("Provided Search parameters are mutually exclusive.")
assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling"
assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling"
if sampling:
search_strategy = search.Sampling(
self.target_dictionary, sampling_topk, sampling_topp
)
elif diverse_beam_groups > 0:
search_strategy = search.DiverseBeamSearch(
self.target_dictionary, diverse_beam_groups, diverse_beam_strength
)
elif match_source_len:
# this is useful for tagging applications where the output
# length should match the input length, so we hardcode the
# length constraints for simplicity
search_strategy = search.LengthConstrainedBeamSearch(
self.target_dictionary,
min_len_a=1,
min_len_b=0,
max_len_a=1,
max_len_b=0,
)
elif diversity_rate > -1:
search_strategy = search.DiverseSiblingsSearch(
self.target_dictionary, diversity_rate
)
else:
search_strategy = search.BeamSearch(self.target_dictionary)
if seq_gen_cls is None:
if getattr(args, "print_alignment", False):
seq_gen_cls = SequenceGeneratorWithAlignment
else:
seq_gen_cls = SequenceGenerator
extra_gen_cls_kwargs = extra_gen_cls_kwargs or {}
return seq_gen_cls(
models,
self.target_dictionary,
beam_size=getattr(args, "beam", 5),
max_len_a=getattr(args, "max_len_a", 0),
max_len_b=getattr(args, "max_len_b", 200),
min_len=getattr(args, "min_len", 1),
normalize_scores=(not getattr(args, "unnormalized", False)),
len_penalty=getattr(args, "lenpen", 1),
unk_penalty=getattr(args, "unkpen", 0),
temperature=getattr(args, "temperature", 1.0),
match_source_len=getattr(args, "match_source_len", False),
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
search_strategy=search_strategy,
**extra_gen_cls_kwargs,
)
def train_step(
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
):
"""
Do forward and backward, and return the loss as computed by *criterion*
for the given *model* and *sample*.
Args:
sample (dict): the mini-batch. The format is defined by the
:class:`~fairseq.data.FairseqDataset`.
model (~fairseq.models.BaseFairseqModel): the model
criterion (~fairseq.criterions.FairseqCriterion): the criterion
optimizer (~fairseq.optim.FairseqOptimizer): the optimizer
update_num (int): the current update
ignore_grad (bool): multiply loss by 0 if this is set to True
Returns:
tuple:
- the loss
- the sample size, which is used as the denominator for the
gradient
- logging outputs to display while training
"""
model.train()
model.set_num_updates(update_num)
with torch.autograd.profiler.record_function("forward"):
loss, sample_size, logging_output = criterion(model, sample)
if ignore_grad:
loss *= 0
with torch.autograd.profiler.record_function("backward"):
optimizer.backward(loss)
return loss, sample_size, logging_output
def valid_step(self, sample, model, criterion):
model.eval()
with torch.no_grad():
loss, sample_size, logging_output = criterion(model, sample)
return loss, sample_size, logging_output
def inference_step(self, generator, models, sample, prefix_tokens=None):
with torch.no_grad():
return generator.generate(models, sample, prefix_tokens=prefix_tokens)
def begin_epoch(self, epoch, model):
"""Hook function called before the start of each epoch."""
pass
def aggregate_logging_outputs(self, logging_outputs, criterion):
"""[deprecated] Aggregate logging outputs from data parallel training."""
utils.deprecation_warning(
"The aggregate_logging_outputs API is deprecated. "
"Please use the reduce_metrics API instead."
)
with metrics.aggregate() as agg:
self.reduce_metrics(logging_outputs, criterion)
return agg.get_smoothed_values()
def reduce_metrics(self, logging_outputs, criterion):
"""Aggregate logging outputs from data parallel training."""
# backward compatibility for tasks that override aggregate_logging_outputs
base_func = FairseqTask.aggregate_logging_outputs
self_func = getattr(self, "aggregate_logging_outputs").__func__
if self_func is not base_func:
utils.deprecation_warning(
"Tasks should implement the reduce_metrics API. "
"Falling back to deprecated aggregate_logging_outputs API."
)
agg_logging_outputs = self.aggregate_logging_outputs(
logging_outputs, criterion
)
for k, v in agg_logging_outputs.items():
metrics.log_scalar(k, v)
return
if not any("ntokens" in log for log in logging_outputs):
warnings.warn(
"ntokens not found in Criterion logging outputs, cannot log wpb or wps"
)
else:
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
metrics.log_scalar("wpb", ntokens, priority=180, round=1)
metrics.log_speed("wps", ntokens, priority=90, round=1)
if not any("nsentences" in log for log in logging_outputs):
warnings.warn(
"nsentences not found in Criterion logging outputs, cannot log bsz"
)
else:
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
metrics.log_scalar("bsz", nsentences, priority=190, round=1)
criterion.__class__.reduce_metrics(logging_outputs)
def max_positions(self):
"""Return the max input length allowed by the task."""
return None
@property
def source_dictionary(self):
"""Return the source :class:`~fairseq.data.Dictionary` (if applicable
for this task)."""
raise NotImplementedError
@property
def target_dictionary(self):
"""Return the target :class:`~fairseq.data.Dictionary` (if applicable
for this task)."""
raise NotImplementedError