stanza-digphil / stanza /utils /training /remove_constituency_optimizer.py
Albin Thörn Cleland
Clean initial commit with LFS
19b8775
"""Saved a huge, bloated model with an optimizer? Use this to remove it, greatly shrinking the model size
This tries to find reasonable defaults for word vectors and charlm
(which need to be loaded so that the model knows the matrix sizes)
so ideally all that needs to be run is
python3 stanza/utils/training/remove_constituency_optimizer.py <treebanks>
python3 stanza/utils/training/remove_constituency_optimizer.py da_arboretum ...
This can also be used to load and save models as part of an update
to the serialized format
"""
import argparse
import logging
import os
from stanza.models import constituency_parser
from stanza.models.common.constant import treebank_to_short_name
from stanza.resources.default_packages import default_charlms, default_pretrains
from stanza.utils.training import common
logger = logging.getLogger('stanza')
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
parser.add_argument('--charlm', default="default", type=str, help='Which charlm to run on. Will use the default charlm for this language/model if not set. Set to None to turn off charlm for languages with a default charlm')
parser.add_argument('--no_charlm', dest='charlm', action="store_const", const=None, help="Don't use a charlm, even if one is used by default for this package")
parser.add_argument('--load_dir', type=str, default="saved_models/constituency", help="Root dir for getting the models to resave.")
parser.add_argument('--save_dir', type=str, default="resaved_models/constituency", help="Root dir for resaving the models.")
parser.add_argument('treebanks', type=str, nargs='+', help='Which treebanks to run on. Use all_ud or ud_all for all UD treebanks')
args = parser.parse_args()
return args
def main():
"""
For each of the models specified, load and resave the model
The resaved model will have the optimizer removed
"""
args = parse_args()
os.makedirs(args.save_dir, exist_ok=True)
for treebank in args.treebanks:
logger.info("PROCESSING %s", treebank)
short_name = treebank_to_short_name(treebank)
language, dataset = short_name.split("_", maxsplit=1)
logger.info("%s: %s %s", short_name, language, dataset)
if not args.wordvec_pretrain_file:
# will throw an error if the pretrain can't be found
wordvec_pretrain = common.find_wordvec_pretrain(language, default_pretrains)
wordvec_args = ['--wordvec_pretrain_file', wordvec_pretrain]
else:
wordvec_args = []
charlm = common.choose_charlm(language, dataset, args.charlm, default_charlms, {})
charlm_args = common.build_charlm_args(language, charlm, base_args=False)
base_name = '{}_constituency.pt'.format(short_name)
load_name = os.path.join(args.load_dir, base_name)
save_name = os.path.join(args.save_dir, base_name)
resave_args = ['--mode', 'remove_optimizer',
'--load_name', load_name,
'--save_name', save_name,
'--save_dir', ".",
'--shorthand', short_name]
resave_args = resave_args + wordvec_args + charlm_args
constituency_parser.main(resave_args)
if __name__ == '__main__':
main()