'''This Code is based on the FrEIA Framework, source: https://github.com/VLL-HD/FrEIA It is a assembly of the necessary modules/functions from FrEIA that are needed for our purposes.''' import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from math import exp import numpy as np VERBOSE = False class dummy_data: def __init__(self, *dims): self.dims = dims @property def shape(self): return self.dims class F_fully_connected(nn.Module): '''Fully connected tranformation, not reversible, but used below.''' def __init__(self, size_in, size, internal_size=None, dropout=0.0): super(F_fully_connected, self).__init__() if not internal_size: internal_size = 2*size self.d1 = nn.Dropout(p=dropout) self.d2 = nn.Dropout(p=dropout) self.d2b = nn.Dropout(p=dropout) self.fc1 = nn.Linear(size_in, internal_size) self.fc2 = nn.Linear(internal_size, internal_size) self.fc2b = nn.Linear(internal_size, internal_size) self.fc3 = nn.Linear(internal_size, size) self.nl1 = nn.ReLU() self.nl2 = nn.ReLU() self.nl2b = nn.ReLU() self.bn = nn.BatchNorm1d(size_in) def forward(self, x): out = self.nl1(self.d1(self.fc1(x))) out = self.nl2(self.d2(self.fc2(out))) out = self.nl2b(self.d2b(self.fc2b(out))) out = self.fc3(out) return out class permute_layer(nn.Module): '''permutes input vector in a random but fixed way''' def __init__(self, dims_in, seed): super(permute_layer, self).__init__() self.in_channels = dims_in[0][0] np.random.seed(seed) self.perm = np.random.permutation(self.in_channels) np.random.seed() self.perm_inv = np.zeros_like(self.perm) for i, p in enumerate(self.perm): self.perm_inv[p] = i self.perm = torch.LongTensor(self.perm) self.perm_inv = torch.LongTensor(self.perm_inv) def forward(self, x, rev=False): if not rev: return [x[0][:, self.perm]] else: return [x[0][:, self.perm_inv]] def jacobian(self, x, rev=False): # TODO: use batch size, set as nn.Parameter so cuda() works return 0. def output_dims(self, input_dims): assert len(input_dims) == 1, "Can only use 1 input" return input_dims class glow_coupling_layer(nn.Module): def __init__(self, dims_in, F_class=F_fully_connected, F_args={}, clamp=5.): super(glow_coupling_layer, self).__init__() channels = dims_in[0][0] self.ndims = len(dims_in[0]) self.split_len1 = channels // 2 self.split_len2 = channels - channels // 2 self.clamp = clamp self.max_s = exp(clamp) self.min_s = exp(-clamp) self.s1 = F_class(self.split_len1, self.split_len2*2, **F_args) self.s2 = F_class(self.split_len2, self.split_len1*2, **F_args) def e(self, s): return torch.exp(self.log_e(s)) def log_e(self, s): return self.clamp * 0.636 * torch.atan(s / self.clamp) def forward(self, x, rev=False): x1, x2 = (x[0].narrow(1, 0, self.split_len1), x[0].narrow(1, self.split_len1, self.split_len2)) if not rev: r2 = self.s2(x2) s2, t2 = r2[:, :self.split_len1], r2[:, self.split_len1:] #print(s2.shape, x1.shape, t2.shape) y1 = self.e(s2) * x1 + t2 r1 = self.s1(y1) s1, t1 = r1[:, :self.split_len2], r1[:, self.split_len2:] y2 = self.e(s1) * x2 + t1 else: # names of x and y are swapped! r1 = self.s1(x1) s1, t1 = r1[:, :self.split_len2], r1[:, self.split_len2:] y2 = (x2 - t1) / self.e(s1) r2 = self.s2(y2) s2, t2 = r2[:, :self.split_len1], r2[:, self.split_len1:] y1 = (x1 - t2) / self.e(s2) y = torch.cat((y1, y2), 1) y = torch.clamp(y, -1e6, 1e6) return [y] def jacobian(self, x, rev=False): x1, x2 = (x[0].narrow(1, 0, self.split_len1), x[0].narrow(1, self.split_len1, self.split_len2)) if not rev: r2 = self.s2(x2) s2, t2 = r2[:, :self.split_len1], r2[:, self.split_len1:] y1 = self.e(s2) * x1 + t2 r1 = self.s1(y1) s1, t1 = r1[:, :self.split_len2], r1[:, self.split_len2:] else: # names of x and y are swapped! r1 = self.s1(x1) s1, t1 = r1[:, :self.split_len2], r1[:, self.split_len2:] y2 = (x2 - t1) / self.e(s1) r2 = self.s2(y2) s2, t2 = r2[:, :self.split_len1], r2[:, self.split_len1:] jac = (torch.sum(self.log_e(s1), dim=1) + torch.sum(self.log_e(s2), dim=1)) for i in range(self.ndims-1): jac = torch.sum(jac, dim=1) return jac def output_dims(self, input_dims): assert len(input_dims) == 1, "Can only use 1 input" return input_dims class Node: '''The Node class represents one transformation in the graph, with an arbitrary number of in- and outputs.''' def __init__(self, inputs, module_type, module_args, name=None): self.inputs = inputs self.outputs = [] self.module_type = module_type self.module_args = module_args self.input_dims, self.module = None, None self.computed = None self.computed_rev = None self.id = None if name: self.name = name else: self.name = hex(id(self))[-6:] for i in range(255): exec('self.out{0} = (self, {0})'.format(i)) def build_modules(self, verbose=VERBOSE): ''' Returns a list with the dimension of each output of this node, recursively calling build_modules of the nodes connected to the input. Use this information to initialize the pytorch nn.Module of this node. ''' if not self.input_dims: # Only do it if this hasn't been computed yet self.input_dims = [n.build_modules(verbose=verbose)[c] for n, c in self.inputs] try: self.module = self.module_type(self.input_dims, **self.module_args) except Exception as e: print('Error in node %s' % (self.name)) raise e if verbose: print("Node %s has following input dimensions:" % (self.name)) for d, (n, c) in zip(self.input_dims, self.inputs): print("\t Output #%i of node %s:" % (c, n.name), d) print() self.output_dims = self.module.output_dims(self.input_dims) self.n_outputs = len(self.output_dims) return self.output_dims def run_forward(self, op_list): '''Determine the order of operations needed to reach this node. Calls run_forward of parent nodes recursively. Each operation is appended to the global list op_list, in the form (node ID, input variable IDs, output variable IDs)''' if not self.computed: # Compute all nodes which provide inputs, filter out the # channels you need self.input_vars = [] for i, (n, c) in enumerate(self.inputs): self.input_vars.append(n.run_forward(op_list)[c]) # Register youself as an output in the input node n.outputs.append((self, i)) # All outputs could now be computed self.computed = [(self.id, i) for i in range(self.n_outputs)] op_list.append((self.id, self.input_vars, self.computed)) # Return the variables you have computed (this happens mulitple times # without recomputing if called repeatedly) return self.computed def run_backward(self, op_list): '''See run_forward, this is the same, only for the reverse computation. Need to call run_forward first, otherwise this function will not work''' assert len(self.outputs) > 0, "Call run_forward first" if not self.computed_rev: # These are the input variables that must be computed first output_vars = [(self.id, i) for i in range(self.n_outputs)] # Recursively compute these for n, c in self.outputs: n.run_backward(op_list) # The variables that this node computes are the input variables # from the forward pass self.computed_rev = self.input_vars op_list.append((self.id, output_vars, self.computed_rev)) return self.computed_rev class InputNode(Node): '''Special type of node that represents the input data of the whole net (or ouput when running reverse)''' def __init__(self, *dims, name='node'): self.name = name self.data = dummy_data(*dims) self.outputs = [] self.module = None self.computed_rev = None self.n_outputs = 1 self.input_vars = [] self.out0 = (self, 0) def build_modules(self, verbose=VERBOSE): return [self.data.shape] def run_forward(self, op_list): return [(self.id, 0)] class OutputNode(Node): '''Special type of node that represents the output of the whole net (of the input when running in reverse)''' class dummy(nn.Module): def __init__(self, *args): super(OutputNode.dummy, self).__init__() def __call__(*args): return args def output_dims(*args): return args def __init__(self, inputs, name='node'): self.module_type, self.module_args = self.dummy, {} self.output_dims = [] self.inputs = inputs self.input_dims, self.module = None, None self.computed = None self.id = None self.name = name for c, inp in enumerate(self.inputs): inp[0].outputs.append((self, c)) def run_backward(self, op_list): return [(self.id, 0)] class ReversibleGraphNet(nn.Module): '''This class represents the invertible net itself. It is a subclass of torch.nn.Module and supports the same methods. The forward method has an additional option 'rev', whith which the net can be computed in reverse.''' def __init__(self, node_list, ind_in=None, ind_out=None, verbose=False): '''node_list should be a list of all nodes involved, and ind_in, ind_out are the indexes of the special nodes InputNode and OutputNode in this list.''' super(ReversibleGraphNet, self).__init__() # Gather lists of input and output nodes if ind_in is not None: if isinstance(ind_in, int): self.ind_in = list([ind_in]) else: self.ind_in = ind_in else: self.ind_in = [i for i in range(len(node_list)) if isinstance(node_list[i], InputNode)] assert len(self.ind_in) > 0, "No input nodes specified." if ind_out is not None: if isinstance(ind_out, int): self.ind_out = list([ind_out]) else: self.ind_out = ind_out else: self.ind_out = [i for i in range(len(node_list)) if isinstance(node_list[i], OutputNode)] assert len(self.ind_out) > 0, "No output nodes specified." self.return_vars = [] self.input_vars = [] # Assign each node a unique ID self.node_list = node_list for i, n in enumerate(node_list): n.id = i # Recursively build the nodes nn.Modules and determine order of # operations ops = [] for i in self.ind_out: node_list[i].build_modules(verbose=verbose) node_list[i].run_forward(ops) # create list of Pytorch variables that are used variables = set() for o in ops: variables = variables.union(set(o[1] + o[2])) self.variables_ind = list(variables) self.indexed_ops = self.ops_to_indexed(ops) self.module_list = nn.ModuleList([n.module for n in node_list]) self.variable_list = [Variable(requires_grad=True) for v in variables] # Find out the order of operations for reverse calculations ops_rev = [] for i in self.ind_in: node_list[i].run_backward(ops_rev) self.indexed_ops_rev = self.ops_to_indexed(ops_rev) def ops_to_indexed(self, ops): '''Helper function to translate the list of variables (origin ID, channel), to variable IDs.''' result = [] for o in ops: try: vars_in = [self.variables_ind.index(v) for v in o[1]] except ValueError: vars_in = -1 vars_out = [self.variables_ind.index(v) for v in o[2]] # Collect input/output nodes in separate lists, but don't add to # indexed ops if o[0] in self.ind_out: self.return_vars.append(self.variables_ind.index(o[1][0])) continue if o[0] in self.ind_in: self.input_vars.append(self.variables_ind.index(o[1][0])) continue result.append((o[0], vars_in, vars_out)) # Sort input/output variables so they correspond to initial node list # order self.return_vars.sort(key=lambda i: self.variables_ind[i][0]) self.input_vars.sort(key=lambda i: self.variables_ind[i][0]) return result def forward(self, x, rev=False): '''Forward or backward computation of the whole net.''' if rev: use_list = self.indexed_ops_rev input_vars, output_vars = self.return_vars, self.input_vars else: use_list = self.indexed_ops input_vars, output_vars = self.input_vars, self.return_vars if isinstance(x, (list, tuple)): assert len(x) == len(input_vars), ( f"Got list of {len(x)} input tensors for " f"{'inverse' if rev else 'forward'} pass, but expected " f"{len(input_vars)}." ) for i in range(len(input_vars)): self.variable_list[input_vars[i]] = x[i] else: assert len(input_vars) == 1, (f"Got single input tensor for " f"{'inverse' if rev else 'forward'} " f"pass, but expected list of " f"{len(input_vars)}.") self.variable_list[input_vars[0]] = x for o in use_list: try: results = self.module_list[o[0]]([self.variable_list[i] for i in o[1]], rev=rev) except TypeError: raise RuntimeError("Are you sure all used Nodes are in the " "Node list?") for i, r in zip(o[2], results): self.variable_list[i] = r # self.variable_list[o[2][0]] = self.variable_list[o[1][0]] out = [self.variable_list[output_vars[i]] for i in range(len(output_vars))] if len(out) == 1: return out[0] else: return out def jacobian(self, x=None, rev=False, run_forward=True): '''Compute the jacobian determinant of the whole net.''' jacobian = 0 if rev: use_list = self.indexed_ops_rev else: use_list = self.indexed_ops if run_forward: if x is None: raise RuntimeError("You need to provide an input if you want " "to run a forward pass") self.forward(x, rev=rev) jacobian_list = list() for o in use_list: try: node_jac = self.module_list[o[0]].jacobian( [self.variable_list[i] for i in o[1]], rev=rev ) jacobian += node_jac jacobian_list.append(jacobian) except TypeError: raise RuntimeError("Are you sure all used Nodes are in the " "Node list?") return jacobian