import base64 from google.protobuf import json_format from importlib import import_module import json import numpy as np import os import sys from mmdnn.conversion.caffe.errors import ConversionError from mmdnn.conversion.caffe.common_graph import fetch_attr_value from mmdnn.conversion.caffe.utils import get_lower_case, get_upper_case, get_real_name class JsonFormatter(object): '''Dumpt a DL graph into a Json file.''' def __init__(self, graph): self.graph_def = graph.as_graph_def() def dump(self, json_path): json_txt = json_format.MessageToJson(self.graph_def) parsed = json.loads(json_txt) formatted = json.dumps(parsed, indent=4, sort_keys=True) with open(json_path, 'w') as f: f.write(formatted) class PyWriter(object): '''Dumpt a DL graph into a Python script.''' def __init__(self, graph, data, target): self.graph = graph self.data = data self.tab = ' ' * 4 self.prefix = '' target = target.lower() if target == 'tensorflow': self.target = target self.net = 'TensorFlowNetwork' elif target == 'keras': self.target = target self.net = 'KerasNetwork' elif target == 'caffe': self.target = target self.net = 'CaffeNetwork' else: raise ConversionError('Target %s is not supported yet.' % target) def indent(self): self.prefix += self.tab def outdent(self): self.prefix = self.prefix[:-len(self.tab)] def statement(self, s): return self.prefix + s + '\n' def emit_imports(self): return self.statement('from dlconv.%s import %s\n' % (self.target, self.net)) def emit_class_def(self, name): return self.statement('class %s(%s):' % (name, self.net)) def emit_setup_def(self): return self.statement('def setup(self):') def emit_node(self, node): '''Emits the Python source for this node.''' def pair(key, value): return '%s=%s' % (key, value) args = [] for input in node.input: input = input.strip().split(':') name = ''.join(input[:-1]) idx = int(input[-1]) assert name in self.graph.node_dict parent = self.graph.get_node(name) args.append(parent.output[idx]) #FIXME: output = [node.output[0]] # output = node.output for k, v in node.attr: if k == 'cell_type': args.append(pair(k, "'" + fetch_attr_value(v) + "'")) else: args.append(pair(k, fetch_attr_value(v))) args.append(pair('name', "'" + node.name + "'")) # Set the node name args = ', '.join(args) return self.statement('%s = self.%s(%s)' % (', '.join(output), node.op, args)) def dump(self, code_output_dir): if not os.path.exists(code_output_dir): os.makedirs(code_output_dir) file_name = get_lower_case(self.graph.name) code_output_path = os.path.join(code_output_dir, file_name + '.py') data_output_path = os.path.join(code_output_dir, file_name + '.npy') with open(code_output_path, 'w') as f: f.write(self.emit()) with open(data_output_path, 'wb') as f: np.save(f, self.data) return code_output_path, data_output_path def emit(self): # Decompose DAG into chains chains = [] for node in self.graph.topologically_sorted(): attach_to_chain = None if len(node.input) == 1: parent = get_real_name(node.input[0]) for chain in chains: if chain[-1].name == parent: # Node is part of an existing chain. attach_to_chain = chain break if attach_to_chain is None: # Start a new chain for this node. attach_to_chain = [] chains.append(attach_to_chain) attach_to_chain.append(node) # Generate Python code line by line source = self.emit_imports() source += self.emit_class_def(self.graph.name) self.indent() source += self.emit_setup_def() self.indent() blocks = [] for chain in chains: b = '' for node in chain: b += self.emit_node(node) blocks.append(b[:-1]) source += '\n\n'.join(blocks) return source class ModelSaver(object): def __init__(self, code_output_path, data_output_path): self.code_output_path = code_output_path self.data_output_path = data_output_path def dump(self, model_output_dir): '''Return the file path containing graph in generated model files.''' if not os.path.exists(model_output_dir): os.makedirs(model_output_dir) sys.path.append(os.path.dirname(self.code_output_path)) file_name = os.path.splitext(os.path.basename(self.code_output_path))[0] module = import_module(file_name) class_name = get_upper_case(file_name) net = getattr(module, class_name) return net.dump(self.data_output_path, model_output_dir) class GraphDrawer(object): def __init__(self, toolkit, meta_path): self.toolkit = toolkit.lower() self.meta_path = meta_path def dump(self, graph_path): if self.toolkit == 'tensorflow': from dlconv.tensorflow.visualizer import TensorFlowVisualizer if self._is_web_page(graph_path): TensorFlowVisualizer(self.meta_path).dump_html(graph_path) else: raise NotImplementedError('Image format or %s is unsupported!' % graph_path) elif self.toolkit == 'keras': from dlconv.keras.visualizer import KerasVisualizer png_path, html_path = (None, None) if graph_path.endswith('.png'): png_path = graph_path elif self._is_web_page(graph_path): png_path = graph_path + ".png" html_path = graph_path else: raise NotImplementedError('Image format or %s is unsupported!' % graph_path) KerasVisualizer(self.meta_path).dump_png(png_path) if html_path: self._png_to_html(png_path, html_path) os.remove(png_path) else: raise NotImplementedError('Visualization of %s is unsupported!' % self.toolkit) def _is_web_page(self, path): return path.split('.')[-1] in ('html', 'htm') def _png_to_html(self, png_path, html_path): with open(png_path, "rb") as f: encoded = base64.b64encode(f.read()).decode('utf-8') source = """