Spaces:
Build error
Build error
| import random | |
| import torch | |
| import comet.src.data.config as cfg | |
| import comet.src.train.atomic_train as base_train | |
| import comet.src.train.batch as batch_utils | |
| import comet.src.evaluate.conceptnet_evaluate as evaluate | |
| import comet.src.evaluate.conceptnet_generate as gen | |
| def make_trainer(opt, *args): | |
| return ConceptNetGenerationIteratorTrainer(opt, *args) | |
| class ConceptNetGenerationIteratorTrainer( | |
| base_train.AtomicGenerationIteratorTrainer): | |
| def set_evaluator(self, opt, model, data_loader): | |
| self.evaluator = evaluate.make_evaluator( | |
| opt, model, data_loader) | |
| def set_generator(self, opt, model, data_loader): | |
| self.generator = gen.make_generator( | |
| opt, model, data_loader) | |
| def batch(self, opt, *args): | |
| outputs = batch_utils.batch_atomic_generate(opt, *args) | |
| token_loss = outputs["loss"] | |
| nums = outputs["nums"] | |
| reset = outputs["reset"] | |
| return token_loss, nums, reset | |
| def update_top_score(self, opt): | |
| print(self.top_score) | |
| tracked_scores = self.get_tracked_score() | |
| if self.top_score is None: | |
| self.top_score = \ | |
| self.top_score = {"epoch": {}, "score": {}} | |
| self.top_score["epoch"]["total_micro"] = self.opt.train.dynamic.epoch | |
| self.top_score["score"]["total_micro"] = tracked_scores["total_micro"] | |
| else: | |
| if tracked_scores["total_micro"] < self.top_score["score"]["total_micro"]: | |
| self.top_score["epoch"]["total_micro"] = self.opt.train.dynamic.epoch | |
| self.top_score["score"]["total_micro"] = tracked_scores["total_micro"] | |
| print(self.top_score) | |
| def get_tracked_score(self): | |
| return { | |
| "total_micro": self.losses["dev"]["total_micro"][self.opt.train.dynamic.epoch] | |
| } | |
| def decide_to_save(self): | |
| to_save = cfg.save and not cfg.toy | |
| curr_epoch = self.opt.train.dynamic.epoch | |
| to_save = to_save or cfg.test_save | |
| print(cfg.save_strategy) | |
| if cfg.save_strategy == "best": | |
| if ((self.top_score["epoch"]["total_micro"] != curr_epoch)): | |
| to_save = False | |
| return to_save | |