| # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import time | |
| from pytorch_lightning.callbacks import Callback | |
| from pytorch_lightning.utilities import rank_zero_only | |
| # from sacrebleu import corpus_bleu | |
| class LogEpochTimeCallback(Callback): | |
| """Simple callback that logs how long each epoch takes, in seconds, to a pytorch lightning log | |
| """ | |
| def on_train_epoch_start(self, trainer, pl_module): | |
| self.epoch_start = time.time() | |
| def on_train_epoch_end(self, trainer, pl_module): | |
| curr_time = time.time() | |
| duration = curr_time - self.epoch_start | |
| trainer.logger.log_metrics({"epoch_time": duration}, step=trainer.global_step) | |
| # class MachineTranslationLogEvalCallback(Callback): | |
| # def _on_eval_end(self, trainer, pl_module, mode): | |
| # counts = np.array(self._non_pad_tokens) | |
| # eval_loss = np.sum(np.array(self._losses) * counts) / np.sum(counts) | |
| # sacre_bleu = corpus_bleu(self._translations, [self._ground_truths], tokenize="13a") | |
| # print(f"{mode} results for process with global rank {pl_module.global_rank}".upper()) | |
| # for i in range(pl_module.num_examples[mode]): | |
| # print('\u0332'.join(f"EXAMPLE {i}:")) # Underline output | |
| # sent_id = np.random.randint(len(self._translations)) | |
| # print(f"Ground truth: {self._ground_truths[sent_id]}\n") | |
| # print(f"Translation: {self._translations[sent_id]}\n") | |
| # print() | |
| # print("-" * 50) | |
| # print(f"loss: {eval_loss:.3f}") | |
| # print(f"SacreBLEU: {sacre_bleu}") | |
| # print("-" * 50) | |
| # @rank_zero_only | |
| # def on_test_end(self, trainer, pl_module): | |
| # self._on_eval_end(trainer, pl_module, "test") | |
| # @rank_zero_only | |
| # def on_validation_end(self, trainer, pl_module): | |
| # self._on_eval_end(trainer, pl_module, "val") | |
| # @rank_zero_only | |
| # def on_sanity_check_end(self, trainer, pl_module): | |
| # self._on_eval_end(trainer, pl_module, "val") | |
| # def _on_eval_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx, mode): | |
| # self._translations.extend(outputs['translations']) | |
| # self._ground_truths.extend(outputs['ground_truths']) | |
| # self._non_pad_tokens.append(outputs['num_non_pad_tokens']) | |
| # self._losses.append(outputs[f'{mode}_loss']) | |
| # @rank_zero_only | |
| # def on_test_batch_end(self, trainer, pl_module, batch, outputs, batch_idx, dataloader_idx): | |
| # self._on_eval_batch_end(trainer, pl_module, batch, outputs, batch_idx, dataloader_idx, 'test') | |
| # @rank_zero_only | |
| # def on_validation_batch_end(self, trainer, pl_module, batch, outputs, batch_idx, dataloader_idx): | |
| # self._on_eval_batch_end(trainer, pl_module, batch, outputs, batch_idx, dataloader_idx, 'val') | |
| # def _on_eval_start(self, trainer, pl_module): | |
| # self._translations = [] | |
| # self._ground_truths = [] | |
| # self._losses = [] | |
| # self._non_pad_tokens = [] | |
| # @rank_zero_only | |
| # def on_test_start(self, trainer, pl_module): | |
| # self._on_eval_start(trainer, pl_module) | |
| # @rank_zero_only | |
| # def on_validation_start(self, trainer, pl_module): | |
| # self._on_eval_start(trainer, pl_module) | |
| # @rank_zero_only | |
| # def on_sanity_check_start(self, trainer, pl_module): | |
| # self._on_eval_start(trainer, pl_module) | |