Spaces:
Running
Running
File size: 1,706 Bytes
f316449 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import logging
import argparse
from collections import namedtuple
from Nested.utils.helpers import load_checkpoint
from Nested.utils.data import get_dataloaders, text2segments
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--model_path",
type=str,
required=True,
help="Model path",
)
parser.add_argument(
"--text",
type=str,
required=True,
help="Text or sequence to tag",
)
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="Batch size",
)
args = parser.parse_args()
return args
def main(args):
# Load tagger
tagger, tag_vocab, train_config = load_checkpoint(args.model_path)
# Convert text to a tagger dataset and index the tokens in args.text
dataset, token_vocab = text2segments(args.text)
vocabs = namedtuple("Vocab", ["tags", "tokens"])
vocab = vocabs(tokens=token_vocab, tags=tag_vocab)
# From the datasets generate the dataloaders
dataloader = get_dataloaders(
(dataset,),
vocab,
train_config.data_config,
batch_size=args.batch_size,
shuffle=(False,),
)[0]
# Perform inference on the text and get back the tagged segments
segments = tagger.infer(dataloader)
# Print results
for segment in segments:
s = [
f"{token.text} ({'|'.join([t['tag'] for t in token.pred_tag])})"
for token in segment
]
print(" ".join(s))
if __name__ == "__main__":
main(parse_args())
|