| | import argparse |
| | from contextlib import contextmanager |
| | import dataclasses |
| | from dataclasses import is_dataclass |
| | from distutils.version import LooseVersion |
| | import logging |
| | from pathlib import Path |
| | import time |
| | from typing import Dict |
| | from typing import Iterable |
| | from typing import List |
| | from typing import Optional |
| | from typing import Sequence |
| | from typing import Tuple |
| | from typing import Union |
| |
|
| | import humanfriendly |
| | import numpy as np |
| | import torch |
| | import torch.nn |
| | import torch.optim |
| | from typeguard import check_argument_types |
| |
|
| | from espnet2.iterators.abs_iter_factory import AbsIterFactory |
| | from espnet2.main_funcs.average_nbest_models import average_nbest_models |
| | from espnet2.main_funcs.calculate_all_attentions import calculate_all_attentions |
| | from espnet2.schedulers.abs_scheduler import AbsBatchStepScheduler |
| | from espnet2.schedulers.abs_scheduler import AbsEpochStepScheduler |
| | from espnet2.schedulers.abs_scheduler import AbsScheduler |
| | from espnet2.schedulers.abs_scheduler import AbsValEpochStepScheduler |
| | from espnet2.torch_utils.add_gradient_noise import add_gradient_noise |
| | from espnet2.torch_utils.device_funcs import to_device |
| | from espnet2.torch_utils.recursive_op import recursive_average |
| | from espnet2.torch_utils.set_all_random_seed import set_all_random_seed |
| | from espnet2.train.abs_espnet_model import AbsESPnetModel |
| | from espnet2.train.distributed_utils import DistributedOption |
| | from espnet2.train.reporter import Reporter |
| | from espnet2.train.reporter import SubReporter |
| | from espnet2.utils.build_dataclass import build_dataclass |
| |
|
| | if LooseVersion(torch.__version__) >= LooseVersion("1.1.0"): |
| | from torch.utils.tensorboard import SummaryWriter |
| | else: |
| | from tensorboardX import SummaryWriter |
| | if torch.distributed.is_available(): |
| | if LooseVersion(torch.__version__) > LooseVersion("1.0.1"): |
| | from torch.distributed import ReduceOp |
| | else: |
| | from torch.distributed import reduce_op as ReduceOp |
| | else: |
| | ReduceOp = None |
| |
|
| | if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): |
| | from torch.cuda.amp import autocast |
| | from torch.cuda.amp import GradScaler |
| | else: |
| | |
| | @contextmanager |
| | def autocast(enabled=True): |
| | yield |
| |
|
| | GradScaler = None |
| |
|
| | try: |
| | import fairscale |
| | except ImportError: |
| | fairscale = None |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class TrainerOptions: |
| | ngpu: int |
| | resume: bool |
| | use_amp: bool |
| | train_dtype: str |
| | grad_noise: bool |
| | accum_grad: int |
| | grad_clip: float |
| | grad_clip_type: float |
| | log_interval: Optional[int] |
| | no_forward_run: bool |
| | use_tensorboard: bool |
| | use_wandb: bool |
| | output_dir: Union[Path, str] |
| | max_epoch: int |
| | seed: int |
| | sharded_ddp: bool |
| | patience: Optional[int] |
| | keep_nbest_models: Union[int, List[int]] |
| | early_stopping_criterion: Sequence[str] |
| | best_model_criterion: Sequence[Sequence[str]] |
| | val_scheduler_criterion: Sequence[str] |
| | unused_parameters: bool |
| |
|
| |
|
| | class Trainer: |
| | """Trainer having a optimizer. |
| | |
| | If you'd like to use multiple optimizers, then inherit this class |
| | and override the methods if necessary - at least "train_one_epoch()" |
| | |
| | >>> class TwoOptimizerTrainer(Trainer): |
| | ... @classmethod |
| | ... def add_arguments(cls, parser): |
| | ... ... |
| | ... |
| | ... @classmethod |
| | ... def train_one_epoch(cls, model, optimizers, ...): |
| | ... loss1 = model.model1(...) |
| | ... loss1.backward() |
| | ... optimizers[0].step() |
| | ... |
| | ... loss2 = model.model2(...) |
| | ... loss2.backward() |
| | ... optimizers[1].step() |
| | |
| | """ |
| |
|
| | def __init__(self): |
| | raise RuntimeError("This class can't be instantiated.") |
| |
|
| | @classmethod |
| | def build_options(cls, args: argparse.Namespace) -> TrainerOptions: |
| | """Build options consumed by train(), eval(), and plot_attention()""" |
| | assert check_argument_types() |
| | return build_dataclass(TrainerOptions, args) |
| |
|
| | @classmethod |
| | def add_arguments(cls, parser: argparse.ArgumentParser): |
| | """Reserved for future development of another Trainer""" |
| | pass |
| |
|
| | @staticmethod |
| | def resume( |
| | checkpoint: Union[str, Path], |
| | model: torch.nn.Module, |
| | reporter: Reporter, |
| | optimizers: Sequence[torch.optim.Optimizer], |
| | schedulers: Sequence[Optional[AbsScheduler]], |
| | scaler: Optional[GradScaler], |
| | ngpu: int = 0, |
| | ): |
| | states = torch.load( |
| | checkpoint, |
| | map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu", |
| | ) |
| | model.load_state_dict(states["model"]) |
| | reporter.load_state_dict(states["reporter"]) |
| | for optimizer, state in zip(optimizers, states["optimizers"]): |
| | optimizer.load_state_dict(state) |
| | for scheduler, state in zip(schedulers, states["schedulers"]): |
| | if scheduler is not None: |
| | scheduler.load_state_dict(state) |
| | if scaler is not None: |
| | if states["scaler"] is None: |
| | logging.warning("scaler state is not found") |
| | else: |
| | scaler.load_state_dict(states["scaler"]) |
| |
|
| | logging.info(f"The training was resumed using {checkpoint}") |
| |
|
| | @classmethod |
| | def run( |
| | cls, |
| | model: AbsESPnetModel, |
| | optimizers: Sequence[torch.optim.Optimizer], |
| | schedulers: Sequence[Optional[AbsScheduler]], |
| | train_iter_factory: AbsIterFactory, |
| | valid_iter_factory: AbsIterFactory, |
| | plot_attention_iter_factory: Optional[AbsIterFactory], |
| | trainer_options, |
| | distributed_option: DistributedOption, |
| | ) -> None: |
| | """Perform training. This method performs the main process of training.""" |
| | assert check_argument_types() |
| | |
| | assert is_dataclass(trainer_options), type(trainer_options) |
| | assert len(optimizers) == len(schedulers), (len(optimizers), len(schedulers)) |
| |
|
| | if isinstance(trainer_options.keep_nbest_models, int): |
| | keep_nbest_models = trainer_options.keep_nbest_models |
| | else: |
| | if len(trainer_options.keep_nbest_models) == 0: |
| | logging.warning("No keep_nbest_models is given. Change to [1]") |
| | trainer_options.keep_nbest_models = [1] |
| | keep_nbest_models = max(trainer_options.keep_nbest_models) |
| |
|
| | output_dir = Path(trainer_options.output_dir) |
| | reporter = Reporter() |
| | if trainer_options.use_amp: |
| | if LooseVersion(torch.__version__) < LooseVersion("1.6.0"): |
| | raise RuntimeError( |
| | "Require torch>=1.6.0 for Automatic Mixed Precision" |
| | ) |
| | if trainer_options.sharded_ddp: |
| | if fairscale is None: |
| | raise RuntimeError( |
| | "Requiring fairscale. Do 'pip install fairscale'" |
| | ) |
| | scaler = fairscale.optim.grad_scaler.ShardedGradScaler() |
| | else: |
| | scaler = GradScaler() |
| | else: |
| | scaler = None |
| |
|
| | if trainer_options.resume and (output_dir / "checkpoint.pth").exists(): |
| | cls.resume( |
| | checkpoint=output_dir / "checkpoint.pth", |
| | model=model, |
| | optimizers=optimizers, |
| | schedulers=schedulers, |
| | reporter=reporter, |
| | scaler=scaler, |
| | ngpu=trainer_options.ngpu, |
| | ) |
| |
|
| | start_epoch = reporter.get_epoch() + 1 |
| | if start_epoch == trainer_options.max_epoch + 1: |
| | logging.warning( |
| | f"The training has already reached at max_epoch: {start_epoch}" |
| | ) |
| |
|
| | if distributed_option.distributed: |
| | if trainer_options.sharded_ddp: |
| | dp_model = fairscale.nn.data_parallel.ShardedDataParallel( |
| | module=model, |
| | sharded_optimizer=optimizers, |
| | ) |
| | else: |
| | dp_model = torch.nn.parallel.DistributedDataParallel( |
| | model, |
| | device_ids=( |
| | |
| | [torch.cuda.current_device()] |
| | if distributed_option.ngpu == 1 |
| | |
| | else None |
| | ), |
| | output_device=( |
| | torch.cuda.current_device() |
| | if distributed_option.ngpu == 1 |
| | else None |
| | ), |
| | find_unused_parameters=trainer_options.unused_parameters, |
| | ) |
| | elif distributed_option.ngpu > 1: |
| | dp_model = torch.nn.parallel.DataParallel( |
| | model, |
| | device_ids=list(range(distributed_option.ngpu)), |
| | ) |
| | else: |
| | |
| | |
| | dp_model = model |
| |
|
| | if trainer_options.use_tensorboard and ( |
| | not distributed_option.distributed or distributed_option.dist_rank == 0 |
| | ): |
| | summary_writer = SummaryWriter(str(output_dir / "tensorboard")) |
| | else: |
| | summary_writer = None |
| |
|
| | start_time = time.perf_counter() |
| | for iepoch in range(start_epoch, trainer_options.max_epoch + 1): |
| | if iepoch != start_epoch: |
| | logging.info( |
| | "{}/{}epoch started. Estimated time to finish: {}".format( |
| | iepoch, |
| | trainer_options.max_epoch, |
| | humanfriendly.format_timespan( |
| | (time.perf_counter() - start_time) |
| | / (iepoch - start_epoch) |
| | * (trainer_options.max_epoch - iepoch + 1) |
| | ), |
| | ) |
| | ) |
| | else: |
| | logging.info(f"{iepoch}/{trainer_options.max_epoch}epoch started") |
| | set_all_random_seed(trainer_options.seed + iepoch) |
| |
|
| | reporter.set_epoch(iepoch) |
| | |
| | with reporter.observe("train") as sub_reporter: |
| | all_steps_are_invalid = cls.train_one_epoch( |
| | model=dp_model, |
| | optimizers=optimizers, |
| | schedulers=schedulers, |
| | iterator=train_iter_factory.build_iter(iepoch), |
| | reporter=sub_reporter, |
| | scaler=scaler, |
| | summary_writer=summary_writer, |
| | options=trainer_options, |
| | distributed_option=distributed_option, |
| | ) |
| |
|
| | with reporter.observe("valid") as sub_reporter: |
| | cls.validate_one_epoch( |
| | model=dp_model, |
| | iterator=valid_iter_factory.build_iter(iepoch), |
| | reporter=sub_reporter, |
| | options=trainer_options, |
| | distributed_option=distributed_option, |
| | ) |
| |
|
| | if not distributed_option.distributed or distributed_option.dist_rank == 0: |
| | |
| | if plot_attention_iter_factory is not None: |
| | with reporter.observe("att_plot") as sub_reporter: |
| | cls.plot_attention( |
| | model=model, |
| | output_dir=output_dir / "att_ws", |
| | summary_writer=summary_writer, |
| | iterator=plot_attention_iter_factory.build_iter(iepoch), |
| | reporter=sub_reporter, |
| | options=trainer_options, |
| | ) |
| |
|
| | |
| | for scheduler in schedulers: |
| | if isinstance(scheduler, AbsValEpochStepScheduler): |
| | scheduler.step( |
| | reporter.get_value(*trainer_options.val_scheduler_criterion) |
| | ) |
| | elif isinstance(scheduler, AbsEpochStepScheduler): |
| | scheduler.step() |
| | if trainer_options.sharded_ddp: |
| | for optimizer in optimizers: |
| | if isinstance(optimizer, fairscale.optim.oss.OSS): |
| | optimizer.consolidate_state_dict() |
| |
|
| | if not distributed_option.distributed or distributed_option.dist_rank == 0: |
| | |
| | logging.info(reporter.log_message()) |
| | reporter.matplotlib_plot(output_dir / "images") |
| | if summary_writer is not None: |
| | reporter.tensorboard_add_scalar(summary_writer) |
| | if trainer_options.use_wandb: |
| | reporter.wandb_log() |
| |
|
| | |
| | torch.save( |
| | { |
| | "model": model.state_dict(), |
| | "reporter": reporter.state_dict(), |
| | "optimizers": [o.state_dict() for o in optimizers], |
| | "schedulers": [ |
| | s.state_dict() if s is not None else None |
| | for s in schedulers |
| | ], |
| | "scaler": scaler.state_dict() if scaler is not None else None, |
| | }, |
| | output_dir / "checkpoint.pth", |
| | ) |
| |
|
| | |
| | torch.save(model.state_dict(), output_dir / f"{iepoch}epoch.pth") |
| |
|
| | |
| | p = output_dir / "latest.pth" |
| | if p.is_symlink() or p.exists(): |
| | p.unlink() |
| | p.symlink_to(f"{iepoch}epoch.pth") |
| |
|
| | _improved = [] |
| | for _phase, k, _mode in trainer_options.best_model_criterion: |
| | |
| | if reporter.has(_phase, k): |
| | best_epoch = reporter.get_best_epoch(_phase, k, _mode) |
| | |
| | if best_epoch == iepoch: |
| | p = output_dir / f"{_phase}.{k}.best.pth" |
| | if p.is_symlink() or p.exists(): |
| | p.unlink() |
| | p.symlink_to(f"{iepoch}epoch.pth") |
| | _improved.append(f"{_phase}.{k}") |
| | if len(_improved) == 0: |
| | logging.info("There are no improvements in this epoch") |
| | else: |
| | logging.info( |
| | "The best model has been updated: " + ", ".join(_improved) |
| | ) |
| |
|
| | |
| | _removed = [] |
| | |
| | nbests = set().union( |
| | *[ |
| | set(reporter.sort_epochs(ph, k, m)[:keep_nbest_models]) |
| | for ph, k, m in trainer_options.best_model_criterion |
| | if reporter.has(ph, k) |
| | ] |
| | ) |
| | for e in range(1, iepoch): |
| | p = output_dir / f"{e}epoch.pth" |
| | if p.exists() and e not in nbests: |
| | p.unlink() |
| | _removed.append(str(p)) |
| | if len(_removed) != 0: |
| | logging.info("The model files were removed: " + ", ".join(_removed)) |
| |
|
| | |
| | if all_steps_are_invalid: |
| | logging.warning( |
| | f"The gradients at all steps are invalid in this epoch. " |
| | f"Something seems wrong. This training was stopped at {iepoch}epoch" |
| | ) |
| | break |
| |
|
| | |
| | if trainer_options.patience is not None: |
| | if reporter.check_early_stopping( |
| | trainer_options.patience, *trainer_options.early_stopping_criterion |
| | ): |
| | break |
| |
|
| | else: |
| | logging.info( |
| | f"The training was finished at {trainer_options.max_epoch} epochs " |
| | ) |
| |
|
| | if not distributed_option.distributed or distributed_option.dist_rank == 0: |
| | |
| | average_nbest_models( |
| | reporter=reporter, |
| | output_dir=output_dir, |
| | best_model_criterion=trainer_options.best_model_criterion, |
| | nbest=keep_nbest_models, |
| | ) |
| |
|
| | @classmethod |
| | def train_one_epoch( |
| | cls, |
| | model: torch.nn.Module, |
| | iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]], |
| | optimizers: Sequence[torch.optim.Optimizer], |
| | schedulers: Sequence[Optional[AbsScheduler]], |
| | scaler: Optional[GradScaler], |
| | reporter: SubReporter, |
| | summary_writer: Optional[SummaryWriter], |
| | options: TrainerOptions, |
| | distributed_option: DistributedOption, |
| | ) -> bool: |
| | assert check_argument_types() |
| |
|
| | grad_noise = options.grad_noise |
| | accum_grad = options.accum_grad |
| | grad_clip = options.grad_clip |
| | grad_clip_type = options.grad_clip_type |
| | log_interval = options.log_interval |
| | no_forward_run = options.no_forward_run |
| | ngpu = options.ngpu |
| | use_wandb = options.use_wandb |
| | distributed = distributed_option.distributed |
| |
|
| | if log_interval is None: |
| | try: |
| | log_interval = max(len(iterator) // 20, 10) |
| | except TypeError: |
| | log_interval = 100 |
| |
|
| | model.train() |
| | all_steps_are_invalid = True |
| | |
| | |
| | iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu") |
| |
|
| | start_time = time.perf_counter() |
| | for iiter, (_, batch) in enumerate( |
| | reporter.measure_iter_time(iterator, "iter_time"), 1 |
| | ): |
| | assert isinstance(batch, dict), type(batch) |
| |
|
| | if distributed: |
| | torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) |
| | if iterator_stop > 0: |
| | break |
| |
|
| | batch = to_device(batch, "cuda" if ngpu > 0 else "cpu") |
| | if no_forward_run: |
| | all_steps_are_invalid = False |
| | continue |
| |
|
| | with autocast(scaler is not None): |
| | with reporter.measure_time("forward_time"): |
| | retval = model(**batch) |
| |
|
| | |
| | |
| | |
| | if isinstance(retval, dict): |
| | loss = retval["loss"] |
| | stats = retval["stats"] |
| | weight = retval["weight"] |
| | optim_idx = retval.get("optim_idx") |
| | if optim_idx is not None and not isinstance(optim_idx, int): |
| | if not isinstance(optim_idx, torch.Tensor): |
| | raise RuntimeError( |
| | "optim_idx must be int or 1dim torch.Tensor, " |
| | f"but got {type(optim_idx)}" |
| | ) |
| | if optim_idx.dim() >= 2: |
| | raise RuntimeError( |
| | "optim_idx must be int or 1dim torch.Tensor, " |
| | f"but got {optim_idx.dim()}dim tensor" |
| | ) |
| | if optim_idx.dim() == 1: |
| | for v in optim_idx: |
| | if v != optim_idx[0]: |
| | raise RuntimeError( |
| | "optim_idx must be 1dim tensor " |
| | "having same values for all entries" |
| | ) |
| | optim_idx = optim_idx[0].item() |
| | else: |
| | optim_idx = optim_idx.item() |
| |
|
| | |
| | else: |
| | loss, stats, weight = retval |
| | optim_idx = None |
| |
|
| | stats = {k: v for k, v in stats.items() if v is not None} |
| | if ngpu > 1 or distributed: |
| | |
| | loss = (loss * weight.type(loss.dtype)).sum() |
| |
|
| | |
| | stats, weight = recursive_average(stats, weight, distributed) |
| |
|
| | |
| | loss /= weight |
| | if distributed: |
| | |
| | |
| | loss *= torch.distributed.get_world_size() |
| |
|
| | loss /= accum_grad |
| |
|
| | reporter.register(stats, weight) |
| |
|
| | with reporter.measure_time("backward_time"): |
| | if scaler is not None: |
| | |
| | |
| | |
| | |
| | |
| | scaler.scale(loss).backward() |
| | else: |
| | loss.backward() |
| |
|
| | if iiter % accum_grad == 0: |
| | if scaler is not None: |
| | |
| | for iopt, optimizer in enumerate(optimizers): |
| | if optim_idx is not None and iopt != optim_idx: |
| | continue |
| | scaler.unscale_(optimizer) |
| |
|
| | |
| | if grad_noise: |
| | add_gradient_noise( |
| | model, |
| | reporter.get_total_count(), |
| | duration=100, |
| | eta=1.0, |
| | scale_factor=0.55, |
| | ) |
| |
|
| | |
| | grad_norm = torch.nn.utils.clip_grad_norm_( |
| | model.parameters(), |
| | max_norm=grad_clip, |
| | norm_type=grad_clip_type, |
| | ) |
| | |
| | if not isinstance(grad_norm, torch.Tensor): |
| | grad_norm = torch.tensor(grad_norm) |
| |
|
| | if not torch.isfinite(grad_norm): |
| | logging.warning( |
| | f"The grad norm is {grad_norm}. Skipping updating the model." |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | if scaler is not None: |
| | for iopt, optimizer in enumerate(optimizers): |
| | if optim_idx is not None and iopt != optim_idx: |
| | continue |
| | scaler.step(optimizer) |
| | scaler.update() |
| |
|
| | else: |
| | all_steps_are_invalid = False |
| | with reporter.measure_time("optim_step_time"): |
| | for iopt, (optimizer, scheduler) in enumerate( |
| | zip(optimizers, schedulers) |
| | ): |
| | if optim_idx is not None and iopt != optim_idx: |
| | continue |
| | if scaler is not None: |
| | |
| | |
| | scaler.step(optimizer) |
| | |
| | scaler.update() |
| | else: |
| | optimizer.step() |
| | if isinstance(scheduler, AbsBatchStepScheduler): |
| | scheduler.step() |
| | optimizer.zero_grad() |
| |
|
| | |
| | |
| | reporter.register( |
| | dict( |
| | { |
| | f"optim{i}_lr{j}": pg["lr"] |
| | for i, optimizer in enumerate(optimizers) |
| | for j, pg in enumerate(optimizer.param_groups) |
| | if "lr" in pg |
| | }, |
| | train_time=time.perf_counter() - start_time, |
| | ), |
| | ) |
| | start_time = time.perf_counter() |
| |
|
| | |
| | reporter.next() |
| | if iiter % log_interval == 0: |
| | logging.info(reporter.log_message(-log_interval)) |
| | if summary_writer is not None: |
| | reporter.tensorboard_add_scalar(summary_writer, -log_interval) |
| | if use_wandb: |
| | reporter.wandb_log() |
| |
|
| | else: |
| | if distributed: |
| | iterator_stop.fill_(1) |
| | torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) |
| |
|
| | return all_steps_are_invalid |
| |
|
| | @classmethod |
| | @torch.no_grad() |
| | def validate_one_epoch( |
| | cls, |
| | model: torch.nn.Module, |
| | iterator: Iterable[Dict[str, torch.Tensor]], |
| | reporter: SubReporter, |
| | options: TrainerOptions, |
| | distributed_option: DistributedOption, |
| | ) -> None: |
| | assert check_argument_types() |
| | ngpu = options.ngpu |
| | no_forward_run = options.no_forward_run |
| | distributed = distributed_option.distributed |
| |
|
| | model.eval() |
| |
|
| | |
| | |
| | iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu") |
| | for (_, batch) in iterator: |
| | assert isinstance(batch, dict), type(batch) |
| | if distributed: |
| | torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) |
| | if iterator_stop > 0: |
| | break |
| |
|
| | batch = to_device(batch, "cuda" if ngpu > 0 else "cpu") |
| | if no_forward_run: |
| | continue |
| |
|
| | retval = model(**batch) |
| | if isinstance(retval, dict): |
| | stats = retval["stats"] |
| | weight = retval["weight"] |
| | else: |
| | _, stats, weight = retval |
| | if ngpu > 1 or distributed: |
| | |
| | |
| | stats, weight = recursive_average(stats, weight, distributed) |
| |
|
| | reporter.register(stats, weight) |
| | reporter.next() |
| |
|
| | else: |
| | if distributed: |
| | iterator_stop.fill_(1) |
| | torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) |
| |
|
| | @classmethod |
| | @torch.no_grad() |
| | def plot_attention( |
| | cls, |
| | model: torch.nn.Module, |
| | output_dir: Optional[Path], |
| | summary_writer: Optional[SummaryWriter], |
| | iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]], |
| | reporter: SubReporter, |
| | options: TrainerOptions, |
| | ) -> None: |
| | assert check_argument_types() |
| | import matplotlib |
| |
|
| | ngpu = options.ngpu |
| | no_forward_run = options.no_forward_run |
| |
|
| | matplotlib.use("Agg") |
| | import matplotlib.pyplot as plt |
| | from matplotlib.ticker import MaxNLocator |
| |
|
| | model.eval() |
| | for ids, batch in iterator: |
| | assert isinstance(batch, dict), type(batch) |
| | assert len(next(iter(batch.values()))) == len(ids), ( |
| | len(next(iter(batch.values()))), |
| | len(ids), |
| | ) |
| | batch = to_device(batch, "cuda" if ngpu > 0 else "cpu") |
| | if no_forward_run: |
| | continue |
| |
|
| | |
| | |
| | att_dict = calculate_all_attentions(model, batch) |
| |
|
| | |
| | for k, att_list in att_dict.items(): |
| | assert len(att_list) == len(ids), (len(att_list), len(ids)) |
| | for id_, att_w in zip(ids, att_list): |
| |
|
| | if isinstance(att_w, torch.Tensor): |
| | att_w = att_w.detach().cpu().numpy() |
| |
|
| | if att_w.ndim == 2: |
| | att_w = att_w[None] |
| | elif att_w.ndim > 3 or att_w.ndim == 1: |
| | raise RuntimeError(f"Must be 2 or 3 dimension: {att_w.ndim}") |
| |
|
| | w, h = plt.figaspect(1.0 / len(att_w)) |
| | fig = plt.Figure(figsize=(w * 1.3, h * 1.3)) |
| | axes = fig.subplots(1, len(att_w)) |
| | if len(att_w) == 1: |
| | axes = [axes] |
| |
|
| | for ax, aw in zip(axes, att_w): |
| | ax.imshow(aw.astype(np.float32), aspect="auto") |
| | ax.set_title(f"{k}_{id_}") |
| | ax.set_xlabel("Input") |
| | ax.set_ylabel("Output") |
| | ax.xaxis.set_major_locator(MaxNLocator(integer=True)) |
| | ax.yaxis.set_major_locator(MaxNLocator(integer=True)) |
| |
|
| | if output_dir is not None: |
| | p = output_dir / id_ / f"{k}.{reporter.get_epoch()}ep.png" |
| | p.parent.mkdir(parents=True, exist_ok=True) |
| | fig.savefig(p) |
| |
|
| | if summary_writer is not None: |
| | summary_writer.add_figure( |
| | f"{k}_{id_}", fig, reporter.get_epoch() |
| | ) |
| | reporter.next() |
| |
|