Spaces:
Runtime error
Runtime error
| '''This Code is based on the FrEIA Framework, source: https://github.com/VLL-HD/FrEIA | |
| It is a assembly of the necessary modules/functions from FrEIA that are needed for our purposes.''' | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.autograd import Variable | |
| from math import exp | |
| import numpy as np | |
| VERBOSE = False | |
| class dummy_data: | |
| def __init__(self, *dims): | |
| self.dims = dims | |
| def shape(self): | |
| return self.dims | |
| class F_fully_connected(nn.Module): | |
| '''Fully connected tranformation, not reversible, but used below.''' | |
| def __init__(self, size_in, size, internal_size=None, dropout=0.0): | |
| super(F_fully_connected, self).__init__() | |
| if not internal_size: | |
| internal_size = 2*size | |
| self.d1 = nn.Dropout(p=dropout) | |
| self.d2 = nn.Dropout(p=dropout) | |
| self.d2b = nn.Dropout(p=dropout) | |
| self.fc1 = nn.Linear(size_in, internal_size) | |
| self.fc2 = nn.Linear(internal_size, internal_size) | |
| self.fc2b = nn.Linear(internal_size, internal_size) | |
| self.fc3 = nn.Linear(internal_size, size) | |
| self.nl1 = nn.ReLU() | |
| self.nl2 = nn.ReLU() | |
| self.nl2b = nn.ReLU() | |
| self.bn = nn.BatchNorm1d(size_in) | |
| def forward(self, x): | |
| out = self.nl1(self.d1(self.fc1(x))) | |
| out = self.nl2(self.d2(self.fc2(out))) | |
| out = self.nl2b(self.d2b(self.fc2b(out))) | |
| out = self.fc3(out) | |
| return out | |
| class permute_layer(nn.Module): | |
| '''permutes input vector in a random but fixed way''' | |
| def __init__(self, dims_in, seed): | |
| super(permute_layer, self).__init__() | |
| self.in_channels = dims_in[0][0] | |
| np.random.seed(seed) | |
| self.perm = np.random.permutation(self.in_channels) | |
| np.random.seed() | |
| self.perm_inv = np.zeros_like(self.perm) | |
| for i, p in enumerate(self.perm): | |
| self.perm_inv[p] = i | |
| self.perm = torch.LongTensor(self.perm) | |
| self.perm_inv = torch.LongTensor(self.perm_inv) | |
| def forward(self, x, rev=False): | |
| if not rev: | |
| return [x[0][:, self.perm]] | |
| else: | |
| return [x[0][:, self.perm_inv]] | |
| def jacobian(self, x, rev=False): | |
| # TODO: use batch size, set as nn.Parameter so cuda() works | |
| return 0. | |
| def output_dims(self, input_dims): | |
| assert len(input_dims) == 1, "Can only use 1 input" | |
| return input_dims | |
| class glow_coupling_layer(nn.Module): | |
| def __init__(self, dims_in, F_class=F_fully_connected, F_args={}, | |
| clamp=5.): | |
| super(glow_coupling_layer, self).__init__() | |
| channels = dims_in[0][0] | |
| self.ndims = len(dims_in[0]) | |
| self.split_len1 = channels // 2 | |
| self.split_len2 = channels - channels // 2 | |
| self.clamp = clamp | |
| self.max_s = exp(clamp) | |
| self.min_s = exp(-clamp) | |
| self.s1 = F_class(self.split_len1, self.split_len2*2, **F_args) | |
| self.s2 = F_class(self.split_len2, self.split_len1*2, **F_args) | |
| def e(self, s): | |
| return torch.exp(self.log_e(s)) | |
| def log_e(self, s): | |
| return self.clamp * 0.636 * torch.atan(s / self.clamp) | |
| def forward(self, x, rev=False): | |
| x1, x2 = (x[0].narrow(1, 0, self.split_len1), | |
| x[0].narrow(1, self.split_len1, self.split_len2)) | |
| if not rev: | |
| r2 = self.s2(x2) | |
| s2, t2 = r2[:, :self.split_len1], r2[:, self.split_len1:] | |
| #print(s2.shape, x1.shape, t2.shape) | |
| y1 = self.e(s2) * x1 + t2 | |
| r1 = self.s1(y1) | |
| s1, t1 = r1[:, :self.split_len2], r1[:, self.split_len2:] | |
| y2 = self.e(s1) * x2 + t1 | |
| else: # names of x and y are swapped! | |
| r1 = self.s1(x1) | |
| s1, t1 = r1[:, :self.split_len2], r1[:, self.split_len2:] | |
| y2 = (x2 - t1) / self.e(s1) | |
| r2 = self.s2(y2) | |
| s2, t2 = r2[:, :self.split_len1], r2[:, self.split_len1:] | |
| y1 = (x1 - t2) / self.e(s2) | |
| y = torch.cat((y1, y2), 1) | |
| y = torch.clamp(y, -1e6, 1e6) | |
| return [y] | |
| def jacobian(self, x, rev=False): | |
| x1, x2 = (x[0].narrow(1, 0, self.split_len1), | |
| x[0].narrow(1, self.split_len1, self.split_len2)) | |
| if not rev: | |
| r2 = self.s2(x2) | |
| s2, t2 = r2[:, :self.split_len1], r2[:, self.split_len1:] | |
| y1 = self.e(s2) * x1 + t2 | |
| r1 = self.s1(y1) | |
| s1, t1 = r1[:, :self.split_len2], r1[:, self.split_len2:] | |
| else: # names of x and y are swapped! | |
| r1 = self.s1(x1) | |
| s1, t1 = r1[:, :self.split_len2], r1[:, self.split_len2:] | |
| y2 = (x2 - t1) / self.e(s1) | |
| r2 = self.s2(y2) | |
| s2, t2 = r2[:, :self.split_len1], r2[:, self.split_len1:] | |
| jac = (torch.sum(self.log_e(s1), dim=1) | |
| + torch.sum(self.log_e(s2), dim=1)) | |
| for i in range(self.ndims-1): | |
| jac = torch.sum(jac, dim=1) | |
| return jac | |
| def output_dims(self, input_dims): | |
| assert len(input_dims) == 1, "Can only use 1 input" | |
| return input_dims | |
| class Node: | |
| '''The Node class represents one transformation in the graph, with an | |
| arbitrary number of in- and outputs.''' | |
| def __init__(self, inputs, module_type, module_args, name=None): | |
| self.inputs = inputs | |
| self.outputs = [] | |
| self.module_type = module_type | |
| self.module_args = module_args | |
| self.input_dims, self.module = None, None | |
| self.computed = None | |
| self.computed_rev = None | |
| self.id = None | |
| if name: | |
| self.name = name | |
| else: | |
| self.name = hex(id(self))[-6:] | |
| for i in range(255): | |
| exec('self.out{0} = (self, {0})'.format(i)) | |
| def build_modules(self, verbose=VERBOSE): | |
| ''' Returns a list with the dimension of each output of this node, | |
| recursively calling build_modules of the nodes connected to the input. | |
| Use this information to initialize the pytorch nn.Module of this node. | |
| ''' | |
| if not self.input_dims: # Only do it if this hasn't been computed yet | |
| self.input_dims = [n.build_modules(verbose=verbose)[c] | |
| for n, c in self.inputs] | |
| try: | |
| self.module = self.module_type(self.input_dims, | |
| **self.module_args) | |
| except Exception as e: | |
| print('Error in node %s' % (self.name)) | |
| raise e | |
| if verbose: | |
| print("Node %s has following input dimensions:" % (self.name)) | |
| for d, (n, c) in zip(self.input_dims, self.inputs): | |
| print("\t Output #%i of node %s:" % (c, n.name), d) | |
| print() | |
| self.output_dims = self.module.output_dims(self.input_dims) | |
| self.n_outputs = len(self.output_dims) | |
| return self.output_dims | |
| def run_forward(self, op_list): | |
| '''Determine the order of operations needed to reach this node. Calls | |
| run_forward of parent nodes recursively. Each operation is appended to | |
| the global list op_list, in the form (node ID, input variable IDs, | |
| output variable IDs)''' | |
| if not self.computed: | |
| # Compute all nodes which provide inputs, filter out the | |
| # channels you need | |
| self.input_vars = [] | |
| for i, (n, c) in enumerate(self.inputs): | |
| self.input_vars.append(n.run_forward(op_list)[c]) | |
| # Register youself as an output in the input node | |
| n.outputs.append((self, i)) | |
| # All outputs could now be computed | |
| self.computed = [(self.id, i) for i in range(self.n_outputs)] | |
| op_list.append((self.id, self.input_vars, self.computed)) | |
| # Return the variables you have computed (this happens mulitple times | |
| # without recomputing if called repeatedly) | |
| return self.computed | |
| def run_backward(self, op_list): | |
| '''See run_forward, this is the same, only for the reverse computation. | |
| Need to call run_forward first, otherwise this function will not | |
| work''' | |
| assert len(self.outputs) > 0, "Call run_forward first" | |
| if not self.computed_rev: | |
| # These are the input variables that must be computed first | |
| output_vars = [(self.id, i) for i in range(self.n_outputs)] | |
| # Recursively compute these | |
| for n, c in self.outputs: | |
| n.run_backward(op_list) | |
| # The variables that this node computes are the input variables | |
| # from the forward pass | |
| self.computed_rev = self.input_vars | |
| op_list.append((self.id, output_vars, self.computed_rev)) | |
| return self.computed_rev | |
| class InputNode(Node): | |
| '''Special type of node that represents the input data of the whole net (or | |
| ouput when running reverse)''' | |
| def __init__(self, *dims, name='node'): | |
| self.name = name | |
| self.data = dummy_data(*dims) | |
| self.outputs = [] | |
| self.module = None | |
| self.computed_rev = None | |
| self.n_outputs = 1 | |
| self.input_vars = [] | |
| self.out0 = (self, 0) | |
| def build_modules(self, verbose=VERBOSE): | |
| return [self.data.shape] | |
| def run_forward(self, op_list): | |
| return [(self.id, 0)] | |
| class OutputNode(Node): | |
| '''Special type of node that represents the output of the whole net (of the | |
| input when running in reverse)''' | |
| class dummy(nn.Module): | |
| def __init__(self, *args): | |
| super(OutputNode.dummy, self).__init__() | |
| def __call__(*args): | |
| return args | |
| def output_dims(*args): | |
| return args | |
| def __init__(self, inputs, name='node'): | |
| self.module_type, self.module_args = self.dummy, {} | |
| self.output_dims = [] | |
| self.inputs = inputs | |
| self.input_dims, self.module = None, None | |
| self.computed = None | |
| self.id = None | |
| self.name = name | |
| for c, inp in enumerate(self.inputs): | |
| inp[0].outputs.append((self, c)) | |
| def run_backward(self, op_list): | |
| return [(self.id, 0)] | |
| class ReversibleGraphNet(nn.Module): | |
| '''This class represents the invertible net itself. It is a subclass of | |
| torch.nn.Module and supports the same methods. The forward method has an | |
| additional option 'rev', whith which the net can be computed in reverse.''' | |
| def __init__(self, node_list, ind_in=None, ind_out=None, verbose=False): | |
| '''node_list should be a list of all nodes involved, and ind_in, | |
| ind_out are the indexes of the special nodes InputNode and OutputNode | |
| in this list.''' | |
| super(ReversibleGraphNet, self).__init__() | |
| # Gather lists of input and output nodes | |
| if ind_in is not None: | |
| if isinstance(ind_in, int): | |
| self.ind_in = list([ind_in]) | |
| else: | |
| self.ind_in = ind_in | |
| else: | |
| self.ind_in = [i for i in range(len(node_list)) | |
| if isinstance(node_list[i], InputNode)] | |
| assert len(self.ind_in) > 0, "No input nodes specified." | |
| if ind_out is not None: | |
| if isinstance(ind_out, int): | |
| self.ind_out = list([ind_out]) | |
| else: | |
| self.ind_out = ind_out | |
| else: | |
| self.ind_out = [i for i in range(len(node_list)) | |
| if isinstance(node_list[i], OutputNode)] | |
| assert len(self.ind_out) > 0, "No output nodes specified." | |
| self.return_vars = [] | |
| self.input_vars = [] | |
| # Assign each node a unique ID | |
| self.node_list = node_list | |
| for i, n in enumerate(node_list): | |
| n.id = i | |
| # Recursively build the nodes nn.Modules and determine order of | |
| # operations | |
| ops = [] | |
| for i in self.ind_out: | |
| node_list[i].build_modules(verbose=verbose) | |
| node_list[i].run_forward(ops) | |
| # create list of Pytorch variables that are used | |
| variables = set() | |
| for o in ops: | |
| variables = variables.union(set(o[1] + o[2])) | |
| self.variables_ind = list(variables) | |
| self.indexed_ops = self.ops_to_indexed(ops) | |
| self.module_list = nn.ModuleList([n.module for n in node_list]) | |
| self.variable_list = [Variable(requires_grad=True) for v in variables] | |
| # Find out the order of operations for reverse calculations | |
| ops_rev = [] | |
| for i in self.ind_in: | |
| node_list[i].run_backward(ops_rev) | |
| self.indexed_ops_rev = self.ops_to_indexed(ops_rev) | |
| def ops_to_indexed(self, ops): | |
| '''Helper function to translate the list of variables (origin ID, channel), | |
| to variable IDs.''' | |
| result = [] | |
| for o in ops: | |
| try: | |
| vars_in = [self.variables_ind.index(v) for v in o[1]] | |
| except ValueError: | |
| vars_in = -1 | |
| vars_out = [self.variables_ind.index(v) for v in o[2]] | |
| # Collect input/output nodes in separate lists, but don't add to | |
| # indexed ops | |
| if o[0] in self.ind_out: | |
| self.return_vars.append(self.variables_ind.index(o[1][0])) | |
| continue | |
| if o[0] in self.ind_in: | |
| self.input_vars.append(self.variables_ind.index(o[1][0])) | |
| continue | |
| result.append((o[0], vars_in, vars_out)) | |
| # Sort input/output variables so they correspond to initial node list | |
| # order | |
| self.return_vars.sort(key=lambda i: self.variables_ind[i][0]) | |
| self.input_vars.sort(key=lambda i: self.variables_ind[i][0]) | |
| return result | |
| def forward(self, x, rev=False): | |
| '''Forward or backward computation of the whole net.''' | |
| if rev: | |
| use_list = self.indexed_ops_rev | |
| input_vars, output_vars = self.return_vars, self.input_vars | |
| else: | |
| use_list = self.indexed_ops | |
| input_vars, output_vars = self.input_vars, self.return_vars | |
| if isinstance(x, (list, tuple)): | |
| assert len(x) == len(input_vars), ( | |
| f"Got list of {len(x)} input tensors for " | |
| f"{'inverse' if rev else 'forward'} pass, but expected " | |
| f"{len(input_vars)}." | |
| ) | |
| for i in range(len(input_vars)): | |
| self.variable_list[input_vars[i]] = x[i] | |
| else: | |
| assert len(input_vars) == 1, (f"Got single input tensor for " | |
| f"{'inverse' if rev else 'forward'} " | |
| f"pass, but expected list of " | |
| f"{len(input_vars)}.") | |
| self.variable_list[input_vars[0]] = x | |
| for o in use_list: | |
| try: | |
| results = self.module_list[o[0]]([self.variable_list[i] | |
| for i in o[1]], rev=rev) | |
| except TypeError: | |
| raise RuntimeError("Are you sure all used Nodes are in the " | |
| "Node list?") | |
| for i, r in zip(o[2], results): | |
| self.variable_list[i] = r | |
| # self.variable_list[o[2][0]] = self.variable_list[o[1][0]] | |
| out = [self.variable_list[output_vars[i]] | |
| for i in range(len(output_vars))] | |
| if len(out) == 1: | |
| return out[0] | |
| else: | |
| return out | |
| def jacobian(self, x=None, rev=False, run_forward=True): | |
| '''Compute the jacobian determinant of the whole net.''' | |
| jacobian = 0 | |
| if rev: | |
| use_list = self.indexed_ops_rev | |
| else: | |
| use_list = self.indexed_ops | |
| if run_forward: | |
| if x is None: | |
| raise RuntimeError("You need to provide an input if you want " | |
| "to run a forward pass") | |
| self.forward(x, rev=rev) | |
| jacobian_list = list() | |
| for o in use_list: | |
| try: | |
| node_jac = self.module_list[o[0]].jacobian( | |
| [self.variable_list[i] for i in o[1]], rev=rev | |
| ) | |
| jacobian += node_jac | |
| jacobian_list.append(jacobian) | |
| except TypeError: | |
| raise RuntimeError("Are you sure all used Nodes are in the " | |
| "Node list?") | |
| return jacobian |