import logging import argparse from collections import namedtuple from Nested.utils.helpers import load_checkpoint from Nested.utils.data import get_dataloaders, text2segments logger = logging.getLogger(__name__) def parse_args(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "--model_path", type=str, required=True, help="Model path", ) parser.add_argument( "--text", type=str, required=True, help="Text or sequence to tag", ) parser.add_argument( "--batch_size", type=int, default=32, help="Batch size", ) args = parser.parse_args() return args def main(args): # Load tagger tagger, tag_vocab, train_config = load_checkpoint(args.model_path) # Convert text to a tagger dataset and index the tokens in args.text dataset, token_vocab = text2segments(args.text) vocabs = namedtuple("Vocab", ["tags", "tokens"]) vocab = vocabs(tokens=token_vocab, tags=tag_vocab) # From the datasets generate the dataloaders dataloader = get_dataloaders( (dataset,), vocab, train_config.data_config, batch_size=args.batch_size, shuffle=(False,), )[0] # Perform inference on the text and get back the tagged segments segments = tagger.infer(dataloader) # Print results for segment in segments: s = [ f"{token.text} ({'|'.join([t['tag'] for t in token.pred_tag])})" for token in segment ] print(" ".join(s)) if __name__ == "__main__": main(parse_args())