Spaces:
Build error
Build error
| import random | |
| import comet.src.train.train as base_train | |
| import comet.src.train.batch as batch | |
| import comet.src.evaluate.atomic_evaluate as evaluate | |
| # import comet.src.evaluate.atomic_generate as gen | |
| def make_trainer(opt, *args): | |
| return AtomicGenerationIteratorTrainer(opt, *args) | |
| class AtomicGenerationIteratorTrainer(base_train.IteratorTrainer): | |
| def __init__(self, opt, *args): | |
| super(AtomicGenerationIteratorTrainer, self).__init__(opt, *args) | |
| self.initialize_losses(opt.data.get("categories", [])) | |
| def set_evaluator(self, opt, model, data_loader): | |
| self.evaluator = evaluate.make_evaluator( | |
| opt, model, data_loader) | |
| # def set_generator(self, opt, model, data_loader, scores, reward=None): | |
| # self.generator = gen.make_generator( | |
| # opt, model, data_loader, scores, reward) | |
| def set_sampler(self, opt): | |
| if opt.train.static.samp not in self.samplers: | |
| self.samplers[opt.train.static.samp] = sampling.make_sampler( | |
| opt.train.static.samp, opt, self.data_loader, batch_mode=True) | |
| self.batch_variables["sampler"] = self.samplers | |
| def batch(self, opt, *args): | |
| outputs = batch.batch_atomic_generate(opt, *args) | |
| token_loss = outputs["loss"] | |
| nums = outputs["nums"] | |
| reset = outputs["reset"] | |
| return token_loss, nums, reset | |
| def initialize_losses(self, categories): | |
| self.losses["train"] = { | |
| "total_micro": [0], | |
| "total_macro": [0] | |
| } | |
| nums = {"total_micro": 0, "total_macro": 0} | |
| for category in categories: | |
| micro_name = "{}_micro".format(category) | |
| macro_name = "{}_macro".format(category) | |
| self.losses["train"][micro_name] = [0] | |
| self.losses["train"][macro_name] = [0] | |
| nums[micro_name] = 0 | |
| nums[macro_name] = 0 | |
| return nums | |
| def update_top_score(self, opt): | |
| print(self.top_score) | |
| if self.top_score is None: | |
| self.top_score = (self.opt.train.dynamic.epoch, | |
| self.get_tracked_score()) | |
| elif self.get_tracked_score() < self.top_score[-1]: | |
| self.top_score = (self.opt.train.dynamic.epoch, | |
| self.get_tracked_score()) | |
| print(self.top_score) | |
| def get_tracked_score(self): | |
| return self.losses["dev"]["total_micro"][self.opt.train.dynamic.epoch] | |
| def counter(self, nums): | |
| return nums["total_macro"] | |