| |
|
|
| """Net summarization tool. |
| |
| This tool summarizes the structure of a net in a concise but comprehensive |
| tabular listing, taking a prototxt file as input. |
| |
| Use this tool to check at a glance that the computation you've specified is the |
| computation you expect. |
| """ |
|
|
| from caffe.proto import caffe_pb2 |
| from google import protobuf |
| import re |
| import argparse |
|
|
| |
| COLORS = ['92', '93', '94', '95', '97', '96', '42', '43;30', '100', |
| '444', '103;30', '107;30'] |
| DISCONNECTED_COLOR = '41' |
|
|
| def read_net(filename): |
| net = caffe_pb2.NetParameter() |
| with open(filename) as f: |
| protobuf.text_format.Parse(f.read(), net) |
| return net |
|
|
| def format_param(param): |
| out = [] |
| if len(param.name) > 0: |
| out.append(param.name) |
| if param.lr_mult != 1: |
| out.append('x{}'.format(param.lr_mult)) |
| if param.decay_mult != 1: |
| out.append('Dx{}'.format(param.decay_mult)) |
| return ' '.join(out) |
|
|
| def printed_len(s): |
| return len(re.sub(r'\033\[[\d;]+m', '', s)) |
|
|
| def print_table(table, max_width): |
| """Print a simple nicely-aligned table. |
| |
| table must be a list of (equal-length) lists. Columns are space-separated, |
| and as narrow as possible, but no wider than max_width. Text may overflow |
| columns; note that unlike string.format, this will not affect subsequent |
| columns, if possible.""" |
|
|
| max_widths = [max_width] * len(table[0]) |
| column_widths = [max(printed_len(row[j]) + 1 for row in table) |
| for j in range(len(table[0]))] |
| column_widths = [min(w, max_w) for w, max_w in zip(column_widths, max_widths)] |
|
|
| for row in table: |
| row_str = '' |
| right_col = 0 |
| for cell, width in zip(row, column_widths): |
| right_col += width |
| row_str += cell + ' ' |
| row_str += ' ' * max(right_col - printed_len(row_str), 0) |
| print row_str |
|
|
| def summarize_net(net): |
| disconnected_tops = set() |
| for lr in net.layer: |
| disconnected_tops |= set(lr.top) |
| disconnected_tops -= set(lr.bottom) |
|
|
| table = [] |
| colors = {} |
| for lr in net.layer: |
| tops = [] |
| for ind, top in enumerate(lr.top): |
| color = colors.setdefault(top, COLORS[len(colors) % len(COLORS)]) |
| if top in disconnected_tops: |
| top = '\033[1;4m' + top |
| if len(lr.loss_weight) > 0: |
| top = '{} * {}'.format(lr.loss_weight[ind], top) |
| tops.append('\033[{}m{}\033[0m'.format(color, top)) |
| top_str = ', '.join(tops) |
|
|
| bottoms = [] |
| for bottom in lr.bottom: |
| color = colors.get(bottom, DISCONNECTED_COLOR) |
| bottoms.append('\033[{}m{}\033[0m'.format(color, bottom)) |
| bottom_str = ', '.join(bottoms) |
|
|
| if lr.type == 'Python': |
| type_str = lr.python_param.module + '.' + lr.python_param.layer |
| else: |
| type_str = lr.type |
|
|
| |
| |
| conv_param = lr.convolution_param |
| if (lr.type in ['Convolution', 'Deconvolution'] |
| and len(conv_param.kernel_size) == 1): |
| arg_str = str(conv_param.kernel_size[0]) |
| if len(conv_param.stride) > 0 and conv_param.stride[0] != 1: |
| arg_str += '/' + str(conv_param.stride[0]) |
| if len(conv_param.pad) > 0 and conv_param.pad[0] != 0: |
| arg_str += '+' + str(conv_param.pad[0]) |
| arg_str += ' ' + str(conv_param.num_output) |
| if conv_param.group != 1: |
| arg_str += '/' + str(conv_param.group) |
| elif lr.type == 'Pooling': |
| arg_str = str(lr.pooling_param.kernel_size) |
| if lr.pooling_param.stride != 1: |
| arg_str += '/' + str(lr.pooling_param.stride) |
| if lr.pooling_param.pad != 0: |
| arg_str += '+' + str(lr.pooling_param.pad) |
| else: |
| arg_str = '' |
|
|
| if len(lr.param) > 0: |
| param_strs = map(format_param, lr.param) |
| if max(map(len, param_strs)) > 0: |
| param_str = '({})'.format(', '.join(param_strs)) |
| else: |
| param_str = '' |
| else: |
| param_str = '' |
|
|
| table.append([lr.name, type_str, param_str, bottom_str, '->', top_str, |
| arg_str]) |
| return table |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Print a concise summary of net computation.") |
| parser.add_argument('filename', help='net prototxt file to summarize') |
| parser.add_argument('-w', '--max-width', help='maximum field width', |
| type=int, default=30) |
| args = parser.parse_args() |
|
|
| net = read_net(args.filename) |
| table = summarize_net(net) |
| print_table(table, max_width=args.max_width) |
|
|
| if __name__ == '__main__': |
| main() |
|
|