| |
|
|
| |
| |
| |
|
|
| import gzip |
| import logging |
| from numpy import array |
| import optparse |
| import os.path |
| import sys |
|
|
| from nbest import * |
| from sampler import * |
| from train import * |
|
|
|
|
| logging.basicConfig(format = "%(asctime)-15s %(message)s") |
| log = logging.getLogger('main') |
| log.setLevel(logging.DEBUG) |
| |
| class Config: |
| def __init__(self): |
| self.parser = optparse.OptionParser(usage="%prog [options] ") |
| self.parser.add_option("-t", "--trainer", action="store",\ |
| dest="trainer", metavar="TYPE", type="choice", choices=("pro","mix"),\ |
| default="pro",\ |
| help="type of trainer to run (pro,mix)") |
| self.parser.add_option("-n", "--nbest", action="append", \ |
| dest="nbest", metavar="NBEST-FILE",\ |
| help="nbest output file(s) from decoder") |
| self.parser.add_option("-S", "--scfile", action="append",\ |
| dest="score", metavar="SCORE-FILE",\ |
| help="score file(s) from extractor (in same order as nbests)") |
| self.parser.add_option("-p", "--phrase-table" , action="append",\ |
| dest="ttable", metavar="TTABLE",\ |
| help="ttable to be used in mixture model training") |
| self.parser.add_option("-i", "--input-file", action="store",\ |
| dest="input_file", metavar="INPUT-FILE", |
| help="source text file") |
| self.parser.add_option("-m", "--moses-bin-dir", action="store",\ |
| dest="moses_bin_dir", metavar="DIR", |
| help="directory containing Moses binaries", |
| default=os.path.expanduser("~/moses/bin")) |
| self.nbest_files = [] |
| self.score_files = [] |
| self.ttables = [] |
|
|
| def parse(self,args=sys.argv[1:]): |
| (options,args) = self.parser.parse_args(args) |
| self.nbest_files = options.nbest |
| self.score_files = options.score |
| self.ttables = options.ttable |
| self.input_file = options.input_file |
| self.trainer = options.trainer |
| self.moses_bin_dir = options.moses_bin_dir |
| if not self.nbest_files: |
| self.nbest_files = ["data/esen.nc.nbest.segment"] |
| if not self.score_files: |
| self.score_files = ["data/esen.nc.scores"] |
| if len(self.nbest_files) != len(self.score_files): |
| self.parser.error("Must have equal numbers of score files and nbest files") |
| if self.trainer == "mix": |
| if not self.input_file or not self.ttables: |
| self.parser.error("Need to specify input file and ttables for mix training") |
| |
| |
|
|
| def main(): |
| config = Config() |
| config.parse() |
|
|
| samples = [] |
| sampler = HopkinsMaySampler() |
| nbests = 0 |
| for nbest_file,score_data_file in zip(config.nbest_files,config.score_files): |
| log.debug("nbest: " + nbest_file + "; score:" + score_data_file) |
| segments = False |
| if config.trainer == "mix": segments = True |
| for nbest in get_scored_nbests(nbest_file, score_data_file, config.input_file, segments=segments): |
| samples += sampler.sample(nbest) |
| nbests += 1 |
| log.debug("Samples loaded") |
| trainer = None |
| if config.trainer == "mix": |
| |
| scorer = MosesPhraseScorer(config.ttables) |
| log.debug("Scoring samples...") |
| for sample in samples: |
| scorer.add_scores(sample.hyp1) |
| scorer.add_scores(sample.hyp2) |
| log.debug("...samples scored") |
| trainer = MixtureModelTrainer(samples) |
| elif config.trainer == "pro": |
| trainer = ProTrainer(samples) |
| else: assert(0) |
| log.debug("Starting training...") |
| weights,mix_weights = trainer.train(debug=False) |
| log.debug("...training complete") |
| for i,w in enumerate(weights): |
| print "F%d %10.8f" % (i,w) |
| for i,f in enumerate(mix_weights): |
| for j,w in enumerate(f): |
| print "M%d_%d %10.8f" % (i,j,w) |
|
|
|
|
|
|
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|