Spaces:
Running
Running
File size: 2,471 Bytes
f316449 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import os
import logging
import argparse
from collections import namedtuple
from Nested.utils.helpers import load_checkpoint, make_output_dirs, logging_config
from Nested.utils.data import get_dataloaders, parse_conll_files
from Nested.utils.metrics import compute_single_label_metrics, compute_nested_metrics
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--output_path",
type=str,
required=True,
help="Path to save results",
)
parser.add_argument(
"--model_path",
type=str,
required=True,
help="Model path",
)
parser.add_argument(
"--data_paths",
nargs="+",
type=str,
required=True,
help="Text or sequence to tag, this is in same format as training data with 'O' tag for all tokens",
)
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="Batch size",
)
args = parser.parse_args()
return args
def main(args):
# Create directory to save predictions
make_output_dirs(args.output_path, overwrite=True)
logging_config(log_file=os.path.join(args.output_path, "eval.log"))
# Load tagger
tagger, tag_vocab, train_config = load_checkpoint(args.model_path)
# Convert text to a tagger dataset and index the tokens in args.text
datasets, vocab = parse_conll_files(args.data_paths)
vocabs = namedtuple("Vocab", ["tags", "tokens"])
vocab = vocabs(tokens=vocab.tokens, tags=tag_vocab)
# From the datasets generate the dataloaders
dataloaders = get_dataloaders(
datasets, vocab,
train_config.data_config,
batch_size=args.batch_size,
shuffle=[False] * len(datasets)
)
# Evaluate the model on each dataloader
for dataloader, input_file in zip(dataloaders, args.data_paths):
filename = os.path.basename(input_file)
predictions_file = os.path.join(args.output_path, f"predictions_{filename}")
_, segments, _, _ = tagger.eval(dataloader)
tagger.segments_to_file(segments, predictions_file)
if "Nested" in train_config.trainer_config["fn"]:
compute_nested_metrics(segments, vocab.tags[1:])
else:
compute_single_label_metrics(segments)
if __name__ == "__main__":
main(parse_args())
|