camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_only
# from sacrebleu import corpus_bleu
class LogEpochTimeCallback(Callback):
"""Simple callback that logs how long each epoch takes, in seconds, to a pytorch lightning log
"""
@rank_zero_only
def on_train_epoch_start(self, trainer, pl_module):
self.epoch_start = time.time()
@rank_zero_only
def on_train_epoch_end(self, trainer, pl_module):
curr_time = time.time()
duration = curr_time - self.epoch_start
trainer.logger.log_metrics({"epoch_time": duration}, step=trainer.global_step)
# class MachineTranslationLogEvalCallback(Callback):
# def _on_eval_end(self, trainer, pl_module, mode):
# counts = np.array(self._non_pad_tokens)
# eval_loss = np.sum(np.array(self._losses) * counts) / np.sum(counts)
# sacre_bleu = corpus_bleu(self._translations, [self._ground_truths], tokenize="13a")
# print(f"{mode} results for process with global rank {pl_module.global_rank}".upper())
# for i in range(pl_module.num_examples[mode]):
# print('\u0332'.join(f"EXAMPLE {i}:")) # Underline output
# sent_id = np.random.randint(len(self._translations))
# print(f"Ground truth: {self._ground_truths[sent_id]}\n")
# print(f"Translation: {self._translations[sent_id]}\n")
# print()
# print("-" * 50)
# print(f"loss: {eval_loss:.3f}")
# print(f"SacreBLEU: {sacre_bleu}")
# print("-" * 50)
# @rank_zero_only
# def on_test_end(self, trainer, pl_module):
# self._on_eval_end(trainer, pl_module, "test")
# @rank_zero_only
# def on_validation_end(self, trainer, pl_module):
# self._on_eval_end(trainer, pl_module, "val")
# @rank_zero_only
# def on_sanity_check_end(self, trainer, pl_module):
# self._on_eval_end(trainer, pl_module, "val")
# def _on_eval_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx, mode):
# self._translations.extend(outputs['translations'])
# self._ground_truths.extend(outputs['ground_truths'])
# self._non_pad_tokens.append(outputs['num_non_pad_tokens'])
# self._losses.append(outputs[f'{mode}_loss'])
# @rank_zero_only
# def on_test_batch_end(self, trainer, pl_module, batch, outputs, batch_idx, dataloader_idx):
# self._on_eval_batch_end(trainer, pl_module, batch, outputs, batch_idx, dataloader_idx, 'test')
# @rank_zero_only
# def on_validation_batch_end(self, trainer, pl_module, batch, outputs, batch_idx, dataloader_idx):
# self._on_eval_batch_end(trainer, pl_module, batch, outputs, batch_idx, dataloader_idx, 'val')
# def _on_eval_start(self, trainer, pl_module):
# self._translations = []
# self._ground_truths = []
# self._losses = []
# self._non_pad_tokens = []
# @rank_zero_only
# def on_test_start(self, trainer, pl_module):
# self._on_eval_start(trainer, pl_module)
# @rank_zero_only
# def on_validation_start(self, trainer, pl_module):
# self._on_eval_start(trainer, pl_module)
# @rank_zero_only
# def on_sanity_check_start(self, trainer, pl_module):
# self._on_eval_start(trainer, pl_module)