| import torch |
| import torch.nn as nn |
|
|
|
|
| def count_params(model): |
| total_params = sum(p.numel() for p in model.parameters()) |
| return total_params |
|
|
|
|
| class ActNorm(nn.Module): |
| def __init__(self, num_features, logdet=False, affine=True, |
| allow_reverse_init=False): |
| assert affine |
| super().__init__() |
| self.logdet = logdet |
| self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) |
| self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) |
| self.allow_reverse_init = allow_reverse_init |
|
|
| self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) |
|
|
| def initialize(self, input): |
| with torch.no_grad(): |
| flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) |
| mean = ( |
| flatten.mean(1) |
| .unsqueeze(1) |
| .unsqueeze(2) |
| .unsqueeze(3) |
| .permute(1, 0, 2, 3) |
| ) |
| std = ( |
| flatten.std(1) |
| .unsqueeze(1) |
| .unsqueeze(2) |
| .unsqueeze(3) |
| .permute(1, 0, 2, 3) |
| ) |
|
|
| self.loc.data.copy_(-mean) |
| self.scale.data.copy_(1 / (std + 1e-6)) |
|
|
| def forward(self, input, reverse=False): |
| if reverse: |
| return self.reverse(input) |
| if len(input.shape) == 2: |
| input = input[:,:,None,None] |
| squeeze = True |
| else: |
| squeeze = False |
|
|
| _, _, height, width = input.shape |
|
|
| if self.training and self.initialized.item() == 0: |
| self.initialize(input) |
| self.initialized.fill_(1) |
|
|
| h = self.scale * (input + self.loc) |
|
|
| if squeeze: |
| h = h.squeeze(-1).squeeze(-1) |
|
|
| if self.logdet: |
| log_abs = torch.log(torch.abs(self.scale)) |
| logdet = height*width*torch.sum(log_abs) |
| logdet = logdet * torch.ones(input.shape[0]).to(input) |
| return h, logdet |
|
|
| return h |
|
|
| def reverse(self, output): |
| if self.training and self.initialized.item() == 0: |
| if not self.allow_reverse_init: |
| raise RuntimeError( |
| "Initializing ActNorm in reverse direction is " |
| "disabled by default. Use allow_reverse_init=True to enable." |
| ) |
| else: |
| self.initialize(output) |
| self.initialized.fill_(1) |
|
|
| if len(output.shape) == 2: |
| output = output[:,:,None,None] |
| squeeze = True |
| else: |
| squeeze = False |
|
|
| h = output / self.scale - self.loc |
|
|
| if squeeze: |
| h = h.squeeze(-1).squeeze(-1) |
| return h |
|
|
|
|
| class AbstractEncoder(nn.Module): |
| def __init__(self): |
| super().__init__() |
|
|
| def encode(self, *args, **kwargs): |
| raise NotImplementedError |
|
|
|
|
| class Labelator(AbstractEncoder): |
| """Net2Net Interface for Class-Conditional Model""" |
| def __init__(self, n_classes, quantize_interface=True): |
| super().__init__() |
| self.n_classes = n_classes |
| self.quantize_interface = quantize_interface |
|
|
| def encode(self, c): |
| c = c[:,None] |
| if self.quantize_interface: |
| return c, None, [None, None, c.long()] |
| return c |
|
|
|
|
| class SOSProvider(AbstractEncoder): |
| |
| def __init__(self, sos_token, quantize_interface=True): |
| super().__init__() |
| self.sos_token = sos_token |
| self.quantize_interface = quantize_interface |
|
|
| def encode(self, x): |
| |
| c = torch.ones(x.shape[0], 1)*self.sos_token |
| c = c.long().to(x.device) |
| if self.quantize_interface: |
| return c, None, [None, None, c] |
| return c |
|
|