| |
| |
| |
| |
|
|
| import ast |
| import collections |
| import contextlib |
| import logging |
| import numpy as np |
| import os |
| import re |
| import time |
| import traceback |
| import math |
| from collections import OrderedDict |
| from typing import Any, Dict, Optional, Union |
|
|
| import torch |
| from fairseq.dataclass.configs import CheckpointConfig |
| from fairseq.dataclass.utils import ( |
| convert_namespace_to_omegaconf, |
| overwrite_args_by_name, |
| ) |
| from fairseq.distributed.fully_sharded_data_parallel import FSDP, has_FSDP |
| from fairseq.file_io import PathManager |
| from fairseq.models import FairseqDecoder, FairseqEncoder |
| from omegaconf import DictConfig, open_dict, OmegaConf |
|
|
| from data import data_utils |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): |
| from fairseq import meters |
|
|
| |
| if trainer.data_parallel_rank == 0: |
| os.makedirs(cfg.save_dir, exist_ok=True) |
|
|
| prev_best = getattr(save_checkpoint, "best", val_loss) |
| if val_loss is not None: |
| best_function = max if cfg.maximize_best_checkpoint_metric else min |
| save_checkpoint.best = best_function(val_loss, prev_best) |
|
|
| if cfg.no_save: |
| return |
|
|
| trainer.consolidate_optimizer() |
|
|
| if not trainer.should_save_checkpoint_on_current_rank: |
| if trainer.always_call_state_dict_during_save_checkpoint: |
| trainer.state_dict() |
| return |
|
|
| write_timer = meters.StopwatchMeter() |
| write_timer.start() |
|
|
| epoch = epoch_itr.epoch |
| end_of_epoch = epoch_itr.end_of_epoch() |
| updates = trainer.get_num_updates() |
|
|
| logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates") |
|
|
| def is_better(a, b): |
| return a >= b if cfg.maximize_best_checkpoint_metric else a <= b |
|
|
| suffix = trainer.checkpoint_suffix |
| checkpoint_conds = collections.OrderedDict() |
| checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = ( |
| end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0 |
| ) |
| checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = ( |
| not end_of_epoch |
| and cfg.save_interval_updates > 0 |
| and updates % cfg.save_interval_updates == 0 |
| ) |
| checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and ( |
| not hasattr(save_checkpoint, "best") |
| or is_better(val_loss, save_checkpoint.best) |
| ) |
| if val_loss is not None and cfg.keep_best_checkpoints > 0: |
| worst_best = getattr(save_checkpoint, "best", None) |
| chkpts = checkpoint_paths( |
| cfg.save_dir, |
| pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format( |
| cfg.best_checkpoint_metric, suffix |
| ), |
| ) |
| if len(chkpts) > 0: |
| p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0] |
| worst_best = float(p.rsplit("_")[-1].replace("{}.pt".format(suffix), "")) |
| |
| with data_utils.numpy_seed(epoch, updates, val_loss): |
| rand_sfx = np.random.randint(0, cfg.keep_best_checkpoints) |
|
|
| checkpoint_conds[ |
| "checkpoint.best_{}_{:.3f}{}{}.pt".format( |
| cfg.best_checkpoint_metric, |
| val_loss, |
| rand_sfx, |
| suffix |
| ) |
| ] = worst_best is None or is_better(val_loss, worst_best) |
| checkpoint_conds[ |
| "checkpoint_last{}.pt".format(suffix) |
| ] = not cfg.no_last_checkpoints |
|
|
| extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss} |
| if hasattr(save_checkpoint, "best"): |
| extra_state.update({"best": save_checkpoint.best}) |
|
|
| checkpoints = [ |
| os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond |
| ] |
| if len(checkpoints) > 0: |
| trainer.save_checkpoint(checkpoints[0], extra_state) |
| for cp in checkpoints[1:]: |
| if cfg.write_checkpoints_asynchronously: |
| |
| |
| logger.warning( |
| f"ioPath is not copying {checkpoints[0]} to {cp} " |
| "since async write mode is on." |
| ) |
| else: |
| assert PathManager.copy( |
| checkpoints[0], cp, overwrite=True |
| ), f"Failed to copy {checkpoints[0]} to {cp}" |
|
|
| write_timer.stop() |
| logger.info( |
| "Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format( |
| checkpoints[0], epoch, updates, val_loss, write_timer.sum |
| ) |
| ) |
|
|
| if not end_of_epoch and cfg.keep_interval_updates > 0: |
| |
| if cfg.keep_interval_updates_pattern == -1: |
| checkpoints = checkpoint_paths( |
| cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix) |
| ) |
| else: |
| checkpoints = checkpoint_paths( |
| cfg.save_dir, |
| pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix), |
| keep_match=True, |
| ) |
| checkpoints = [ |
| x[0] |
| for x in checkpoints |
| if x[1] % cfg.keep_interval_updates_pattern != 0 |
| ] |
|
|
| for old_chk in checkpoints[cfg.keep_interval_updates :]: |
| if os.path.lexists(old_chk): |
| os.remove(old_chk) |
| elif PathManager.exists(old_chk): |
| PathManager.rm(old_chk) |
|
|
| if cfg.keep_last_epochs > 0: |
| |
| checkpoints = checkpoint_paths( |
| cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix) |
| ) |
| for old_chk in checkpoints[cfg.keep_last_epochs :]: |
| if os.path.lexists(old_chk): |
| os.remove(old_chk) |
| elif PathManager.exists(old_chk): |
| PathManager.rm(old_chk) |
|
|
| if cfg.keep_best_checkpoints > 0: |
| |
| checkpoints = checkpoint_paths( |
| cfg.save_dir, |
| pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format( |
| cfg.best_checkpoint_metric, suffix |
| ), |
| ) |
| if not cfg.maximize_best_checkpoint_metric: |
| checkpoints = checkpoints[::-1] |
| for old_chk in checkpoints[cfg.keep_best_checkpoints :]: |
| if os.path.lexists(old_chk): |
| os.remove(old_chk) |
| elif PathManager.exists(old_chk): |
| PathManager.rm(old_chk) |
|
|
|
|
| def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): |
| """ |
| Load a checkpoint and restore the training iterator. |
| |
| *passthrough_args* will be passed through to |
| ``trainer.get_train_iterator``. |
| """ |
|
|
| reset_optimizer = cfg.reset_optimizer |
| reset_lr_scheduler = cfg.reset_lr_scheduler |
| optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides) |
| reset_meters = cfg.reset_meters |
| reset_dataloader = cfg.reset_dataloader |
|
|
| if cfg.finetune_from_model is not None and ( |
| reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader |
| ): |
| raise ValueError( |
| "--finetune-from-model can not be set together with either --reset-optimizer" |
| " or reset_lr_scheduler or reset_meters or reset_dataloader" |
| ) |
|
|
| suffix = trainer.checkpoint_suffix |
| if ( |
| cfg.restore_file == "checkpoint_last.pt" |
| ): |
| checkpoint_path = os.path.join( |
| cfg.save_dir, "checkpoint_last{}.pt".format(suffix) |
| ) |
| first_launch = not PathManager.exists(checkpoint_path) |
| if cfg.finetune_from_model is not None and first_launch: |
| |
| |
| if PathManager.exists(cfg.finetune_from_model): |
| checkpoint_path = cfg.finetune_from_model |
| reset_optimizer = True |
| reset_lr_scheduler = True |
| reset_meters = True |
| reset_dataloader = True |
| logger.info( |
| f"loading pretrained model from {checkpoint_path}: " |
| "optimizer, lr scheduler, meters, dataloader will be reset" |
| ) |
| else: |
| raise ValueError( |
| f"--funetune-from-model {cfg.finetune_from_model} does not exist" |
| ) |
| elif suffix is not None: |
| checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt") |
| else: |
| checkpoint_path = cfg.restore_file |
|
|
| if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model: |
| raise ValueError( |
| "--finetune-from-model and --restore-file (non-default value) " |
| "can not be specified together: " + str(cfg) |
| ) |
|
|
| extra_state = trainer.load_checkpoint( |
| checkpoint_path, |
| reset_optimizer, |
| reset_lr_scheduler, |
| optimizer_overrides, |
| reset_meters=reset_meters, |
| ) |
|
|
| if ( |
| extra_state is not None |
| and "best" in extra_state |
| and not reset_optimizer |
| and not reset_meters |
| ): |
| save_checkpoint.best = extra_state["best"] |
|
|
| if extra_state is not None and not reset_dataloader: |
| |
| itr_state = extra_state["train_iterator"] |
| epoch_itr = trainer.get_train_iterator( |
| epoch=itr_state["epoch"], load_dataset=True, **passthrough_args |
| ) |
| epoch_itr.load_state_dict(itr_state) |
| _n = itr_state['iterations_in_epoch'] |
| offset = sum(len(_) for _ in epoch_itr.batch_sampler[:_n]) |
| epoch_itr.dataset.dataset._seek(offset=offset) |
| true_num = int(math.ceil(len(epoch_itr.dataset) / 8)) * 8 |
| another_offset = ((epoch_itr.epoch - 1) * true_num + offset) // 8 |
| if hasattr(epoch_itr.dataset, 'pure_text_dataset'): |
| text_offset = (2 * another_offset) % len(epoch_itr.dataset.pure_text_dataset) |
| epoch_itr.dataset.pure_text_dataset._seek(offset=text_offset) |
| if hasattr(epoch_itr.dataset, 'pure_image_dataset'): |
| image_offset = another_offset % len(epoch_itr.dataset.pure_image_dataset) |
| epoch_itr.dataset.pure_image_dataset._seek(offset=image_offset) |
| if hasattr(epoch_itr.dataset, 'detection_dataset'): |
| detection_offset = another_offset % len(epoch_itr.dataset.detection_dataset) |
| epoch_itr.dataset.detection_dataset._seek(offset=detection_offset) |
| else: |
| epoch_itr = trainer.get_train_iterator( |
| epoch=1, load_dataset=True, **passthrough_args |
| ) |
|
|
| trainer.lr_step(epoch_itr.epoch) |
|
|
| return extra_state, epoch_itr |
|
|
|
|
| def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False): |
| """Loads a checkpoint to CPU (with upgrading for backward compatibility). |
| |
| If doing single-GPU training or if the checkpoint is only being loaded by at |
| most one process on each node (current default behavior is for only rank 0 |
| to read the checkpoint from disk), load_on_all_ranks should be False to |
| avoid errors from torch.distributed not having been initialized or |
| torch.distributed.barrier() hanging. |
| |
| If all processes on each node may be loading the checkpoint |
| simultaneously, load_on_all_ranks should be set to True to avoid I/O |
| conflicts. |
| |
| There's currently no support for > 1 but < all processes loading the |
| checkpoint on each node. |
| """ |
| local_path = PathManager.get_local_path(path) |
| |
| |
| |
| |
| if local_path != path and PathManager.path_requires_pathmanager(path): |
| try: |
| os.remove(local_path) |
| except FileNotFoundError: |
| |
| |
| |
| pass |
| if load_on_all_ranks: |
| torch.distributed.barrier() |
| local_path = PathManager.get_local_path(path) |
|
|
| with open(local_path, "rb") as f: |
| state = torch.load(f, map_location=torch.device("cpu")) |
|
|
| if "args" in state and state["args"] is not None and arg_overrides is not None: |
| args = state["args"] |
| for arg_name, arg_val in arg_overrides.items(): |
| setattr(args, arg_name, arg_val) |
|
|
| if "cfg" in state and state["cfg"] is not None: |
|
|
| |
| |
| from omegaconf import _utils |
|
|
| old_primitive = _utils.is_primitive_type |
| _utils.is_primitive_type = lambda _: True |
|
|
| state["cfg"] = OmegaConf.create(state["cfg"]) |
|
|
| _utils.is_primitive_type = old_primitive |
| OmegaConf.set_struct(state["cfg"], True) |
|
|
| if arg_overrides is not None: |
| overwrite_args_by_name(state["cfg"], arg_overrides) |
|
|
| state = _upgrade_state_dict(state) |
| return state |
|
|
|
|
| def load_model_ensemble( |
| filenames, |
| arg_overrides: Optional[Dict[str, Any]] = None, |
| task=None, |
| strict=True, |
| suffix="", |
| num_shards=1, |
| state=None, |
| ): |
| """Loads an ensemble of models. |
| |
| Args: |
| filenames (List[str]): checkpoint files to load |
| arg_overrides (Dict[str,Any], optional): override model args that |
| were used during model training |
| task (fairseq.tasks.FairseqTask, optional): task to use for loading |
| """ |
| assert not ( |
| strict and num_shards > 1 |
| ), "Cannot load state dict with strict=True and checkpoint shards > 1" |
| ensemble, args, _task = load_model_ensemble_and_task( |
| filenames, |
| arg_overrides, |
| task, |
| strict, |
| suffix, |
| num_shards, |
| state, |
| ) |
| return ensemble, args |
|
|
|
|
| def get_maybe_sharded_checkpoint_filename( |
| filename: str, suffix: str, shard_idx: int, num_shards: int |
| ) -> str: |
| orig_filename = filename |
| filename = filename.replace(".pt", suffix + ".pt") |
| fsdp_filename = filename[:-3] + f"-shard{shard_idx}.pt" |
| model_parallel_filename = orig_filename[:-3] + f"_part{shard_idx}.pt" |
| if PathManager.exists(fsdp_filename): |
| return fsdp_filename |
| elif num_shards > 1: |
| return model_parallel_filename |
| else: |
| return filename |
|
|
|
|
| def load_model_ensemble_and_task( |
| filenames, |
| arg_overrides: Optional[Dict[str, Any]] = None, |
| task=None, |
| strict=True, |
| suffix="", |
| num_shards=1, |
| state=None, |
| ): |
| assert state is None or len(filenames) == 1 |
|
|
| from fairseq import tasks |
|
|
| assert not ( |
| strict and num_shards > 1 |
| ), "Cannot load state dict with strict=True and checkpoint shards > 1" |
| ensemble = [] |
| cfg = None |
| for filename in filenames: |
| orig_filename = filename |
| model_shard_state = {"shard_weights": [], "shard_metadata": []} |
| assert num_shards > 0 |
| st = time.time() |
| for shard_idx in range(num_shards): |
| filename = get_maybe_sharded_checkpoint_filename( |
| orig_filename, suffix, shard_idx, num_shards |
| ) |
|
|
| if not PathManager.exists(filename): |
| raise IOError("Model file not found: {}".format(filename)) |
| if state is None: |
| state = load_checkpoint_to_cpu(filename, arg_overrides) |
| if "args" in state and state["args"] is not None: |
| cfg = convert_namespace_to_omegaconf(state["args"]) |
| elif "cfg" in state and state["cfg"] is not None: |
| cfg = state["cfg"] |
| else: |
| raise RuntimeError( |
| f"Neither args nor cfg exist in state keys = {state.keys()}" |
| ) |
|
|
| if task is None: |
| task = tasks.setup_task(cfg.task) |
|
|
| if "task_state" in state: |
| task.load_state_dict(state["task_state"]) |
|
|
| if "fsdp_metadata" in state and num_shards > 1: |
| model_shard_state["shard_weights"].append(state["model"]) |
| model_shard_state["shard_metadata"].append(state["fsdp_metadata"]) |
| |
| if not has_FSDP: |
| raise ImportError( |
| "Cannot find FullyShardedDataParallel. " |
| "Please install fairscale with: pip install fairscale" |
| ) |
| if shard_idx == num_shards - 1: |
| consolidated_model_state = FSDP.consolidate_shard_weights( |
| shard_weights=model_shard_state["shard_weights"], |
| shard_metadata=model_shard_state["shard_metadata"], |
| ) |
| model = task.build_model(cfg.model) |
| model.load_state_dict( |
| consolidated_model_state, strict=strict, model_cfg=cfg.model |
| ) |
| else: |
| |
| model = task.build_model(cfg.model) |
| model.load_state_dict( |
| state["model"], strict=strict, model_cfg=cfg.model |
| ) |
|
|
| |
| state = None |
| if shard_idx % 10 == 0 and shard_idx > 0: |
| elapsed = time.time() - st |
| logger.info( |
| f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard" |
| ) |
|
|
| |
| ensemble.append(model) |
| return ensemble, cfg, task |
|
|
|
|
| def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt", keep_match=False): |
| """Retrieves all checkpoints found in `path` directory. |
| |
| Checkpoints are identified by matching filename to the specified pattern. If |
| the pattern contains groups, the result will be sorted by the first group in |
| descending order. |
| """ |
| pt_regexp = re.compile(pattern) |
| files = PathManager.ls(path) |
|
|
| entries = [] |
| for i, f in enumerate(files): |
| m = pt_regexp.fullmatch(f) |
| if m is not None: |
| idx = float(m.group(1)) if len(m.groups()) > 0 else i |
| entries.append((idx, m.group(0))) |
| if keep_match: |
| return [(os.path.join(path, x[1]), x[0]) for x in sorted(entries, reverse=True)] |
| else: |
| return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] |
|
|
|
|
| def torch_persistent_save(obj, filename, async_write: bool = False): |
| if async_write: |
| with PathManager.opena(filename, "wb") as f: |
| _torch_persistent_save(obj, f) |
| else: |
| with PathManager.open(filename, "wb") as f: |
| _torch_persistent_save(obj, f) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| def _torch_persistent_save(obj, f): |
| if isinstance(f, str): |
| with PathManager.open(f, "wb") as h: |
| torch_persistent_save(obj, h) |
| return |
| for i in range(3): |
| try: |
| return torch.save(obj, f) |
| except Exception: |
| if i == 2: |
| logger.error(traceback.format_exc()) |
| raise |
|
|
|
|
| def _upgrade_state_dict(state): |
| """Helper for upgrading old model checkpoints.""" |
|
|
| |
| if "optimizer_history" not in state: |
| state["optimizer_history"] = [ |
| {"criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"]} |
| ] |
| state["last_optimizer_state"] = state["optimizer"] |
| del state["optimizer"] |
| del state["best_loss"] |
| |
| if "epoch" in state and "extra_state" not in state: |
| state["extra_state"] = { |
| "epoch": state["epoch"], |
| "batch_offset": state["batch_offset"], |
| "val_loss": state["val_loss"], |
| } |
| del state["epoch"] |
| del state["batch_offset"] |
| del state["val_loss"] |
| |
| if "optimizer" in state["optimizer_history"][-1]: |
| state["last_optimizer_state"] = state["optimizer_history"][-1]["optimizer"] |
| for optim_hist in state["optimizer_history"]: |
| del optim_hist["optimizer"] |
| |
| if "optimizer_name" not in state["optimizer_history"][-1]: |
| state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG" |
| |
| if "lr_scheduler_state" not in state["optimizer_history"][-1]: |
| state["optimizer_history"][-1]["lr_scheduler_state"] = { |
| "best": state["optimizer_history"][-1]["best_loss"] |
| } |
| del state["optimizer_history"][-1]["best_loss"] |
| |
| if "num_updates" not in state["optimizer_history"][-1]: |
| state["optimizer_history"][-1]["num_updates"] = 0 |
| |
| if ( |
| "args" in state |
| and hasattr(state["args"], "max_positions") |
| and not hasattr(state["args"], "max_source_positions") |
| ): |
| state["args"].max_source_positions = state["args"].max_positions |
| state["args"].max_target_positions = state["args"].max_positions |
| |
| if "train_iterator" not in state["extra_state"]: |
| state["extra_state"]["train_iterator"] = { |
| "epoch": state["extra_state"]["epoch"], |
| "iterations_in_epoch": state["extra_state"].get("batch_offset", 0), |
| } |
|
|
| |
| if "args" in state and state["args"] is not None: |
| |
| if not hasattr(state["args"], "task"): |
| state["args"].task = "translation" |
| |
| if getattr(state["args"], "raw_text", False): |
| state["args"].dataset_impl = "raw" |
| elif getattr(state["args"], "lazy_load", False): |
| state["args"].dataset_impl = "lazy" |
| |
| if state["extra_state"]["train_iterator"] is not None: |
| state["extra_state"]["train_iterator"]["epoch"] = max( |
| state["extra_state"]["train_iterator"].get("epoch", 1), 1 |
| ) |
| |
| if hasattr(state["args"], "remove_bpe"): |
| state["args"].post_process = state["args"].remove_bpe |
| |
| if hasattr(state["args"], "min_lr"): |
| state["args"].stop_min_lr = state["args"].min_lr |
| del state["args"].min_lr |
| |
| if ( |
| hasattr(state["args"], "criterion") |
| and state["args"].criterion in [ |
| "binary_cross_entropy", |
| "kd_binary_cross_entropy", |
| ] |
| ): |
| state["args"].criterion = "wav2vec" |
| |
| if hasattr(state["args"], "log_keys") and state["args"].log_keys is None: |
| delattr(state["args"], "log_keys") |
| |
| if ( |
| hasattr(state["args"], "task") |
| and state["args"].task == "speech_pretraining" |
| ): |
| state["args"].task = "audio_pretraining" |
| |
| if hasattr(state["args"], "arch") and state["args"].arch == "audio_cpc": |
| state["args"].arch = "wav2vec" |
| |
| if hasattr(state["args"], "lr") and isinstance(state["args"].lr, float): |
| state["args"].lr = [state["args"].lr] |
| |
| if ( |
| hasattr(state["args"], "data") |
| and isinstance(state["args"].data, list) |
| and len(state["args"].data) > 0 |
| ): |
| state["args"].data = state["args"].data[0] |
| |
| for key in [ |
| "static_teachers", |
| "static_teacher_weights", |
| "dynamic_teachers", |
| "dynamic_teacher_weights", |
| ]: |
| if key in state["args"]: |
| delattr(state["args"], key) |
|
|
| state["cfg"] = convert_namespace_to_omegaconf(state["args"]) |
|
|
| if "cfg" in state and state["cfg"] is not None: |
| cfg = state["cfg"] |
| with open_dict(cfg): |
| |
| if ( |
| "task" in cfg |
| and "eval_wer_config" in cfg.task |
| and isinstance(cfg.task.eval_wer_config.print_alignment, bool) |
| ): |
| cfg.task.eval_wer_config.print_alignment = "hard" |
| if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool): |
| cfg.generation.print_alignment = "hard" if cfg.generation.print_alignment else None |
| if ( |
| "model" in cfg |
| and "w2v_args" in cfg.model |
| and cfg.model.w2v_args is not None |
| and ( |
| hasattr(cfg.model.w2v_args, "task") or "task" in cfg.model.w2v_args |
| ) |
| and hasattr(cfg.model.w2v_args.task, "eval_wer_config") |
| and cfg.model.w2v_args.task.eval_wer_config is not None |
| and isinstance( |
| cfg.model.w2v_args.task.eval_wer_config.print_alignment, bool |
| ) |
| ): |
| cfg.model.w2v_args.task.eval_wer_config.print_alignment = "hard" |
|
|
| return state |
|
|
|
|
| def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]): |
| """Prune the given state_dict if desired for LayerDrop |
| (https://arxiv.org/abs/1909.11556). |
| |
| Training with LayerDrop allows models to be robust to pruning at inference |
| time. This function prunes state_dict to allow smaller models to be loaded |
| from a larger model and re-maps the existing state_dict for this to occur. |
| |
| It's called by functions that load models from checkpoints and does not |
| need to be called directly. |
| """ |
| arch = None |
| if model_cfg is not None: |
| arch = ( |
| model_cfg._name |
| if isinstance(model_cfg, DictConfig) |
| else getattr(model_cfg, "arch", None) |
| ) |
|
|
| if not model_cfg or arch is None or arch == "ptt_transformer": |
| |
| return state_dict |
|
|
| encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None) |
| decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None) |
|
|
| if not encoder_layers_to_keep and not decoder_layers_to_keep: |
| return state_dict |
|
|
| |
| logger.info( |
| "Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop" |
| ) |
|
|
| def create_pruning_pass(layers_to_keep, layer_name): |
| keep_layers = sorted( |
| int(layer_string) for layer_string in layers_to_keep.split(",") |
| ) |
| mapping_dict = {} |
| for i in range(len(keep_layers)): |
| mapping_dict[str(keep_layers[i])] = str(i) |
|
|
| regex = re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name)) |
| return {"substitution_regex": regex, "mapping_dict": mapping_dict} |
|
|
| pruning_passes = [] |
| if encoder_layers_to_keep: |
| pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder")) |
| if decoder_layers_to_keep: |
| pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder")) |
|
|
| new_state_dict = {} |
| for layer_name in state_dict.keys(): |
| match = re.search(r"\.layers\.(\d+)\.", layer_name) |
| |
| |
| if not match: |
| new_state_dict[layer_name] = state_dict[layer_name] |
| continue |
|
|
| |
| original_layer_number = match.group(1) |
| |
| for pruning_pass in pruning_passes: |
| if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass[ |
| "substitution_regex" |
| ].search(layer_name): |
| new_layer_number = pruning_pass["mapping_dict"][original_layer_number] |
| substitution_match = pruning_pass["substitution_regex"].search( |
| layer_name |
| ) |
| new_state_key = ( |
| layer_name[: substitution_match.start(1)] |
| + new_layer_number |
| + layer_name[substitution_match.end(1) :] |
| ) |
| new_state_dict[new_state_key] = state_dict[layer_name] |
|
|
| |
| |
| if isinstance(model_cfg, DictConfig): |
| context = open_dict(model_cfg) |
| else: |
| context = contextlib.ExitStack() |
| with context: |
| if hasattr(model_cfg, "encoder_layers_to_keep"): |
| model_cfg.encoder_layers_to_keep = None |
| if hasattr(model_cfg, "decoder_layers_to_keep"): |
| model_cfg.decoder_layers_to_keep = None |
|
|
| return new_state_dict |
|
|
|
|
| def load_pretrained_component_from_model( |
| component: Union[FairseqEncoder, FairseqDecoder], checkpoint: str |
| ): |
| """ |
| Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the |
| provided `component` object. If state_dict fails to load, there may be a |
| mismatch in the architecture of the corresponding `component` found in the |
| `checkpoint` file. |
| """ |
| if not PathManager.exists(checkpoint): |
| raise IOError("Model file not found: {}".format(checkpoint)) |
| state = load_checkpoint_to_cpu(checkpoint) |
| if isinstance(component, FairseqEncoder): |
| component_type = "encoder" |
| elif isinstance(component, FairseqDecoder): |
| component_type = "decoder" |
| else: |
| raise ValueError( |
| "component to load must be either a FairseqEncoder or " |
| "FairseqDecoder. Loading other component types are not supported." |
| ) |
| component_state_dict = OrderedDict() |
| for key in state["model"].keys(): |
| if key.startswith(component_type): |
| |
| component_subkey = key[len(component_type) + 1 :] |
| component_state_dict[component_subkey] = state["model"][key] |
| component.load_state_dict(component_state_dict, strict=True) |
| return component |
|
|
|
|
| def verify_checkpoint_directory(save_dir: str) -> None: |
| if not os.path.exists(save_dir): |
| os.makedirs(save_dir, exist_ok=True) |
| temp_file_path = os.path.join(save_dir, "dummy") |
| try: |
| with open(temp_file_path, "w"): |
| pass |
| except OSError as e: |
| logger.warning( |
| "Unable to access checkpoint save directory: {}".format(save_dir) |
| ) |
| raise e |
| else: |
| os.remove(temp_file_path) |
|
|
|
|
| def load_ema_from_checkpoint(fpath): |
| """Loads exponential moving averaged (EMA) checkpoint from input and |
| returns a model with ema weights. |
| |
| Args: |
| fpath: A string path of checkpoint to load from. |
| |
| Returns: |
| A dict of string keys mapping to various values. The 'model' key |
| from the returned dict should correspond to an OrderedDict mapping |
| string parameter names to torch Tensors. |
| """ |
| params_dict = collections.OrderedDict() |
| new_state = None |
|
|
| with PathManager.open(fpath, 'rb') as f: |
| new_state = torch.load( |
| f, |
| map_location=( |
| lambda s, _: torch.serialization.default_restore_location(s, 'cpu') |
| ), |
| ) |
|
|
| |
| model_params = new_state['extra_state']['ema'] |
|
|
| for key in list(model_params.keys()): |
| p = model_params[key] |
| if isinstance(p, torch.HalfTensor): |
| p = p.float() |
| if key not in params_dict: |
| params_dict[key] = p.clone() |
| |
| else: |
| raise ValueError("Key {} is repeated in EMA model params.".format(key)) |
|
|
| if len(params_dict) == 0: |
| raise ValueError( |
| f"Input checkpoint path '{fpath}' does not contain " |
| "ema model weights, is this model trained with EMA?" |
| ) |
|
|
| new_state['model'] = params_dict |
| return new_state |
|
|