Spaces:
Runtime error
Runtime error
| import collections | |
| import glob | |
| import logging | |
| import os | |
| from typing import List | |
| import torch | |
| from torch import nn | |
| from torch.optim.lr_scheduler import LambdaLR | |
| from torch.serialization import default_restore_location | |
| logger = logging.getLogger() | |
| CheckpointState = collections.namedtuple( | |
| "CheckpointState", | |
| [ | |
| "model_dict", | |
| "optimizer_dict", | |
| "scheduler_dict", | |
| "offset", | |
| "epoch", | |
| "encoder_params", | |
| ], | |
| ) | |
| def setup_for_distributed_mode( | |
| model: nn.Module, | |
| optimizer: torch.optim.Optimizer, | |
| device: object, | |
| n_gpu: int = 1, | |
| local_rank: int = -1, | |
| fp16: bool = False, | |
| fp16_opt_level: str = "O1", | |
| ) -> (nn.Module, torch.optim.Optimizer): | |
| model.to(device) | |
| if fp16: | |
| try: | |
| import apex | |
| from apex import amp | |
| apex.amp.register_half_function(torch, "einsum") | |
| except ImportError: | |
| raise ImportError( | |
| "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." | |
| ) | |
| model, optimizer = amp.initialize(model, optimizer, opt_level=fp16_opt_level) | |
| if n_gpu > 1: | |
| model = torch.nn.DataParallel(model) | |
| if local_rank != -1: | |
| model = torch.nn.parallel.DistributedDataParallel( | |
| model, | |
| device_ids=[local_rank], | |
| output_device=local_rank, | |
| find_unused_parameters=True, | |
| ) | |
| return model, optimizer | |
| def move_to_cuda(sample): | |
| if len(sample) == 0: | |
| return {} | |
| def _move_to_cuda(maybe_tensor): | |
| if torch.is_tensor(maybe_tensor): | |
| return maybe_tensor.cuda() | |
| elif isinstance(maybe_tensor, dict): | |
| return {key: _move_to_cuda(value) for key, value in maybe_tensor.items()} | |
| elif isinstance(maybe_tensor, list): | |
| return [_move_to_cuda(x) for x in maybe_tensor] | |
| elif isinstance(maybe_tensor, tuple): | |
| return [_move_to_cuda(x) for x in maybe_tensor] | |
| else: | |
| return maybe_tensor | |
| return _move_to_cuda(sample) | |
| def move_to_device(sample, device): | |
| if len(sample) == 0: | |
| return {} | |
| def _move_to_device(maybe_tensor, device): | |
| if torch.is_tensor(maybe_tensor): | |
| return maybe_tensor.to(device) | |
| elif isinstance(maybe_tensor, dict): | |
| return { | |
| key: _move_to_device(value, device) | |
| for key, value in maybe_tensor.items() | |
| } | |
| elif isinstance(maybe_tensor, list): | |
| return [_move_to_device(x, device) for x in maybe_tensor] | |
| elif isinstance(maybe_tensor, tuple): | |
| return [_move_to_device(x, device) for x in maybe_tensor] | |
| else: | |
| return maybe_tensor | |
| return _move_to_device(sample, device) | |
| def get_schedule_linear(optimizer, warmup_steps, training_steps, last_epoch=-1): | |
| """Create a schedule with a learning rate that decreases linearly after | |
| linearly increasing during a warmup period. | |
| """ | |
| def lr_lambda(current_step): | |
| if current_step < warmup_steps: | |
| return float(current_step) / float(max(1, warmup_steps)) | |
| return max( | |
| 0.0, | |
| float(training_steps - current_step) | |
| / float(max(1, training_steps - warmup_steps)), | |
| ) | |
| return LambdaLR(optimizer, lr_lambda, last_epoch) | |
| def init_weights(modules: List): | |
| for module in modules: | |
| if isinstance(module, (nn.Linear, nn.Embedding)): | |
| module.weight.data.normal_(mean=0.0, std=0.02) | |
| elif isinstance(module, nn.LayerNorm): | |
| module.bias.data.zero_() | |
| module.weight.data.fill_(1.0) | |
| if isinstance(module, nn.Linear) and module.bias is not None: | |
| module.bias.data.zero_() | |
| def get_model_obj(model: nn.Module): | |
| return model.module if hasattr(model, "module") else model | |
| def get_model_file(args, file_prefix) -> str: | |
| if args.model_file and os.path.exists(args.model_file): | |
| return args.model_file | |
| out_cp_files = ( | |
| glob.glob(os.path.join(args.output_dir, file_prefix + "*")) | |
| if args.output_dir | |
| else [] | |
| ) | |
| logger.info("Checkpoint files %s", out_cp_files) | |
| model_file = None | |
| if len(out_cp_files) > 0: | |
| model_file = max(out_cp_files, key=os.path.getctime) | |
| return model_file | |
| def load_states_from_checkpoint(model_file: str) -> CheckpointState: | |
| logger.info("Reading saved model from s", model_file) | |
| if isinstance(model_file, tuple): | |
| model_file = model_file[0] | |
| state_dict = torch.load( | |
| model_file, map_location=lambda s, l: default_restore_location(s, "cpu") | |
| ) | |
| logger.info("model_state_dict keys %s", state_dict.keys()) | |
| return CheckpointState(**state_dict) |