|
|
from collections import OrderedDict |
|
|
import contextlib |
|
|
from typing import Dict, Any |
|
|
|
|
|
from tensorboard.compat.proto.config_pb2 import RunMetadata |
|
|
from tensorboard.compat.proto.graph_pb2 import GraphDef |
|
|
from tensorboard.compat.proto.step_stats_pb2 import StepStats, DeviceStepStats |
|
|
from tensorboard.compat.proto.versions_pb2 import VersionDef |
|
|
|
|
|
import torch |
|
|
from ._proto_graph import node_proto |
|
|
|
|
|
methods_OP = [ |
|
|
"attributeNames", |
|
|
"hasMultipleOutputs", |
|
|
"hasUses", |
|
|
"inputs", |
|
|
"kind", |
|
|
"outputs", |
|
|
"outputsSize", |
|
|
"scopeName", |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
methods_IO = ["node", "offset", "debugName"] |
|
|
|
|
|
GETATTR_KIND = "prim::GetAttr" |
|
|
CLASSTYPE_KIND = "ClassType" |
|
|
|
|
|
|
|
|
class NodeBase(object): |
|
|
def __init__( |
|
|
self, |
|
|
debugName=None, |
|
|
inputs=None, |
|
|
scope=None, |
|
|
tensor_size=None, |
|
|
op_type="UnSpecified", |
|
|
attributes="", |
|
|
): |
|
|
|
|
|
|
|
|
self.debugName = debugName |
|
|
self.inputs = inputs |
|
|
self.tensor_size = tensor_size |
|
|
self.kind = op_type |
|
|
self.attributes = attributes |
|
|
self.scope = scope |
|
|
|
|
|
def __repr__(self): |
|
|
repr = [] |
|
|
repr.append(str(type(self))) |
|
|
for m in dir(self): |
|
|
if "__" not in m: |
|
|
repr.append( |
|
|
m + ": " + str(getattr(self, m)) + str(type(getattr(self, m))) |
|
|
) |
|
|
return "\n".join(repr) + "\n\n" |
|
|
|
|
|
|
|
|
class NodePy(NodeBase): |
|
|
def __init__(self, node_cpp, valid_methods): |
|
|
super(NodePy, self).__init__(node_cpp) |
|
|
valid_methods = valid_methods[:] |
|
|
self.inputs = [] |
|
|
|
|
|
for m in valid_methods: |
|
|
if m == "inputs" or m == "outputs": |
|
|
list_of_node = list(getattr(node_cpp, m)()) |
|
|
io_unique_names = [] |
|
|
io_tensor_sizes = [] |
|
|
for n in list_of_node: |
|
|
io_unique_names.append(n.debugName()) |
|
|
if n.isCompleteTensor(): |
|
|
io_tensor_sizes.append(n.type().sizes()) |
|
|
else: |
|
|
io_tensor_sizes.append(None) |
|
|
|
|
|
setattr(self, m, io_unique_names) |
|
|
setattr(self, m + "tensor_size", io_tensor_sizes) |
|
|
|
|
|
else: |
|
|
setattr(self, m, getattr(node_cpp, m)()) |
|
|
|
|
|
|
|
|
class NodePyIO(NodePy): |
|
|
def __init__(self, node_cpp, input_or_output=None): |
|
|
super(NodePyIO, self).__init__(node_cpp, methods_IO) |
|
|
try: |
|
|
tensor_size = node_cpp.type().sizes() |
|
|
except RuntimeError: |
|
|
tensor_size = [ |
|
|
1, |
|
|
] |
|
|
self.tensor_size = tensor_size |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.kind = "Parameter" |
|
|
if input_or_output: |
|
|
self.input_or_output = input_or_output |
|
|
self.kind = "IO Node" |
|
|
|
|
|
|
|
|
class NodePyOP(NodePy): |
|
|
def __init__(self, node_cpp): |
|
|
super(NodePyOP, self).__init__(node_cpp, methods_OP) |
|
|
|
|
|
|
|
|
self.attributes = str( |
|
|
{k: _node_get(node_cpp, k) for k in node_cpp.attributeNames()} |
|
|
).replace("'", " ") |
|
|
self.kind = node_cpp.kind() |
|
|
|
|
|
|
|
|
class GraphPy(object): |
|
|
"""Helper class to convert torch.nn.Module to GraphDef proto and visualization |
|
|
with TensorBoard. |
|
|
|
|
|
GraphDef generation operates in two passes: |
|
|
|
|
|
In the first pass, all nodes are read and saved to two lists. |
|
|
One list is for input/output nodes (nodes_io), which only have inbound |
|
|
or outbound connections, but not both. Another list is for internal |
|
|
operator nodes (nodes_op). The first pass also saves all scope name |
|
|
appeared in the nodes in scope_name_appeared list for later processing. |
|
|
|
|
|
In the second pass, scope names are fully applied to all nodes. |
|
|
debugNameToScopedName is a mapping from a node's ID to its fully qualified |
|
|
scope name. e.g. Net1/Linear[0]/1. Unfortunately torch.jit doesn't have |
|
|
totally correct scope output, so this is nontrivial. The function |
|
|
populate_namespace_from_OP_to_IO and find_common_root are used to |
|
|
assign scope name to a node based on the connection between nodes |
|
|
in a heuristic kind of way. Bookkeeping is done with shallowest_scope_name |
|
|
and scope_name_appeared. |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
self.nodes_op = [] |
|
|
self.nodes_io = OrderedDict() |
|
|
self.unique_name_to_scoped_name = {} |
|
|
self.shallowest_scope_name = "default" |
|
|
self.scope_name_appeared = [] |
|
|
|
|
|
def append(self, x): |
|
|
if isinstance(x, NodePyIO): |
|
|
self.nodes_io[x.debugName] = x |
|
|
if isinstance(x, NodePyOP): |
|
|
self.nodes_op.append(x) |
|
|
|
|
|
def printall(self): |
|
|
print("all nodes") |
|
|
for node in self.nodes_op: |
|
|
print(node) |
|
|
for key in self.nodes_io: |
|
|
print(self.nodes_io[key]) |
|
|
|
|
|
def find_common_root(self): |
|
|
for fullscope in self.scope_name_appeared: |
|
|
if fullscope: |
|
|
self.shallowest_scope_name = fullscope.split("/")[0] |
|
|
|
|
|
def populate_namespace_from_OP_to_IO(self): |
|
|
for node in self.nodes_op: |
|
|
for node_output, outputSize in zip(node.outputs, node.outputstensor_size): |
|
|
self.scope_name_appeared.append(node.scopeName) |
|
|
self.nodes_io[node_output] = NodeBase( |
|
|
node_output, |
|
|
node.inputs, |
|
|
node.scopeName, |
|
|
outputSize, |
|
|
op_type=node.kind, |
|
|
attributes=node.attributes, |
|
|
) |
|
|
|
|
|
self.find_common_root() |
|
|
|
|
|
for node in self.nodes_op: |
|
|
for input_node_id in node.inputs: |
|
|
self.unique_name_to_scoped_name[input_node_id] = ( |
|
|
node.scopeName + "/" + input_node_id |
|
|
) |
|
|
|
|
|
for key, node in self.nodes_io.items(): |
|
|
if type(node) == NodeBase: |
|
|
self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName |
|
|
if hasattr(node, "input_or_output"): |
|
|
self.unique_name_to_scoped_name[key] = ( |
|
|
node.input_or_output + "/" + node.debugName |
|
|
) |
|
|
|
|
|
if hasattr(node, "scope") and node.scope is not None: |
|
|
self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName |
|
|
if node.scope == "" and self.shallowest_scope_name: |
|
|
self.unique_name_to_scoped_name[node.debugName] = ( |
|
|
self.shallowest_scope_name + "/" + node.debugName |
|
|
) |
|
|
|
|
|
|
|
|
for key, node in self.nodes_io.items(): |
|
|
self.nodes_io[key].inputs = [ |
|
|
self.unique_name_to_scoped_name[node_input_id] |
|
|
for node_input_id in node.inputs |
|
|
] |
|
|
if node.debugName in self.unique_name_to_scoped_name: |
|
|
self.nodes_io[key].debugName = self.unique_name_to_scoped_name[ |
|
|
node.debugName |
|
|
] |
|
|
|
|
|
def to_proto(self): |
|
|
""" |
|
|
Converts graph representation of GraphPy object to TensorBoard |
|
|
required format. |
|
|
""" |
|
|
|
|
|
|
|
|
nodes = [] |
|
|
for v in self.nodes_io.values(): |
|
|
nodes.append( |
|
|
node_proto( |
|
|
v.debugName, |
|
|
input=v.inputs, |
|
|
outputsize=v.tensor_size, |
|
|
op=v.kind, |
|
|
attributes=v.attributes, |
|
|
) |
|
|
) |
|
|
return nodes |
|
|
|
|
|
|
|
|
def parse(graph, trace, args=None, omit_useless_nodes=True): |
|
|
"""This method parses an optimized PyTorch model graph and produces |
|
|
a list of nodes and node stats for eventual conversion to TensorBoard |
|
|
protobuf format. |
|
|
|
|
|
Args: |
|
|
graph (PyTorch module): The model graph to be parsed. |
|
|
trace (PyTorch JIT TracedModule): The model trace to be parsed. |
|
|
args (tuple): input tensor[s] for the model. |
|
|
omit_useless_nodes (boolean): Whether to remove nodes from the graph. |
|
|
""" |
|
|
n_inputs = len(args) |
|
|
|
|
|
scope = {} |
|
|
nodes_py = GraphPy() |
|
|
for node in graph.inputs(): |
|
|
if omit_useless_nodes: |
|
|
if ( |
|
|
len(node.uses()) == 0 |
|
|
): |
|
|
continue |
|
|
|
|
|
if node.type().kind() != CLASSTYPE_KIND: |
|
|
nodes_py.append(NodePyIO(node, "input")) |
|
|
|
|
|
attr_to_scope: Dict[Any, str] = {} |
|
|
for node in graph.nodes(): |
|
|
if node.kind() == GETATTR_KIND: |
|
|
attr_name = node.s("name") |
|
|
attr_key = node.output().debugName() |
|
|
parent = node.input().node() |
|
|
if ( |
|
|
parent.kind() == GETATTR_KIND |
|
|
): |
|
|
parent_attr_name = parent.s("name") |
|
|
parent_attr_key = parent.output().debugName() |
|
|
parent_scope = attr_to_scope[parent_attr_key] |
|
|
attr_scope = parent_scope.split("/")[-1] |
|
|
attr_to_scope[attr_key] = "{}/{}.{}".format( |
|
|
parent_scope, attr_scope, attr_name |
|
|
) |
|
|
else: |
|
|
attr_to_scope[attr_key] = "__module.{}".format(attr_name) |
|
|
|
|
|
if node.output().type().kind() != CLASSTYPE_KIND: |
|
|
node_py = NodePyOP(node) |
|
|
node_py.scopeName = attr_to_scope[attr_key] |
|
|
nodes_py.append(node_py) |
|
|
else: |
|
|
nodes_py.append(NodePyOP(node)) |
|
|
|
|
|
for i, node in enumerate(graph.outputs()): |
|
|
node_pyio = NodePyIO(node, "output") |
|
|
node_pyio.debugName = "output.{}".format(i + 1) |
|
|
node_pyio.inputs = [node.debugName()] |
|
|
nodes_py.append(node_pyio) |
|
|
|
|
|
def parse_traced_name(module): |
|
|
if isinstance(module, torch.jit.TracedModule): |
|
|
module_name = module._name |
|
|
else: |
|
|
module_name = getattr(module, "original_name", "Module") |
|
|
return module_name |
|
|
|
|
|
alias_to_name = {} |
|
|
base_name = parse_traced_name(trace) |
|
|
for name, module in trace.named_modules(prefix="__module"): |
|
|
mod_name = parse_traced_name(module) |
|
|
attr_name = name.split(".")[-1] |
|
|
alias_to_name[name] = "{}[{}]".format(mod_name, attr_name) |
|
|
|
|
|
for node in nodes_py.nodes_op: |
|
|
module_aliases = node.scopeName.split("/") |
|
|
replacements = [ |
|
|
alias_to_name[alias] if alias in alias_to_name else alias.split(".")[-1] |
|
|
for alias in module_aliases |
|
|
] |
|
|
node.scopeName = base_name |
|
|
if any(replacements): |
|
|
node.scopeName += "/" + "/".join(replacements) |
|
|
|
|
|
nodes_py.populate_namespace_from_OP_to_IO() |
|
|
return nodes_py.to_proto() |
|
|
|
|
|
|
|
|
def graph(model, args, verbose=False, use_strict_trace=True): |
|
|
""" |
|
|
This method processes a PyTorch model and produces a `GraphDef` proto |
|
|
that can be logged to TensorBoard. |
|
|
|
|
|
Args: |
|
|
model (PyTorch module): The model to be parsed. |
|
|
args (tuple): input tensor[s] for the model. |
|
|
verbose (bool): Whether to print out verbose information while |
|
|
processing. |
|
|
use_strict_trace (bool): Whether to pass keyword argument `strict` to |
|
|
`torch.jit.trace`. Pass False when you want the tracer to |
|
|
record your mutable container types (list, dict) |
|
|
""" |
|
|
with _set_model_to_eval(model): |
|
|
try: |
|
|
trace = torch.jit.trace(model, args, strict=use_strict_trace) |
|
|
graph = trace.graph |
|
|
torch._C._jit_pass_inline(graph) |
|
|
except RuntimeError as e: |
|
|
print(e) |
|
|
print("Error occurs, No graph saved") |
|
|
raise e |
|
|
|
|
|
if verbose: |
|
|
print(graph) |
|
|
list_of_nodes = parse(graph, trace, args) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
stepstats = RunMetadata( |
|
|
step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0")]) |
|
|
) |
|
|
return GraphDef(node=list_of_nodes, versions=VersionDef(producer=22)), stepstats |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@contextlib.contextmanager |
|
|
def _set_model_to_eval(model): |
|
|
"""A context manager to temporarily set the training mode of ``model`` to eval.""" |
|
|
if not isinstance(model, torch.jit.ScriptFunction): |
|
|
originally_training = model.training |
|
|
model.train(False) |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
model.train(originally_training) |
|
|
else: |
|
|
|
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
pass |
|
|
|
|
|
|
|
|
def _node_get(node: torch._C.Node, key: str): |
|
|
"""Gets attributes of a node which is polymorphic over return type.""" |
|
|
sel = node.kindOf(key) |
|
|
return getattr(node, sel)(key) |
|
|
|