Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """This tool creates an html visualization of a TensorFlow Lite graph. | |
| Example usage: | |
| python visualize.py foo.tflite foo.html | |
| """ | |
| import json | |
| import os | |
| import re | |
| import sys | |
| import numpy as np | |
| # pylint: disable=g-import-not-at-top | |
| if not os.path.splitext(__file__)[0].endswith( | |
| os.path.join("tflite_runtime", "visualize")): | |
| # This file is part of tensorflow package. | |
| from tensorflow.lite.python import schema_py_generated as schema_fb | |
| else: | |
| # This file is part of tflite_runtime package. | |
| from tflite_runtime import schema_py_generated as schema_fb | |
| import gradio as gr | |
| from html import escape | |
| # A CSS description for making the visualizer | |
| # body {font-family: sans-serif; background-color: #fa0;} | |
| # # font-family: sans-serif; | |
| """<style> | |
| table {background-color: #eca;} | |
| th {background-color: black; color: white;} | |
| h1 { | |
| background-color: ffaa00; | |
| padding:5px; | |
| color: black; | |
| } | |
| svg { | |
| margin: 10px; | |
| border: 2px; | |
| border-style: solid; | |
| border-color: black; | |
| background: white; | |
| } | |
| div { | |
| border-radius: 5px; | |
| background-color: #fec; | |
| padding:5px; | |
| margin:5px; | |
| } | |
| .tooltip {color: blue;} | |
| .tooltip .tooltipcontent { | |
| visibility: hidden; | |
| color: black; | |
| background-color: yellow; | |
| padding: 5px; | |
| border-radius: 4px; | |
| position: absolute; | |
| z-index: 1; | |
| } | |
| .tooltip:hover .tooltipcontent { | |
| visibility: visible; | |
| } | |
| .edges line { | |
| stroke: #333; | |
| } | |
| text { | |
| font-weight: bold; | |
| } | |
| .nodes text { | |
| color: black; | |
| pointer-events: none; | |
| font-size: 11px; | |
| } | |
| </style>""" | |
| _CSS = """ | |
| <script src="https://d3js.org/d3.v4.min.js"></script> | |
| """ | |
| _D3_HTML_TEMPLATE = """ | |
| <script> | |
| function buildGraph() { | |
| // Build graph data | |
| var graph = %s; | |
| var svg = d3.select("#subgraph%d") | |
| var width = svg.attr("width"); | |
| var height = svg.attr("height"); | |
| // Make the graph scrollable. | |
| svg = svg.call(d3.zoom().on("zoom", function() { | |
| svg.attr("transform", d3.event.transform); | |
| })).append("g"); | |
| var color = d3.scaleOrdinal(d3.schemeDark2); | |
| var simulation = d3.forceSimulation() | |
| .force("link", d3.forceLink().id(function(d) {return d.id;})) | |
| .force("charge", d3.forceManyBody()) | |
| .force("center", d3.forceCenter(0.5 * width, 0.5 * height)); | |
| var edge = svg.append("g").attr("class", "edges").selectAll("line") | |
| .data(graph.edges).enter().append("path").attr("stroke","black").attr("fill","none") | |
| // Make the node group | |
| var node = svg.selectAll(".nodes") | |
| .data(graph.nodes) | |
| .enter().append("g") | |
| .attr("x", function(d){return d.x}) | |
| .attr("y", function(d){return d.y}) | |
| .attr("transform", function(d) { | |
| return "translate( " + d.x + ", " + d.y + ")" | |
| }) | |
| .attr("class", "nodes") | |
| .call(d3.drag() | |
| .on("start", function(d) { | |
| if(!d3.event.active) simulation.alphaTarget(1.0).restart(); | |
| d.fx = d.x;d.fy = d.y; | |
| }) | |
| .on("drag", function(d) { | |
| d.fx = d3.event.x; d.fy = d3.event.y; | |
| }) | |
| .on("end", function(d) { | |
| if (!d3.event.active) simulation.alphaTarget(0); | |
| d.fx = d.fy = null; | |
| })); | |
| // Within the group, draw a box for the node position and text | |
| // on the side. | |
| var node_width = 150; | |
| var node_height = 30; | |
| node.append("rect") | |
| .attr("r", "5px") | |
| .attr("width", node_width) | |
| .attr("height", node_height) | |
| .attr("rx", function(d) { return d.group == 1 ? 1 : 10; }) | |
| .attr("stroke", "#000000") | |
| .attr("fill", function(d) { return d.group == 1 ? "#dddddd" : "#000000"; }) | |
| node.append("text") | |
| .text(function(d) { return d.name; }) | |
| .attr("x", 5) | |
| .attr("y", 20) | |
| .attr("fill", function(d) { return d.group == 1 ? "#000000" : "#eeeeee"; }) | |
| // Setup force parameters and update position callback | |
| var node = svg.selectAll(".nodes") | |
| .data(graph.nodes); | |
| // Bind the links | |
| var name_to_g = {} | |
| node.each(function(data, index, nodes) { | |
| console.log(data.id) | |
| name_to_g[data.id] = this; | |
| }); | |
| function proc(w, t) { | |
| return parseInt(w.getAttribute(t)); | |
| } | |
| edge.attr("d", function(d) { | |
| function lerp(t, a, b) { | |
| return (1.0-t) * a + t * b; | |
| } | |
| var x1 = proc(name_to_g[d.source],"x") + node_width /2; | |
| var y1 = proc(name_to_g[d.source],"y") + node_height; | |
| var x2 = proc(name_to_g[d.target],"x") + node_width /2; | |
| var y2 = proc(name_to_g[d.target],"y"); | |
| var s = "M " + x1 + " " + y1 | |
| + " C " + x1 + " " + lerp(.5, y1, y2) | |
| + " " + x2 + " " + lerp(.5, y1, y2) | |
| + " " + x2 + " " + y2 | |
| return s; | |
| }); | |
| } | |
| console.log("Helllo!"); | |
| buildGraph(); | |
| </script> | |
| """ | |
| def TensorTypeToName(tensor_type): | |
| """Converts a numerical enum to a readable tensor type.""" | |
| for name, value in schema_fb.TensorType.__dict__.items(): | |
| if value == tensor_type: | |
| return name | |
| return None | |
| def BuiltinCodeToName(code): | |
| """Converts a builtin op code enum to a readable name.""" | |
| for name, value in schema_fb.BuiltinOperator.__dict__.items(): | |
| if value == code: | |
| return name | |
| return None | |
| def NameListToString(name_list): | |
| """Converts a list of integers to the equivalent ASCII string.""" | |
| if isinstance(name_list, str): | |
| return name_list | |
| else: | |
| result = "" | |
| if name_list is not None: | |
| for val in name_list: | |
| result = result + chr(int(val)) | |
| return result | |
| class OpCodeMapper: | |
| """Maps an opcode index to an op name.""" | |
| def __init__(self, data): | |
| self.code_to_name = {} | |
| for idx, d in enumerate(data["operator_codes"]): | |
| self.code_to_name[idx] = BuiltinCodeToName(d["builtin_code"]) | |
| if self.code_to_name[idx] == "CUSTOM": | |
| self.code_to_name[idx] = NameListToString(d["custom_code"]) | |
| def __call__(self, x): | |
| if x not in self.code_to_name: | |
| s = "<UNKNOWN>" | |
| else: | |
| s = self.code_to_name[x] | |
| return "%s (%d)" % (s, x) | |
| class DataSizeMapper: | |
| """For buffers, report the number of bytes.""" | |
| def __call__(self, x): | |
| if x is not None: | |
| return "%d bytes" % len(x) | |
| else: | |
| return "--" | |
| class TensorMapper: | |
| """Maps a list of tensor indices to a tooltip hoverable indicator of more.""" | |
| def __init__(self, subgraph_data): | |
| self.data = subgraph_data | |
| def __call__(self, x): | |
| html = "" | |
| if x is None: | |
| return html | |
| html += "<span class='tooltip'><span class='tooltipcontent'>" | |
| for i in x: | |
| tensor = self.data["tensors"][i] | |
| html += str(i) + " " | |
| html += NameListToString(tensor["name"]) + " " | |
| html += TensorTypeToName(tensor["type"]) + " " | |
| html += (repr(tensor["shape"]) if "shape" in tensor else "[]") | |
| html += (repr(tensor["shape_signature"]) | |
| if "shape_signature" in tensor else "[]") + "<br>" | |
| html += "</span>" | |
| html += repr(x) | |
| html += "</span>" | |
| return html | |
| def GenerateGraph(subgraph_idx, g, opcode_mapper): | |
| """Produces the HTML required to have a d3 visualization of the dag.""" | |
| def TensorName(idx): | |
| return "t%d" % idx | |
| def OpName(idx): | |
| return "o%d" % idx | |
| edges = [] | |
| nodes = [] | |
| first = {} | |
| second = {} | |
| pixel_mult = 200 # TODO(aselle): multiplier for initial placement | |
| width_mult = 170 # TODO(aselle): multiplier for initial placement | |
| for op_index, op in enumerate(g["operators"] or []): | |
| if op["inputs"] is not None: | |
| for tensor_input_position, tensor_index in enumerate(op["inputs"]): | |
| if tensor_index not in first: | |
| first[tensor_index] = ((op_index - 0.5 + 1) * pixel_mult, | |
| (tensor_input_position + 1) * width_mult) | |
| edges.append({ | |
| "source": TensorName(tensor_index), | |
| "target": OpName(op_index) | |
| }) | |
| if op["outputs"] is not None: | |
| for tensor_output_position, tensor_index in enumerate(op["outputs"]): | |
| if tensor_index not in second: | |
| second[tensor_index] = ((op_index + 0.5 + 1) * pixel_mult, | |
| (tensor_output_position + 1) * width_mult) | |
| edges.append({ | |
| "target": TensorName(tensor_index), | |
| "source": OpName(op_index) | |
| }) | |
| nodes.append({ | |
| "id": OpName(op_index), | |
| "name": opcode_mapper(op["opcode_index"]), | |
| "group": 2, | |
| "x": pixel_mult, | |
| "y": (op_index + 1) * pixel_mult | |
| }) | |
| for tensor_index, tensor in enumerate(g["tensors"]): | |
| initial_y = ( | |
| first[tensor_index] if tensor_index in first else | |
| second[tensor_index] if tensor_index in second else (0, 0)) | |
| nodes.append({ | |
| "id": TensorName(tensor_index), | |
| "name": "%r (%d)" % (getattr(tensor, "shape", []), tensor_index), | |
| "group": 1, | |
| "x": initial_y[1], | |
| "y": initial_y[0] | |
| }) | |
| graph_str = json.dumps({"nodes": nodes, "edges": edges}) | |
| html = _D3_HTML_TEMPLATE % (graph_str, subgraph_idx) | |
| return html | |
| def GenerateTableHtml(items, keys_to_print, display_index=True): | |
| """Given a list of object values and keys to print, make an HTML table. | |
| Args: | |
| items: Items to print an array of dicts. | |
| keys_to_print: (key, display_fn). `key` is a key in the object. i.e. | |
| items[0][key] should exist. display_fn is the mapping function on display. | |
| i.e. the displayed html cell will have the string returned by | |
| `mapping_fn(items[0][key])`. | |
| display_index: add a column which is the index of each row in `items`. | |
| Returns: | |
| An html table. | |
| """ | |
| html = "" | |
| # Print the list of items | |
| html += "<table><tr>\n" | |
| html += "<tr>\n" | |
| if display_index: | |
| html += "<th>index</th>" | |
| for h, mapper in keys_to_print: | |
| html += "<th>%s</th>" % h | |
| html += "</tr>\n" | |
| for idx, tensor in enumerate(items): | |
| html += "<tr>\n" | |
| if display_index: | |
| html += "<td>%d</td>" % idx | |
| # print tensor.keys() | |
| for h, mapper in keys_to_print: | |
| val = tensor[h] if h in tensor else None | |
| val = val if mapper is None else mapper(val) | |
| html += "<td>%s</td>\n" % val | |
| html += "</tr>\n" | |
| html += "</table>\n" | |
| return html | |
| def CamelCaseToSnakeCase(camel_case_input): | |
| """Converts an identifier in CamelCase to snake_case.""" | |
| s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_input) | |
| return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() | |
| def FlatbufferToDict(fb, preserve_as_numpy): | |
| """Converts a hierarchy of FB objects into a nested dict. | |
| We avoid transforming big parts of the flat buffer into python arrays. This | |
| speeds conversion from ten minutes to a few seconds on big graphs. | |
| Args: | |
| fb: a flat buffer structure. (i.e. ModelT) | |
| preserve_as_numpy: true if all downstream np.arrays should be preserved. | |
| false if all downstream np.array should become python arrays | |
| Returns: | |
| A dictionary representing the flatbuffer rather than a flatbuffer object. | |
| """ | |
| if isinstance(fb, int) or isinstance(fb, float) or isinstance(fb, str): | |
| return fb | |
| elif hasattr(fb, "__dict__"): | |
| result = {} | |
| for attribute_name in dir(fb): | |
| attribute = fb.__getattribute__(attribute_name) | |
| if not callable(attribute) and attribute_name[0] != "_": | |
| snake_name = CamelCaseToSnakeCase(attribute_name) | |
| preserve = True if attribute_name == "buffers" else preserve_as_numpy | |
| result[snake_name] = FlatbufferToDict(attribute, preserve) | |
| return result | |
| elif isinstance(fb, np.ndarray): | |
| return fb if preserve_as_numpy else fb.tolist() | |
| elif hasattr(fb, "__len__"): | |
| return [FlatbufferToDict(entry, preserve_as_numpy) for entry in fb] | |
| else: | |
| return fb | |
| def CreateDictFromFlatbuffer(buffer_data): | |
| model_obj = schema_fb.Model.GetRootAsModel(buffer_data, 0) | |
| model = schema_fb.ModelT.InitFromObj(model_obj) | |
| return FlatbufferToDict(model, preserve_as_numpy=False) | |
| def create_html(tflite_input, input_is_filepath=True): # pylint: disable=invalid-name | |
| """Returns html description with the given tflite model. | |
| Args: | |
| tflite_input: TFLite flatbuffer model path or model object. | |
| input_is_filepath: Tells if tflite_input is a model path or a model object. | |
| Returns: | |
| Dump of the given tflite model in HTML format. | |
| Raises: | |
| RuntimeError: If the input is not valid. | |
| """ | |
| # Convert the model into a JSON flatbuffer using flatc (build if doesn't | |
| # exist. | |
| if input_is_filepath: | |
| if not os.path.exists(tflite_input): | |
| raise RuntimeError("Invalid filename %r" % tflite_input) | |
| if tflite_input.endswith(".tflite") or tflite_input.endswith(".bin") or tflite_input.endswith(".tf_lite"): | |
| with open(tflite_input, "rb") as file_handle: | |
| file_data = bytearray(file_handle.read()) | |
| data = CreateDictFromFlatbuffer(file_data) | |
| elif tflite_input.endswith(".json"): | |
| data = json.load(open(tflite_input)) | |
| else: | |
| raise RuntimeError("Input file was not .tflite or .json") | |
| else: | |
| data = CreateDictFromFlatbuffer(tflite_input) | |
| html = "" | |
| # html += _CSS | |
| html += "<h1>TensorFlow Lite Model</h2>" | |
| data["filename"] = tflite_input if input_is_filepath else ( | |
| "Null (used model object)") # Avoid special case | |
| toplevel_stuff = [("filename", None), ("version", None), | |
| ("description", None)] | |
| html += "<table>\n" | |
| for key, mapping in toplevel_stuff: | |
| if not mapping: | |
| mapping = lambda x: x | |
| html += "<tr><th>%s</th><td>%s</td></tr>\n" % (key, mapping(data.get(key))) | |
| html += "</table>\n" | |
| # Spec on what keys to display | |
| buffer_keys_to_display = [("data", DataSizeMapper())] | |
| operator_keys_to_display = [("builtin_code", BuiltinCodeToName), | |
| ("custom_code", NameListToString), | |
| ("version", None)] | |
| # Update builtin code fields. | |
| for d in data["operator_codes"]: | |
| d["builtin_code"] = max(d["builtin_code"], d["deprecated_builtin_code"]) | |
| for subgraph_idx, g in enumerate(data["subgraphs"]): | |
| # Subgraph local specs on what to display | |
| html += "<div class='subgraph'>" | |
| tensor_mapper = TensorMapper(g) | |
| opcode_mapper = OpCodeMapper(data) | |
| op_keys_to_display = [("inputs", tensor_mapper), ("outputs", tensor_mapper), | |
| ("builtin_options", None), | |
| ("opcode_index", opcode_mapper)] | |
| tensor_keys_to_display = [("name", NameListToString), | |
| ("type", TensorTypeToName), ("shape", None), | |
| ("shape_signature", None), ("buffer", None), | |
| ("quantization", None)] | |
| html += "<h2>Subgraph %d</h2>\n" % subgraph_idx | |
| # Inputs and outputs. | |
| html += "<h3>Inputs/Outputs</h3>\n" | |
| html += GenerateTableHtml([{ | |
| "inputs": g["inputs"], | |
| "outputs": g["outputs"] | |
| }], [("inputs", tensor_mapper), ("outputs", tensor_mapper)], | |
| display_index=False) | |
| # Print the tensors. | |
| html += "<h3>Tensors</h3>\n" | |
| html += GenerateTableHtml(g["tensors"], tensor_keys_to_display) | |
| # Print the ops. | |
| if g["operators"]: | |
| html += "<h3>Ops</h3>\n" | |
| html += GenerateTableHtml(g["operators"], op_keys_to_display) | |
| # Visual graph. | |
| html += "<svg id='subgraph%d' width='1600' height='900'></svg>\n" % ( | |
| subgraph_idx,) | |
| html += GenerateGraph(subgraph_idx, g, opcode_mapper) | |
| html += "</div>" | |
| # Buffers have no data, but maybe in the future they will | |
| html += "<h2>Buffers</h2>\n" | |
| html += GenerateTableHtml(data["buffers"], buffer_keys_to_display) | |
| # Operator codes | |
| html += "<h2>Operator Codes</h2>\n" | |
| html += GenerateTableHtml(data["operator_codes"], operator_keys_to_display) | |
| # html += "</body></html>\n" | |
| # return f"<iframe src={escape(html)} ></iframe>" | |
| html += """ <script src="https://d3js.org/d3.v4.min.js"></script> """ | |
| return html | |
| def main(argv): | |
| try: | |
| tflite_input = argv[1] | |
| html_output = argv[2] | |
| except IndexError: | |
| print("Usage: %s <input tflite> <output html>" % (argv[0])) | |
| else: | |
| html = create_html(tflite_input) | |
| with open(html_output, "w") as output_file: | |
| output_file.write(html) | |
| def process_file(file): | |
| try: | |
| html = create_html(file.name) | |
| return html | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| with gr.Blocks(head=_CSS, ) as demo: | |
| gr.Markdown( | |
| """ | |
| ## TensorFlow Lite Model Visualizer | |
| Drag and drop your `.tflite`, `.bin` or `.tf_lite` model files below to analyze them. | |
| """) | |
| file_input = gr.File(label="Upload TFLite File") | |
| html_output = gr.HTML(label="Generated HTML", container=True) | |
| file_input.change(process_file, inputs=file_input, outputs=html_output) | |
| demo.launch() | |
| # if __name__ == "__main__": | |
| # main(sys.argv) | |