| from collections import OrderedDict, deque |
| from datetime import datetime, timedelta |
| from numbers import Number |
| from time import time |
| from typing import Any, Dict, Union |
|
|
| import pytorch_lightning as pl |
| import torch |
| from pytorch_lightning.callbacks.progress import ProgressBar |
| from pytorch_lightning.callbacks.progress.tqdm_progress import Tqdm, TQDMProgressBar |
| from pytorch_lightning.utilities import rank_zero_only |
|
|
| from genmo.utils.pylogger import Log |
|
|
| |
|
|
|
|
| def format_num(n): |
| f = "{0:.3g}".format(n).replace("+0", "+").replace("-0", "-") |
| n = str(n) |
| return f if len(f) < len(n) else n |
|
|
|
|
| def convert_kwargs_to_str(**kwargs): |
| |
| postfix = OrderedDict([]) |
| for key in sorted(kwargs.keys()): |
| new_key = key.split("/")[-1] |
| postfix[new_key] = kwargs[key] |
| |
| for key in postfix.keys(): |
| |
| if isinstance(postfix[key], Number): |
| postfix[key] = format_num(postfix[key]) |
| |
| elif not isinstance(postfix[key], str): |
| postfix[key] = str(postfix[key]) |
| |
| |
| postfix = ", ".join(key + "=" + postfix[key].strip() for key in postfix.keys()) |
| return postfix |
|
|
|
|
| def convert_t_to_str(t): |
| """Convert time in second to string in format hour:minute:second. |
| If hour is 0, don't show it. Always show minute and second. |
| """ |
| t_str = timedelta(seconds=t) |
| t_str = str(t_str).split(".")[0] |
| if t_str[:2] == "0:": |
| t_str = t_str[2:] |
| return t_str |
|
|
|
|
| class MyTQDMProgressBar(TQDMProgressBar, pl.Callback): |
| def init_train_tqdm(self): |
| bar = Tqdm( |
| desc="Training", |
| bar_format="{desc}{percentage:3.0f}%[{bar:10}][{n_fmt}/{total_fmt}, {elapsed}→{remaining},{rate_fmt}]{postfix}", |
| position=(2 * self.process_position), |
| disable=self.is_disabled, |
| leave=False, |
| smoothing=0, |
| dynamic_ncols=False, |
| ) |
| return bar |
|
|
| @rank_zero_only |
| def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): |
| |
| super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) |
| |
| n = batch_idx + 1 |
| if self._should_update(n, self.train_progress_bar.total): |
| |
| |
| max_mem = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 / 1024.0 |
| post_fix_str = f"maxGPU={max_mem:.1f}GB" |
|
|
| |
| training_metrics = self.get_metrics(trainer, pl_module) |
| training_metrics.pop("v_num", None) |
| post_fix_str += ", " + convert_kwargs_to_str(**training_metrics) |
|
|
| |
| if "message" in outputs: |
| post_fix_str += ", " + outputs["message"] |
|
|
| self.train_progress_bar.set_postfix_str(post_fix_str) |
|
|
|
|
| class ProgressReporter(ProgressBar, pl.Callback): |
| def __init__( |
| self, |
| log_every_percent: float = 0.1, |
| exp_name=None, |
| data_name=None, |
| **kwargs, |
| ): |
| super().__init__() |
| self.enable = True |
| |
| self.log_every_percent = log_every_percent |
| self.exp_name = exp_name |
| self.data_name = data_name |
| self.batch_time_queue = deque(maxlen=5) |
| self.start_prompt = "🚀" |
| self.finish_prompt = "✅" |
| |
| self.n_finished = 0 |
| self.time_train_epoch_start = time() |
|
|
| def disable(self): |
| self.enable = False |
|
|
| def setup( |
| self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str |
| ) -> None: |
| |
| super().setup(trainer, pl_module, stage) |
| self.stage = stage |
| self.time_exp_start = time() |
| self.epoch_exp_start = trainer.current_epoch |
|
|
| if self.exp_name is None: |
| if hasattr(pl_module, "exp_name"): |
| self.exp_name = pl_module.exp_name |
| else: |
| self.exp_name = "Unnamed Experiment" |
| if self.data_name is None: |
| if hasattr(pl_module, "data_name"): |
| self.data_name = pl_module.data_name |
| else: |
| self.data_name = "Unknown Data" |
|
|
| def print(self, *args: Any, **kwargs: Any) -> None: |
| print(*args) |
|
|
| def get_metrics( |
| self, trainer: pl.Trainer, pl_module: pl.LightningModule |
| ) -> Dict[str, Union[str, float]]: |
| """Get metrics from trainer for progress bar.""" |
| items = super().get_metrics(trainer, pl_module) |
| items.pop("v_num", None) |
| return items |
|
|
| def _should_update(self, n_finished: int, total: int) -> bool: |
| """ |
| Rule: Log every `log_every_percent` percent, or the last batch. |
| """ |
| log_interval = max(int(total * self.log_every_percent), 1) |
| able = n_finished % log_interval == 0 or n_finished == total |
| if log_interval > 10: |
| able = able or n_finished in [5, 10] |
| able = able and self.enable |
| return able |
|
|
| @rank_zero_only |
| def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None: |
| self.print("=" * 80) |
| Log.info( |
| f"{self.start_prompt}[FIT][Epoch {trainer.current_epoch}] Data: {self.data_name} Experiment: {self.exp_name}" |
| ) |
| self.time_train_epoch_start = time() |
|
|
| @rank_zero_only |
| def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): |
| super().on_train_batch_end( |
| trainer, pl_module, outputs, batch, batch_idx |
| ) |
| total = self.total_train_batches |
|
|
| |
| n_finished = batch_idx + 1 |
| percent = 100 * n_finished / total |
| time_current = time() |
| self.batch_time_queue.append(time_current) |
| time_elapsed = time_current - self.time_train_epoch_start |
| time_remaining = time_elapsed * (total - n_finished) / n_finished |
| if len(self.batch_time_queue) == 1: |
| speed = 1 / time_elapsed |
| else: |
| speed = (len(self.batch_time_queue) - 1) / ( |
| self.batch_time_queue[-1] - self.batch_time_queue[0] |
| ) |
|
|
| |
| if not self._should_update(n_finished, total): |
| return |
|
|
| |
| |
| desc = "[Train]" |
|
|
| |
| time_elapsed_str = convert_t_to_str(time_elapsed) |
| time_remaining_str = convert_t_to_str(time_remaining) |
| speed_str = f"{speed:.2f}it/s" if speed > 1 else f"{1 / speed:.1f}s/it" |
| n_digit = len(str(total)) |
| desc_speed = f"[{n_finished:{n_digit}d}/{total}={percent:3.0f}%, {time_elapsed_str} → {time_remaining_str}, {speed_str}]" |
|
|
| |
| |
| max_mem = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 / 1024.0 |
| post_fix_str = f"maxGPU={max_mem:.1f}GB" |
|
|
| |
| train_metrics = self.get_metrics(trainer, pl_module) |
| train_metrics = { |
| k: v |
| for k, v in train_metrics.items() |
| if ("train" in k and "epoch" not in k) |
| } |
| post_fix_str += ", " + convert_kwargs_to_str(**train_metrics) |
|
|
| |
| if "message" in outputs: |
| post_fix_str += ", " + outputs["message"] |
| post_fix_str = f"[{post_fix_str}]" |
|
|
| |
| bar_output = f"{desc}{desc_speed}{post_fix_str}" |
| self.print(bar_output) |
|
|
| @rank_zero_only |
| def on_train_epoch_end( |
| self, trainer: pl.Trainer, pl_module: pl.LightningModule |
| ) -> None: |
| super().on_train_epoch_end(trainer, pl_module) |
|
|
| |
| self.batch_time_queue.clear() |
|
|
| |
| n_finished = trainer.current_epoch + 1 - self.epoch_exp_start |
| n_to_finish = trainer.max_epochs - trainer.current_epoch - 1 |
| time_current = time() |
| time_elapsed = time_current - self.time_exp_start |
| time_remaining = time_elapsed * n_to_finish / n_finished |
| time_elapsed_str = convert_t_to_str(time_elapsed) |
| time_remaining_str = convert_t_to_str(time_remaining) |
|
|
| |
| |
| train_metrics = self.get_metrics(trainer, pl_module) |
| train_metrics = { |
| k: v for k, v in train_metrics.items() if ("train" in k and "epoch" in k) |
| } |
| train_metrics_str = convert_kwargs_to_str(**train_metrics) |
|
|
| Log.info( |
| f"{self.finish_prompt}[FIT][Epoch {trainer.current_epoch}] finished! {time_elapsed_str}→{time_remaining_str} | {train_metrics_str}" |
| ) |
|
|
| |
| @rank_zero_only |
| def on_validation_epoch_start(self, trainer, pl_module): |
| self.time_val_epoch_start = time() |
|
|
| @rank_zero_only |
| def on_validation_batch_end( |
| self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0 |
| ): |
| self.n_finished += 1 |
| n_finished = self.n_finished |
| total = self.total_val_batches |
| if not self._should_update(n_finished, total): |
| return |
|
|
| |
| desc = "[Val]" |
|
|
| |
| percent = 100 * n_finished / total |
| time_current = time() |
| time_elapsed = time_current - self.time_val_epoch_start |
| time_remaining = time_elapsed * (total - n_finished) / n_finished |
| time_elapsed_str = convert_t_to_str(time_elapsed) |
| time_remaining_str = convert_t_to_str(time_remaining) |
| desc_speed = f"[{n_finished}/{total} ={percent:3.0f}%, {time_elapsed_str}→{time_remaining_str}]" |
|
|
| |
| bar_output = f"{desc} {desc_speed}" |
| self.print(bar_output) |
|
|
| def on_validation_epoch_end( |
| self, trainer: pl.Trainer, pl_module: pl.LightningModule |
| ) -> None: |
| |
| self.n_finished = 0 |
|
|
|
|
| class EmojiProgressReporter(ProgressBar, pl.Callback): |
| def __init__( |
| self, |
| refresh_rate_batch: Union[ |
| int, None |
| ] = 1, |
| refresh_rate_epoch: int = 1, |
| **kwargs, |
| ): |
| super().__init__() |
| self.enable = True |
| |
| self.refresh_rate_batch = refresh_rate_batch |
| self.refresh_rate_epoch = refresh_rate_epoch |
|
|
| |
| self.title_prompt = "📝" |
| self.prog_prompt = "🚀" |
| self.timer_prompt = "⌛️" |
| self.metric_prompt = "📌" |
| self.finish_prompt = "✅" |
|
|
| def disable(self): |
| self.enable = False |
|
|
| def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str): |
| |
| super().setup(trainer, pl_module, stage) |
| self.stage = stage |
| self.time_start_batch = None |
| self.time_start_epoch = None |
| if hasattr(pl_module, "exp_name"): |
| self.exp_name = pl_module.exp_name |
| else: |
| self.exp_name = "Unnamed Experiment" |
| Log.warn( |
| "Experiment name not found, please set it to `pl_module.exp_name`!" |
| ) |
|
|
| def print(self, *args: Any, **kwargs: Any): |
| print(*args) |
|
|
| def get_metrics( |
| self, trainer: pl.Trainer, pl_module: pl.LightningModule |
| ) -> Dict[str, Union[str, float]]: |
| """Get metrics from trainer for progress bar.""" |
| items = super().get_metrics(trainer, pl_module) |
| items.pop("v_num", None) |
| return dict(sorted(items.items())) |
|
|
| def _should_log_batch(self, n: int) -> bool: |
| |
| if self.refresh_rate_batch is None: |
| return False |
| |
| able = n % self.refresh_rate_batch == 0 or n == self.total_train_batches - 1 |
| able = able and self.enable |
| return able |
|
|
| def _should_log_epoch(self, n: int) -> bool: |
| |
| able = n % self.refresh_rate_epoch == 0 or n == self.trainer.max_epochs - 1 |
| able = able and self.enable |
| return able |
|
|
| def timestamp_delta_to_str(self, timestamp_delta: float): |
| """Convert delta timestamp to string.""" |
| time_rest = timedelta(seconds=timestamp_delta) |
| hours, remainder = divmod(time_rest.seconds, 3600) |
| minutes, seconds = divmod(remainder, 60) |
| time_str = "" |
|
|
| |
| if hours <= 0: |
| hours = None |
| if minutes <= 0: |
| minutes = None |
| if seconds <= 0: |
| seconds = None |
|
|
| time_str += f"{hours}h " if hours is not None else "" |
| time_str += f"{minutes}m " if minutes is not None else "" |
| time_str += f"{seconds}s" if seconds is not None else "" |
| return time_str |
|
|
| @rank_zero_only |
| def on_train_batch_start( |
| self, |
| trainer: pl.Trainer, |
| pl_module: pl.LightningModule, |
| batch: Any, |
| batch_idx: int, |
| ): |
| super().on_train_batch_start(trainer, pl_module, batch, batch_idx) |
| |
| if self.time_start_batch is None: |
| self.time_start_batch = datetime.now().timestamp() |
|
|
| @rank_zero_only |
| def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): |
| super().on_train_batch_end( |
| trainer, pl_module, outputs, batch, batch_idx |
| ) |
| |
| epoch_idx = trainer.current_epoch |
| percent = 100 * (batch_idx + 1) / (self.total_train_batches + 1) |
| metrics = self.get_metrics(trainer, pl_module) |
|
|
| |
| time_cur_stamp = datetime.now().timestamp() |
| time_cur_str = datetime.fromtimestamp(time_cur_stamp).strftime("%m-%d %H:%M:%S") |
| |
| time_rest_stamp = ( |
| (time_cur_stamp - self.time_start_batch) * (100 - percent) / percent |
| ) |
| time_rest_str = self.timestamp_delta_to_str(time_rest_stamp) |
|
|
| if not self._should_log_batch(batch_idx): |
| return |
|
|
| |
| self.print( |
| f"{self.title_prompt} [{self.stage.upper()}] Exp: {self.exp_name}..." |
| ) |
| self.print( |
| f"{self.prog_prompt} Ep {epoch_idx}: {int(percent):02d}% <= [{batch_idx}/{self.total_train_batches}]" |
| ) |
| self.print( |
| f"{self.timer_prompt} Time: {time_cur_str} | Ep Rest: {time_rest_str}" |
| ) |
| for k, v in metrics.items(): |
| self.print(f"{self.metric_prompt} {k}: {v}") |
| self.print("") |
|
|
| def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule): |
| super().on_train_epoch_start(trainer, pl_module) |
| |
| self.time_start_batch = None |
| if self.time_start_epoch is None: |
| self.time_start_epoch = datetime.now().timestamp() |
|
|
| @rank_zero_only |
| def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): |
| super().on_train_epoch_end(trainer, pl_module) |
| |
| epoch_idx = trainer.current_epoch |
| percent = 100 * (epoch_idx + 1) / (self.trainer.max_epochs + 1) |
| metrics = self.get_metrics(trainer, pl_module) |
|
|
| |
| time_cur = datetime.now().timestamp() |
| time_str = datetime.fromtimestamp(time_cur).strftime("%m-%d %H: %M:%S") |
| |
| time_rest_stamp = (time_cur - self.time_start_epoch) * (100 - percent) / percent |
| time_rest_str = self.timestamp_delta_to_str(time_rest_stamp) |
|
|
| if not self._should_log_batch(epoch_idx): |
| return |
|
|
| |
| self.print(">> >> >> >>") |
| self.print(f"{self.title_prompt} [{self.stage.upper()}] Exp: {self.exp_name}") |
| self.print(f"{self.finish_prompt} Ep {epoch_idx} finished!") |
| self.print(f"{self.timer_prompt} Time: {time_str} | Rest: {time_rest_str}") |
| for k, v in metrics.items(): |
| self.print(f"{self.metric_prompt} {k}: {v}") |
| self.print("<< << << <<") |
| self.print("") |
|
|