|
|
"""Utilities related to model visualization.""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
|
|
|
from keras.src import tree |
|
|
from keras.src.api_export import keras_export |
|
|
from keras.src.utils import io_utils |
|
|
|
|
|
try: |
|
|
import pydot |
|
|
except ImportError: |
|
|
|
|
|
|
|
|
try: |
|
|
import pydot_ng as pydot |
|
|
except ImportError: |
|
|
try: |
|
|
import pydotplus as pydot |
|
|
except ImportError: |
|
|
pydot = None |
|
|
|
|
|
|
|
|
def check_pydot(): |
|
|
"""Returns True if PyDot is available.""" |
|
|
return pydot is not None |
|
|
|
|
|
|
|
|
def check_graphviz(): |
|
|
"""Returns True if both PyDot and Graphviz are available.""" |
|
|
if not check_pydot(): |
|
|
return False |
|
|
try: |
|
|
|
|
|
|
|
|
pydot.Dot.create(pydot.Dot()) |
|
|
return True |
|
|
except (OSError, pydot.PydotException): |
|
|
return False |
|
|
|
|
|
|
|
|
def add_edge(dot, src, dst): |
|
|
src_id = str(id(src)) |
|
|
dst_id = str(id(dst)) |
|
|
if not dot.get_edge(src_id, dst_id): |
|
|
edge = pydot.Edge(src_id, dst_id) |
|
|
edge.set("penwidth", "2") |
|
|
dot.add_edge(edge) |
|
|
|
|
|
|
|
|
def get_layer_activation_name(layer): |
|
|
if hasattr(layer.activation, "name"): |
|
|
activation_name = layer.activation.name |
|
|
elif hasattr(layer.activation, "__name__"): |
|
|
activation_name = layer.activation.__name__ |
|
|
else: |
|
|
activation_name = str(layer.activation) |
|
|
return activation_name |
|
|
|
|
|
|
|
|
def make_layer_label(layer, **kwargs): |
|
|
class_name = layer.__class__.__name__ |
|
|
|
|
|
show_layer_names = kwargs.pop("show_layer_names") |
|
|
show_layer_activations = kwargs.pop("show_layer_activations") |
|
|
show_dtype = kwargs.pop("show_dtype") |
|
|
show_shapes = kwargs.pop("show_shapes") |
|
|
show_trainable = kwargs.pop("show_trainable") |
|
|
if kwargs: |
|
|
raise ValueError(f"Invalid kwargs: {kwargs}") |
|
|
|
|
|
table = ( |
|
|
'<<table border="0" cellborder="1" bgcolor="black" cellpadding="10">' |
|
|
) |
|
|
|
|
|
colspan_max = sum(int(x) for x in (show_dtype, show_trainable)) |
|
|
if show_shapes: |
|
|
colspan_max += 2 |
|
|
colspan = max(1, colspan_max) |
|
|
|
|
|
if show_layer_names: |
|
|
table += ( |
|
|
f'<tr><td colspan="{colspan}" bgcolor="black">' |
|
|
'<font point-size="16" color="white">' |
|
|
f"<b>{layer.name}</b> ({class_name})" |
|
|
"</font></td></tr>" |
|
|
) |
|
|
else: |
|
|
table += ( |
|
|
f'<tr><td colspan="{colspan}" bgcolor="black">' |
|
|
'<font point-size="16" color="white">' |
|
|
f"<b>{class_name}</b>" |
|
|
"</font></td></tr>" |
|
|
) |
|
|
if ( |
|
|
show_layer_activations |
|
|
and hasattr(layer, "activation") |
|
|
and layer.activation is not None |
|
|
): |
|
|
table += ( |
|
|
f'<tr><td bgcolor="white" colspan="{colspan}">' |
|
|
'<font point-size="14">' |
|
|
f"Activation: <b>{get_layer_activation_name(layer)}</b>" |
|
|
"</font></td></tr>" |
|
|
) |
|
|
|
|
|
cols = [] |
|
|
if show_shapes: |
|
|
input_shape = None |
|
|
output_shape = None |
|
|
try: |
|
|
input_shape = tree.map_structure(lambda x: x.shape, layer.input) |
|
|
output_shape = tree.map_structure(lambda x: x.shape, layer.output) |
|
|
except (ValueError, AttributeError): |
|
|
pass |
|
|
|
|
|
def format_shape(shape): |
|
|
if shape is not None: |
|
|
if isinstance(shape, dict): |
|
|
shape_str = ", ".join( |
|
|
[f"{k}: {v}" for k, v in shape.items()] |
|
|
) |
|
|
else: |
|
|
shape_str = f"{shape}" |
|
|
shape_str = shape_str.replace("}", "").replace("{", "") |
|
|
else: |
|
|
shape_str = "?" |
|
|
return shape_str |
|
|
|
|
|
if class_name != "InputLayer": |
|
|
cols.append( |
|
|
( |
|
|
'<td bgcolor="white"><font point-size="14">' |
|
|
f"Input shape: <b>{format_shape(input_shape)}</b>" |
|
|
"</font></td>" |
|
|
) |
|
|
) |
|
|
cols.append( |
|
|
( |
|
|
'<td bgcolor="white"><font point-size="14">' |
|
|
f"Output shape: <b>{format_shape(output_shape)}</b>" |
|
|
"</font></td>" |
|
|
) |
|
|
) |
|
|
if show_dtype: |
|
|
dtype = None |
|
|
try: |
|
|
dtype = tree.map_structure(lambda x: x.dtype, layer.output) |
|
|
except (ValueError, AttributeError): |
|
|
pass |
|
|
cols.append( |
|
|
( |
|
|
'<td bgcolor="white"><font point-size="14">' |
|
|
f"Output dtype: <b>{dtype or '?'}</b>" |
|
|
"</font></td>" |
|
|
) |
|
|
) |
|
|
if show_trainable and hasattr(layer, "trainable") and layer.weights: |
|
|
if layer.trainable: |
|
|
cols.append( |
|
|
( |
|
|
'<td bgcolor="forestgreen">' |
|
|
'<font point-size="14" color="white">' |
|
|
"<b>Trainable</b></font></td>" |
|
|
) |
|
|
) |
|
|
else: |
|
|
cols.append( |
|
|
( |
|
|
'<td bgcolor="firebrick">' |
|
|
'<font point-size="14" color="white">' |
|
|
"<b>Non-trainable</b></font></td>" |
|
|
) |
|
|
) |
|
|
if cols: |
|
|
colspan = len(cols) |
|
|
else: |
|
|
colspan = 1 |
|
|
|
|
|
if cols: |
|
|
table += "<tr>" + "".join(cols) + "</tr>" |
|
|
table += "</table>>" |
|
|
return table |
|
|
|
|
|
|
|
|
def make_node(layer, **kwargs): |
|
|
node = pydot.Node(str(id(layer)), label=make_layer_label(layer, **kwargs)) |
|
|
node.set("fontname", "Helvetica") |
|
|
node.set("border", "0") |
|
|
node.set("margin", "0") |
|
|
return node |
|
|
|
|
|
|
|
|
@keras_export("keras.utils.model_to_dot") |
|
|
def model_to_dot( |
|
|
model, |
|
|
show_shapes=False, |
|
|
show_dtype=False, |
|
|
show_layer_names=True, |
|
|
rankdir="TB", |
|
|
expand_nested=False, |
|
|
dpi=200, |
|
|
subgraph=False, |
|
|
show_layer_activations=False, |
|
|
show_trainable=False, |
|
|
**kwargs, |
|
|
): |
|
|
"""Convert a Keras model to dot format. |
|
|
|
|
|
Args: |
|
|
model: A Keras model instance. |
|
|
show_shapes: whether to display shape information. |
|
|
show_dtype: whether to display layer dtypes. |
|
|
show_layer_names: whether to display layer names. |
|
|
rankdir: `rankdir` argument passed to PyDot, |
|
|
a string specifying the format of the plot: `"TB"` |
|
|
creates a vertical plot; `"LR"` creates a horizontal plot. |
|
|
expand_nested: whether to expand nested Functional models |
|
|
into clusters. |
|
|
dpi: Image resolution in dots per inch. |
|
|
subgraph: whether to return a `pydot.Cluster` instance. |
|
|
show_layer_activations: Display layer activations (only for layers that |
|
|
have an `activation` property). |
|
|
show_trainable: whether to display if a layer is trainable. |
|
|
|
|
|
Returns: |
|
|
A `pydot.Dot` instance representing the Keras model or |
|
|
a `pydot.Cluster` instance representing nested model if |
|
|
`subgraph=True`. |
|
|
""" |
|
|
from keras.src.ops.function import make_node_key |
|
|
|
|
|
if not model.built: |
|
|
raise ValueError( |
|
|
"This model has not yet been built. " |
|
|
"Build the model first by calling `build()` or by calling " |
|
|
"the model on a batch of data." |
|
|
) |
|
|
|
|
|
from keras.src.models import functional |
|
|
from keras.src.models import sequential |
|
|
|
|
|
|
|
|
|
|
|
if not check_pydot(): |
|
|
raise ImportError( |
|
|
"You must install pydot (`pip install pydot`) for " |
|
|
"model_to_dot to work." |
|
|
) |
|
|
|
|
|
if subgraph: |
|
|
dot = pydot.Cluster(style="dashed", graph_name=model.name) |
|
|
dot.set("label", model.name) |
|
|
dot.set("labeljust", "l") |
|
|
else: |
|
|
dot = pydot.Dot() |
|
|
dot.set("rankdir", rankdir) |
|
|
dot.set("concentrate", True) |
|
|
dot.set("dpi", dpi) |
|
|
dot.set("splines", "ortho") |
|
|
dot.set_node_defaults(shape="record") |
|
|
|
|
|
if kwargs.pop("layer_range", None) is not None: |
|
|
raise ValueError("Argument `layer_range` is no longer supported.") |
|
|
if kwargs: |
|
|
raise ValueError(f"Unrecognized keyword arguments: {kwargs}") |
|
|
|
|
|
kwargs = { |
|
|
"show_layer_names": show_layer_names, |
|
|
"show_layer_activations": show_layer_activations, |
|
|
"show_dtype": show_dtype, |
|
|
"show_shapes": show_shapes, |
|
|
"show_trainable": show_trainable, |
|
|
} |
|
|
|
|
|
if isinstance(model, sequential.Sequential): |
|
|
layers = model.layers |
|
|
elif not isinstance(model, functional.Functional): |
|
|
|
|
|
node = make_node(model, **kwargs) |
|
|
dot.add_node(node) |
|
|
return dot |
|
|
else: |
|
|
layers = model._operations |
|
|
|
|
|
|
|
|
for i, layer in enumerate(layers): |
|
|
|
|
|
if expand_nested and isinstance( |
|
|
layer, (functional.Functional, sequential.Sequential) |
|
|
): |
|
|
submodel = model_to_dot( |
|
|
layer, |
|
|
show_shapes, |
|
|
show_dtype, |
|
|
show_layer_names, |
|
|
rankdir, |
|
|
expand_nested, |
|
|
subgraph=True, |
|
|
show_layer_activations=show_layer_activations, |
|
|
show_trainable=show_trainable, |
|
|
) |
|
|
dot.add_subgraph(submodel) |
|
|
|
|
|
else: |
|
|
node = make_node(layer, **kwargs) |
|
|
dot.add_node(node) |
|
|
|
|
|
|
|
|
if isinstance(model, sequential.Sequential): |
|
|
if not expand_nested: |
|
|
|
|
|
for i in range(len(layers) - 1): |
|
|
add_edge(dot, layers[i], layers[i + 1]) |
|
|
return dot |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
layers = model.layers[1:] |
|
|
|
|
|
|
|
|
for layer in layers: |
|
|
|
|
|
for inbound_index, inbound_node in enumerate(layer._inbound_nodes): |
|
|
|
|
|
if ( |
|
|
isinstance(model, functional.Functional) |
|
|
and make_node_key(layer, inbound_index) not in model._nodes |
|
|
): |
|
|
continue |
|
|
|
|
|
|
|
|
for input_index, input_tensor in enumerate( |
|
|
inbound_node.input_tensors |
|
|
): |
|
|
|
|
|
|
|
|
input_history = input_tensor._keras_history |
|
|
if input_history.operation is None: |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
input_node = input_history.operation._inbound_nodes[ |
|
|
input_history.node_index |
|
|
] |
|
|
output_index = input_history.tensor_index |
|
|
|
|
|
|
|
|
source = input_node.operation |
|
|
destination = layer |
|
|
|
|
|
if not expand_nested: |
|
|
|
|
|
add_edge(dot, source, layer) |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
while isinstance( |
|
|
source, |
|
|
(functional.Functional, sequential.Sequential), |
|
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
source, _, output_index = source.outputs[ |
|
|
output_index |
|
|
]._keras_history |
|
|
|
|
|
|
|
|
while isinstance( |
|
|
destination, |
|
|
(functional.Functional, sequential.Sequential), |
|
|
): |
|
|
if isinstance(destination, functional.Functional): |
|
|
|
|
|
|
|
|
destination = destination.inputs[ |
|
|
input_index |
|
|
]._keras_history.operation |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
destination = destination.layers[0] |
|
|
|
|
|
add_edge(dot, source, destination) |
|
|
return dot |
|
|
|
|
|
|
|
|
@keras_export("keras.utils.plot_model") |
|
|
def plot_model( |
|
|
model, |
|
|
to_file="model.png", |
|
|
show_shapes=False, |
|
|
show_dtype=False, |
|
|
show_layer_names=False, |
|
|
rankdir="TB", |
|
|
expand_nested=False, |
|
|
dpi=200, |
|
|
show_layer_activations=False, |
|
|
show_trainable=False, |
|
|
**kwargs, |
|
|
): |
|
|
"""Converts a Keras model to dot format and save to a file. |
|
|
|
|
|
Example: |
|
|
|
|
|
```python |
|
|
inputs = ... |
|
|
outputs = ... |
|
|
model = keras.Model(inputs=inputs, outputs=outputs) |
|
|
|
|
|
dot_img_file = '/tmp/model_1.png' |
|
|
keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True) |
|
|
``` |
|
|
|
|
|
Args: |
|
|
model: A Keras model instance |
|
|
to_file: File name of the plot image. |
|
|
show_shapes: whether to display shape information. |
|
|
show_dtype: whether to display layer dtypes. |
|
|
show_layer_names: whether to display layer names. |
|
|
rankdir: `rankdir` argument passed to PyDot, |
|
|
a string specifying the format of the plot: `"TB"` |
|
|
creates a vertical plot; `"LR"` creates a horizontal plot. |
|
|
expand_nested: whether to expand nested Functional models |
|
|
into clusters. |
|
|
dpi: Image resolution in dots per inch. |
|
|
show_layer_activations: Display layer activations (only for layers that |
|
|
have an `activation` property). |
|
|
show_trainable: whether to display if a layer is trainable. |
|
|
|
|
|
Returns: |
|
|
A Jupyter notebook Image object if Jupyter is installed. |
|
|
This enables in-line display of the model plots in notebooks. |
|
|
""" |
|
|
|
|
|
if not model.built: |
|
|
raise ValueError( |
|
|
"This model has not yet been built. " |
|
|
"Build the model first by calling `build()` or by calling " |
|
|
"the model on a batch of data." |
|
|
) |
|
|
if not check_pydot(): |
|
|
message = ( |
|
|
"You must install pydot (`pip install pydot`) " |
|
|
"for `plot_model` to work." |
|
|
) |
|
|
if "IPython.core.magics.namespace" in sys.modules: |
|
|
|
|
|
|
|
|
io_utils.print_msg(message) |
|
|
return |
|
|
else: |
|
|
raise ImportError(message) |
|
|
if not check_graphviz(): |
|
|
message = ( |
|
|
"You must install graphviz " |
|
|
"(see instructions at https://graphviz.gitlab.io/download/) " |
|
|
"for `plot_model` to work." |
|
|
) |
|
|
if "IPython.core.magics.namespace" in sys.modules: |
|
|
|
|
|
|
|
|
io_utils.print_msg(message) |
|
|
return |
|
|
else: |
|
|
raise ImportError(message) |
|
|
|
|
|
if kwargs.pop("layer_range", None) is not None: |
|
|
raise ValueError("Argument `layer_range` is no longer supported.") |
|
|
if kwargs: |
|
|
raise ValueError(f"Unrecognized keyword arguments: {kwargs}") |
|
|
|
|
|
dot = model_to_dot( |
|
|
model, |
|
|
show_shapes=show_shapes, |
|
|
show_dtype=show_dtype, |
|
|
show_layer_names=show_layer_names, |
|
|
rankdir=rankdir, |
|
|
expand_nested=expand_nested, |
|
|
dpi=dpi, |
|
|
show_layer_activations=show_layer_activations, |
|
|
show_trainable=show_trainable, |
|
|
) |
|
|
to_file = str(to_file) |
|
|
if dot is None: |
|
|
return |
|
|
_, extension = os.path.splitext(to_file) |
|
|
if not extension: |
|
|
extension = "png" |
|
|
else: |
|
|
extension = extension[1:] |
|
|
|
|
|
dot.write(to_file, format=extension) |
|
|
|
|
|
|
|
|
|
|
|
if extension != "pdf": |
|
|
try: |
|
|
from IPython import display |
|
|
|
|
|
return display.Image(filename=to_file) |
|
|
except ImportError: |
|
|
pass |
|
|
|