Spaces:
Runtime error
Runtime error
Upload evaluate.py
Browse files- evaluate.py +152 -0
evaluate.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3 -u
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import json
|
| 11 |
+
from itertools import chain
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
import torch.distributed as dist
|
| 16 |
+
from fairseq import distributed_utils, options, tasks, utils
|
| 17 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
| 18 |
+
from fairseq.logging import progress_bar
|
| 19 |
+
from fairseq.utils import reset_logging
|
| 20 |
+
from omegaconf import DictConfig
|
| 21 |
+
|
| 22 |
+
from utils import checkpoint_utils
|
| 23 |
+
from utils.eval_utils import eval_step
|
| 24 |
+
|
| 25 |
+
logging.basicConfig(
|
| 26 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
| 27 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 28 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
| 29 |
+
stream=sys.stdout,
|
| 30 |
+
)
|
| 31 |
+
logger = logging.getLogger("ofa.evaluate")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def apply_half(t):
|
| 35 |
+
if t.dtype is torch.float32:
|
| 36 |
+
return t.to(dtype=torch.half)
|
| 37 |
+
return t
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def main(cfg: DictConfig):
|
| 41 |
+
utils.import_user_module(cfg.common)
|
| 42 |
+
|
| 43 |
+
reset_logging()
|
| 44 |
+
logger.info(cfg)
|
| 45 |
+
|
| 46 |
+
assert (
|
| 47 |
+
cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
|
| 48 |
+
), "Must specify batch size either with --max-tokens or --batch-size"
|
| 49 |
+
|
| 50 |
+
# Fix seed for stochastic decoding
|
| 51 |
+
if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
|
| 52 |
+
np.random.seed(cfg.common.seed)
|
| 53 |
+
utils.set_torch_seed(cfg.common.seed)
|
| 54 |
+
|
| 55 |
+
use_fp16 = cfg.common.fp16
|
| 56 |
+
use_cuda = torch.cuda.is_available() and not cfg.common.cpu
|
| 57 |
+
|
| 58 |
+
if use_cuda:
|
| 59 |
+
torch.cuda.set_device(cfg.distributed_training.device_id)
|
| 60 |
+
|
| 61 |
+
# Load ensemble
|
| 62 |
+
overrides = eval(cfg.common_eval.model_overrides)
|
| 63 |
+
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
|
| 64 |
+
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
| 65 |
+
utils.split_paths(cfg.common_eval.path),
|
| 66 |
+
arg_overrides=overrides,
|
| 67 |
+
suffix=cfg.checkpoint.checkpoint_suffix,
|
| 68 |
+
strict=(cfg.checkpoint.checkpoint_shard_count == 1),
|
| 69 |
+
num_shards=cfg.checkpoint.checkpoint_shard_count,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
|
| 73 |
+
task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)
|
| 74 |
+
|
| 75 |
+
# Move models to GPU
|
| 76 |
+
for model in models:
|
| 77 |
+
model.eval()
|
| 78 |
+
if use_fp16:
|
| 79 |
+
model.half()
|
| 80 |
+
if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
|
| 81 |
+
model.cuda()
|
| 82 |
+
model.prepare_for_inference_(cfg)
|
| 83 |
+
|
| 84 |
+
# Load dataset (possibly sharded)
|
| 85 |
+
itr = task.get_batch_iterator(
|
| 86 |
+
dataset=task.dataset(cfg.dataset.gen_subset),
|
| 87 |
+
max_tokens=cfg.dataset.max_tokens,
|
| 88 |
+
max_sentences=cfg.dataset.batch_size,
|
| 89 |
+
max_positions=utils.resolve_max_positions(
|
| 90 |
+
task.max_positions(), *[m.max_positions() for m in models]
|
| 91 |
+
),
|
| 92 |
+
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
|
| 93 |
+
required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
|
| 94 |
+
seed=cfg.common.seed,
|
| 95 |
+
num_shards=cfg.distributed_training.distributed_world_size,
|
| 96 |
+
shard_id=cfg.distributed_training.distributed_rank,
|
| 97 |
+
num_workers=cfg.dataset.num_workers,
|
| 98 |
+
data_buffer_size=cfg.dataset.data_buffer_size,
|
| 99 |
+
).next_epoch_itr(shuffle=False)
|
| 100 |
+
progress = progress_bar.progress_bar(
|
| 101 |
+
itr,
|
| 102 |
+
log_format=cfg.common.log_format,
|
| 103 |
+
log_interval=cfg.common.log_interval,
|
| 104 |
+
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Initialize generator
|
| 108 |
+
generator = task.build_generator(models, cfg.generation)
|
| 109 |
+
|
| 110 |
+
results = []
|
| 111 |
+
score_sum = torch.FloatTensor([0]).cuda()
|
| 112 |
+
score_cnt = torch.FloatTensor([0]).cuda()
|
| 113 |
+
for sample in progress:
|
| 114 |
+
if "net_input" not in sample:
|
| 115 |
+
continue
|
| 116 |
+
sample = utils.move_to_cuda(sample) if use_cuda else sample
|
| 117 |
+
sample = utils.apply_to_sample(apply_half, sample) if cfg.common.fp16 else sample
|
| 118 |
+
with torch.no_grad():
|
| 119 |
+
result, scores = eval_step(task, generator, models, sample)
|
| 120 |
+
results += result
|
| 121 |
+
score_sum += sum(scores) if scores is not None else 0
|
| 122 |
+
score_cnt += len(scores) if scores is not None else 0
|
| 123 |
+
progress.log({"sentences": sample["nsentences"]})
|
| 124 |
+
|
| 125 |
+
gather_results = None
|
| 126 |
+
if cfg.distributed_training.distributed_world_size > 1:
|
| 127 |
+
gather_results = [None for _ in range(dist.get_world_size())]
|
| 128 |
+
dist.all_gather_object(gather_results, results)
|
| 129 |
+
dist.all_reduce(score_sum.data)
|
| 130 |
+
dist.all_reduce(score_cnt.data)
|
| 131 |
+
if score_cnt.item() > 0:
|
| 132 |
+
logger.info("score_sum: {}, score_cnt: {}, score: {}".format(
|
| 133 |
+
score_sum, score_cnt, round(score_sum.item() / score_cnt.item(), 4)
|
| 134 |
+
))
|
| 135 |
+
|
| 136 |
+
if cfg.distributed_training.distributed_world_size == 1 or dist.get_rank() == 0:
|
| 137 |
+
os.makedirs(cfg.common_eval.results_path, exist_ok=True)
|
| 138 |
+
output_path = os.path.join(cfg.common_eval.results_path, "{}_predict.json".format(cfg.dataset.gen_subset))
|
| 139 |
+
gather_results = list(chain(*gather_results)) if gather_results is not None else results
|
| 140 |
+
with open(output_path, 'w') as fw:
|
| 141 |
+
json.dump(gather_results, fw)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def cli_main():
|
| 145 |
+
parser = options.get_generation_parser()
|
| 146 |
+
args = options.parse_args_and_arch(parser)
|
| 147 |
+
cfg = convert_namespace_to_omegaconf(args)
|
| 148 |
+
distributed_utils.call_main(cfg, main)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
if __name__ == "__main__":
|
| 152 |
+
cli_main()
|