| |
| |
| |
| |
| |
|
|
| """ |
| Evaluate the perplexity of a trained language model. |
| """ |
|
|
| import logging |
| import math |
| import os |
|
|
| import torch |
| from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils |
| from fairseq.data import LMContextWindowDataset |
| from fairseq.logging import progress_bar |
| from fairseq.logging.meters import StopwatchMeter, TimeMeter |
| from fairseq.sequence_scorer import SequenceScorer |
|
|
|
|
| logging.basicConfig( |
| format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", |
| datefmt="%Y-%m-%d %H:%M:%S", |
| level=os.environ.get("LOGLEVEL", "INFO").upper(), |
| ) |
| logger = logging.getLogger("fairseq_cli.eval_lm") |
|
|
|
|
| class WordStat(object): |
| def __init__(self, word, is_bpe): |
| self.word = word |
| self.is_bpe = is_bpe |
| self.log_prob = 0 |
| self.next_word_prob = 0 |
| self.count = 0 |
| self.missing_next_words = 0 |
|
|
| def add(self, log_prob, next_word_prob): |
| """increments counters for the sum of log probs of current word and next |
| word (given context ending at current word). Since the next word might be at the end of the example, |
| or it might be not counted because it is not an ending subword unit, |
| also keeps track of how many of those we have seen""" |
| if next_word_prob is not None: |
| self.next_word_prob += next_word_prob |
| else: |
| self.missing_next_words += 1 |
| self.log_prob += log_prob |
| self.count += 1 |
|
|
| def __str__(self): |
| return "{}\t{}\t{}\t{}\t{}\t{}".format( |
| self.word, |
| self.count, |
| self.log_prob, |
| self.is_bpe, |
| self.next_word_prob, |
| self.count - self.missing_next_words, |
| ) |
|
|
|
|
| def main(parsed_args, **unused_kwargs): |
| assert parsed_args.path is not None, "--path required for evaluation!" |
|
|
| if torch.cuda.is_available() and not parsed_args.cpu: |
| torch.cuda.set_device(parsed_args.device_id) |
|
|
| utils.import_user_module(parsed_args) |
|
|
| logger.info(parsed_args) |
|
|
| use_cuda = torch.cuda.is_available() and not parsed_args.cpu |
|
|
| task = tasks.setup_task(parsed_args) |
|
|
| |
| logger.info("loading model(s) from {}".format(parsed_args.path)) |
| models, args = checkpoint_utils.load_model_ensemble( |
| parsed_args.path.split(os.pathsep), |
| arg_overrides=eval(parsed_args.model_overrides), |
| task=task, |
| suffix=getattr(parsed_args, "checkpoint_suffix", ""), |
| strict=(parsed_args.checkpoint_shard_count == 1), |
| num_shards=parsed_args.checkpoint_shard_count, |
| ) |
|
|
| for arg in vars(parsed_args).keys(): |
| if arg not in { |
| "self_target", |
| "future_target", |
| "past_target", |
| "tokens_per_sample", |
| "output_size_dictionary", |
| "add_bos_token", |
| }: |
| setattr(args, arg, getattr(parsed_args, arg)) |
|
|
| |
| args.tokens_per_sample -= args.context_window |
| task = tasks.setup_task(args) |
|
|
| |
| task.load_dataset(args.gen_subset) |
| dataset = task.dataset(args.gen_subset) |
| if args.context_window > 0: |
| dataset = LMContextWindowDataset( |
| dataset=dataset, |
| tokens_per_sample=args.tokens_per_sample, |
| context_window=args.context_window, |
| pad_idx=task.source_dictionary.pad(), |
| ) |
| logger.info("{} {} {} examples".format(args.data, args.gen_subset, len(dataset))) |
|
|
| |
| for model in models: |
| if args.fp16: |
| model.half() |
| if use_cuda and not args.pipeline_model_parallel: |
| model.cuda() |
| model.prepare_for_inference_(args) |
|
|
| assert len(models) > 0 |
|
|
| logger.info( |
| "num. model params: {}".format(sum(p.numel() for p in models[0].parameters())) |
| ) |
|
|
| itr = task.get_batch_iterator( |
| dataset=dataset, |
| max_tokens=args.max_tokens or 36000, |
| max_sentences=args.batch_size, |
| max_positions=utils.resolve_max_positions( |
| *[model.max_positions() for model in models] |
| ), |
| ignore_invalid_inputs=True, |
| num_shards=args.num_shards, |
| shard_id=args.shard_id, |
| num_workers=args.num_workers, |
| data_buffer_size=args.data_buffer_size, |
| ).next_epoch_itr(shuffle=False) |
| progress = progress_bar.progress_bar( |
| itr, |
| log_format=args.log_format, |
| log_interval=args.log_interval, |
| default_log_format=("tqdm" if not args.no_progress_bar else "none"), |
| ) |
|
|
| gen_timer = StopwatchMeter() |
| scorer = SequenceScorer(task.target_dictionary, args.softmax_batch) |
|
|
| score_sum = 0.0 |
| count = 0 |
|
|
| if args.remove_bpe is not None: |
| if args.remove_bpe == "sentencepiece": |
| raise NotImplementedError |
| else: |
| bpe_cont = args.remove_bpe.rstrip() |
| bpe_toks = { |
| i |
| for i in range(len(task.source_dictionary)) |
| if task.source_dictionary[i].endswith(bpe_cont) |
| } |
| bpe_len = len(bpe_cont) |
| else: |
| bpe_toks = None |
| bpe_len = 0 |
|
|
| word_stats = dict() |
|
|
| wps_meter = TimeMeter() |
|
|
| for sample in progress: |
| if "net_input" not in sample: |
| continue |
|
|
| sample = utils.move_to_cuda(sample) if use_cuda else sample |
|
|
| gen_timer.start() |
| hypos = scorer.generate(models, sample) |
| gen_timer.stop(sample["ntokens"]) |
|
|
| for i, hypos_i in enumerate(hypos): |
| hypo = hypos_i[0] |
| sample_id = sample["id"][i] |
|
|
| tokens = hypo["tokens"] |
| tgt_len = tokens.numel() |
| pos_scores = hypo["positional_scores"].float() |
|
|
| if getattr(args, "add_bos_token", False): |
| assert hypo["tokens"][0].item() == task.target_dictionary.bos() |
| tokens = tokens[1:] |
| pos_scores = pos_scores[1:] |
|
|
| skipped_toks = 0 |
| if bpe_toks is not None: |
| for i in range(tgt_len - 1): |
| if tokens[i].item() in bpe_toks: |
| skipped_toks += 1 |
| pos_scores[i + 1] += pos_scores[i] |
| pos_scores[i] = 0 |
|
|
| inf_scores = pos_scores.eq(float("inf")) | pos_scores.eq(float("-inf")) |
| if inf_scores.any(): |
| logger.info( |
| "skipping tokens with inf scores:", |
| task.target_dictionary.string(tokens[inf_scores.nonzero()]), |
| ) |
| pos_scores = pos_scores[(~inf_scores).nonzero()] |
| score_sum += pos_scores.sum().cpu() |
| count += pos_scores.numel() - skipped_toks |
|
|
| if args.output_word_probs or args.output_word_stats: |
| w = "" |
| word_prob = [] |
| is_bpe = False |
| for i in range(len(tokens)): |
| w_ind = tokens[i].item() |
| w += task.source_dictionary[w_ind] |
| if bpe_toks is not None and w_ind in bpe_toks: |
| w = w[:-bpe_len] |
| is_bpe = True |
| else: |
| word_prob.append((w, pos_scores[i].item())) |
|
|
| next_prob = None |
| ind = i + 1 |
| while ind < len(tokens): |
| if pos_scores[ind].item() != 0: |
| next_prob = pos_scores[ind] |
| break |
| ind += 1 |
|
|
| word_stats.setdefault(w, WordStat(w, is_bpe)).add( |
| pos_scores[i].item(), next_prob |
| ) |
| is_bpe = False |
| w = "" |
| if args.output_word_probs: |
| logger.info( |
| str(int(sample_id)) |
| + " " |
| + ( |
| "\t".join( |
| "{} [{:2f}]".format(x[0], x[1]) for x in word_prob |
| ) |
| ) |
| ) |
|
|
| wps_meter.update(sample["ntokens"]) |
| progress.log({"wps": round(wps_meter.avg)}) |
|
|
| avg_nll_loss = -score_sum / count / math.log(2) |
| logger.info( |
| "Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)".format( |
| gen_timer.n, gen_timer.sum, 1.0 / gen_timer.avg |
| ) |
| ) |
| logger.info( |
| "Loss (base 2): {:.4f}, Perplexity: {:.2f}".format( |
| avg_nll_loss, 2 ** avg_nll_loss |
| ) |
| ) |
|
|
| if args.output_word_stats: |
| for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): |
| logger.info(ws) |
|
|
|
|
| def cli_main(): |
| parser = options.get_eval_lm_parser() |
| args = options.parse_args_and_arch(parser) |
| distributed_utils.call_main(args, main) |
|
|
|
|
| if __name__ == "__main__": |
| cli_main() |
|
|