| """ |
| Helpers to train with 16-bit precision. |
| """ |
|
|
| import torch.nn as nn |
| from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors |
|
|
|
|
| def convert_module_to_f16(l): |
| """ |
| Convert primitive modules to float16. |
| """ |
| if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): |
| l.weight.data = l.weight.data.half() |
| l.bias.data = l.bias.data.half() |
|
|
|
|
| def convert_module_to_f32(l): |
| """ |
| Convert primitive modules to float32, undoing convert_module_to_f16(). |
| """ |
| if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): |
| l.weight.data = l.weight.data.float() |
| l.bias.data = l.bias.data.float() |
|
|
|
|
| def make_master_params(model_params): |
| """ |
| Copy model parameters into a (differently-shaped) list of full-precision |
| parameters. |
| """ |
| master_params = _flatten_dense_tensors( |
| [param.detach().float() for param in model_params] |
| ) |
| master_params = nn.Parameter(master_params) |
| master_params.requires_grad = True |
| return [master_params] |
|
|
|
|
| def model_grads_to_master_grads(model_params, master_params): |
| """ |
| Copy the gradients from the model parameters into the master parameters |
| from make_master_params(). |
| """ |
| master_params[0].grad = _flatten_dense_tensors( |
| [param.grad.data.detach().float() for param in model_params] |
| ) |
|
|
|
|
| def master_params_to_model_params(model_params, master_params): |
| """ |
| Copy the master parameter data back into the model parameters. |
| """ |
| |
| |
| model_params = list(model_params) |
|
|
| for param, master_param in zip( |
| model_params, unflatten_master_params(model_params, master_params) |
| ): |
| param.detach().copy_(master_param) |
|
|
|
|
| def unflatten_master_params(model_params, master_params): |
| """ |
| Unflatten the master parameters to look like model_params. |
| """ |
| return _unflatten_dense_tensors(master_params[0].detach(), model_params) |
|
|
|
|
| def zero_grad(model_params): |
| for param in model_params: |
| |
| if param.grad is not None: |
| param.grad.detach_() |
| param.grad.zero_() |
|
|