Spaces:
Sleeping
Sleeping
| # This code is based on https://github.com/openai/guided-diffusion | |
| """ | |
| Various utilities for neural networks. | |
| """ | |
| import torch as th | |
| import torch.nn as nn | |
| # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. | |
| class SiLU(nn.Module): | |
| def forward(self, x): | |
| return x * th.sigmoid(x) | |
| class GroupNorm32(nn.GroupNorm): | |
| def forward(self, x): | |
| return super().forward(x.float()).type(x.dtype) | |
| def linear(*args, **kwargs): | |
| """ | |
| Create a linear module. | |
| """ | |
| return nn.Linear(*args, **kwargs) | |
| def mean_flat(tensor): | |
| """ | |
| Take the mean over all non-batch dimensions. | |
| """ | |
| return tensor.mean(dim=list(range(1, len(tensor.shape)))) | |
| def sum_flat(tensor): | |
| """ | |
| Take the sum over all non-batch dimensions. | |
| """ | |
| return tensor.sum(dim=list(range(1, len(tensor.shape)))) | |
| def normalization(channels): | |
| """ | |
| Make a standard normalization layer. | |
| :param channels: number of input channels. | |
| :return: an nn.Module for normalization. | |
| """ | |
| return GroupNorm32(32, channels) | |
| def checkpoint(func, inputs, params, flag): | |
| """ | |
| Evaluate a function without caching intermediate activations, allowing for | |
| reduced memory at the expense of extra compute in the backward pass. | |
| :param func: the function to evaluate. | |
| :param inputs: the argument sequence to pass to `func`. | |
| :param params: a sequence of parameters `func` depends on but does not | |
| explicitly take as arguments. | |
| :param flag: if False, disable gradient checkpointing. | |
| """ | |
| if flag: | |
| args = tuple(inputs) + tuple(params) | |
| return CheckpointFunction.apply(func, len(inputs), *args) | |
| else: | |
| return func(*inputs) | |
| class CheckpointFunction(th.autograd.Function): | |
| def forward(ctx, run_function, length, *args): | |
| ctx.run_function = run_function | |
| ctx.input_length = length | |
| ctx.save_for_backward(*args) | |
| with th.no_grad(): | |
| output_tensors = ctx.run_function(*args[:length]) | |
| return output_tensors | |
| def backward(ctx, *output_grads): | |
| args = list(ctx.saved_tensors) | |
| # Filter for inputs that require grad. If none, exit early. | |
| input_indices = [i for (i, x) in enumerate(args) if x.requires_grad] | |
| if not input_indices: | |
| return (None, None) + tuple(None for _ in args) | |
| with th.enable_grad(): | |
| for i in input_indices: | |
| if i < ctx.input_length: | |
| # Not sure why the OAI code does this little | |
| # dance. It might not be necessary. | |
| args[i] = args[i].detach().requires_grad_() | |
| args[i] = args[i].view_as(args[i]) | |
| output_tensors = ctx.run_function(*args[: ctx.input_length]) | |
| if isinstance(output_tensors, th.Tensor): | |
| output_tensors = [output_tensors] | |
| # Filter for outputs that require grad. If none, exit early. | |
| out_and_grads = [ | |
| (o, g) for (o, g) in zip(output_tensors, output_grads) if o.requires_grad | |
| ] | |
| if not out_and_grads: | |
| return (None, None) + tuple(None for _ in args) | |
| # Compute gradients on the filtered tensors. | |
| computed_grads = th.autograd.grad( | |
| [o for (o, g) in out_and_grads], | |
| [args[i] for i in input_indices], | |
| [g for (o, g) in out_and_grads], | |
| ) | |
| # Reassemble the complete gradient tuple. | |
| input_grads = [None for _ in args] | |
| for i, g in zip(input_indices, computed_grads): | |
| input_grads[i] = g | |
| return (None, None) + tuple(input_grads) | |