Spaces:
Configuration error
Configuration error
| from concurrent.futures import ThreadPoolExecutor | |
| from lightning.pytorch import callbacks as pl_callbacks | |
| from typing_extensions import override | |
| from src.utils import logger | |
| class ModelCheckpointParallel(pl_callbacks.ModelCheckpoint): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.threads = [] | |
| self.thread_pool = ThreadPoolExecutor(1, thread_name_prefix="ModelCheckpointParallel") | |
| def on_train_batch_end(self, *args, **kwargs): | |
| trainer = args[0] | |
| if self._should_skip_saving_checkpoint(trainer): | |
| return | |
| self.threads.append(self.thread_pool.submit(super().on_train_batch_end, *args, **kwargs)) | |
| def on_train_epoch_end(self, *args, **kwargs): | |
| trainer = args[0] | |
| if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer): | |
| self.threads.append(self.thread_pool.submit(super().on_train_epoch_end, *args, **kwargs)) | |
| def on_validation_end(self, *args, **kwargs): | |
| trainer = args[0] | |
| if not self._should_skip_saving_checkpoint(trainer) and not self._should_save_on_train_epoch_end(trainer): | |
| self.threads.append(self.thread_pool.submit(super().on_validation_end, *args, **kwargs)) | |
| def wait(self): | |
| for thread in self.threads: | |
| try: | |
| thread.result() | |
| except Exception as e: | |
| logger.print_error(f"Exception during checkpoint saving in thread: {e}") | |
| self.thread_pool.shutdown(wait=True) | |
| self.thread_pool = ThreadPoolExecutor(1, thread_name_prefix="ModelCheckpointParallel") | |
| self.threads = [] | |
| def on_train_end(self, *args, **kwargs): | |
| self.wait() | |
| def on_test_start(self, *args, **kwargs): | |
| self.wait() | |