varox34's picture
Upload 64 files
366b225 verified
# -*- coding: utf-8 -*-
from datetime import datetime
from parser import Model
from parser.cmds.cmd import CMD
from parser.utils.corpus import Corpus
from parser.utils.data import TextDataset, batchify
class Predict(CMD):
def add_subparser(self, name, parser):
subparser = parser.add_parser(
name, help='Use a trained model to make predictions.'
)
subparser.add_argument('--batch-size', default=1000, type=int,
help='batch size')
subparser.add_argument('--fdata', default='data/ptb/tamtest.conllx',
help='path to dataset')
subparser.add_argument('--fpred', default='pred.conllx',
help='path to predicted result')
return subparser
def __call__(self, args):
super(Predict, self).__call__(args)
print("Load the dataset")
corpus = Corpus.load(args.fdata, self.fields)
dataset = TextDataset(corpus, [self.WORD, self.FEAT])
# set the data loader
dataset.loader = batchify(dataset, args.batch_size)
print(type(dataset))
print(f"{len(dataset)} sentences, "
f"{len(dataset.loader)} batches")
print("Load the model")
self.model = Model.load(args.model)
print(f"{self.model}\n")
print("Make predictions on the dataset")
start = datetime.now()
corpus.heads, corpus.rels = self.predict(dataset.loader)
print(f"Save the predicted result to {args.fpred}")
corpus.save(args.fpred)
total_time = datetime.now() - start
print(f"{total_time}s elapsed, "
f"{len(dataset) / total_time.total_seconds():.2f} Sents/s")