| """ |
| Dumps things to tensorboard and console |
| """ |
|
|
| import datetime |
| import logging |
| import math |
| import os |
| from collections import defaultdict |
| from pathlib import Path |
| from typing import Optional, Union |
|
|
| import matplotlib.pyplot as plt |
| import numpy as np |
| import torch |
| import torchaudio |
| from PIL import Image |
| from pytz import timezone |
| from torch.utils.tensorboard import SummaryWriter |
|
|
| from mmaudio.utils.email_utils import EmailSender |
| from mmaudio.utils.time_estimator import PartialTimeEstimator, TimeEstimator |
| from mmaudio.utils.timezone import my_timezone |
|
|
|
|
| def tensor_to_numpy(image: torch.Tensor): |
| image_np = (image.numpy() * 255).astype('uint8') |
| return image_np |
|
|
|
|
| def detach_to_cpu(x: torch.Tensor): |
| return x.detach().cpu() |
|
|
|
|
| def fix_width_trunc(x: float): |
| return ('{:.9s}'.format('{:0.9f}'.format(x))) |
|
|
|
|
| def plot_spectrogram(spectrogram: np.ndarray, title=None, ylabel="freq_bin", ax=None): |
| if ax is None: |
| _, ax = plt.subplots(1, 1) |
| if title is not None: |
| ax.set_title(title) |
| ax.set_ylabel(ylabel) |
| ax.imshow(spectrogram, origin="lower", aspect="auto", interpolation="nearest") |
|
|
|
|
| class TensorboardLogger: |
|
|
| def __init__(self, |
| exp_id: str, |
| run_dir: Union[Path, str], |
| py_logger: logging.Logger, |
| *, |
| is_rank0: bool = False, |
| enable_email: bool = False): |
| self.exp_id = exp_id |
| self.run_dir = Path(run_dir) |
| self.py_log = py_logger |
| self.email_sender = EmailSender(exp_id, enable=(is_rank0 and enable_email)) |
| if is_rank0: |
| self.tb_log = SummaryWriter(run_dir) |
| else: |
| self.tb_log = None |
|
|
| |
| try: |
| import git |
| repo = git.Repo(".") |
| git_info = str(repo.active_branch) + ' ' + str(repo.head.commit.hexsha) |
| except (ImportError, RuntimeError, TypeError): |
| print('Failed to fetch git info. Defaulting to None') |
| git_info = 'None' |
|
|
| self.log_string('git', git_info) |
|
|
| |
| job_id = os.environ.get('SLURM_JOB_ID', None) |
| if job_id is not None: |
| self.log_string('slurm_job_id', job_id) |
| self.email_sender.send(f'Job {job_id} started', f'Job started {run_dir}') |
|
|
| |
| self.batch_timer: TimeEstimator = None |
| self.data_timer: PartialTimeEstimator = None |
|
|
| self.nan_count = defaultdict(int) |
|
|
| def log_scalar(self, tag: str, x: float, it: int): |
| if self.tb_log is None: |
| return |
| if math.isnan(x) and 'grad_norm' not in tag: |
| self.nan_count[tag] += 1 |
| if self.nan_count[tag] == 10: |
| self.email_sender.send( |
| f'Nan detected in {tag} @ {self.run_dir}', |
| f'Nan detected in {tag} at iteration {it}; run_dir: {self.run_dir}') |
| else: |
| self.nan_count[tag] = 0 |
| self.tb_log.add_scalar(tag, x, it) |
|
|
| def log_metrics(self, |
| prefix: str, |
| metrics: dict[str, float], |
| it: int, |
| ignore_timer: bool = False): |
| msg = f'{self.exp_id}-{prefix} - it {it:6d}: ' |
| metrics_msg = '' |
| for k, v in sorted(metrics.items()): |
| self.log_scalar(f'{prefix}/{k}', v, it) |
| metrics_msg += f'{k: >10}:{v:.7f},\t' |
|
|
| if self.batch_timer is not None and not ignore_timer: |
| self.batch_timer.update() |
| avg_time = self.batch_timer.get_and_reset_avg_time() |
| data_time = self.data_timer.get_and_reset_avg_time() |
|
|
| |
| self.log_scalar(f'{prefix}/avg_time', avg_time, it) |
| self.log_scalar(f'{prefix}/data_time', data_time, it) |
|
|
| est = self.batch_timer.get_est_remaining(it) |
| est = datetime.timedelta(seconds=est) |
| if est.days > 0: |
| remaining_str = f'{est.days}d {est.seconds // 3600}h' |
| else: |
| remaining_str = f'{est.seconds // 3600}h {(est.seconds%3600) // 60}m' |
| eta = datetime.datetime.now(timezone(my_timezone)) + est |
| eta_str = eta.strftime('%Y-%m-%d %H:%M:%S %Z%z') |
| time_msg = f'avg_time:{avg_time:.3f},data:{data_time:.3f},remaining:{remaining_str},eta:{eta_str},\t' |
| msg = f'{msg} {time_msg}' |
|
|
| msg = f'{msg} {metrics_msg}' |
| self.py_log.info(msg) |
|
|
| def log_histogram(self, tag: str, hist: torch.Tensor, it: int): |
| if self.tb_log is None: |
| return |
| |
| hist = hist.cpu().numpy() |
| fig, ax = plt.subplots() |
| x_range = np.linspace(0, 1, len(hist)) |
| ax.bar(x_range, hist, width=1 / (len(hist) - 1)) |
| ax.set_xticks(x_range) |
| ax.set_xticklabels(x_range) |
| plt.tight_layout() |
| self.tb_log.add_figure(tag, fig, it) |
| plt.close() |
|
|
| def log_image(self, prefix: str, tag: str, image: np.ndarray, it: int): |
| image_dir = self.run_dir / f'{prefix}_images' |
| image_dir.mkdir(exist_ok=True, parents=True) |
|
|
| image = Image.fromarray(image) |
| image.save(image_dir / f'{it:09d}_{tag}.png') |
|
|
| def log_audio(self, |
| prefix: str, |
| tag: str, |
| waveform: torch.Tensor, |
| it: Optional[int] = None, |
| *, |
| subdir: Optional[Path] = None, |
| sample_rate: int = 16000) -> Path: |
| if subdir is None: |
| audio_dir = self.run_dir / prefix |
| else: |
| audio_dir = self.run_dir / subdir / prefix |
| audio_dir.mkdir(exist_ok=True, parents=True) |
|
|
| if it is None: |
| name = f'{tag}.flac' |
| else: |
| name = f'{it:09d}_{tag}.flac' |
|
|
| torchaudio.save(audio_dir / name, |
| waveform.cpu().float(), |
| sample_rate=sample_rate, |
| channels_first=True) |
| return Path(audio_dir) |
|
|
| def log_spectrogram( |
| self, |
| prefix: str, |
| tag: str, |
| spec: torch.Tensor, |
| it: Optional[int], |
| *, |
| subdir: Optional[Path] = None, |
| ): |
| if subdir is None: |
| spec_dir = self.run_dir / prefix |
| else: |
| spec_dir = self.run_dir / subdir / prefix |
| spec_dir.mkdir(exist_ok=True, parents=True) |
|
|
| if it is None: |
| name = f'{tag}.png' |
| else: |
| name = f'{it:09d}_{tag}.png' |
|
|
| plot_spectrogram(spec.cpu().float()) |
| plt.tight_layout() |
| plt.savefig(spec_dir / name) |
| plt.close() |
|
|
| def log_string(self, tag: str, x: str): |
| self.py_log.info(f'{tag} - {x}') |
| if self.tb_log is None: |
| return |
| self.tb_log.add_text(tag, x) |
|
|
| def debug(self, x): |
| self.py_log.debug(x) |
|
|
| def info(self, x): |
| self.py_log.info(x) |
|
|
| def warning(self, x): |
| self.py_log.warning(x) |
|
|
| def error(self, x): |
| self.py_log.error(x) |
|
|
| def critical(self, x): |
| self.py_log.critical(x) |
|
|
| self.email_sender.send(f'Error occurred in {self.run_dir}', x) |
|
|
| def complete(self): |
| self.email_sender.send(f'Job completed in {self.run_dir}', 'Job completed') |
|
|