| |
| |
| |
| |
|
|
| from __future__ import print_function, unicode_literals |
|
|
| import logging |
| import argparse |
| import subprocess |
| import sys |
| import os |
|
|
| logging.basicConfig( |
| format='%(asctime)s %(levelname)s: %(message)s', |
| datefmt='%Y-%m-%d %H:%M:%S', level=logging.DEBUG) |
| parser = argparse.ArgumentParser() |
| parser.add_argument("-w", "--working-dir", dest="working_dir") |
| parser.add_argument("-c", "--corpus", dest="corpus_stem") |
| parser.add_argument("-l", "--nplm-home", dest="nplm_home") |
| parser.add_argument("-e", "--epochs", dest="epochs", type=int) |
| parser.add_argument("-n", "--ngram-size", dest="ngram_size", type=int) |
| parser.add_argument("-b", "--minibatch-size", dest="minibatch_size", type=int) |
| parser.add_argument("-s", "--noise", dest="noise", type=int) |
| parser.add_argument("-d", "--hidden", dest="hidden", type=int) |
| parser.add_argument( |
| "-i", "--input-embedding", dest="input_embedding", type=int) |
| parser.add_argument( |
| "-o", "--output-embedding", dest="output_embedding", type=int) |
| parser.add_argument("-t", "--threads", dest="threads", type=int) |
| parser.add_argument("-m", "--output-model", dest="output_model") |
| parser.add_argument("-r", "--output-dir", dest="output_dir") |
| parser.add_argument("-f", "--config-options-file", dest="config_options_file") |
| parser.add_argument("-g", "--log-file", dest="log_file") |
| parser.add_argument("-v", "--validation-ngrams", dest="validation_file") |
| parser.add_argument("-a", "--activation-function", dest="activation_fn") |
| parser.add_argument("-z", "--learning-rate", dest="learning_rate") |
| parser.add_argument("--input-words-file", dest="input_words_file") |
| parser.add_argument("--output-words-file", dest="output_words_file") |
| parser.add_argument("--input_vocab_size", dest="input_vocab_size", type=int) |
| parser.add_argument("--output_vocab_size", dest="output_vocab_size", type=int) |
| parser.add_argument("--mmap", dest="mmap", action="store_true", |
| help="Use memory-mapped file (for lower memory consumption).") |
| parser.add_argument("--extra-settings", dest="extra_settings", |
| help="Extra settings to be passed to NPLM") |
| parser.add_argument( |
| "--train-host", dest="train_host", |
| help="Execute nplm training on this host, via ssh") |
|
|
| parser.set_defaults( |
| working_dir="working", |
| corpus_stem="train.10k", |
| nplm_home="/home/bhaddow/tools/nplm", |
| epochs=10, |
| ngram_size=14, |
| minibatch_size=1000, |
| noise=100, |
| hidden=0, |
| input_embedding=150, |
| output_embedding=750, |
| threads=1, |
| output_model="train.10k", |
| output_dir=None, |
| config_options_file="config", |
| log_file="log", |
| validation_file=None, |
| activation_fn="rectifier", |
| learning_rate=1, |
| input_words_file=None, |
| output_words_file=None, |
| input_vocab_size=0, |
| output_vocab_size=0 |
| ) |
|
|
|
|
| def main(options): |
|
|
| vocab_command = [] |
| if options.input_words_file is not None: |
| vocab_command += ['--input_words_file', options.input_words_file] |
| if options.output_words_file is not None: |
| vocab_command += ['--output_words_file', options.output_words_file] |
| if options.input_vocab_size: |
| vocab_command += ['--input_vocab_size', str(options.input_vocab_size)] |
| if options.output_vocab_size: |
| vocab_command += [ |
| '--output_vocab_size', str(options.output_vocab_size)] |
|
|
| |
| validations_command = [] |
| if options.validation_file is not None: |
| validations_command = [ |
| "--validation_file", (options.validation_file + ".numberized")] |
|
|
| |
| |
| |
| |
| |
|
|
| if options.output_dir is None: |
| options.output_dir = options.working_dir |
| else: |
| |
| if not os.path.exists(options.output_dir): |
| os.makedirs(options.output_dir) |
|
|
| config_file = os.path.join( |
| options.output_dir, |
| options.config_options_file + '-' + options.output_model) |
| log_file = os.path.join( |
| options.output_dir, options.log_file + '-' + options.output_model) |
| log_file_write = open(log_file, 'w') |
| config_file_write = open(config_file, 'w') |
|
|
| config_file_write.write("Called: " + ' '.join(sys.argv) + '\n\n') |
|
|
| in_file = os.path.join( |
| options.working_dir, |
| os.path.basename(options.corpus_stem) + ".numberized") |
|
|
| mmap_command = [] |
| if options.mmap: |
| in_file += '.mmap' |
| mmap_command = ['--mmap_file', '1'] |
|
|
| model_prefix = os.path.join( |
| options.output_dir, options.output_model + ".model.nplm") |
| train_args = [] |
| if options.train_host: |
| train_args = ["ssh", options.train_host] |
| train_args += [ |
| options.nplm_home + "/src/trainNeuralNetwork", |
| "--train_file", in_file, |
| "--num_epochs", str(options.epochs), |
| "--model_prefix", model_prefix, |
| "--learning_rate", str(options.learning_rate), |
| "--minibatch_size", str(options.minibatch_size), |
| "--num_noise_samples", str(options.noise), |
| "--num_hidden", str(options.hidden), |
| "--input_embedding_dimension", str(options.input_embedding), |
| "--output_embedding_dimension", str(options.output_embedding), |
| "--num_threads", str(options.threads), |
| "--activation_function", options.activation_fn, |
| "--ngram_size", str(options.ngram_size), |
| ] + validations_command + vocab_command + mmap_command |
| if options.extra_settings: train_args += options.extra_settings.split() |
| print("Train model command: ") |
| print(', '.join(train_args)) |
|
|
| config_file_write.write("Training step:\n" + ' '.join(train_args) + '\n') |
| config_file_write.close() |
|
|
| log_file_write.write("Training output:\n") |
| ret = subprocess.call( |
| train_args, stdout=log_file_write, stderr=log_file_write) |
| if ret: |
| raise Exception("Training failed") |
|
|
| log_file_write.close() |
|
|
|
|
| if __name__ == "__main__": |
| options = parser.parse_args() |
| main(options) |
|
|