joebruce1313's picture
Upload 38004 files
1f5470c verified
"""Utilities related to model visualization."""
import os
import sys
from keras.src import tree
from keras.src.api_export import keras_export
from keras.src.utils import io_utils
try:
import pydot
except ImportError:
# pydot_ng and pydotplus are older forks of pydot
# which may still be used by some users
try:
import pydot_ng as pydot
except ImportError:
try:
import pydotplus as pydot
except ImportError:
pydot = None
def check_pydot():
"""Returns True if PyDot is available."""
return pydot is not None
def check_graphviz():
"""Returns True if both PyDot and Graphviz are available."""
if not check_pydot():
return False
try:
# Attempt to create an image of a blank graph
# to check the pydot/graphviz installation.
pydot.Dot.create(pydot.Dot())
return True
except (OSError, pydot.PydotException):
return False
def add_edge(dot, src, dst):
src_id = str(id(src))
dst_id = str(id(dst))
if not dot.get_edge(src_id, dst_id):
edge = pydot.Edge(src_id, dst_id)
edge.set("penwidth", "2")
dot.add_edge(edge)
def get_layer_activation_name(layer):
if hasattr(layer.activation, "name"):
activation_name = layer.activation.name
elif hasattr(layer.activation, "__name__"):
activation_name = layer.activation.__name__
else:
activation_name = str(layer.activation)
return activation_name
def make_layer_label(layer, **kwargs):
class_name = layer.__class__.__name__
show_layer_names = kwargs.pop("show_layer_names")
show_layer_activations = kwargs.pop("show_layer_activations")
show_dtype = kwargs.pop("show_dtype")
show_shapes = kwargs.pop("show_shapes")
show_trainable = kwargs.pop("show_trainable")
if kwargs:
raise ValueError(f"Invalid kwargs: {kwargs}")
table = (
'<<table border="0" cellborder="1" bgcolor="black" cellpadding="10">'
)
colspan_max = sum(int(x) for x in (show_dtype, show_trainable))
if show_shapes:
colspan_max += 2
colspan = max(1, colspan_max)
if show_layer_names:
table += (
f'<tr><td colspan="{colspan}" bgcolor="black">'
'<font point-size="16" color="white">'
f"<b>{layer.name}</b> ({class_name})"
"</font></td></tr>"
)
else:
table += (
f'<tr><td colspan="{colspan}" bgcolor="black">'
'<font point-size="16" color="white">'
f"<b>{class_name}</b>"
"</font></td></tr>"
)
if (
show_layer_activations
and hasattr(layer, "activation")
and layer.activation is not None
):
table += (
f'<tr><td bgcolor="white" colspan="{colspan}">'
'<font point-size="14">'
f"Activation: <b>{get_layer_activation_name(layer)}</b>"
"</font></td></tr>"
)
cols = []
if show_shapes:
input_shape = None
output_shape = None
try:
input_shape = tree.map_structure(lambda x: x.shape, layer.input)
output_shape = tree.map_structure(lambda x: x.shape, layer.output)
except (ValueError, AttributeError):
pass
def format_shape(shape):
if shape is not None:
if isinstance(shape, dict):
shape_str = ", ".join(
[f"{k}: {v}" for k, v in shape.items()]
)
else:
shape_str = f"{shape}"
shape_str = shape_str.replace("}", "").replace("{", "")
else:
shape_str = "?"
return shape_str
if class_name != "InputLayer":
cols.append(
(
'<td bgcolor="white"><font point-size="14">'
f"Input shape: <b>{format_shape(input_shape)}</b>"
"</font></td>"
)
)
cols.append(
(
'<td bgcolor="white"><font point-size="14">'
f"Output shape: <b>{format_shape(output_shape)}</b>"
"</font></td>"
)
)
if show_dtype:
dtype = None
try:
dtype = tree.map_structure(lambda x: x.dtype, layer.output)
except (ValueError, AttributeError):
pass
cols.append(
(
'<td bgcolor="white"><font point-size="14">'
f"Output dtype: <b>{dtype or '?'}</b>"
"</font></td>"
)
)
if show_trainable and hasattr(layer, "trainable") and layer.weights:
if layer.trainable:
cols.append(
(
'<td bgcolor="forestgreen">'
'<font point-size="14" color="white">'
"<b>Trainable</b></font></td>"
)
)
else:
cols.append(
(
'<td bgcolor="firebrick">'
'<font point-size="14" color="white">'
"<b>Non-trainable</b></font></td>"
)
)
if cols:
colspan = len(cols)
else:
colspan = 1
if cols:
table += "<tr>" + "".join(cols) + "</tr>"
table += "</table>>"
return table
def make_node(layer, **kwargs):
node = pydot.Node(str(id(layer)), label=make_layer_label(layer, **kwargs))
node.set("fontname", "Helvetica")
node.set("border", "0")
node.set("margin", "0")
return node
@keras_export("keras.utils.model_to_dot")
def model_to_dot(
model,
show_shapes=False,
show_dtype=False,
show_layer_names=True,
rankdir="TB",
expand_nested=False,
dpi=200,
subgraph=False,
show_layer_activations=False,
show_trainable=False,
**kwargs,
):
"""Convert a Keras model to dot format.
Args:
model: A Keras model instance.
show_shapes: whether to display shape information.
show_dtype: whether to display layer dtypes.
show_layer_names: whether to display layer names.
rankdir: `rankdir` argument passed to PyDot,
a string specifying the format of the plot: `"TB"`
creates a vertical plot; `"LR"` creates a horizontal plot.
expand_nested: whether to expand nested Functional models
into clusters.
dpi: Image resolution in dots per inch.
subgraph: whether to return a `pydot.Cluster` instance.
show_layer_activations: Display layer activations (only for layers that
have an `activation` property).
show_trainable: whether to display if a layer is trainable.
Returns:
A `pydot.Dot` instance representing the Keras model or
a `pydot.Cluster` instance representing nested model if
`subgraph=True`.
"""
from keras.src.ops.function import make_node_key
if not model.built:
raise ValueError(
"This model has not yet been built. "
"Build the model first by calling `build()` or by calling "
"the model on a batch of data."
)
from keras.src.models import functional
from keras.src.models import sequential
# from keras.src.layers import Wrapper
if not check_pydot():
raise ImportError(
"You must install pydot (`pip install pydot`) for "
"model_to_dot to work."
)
if subgraph:
dot = pydot.Cluster(style="dashed", graph_name=model.name)
dot.set("label", model.name)
dot.set("labeljust", "l")
else:
dot = pydot.Dot()
dot.set("rankdir", rankdir)
dot.set("concentrate", True)
dot.set("dpi", dpi)
dot.set("splines", "ortho")
dot.set_node_defaults(shape="record")
if kwargs.pop("layer_range", None) is not None:
raise ValueError("Argument `layer_range` is no longer supported.")
if kwargs:
raise ValueError(f"Unrecognized keyword arguments: {kwargs}")
kwargs = {
"show_layer_names": show_layer_names,
"show_layer_activations": show_layer_activations,
"show_dtype": show_dtype,
"show_shapes": show_shapes,
"show_trainable": show_trainable,
}
if isinstance(model, sequential.Sequential):
layers = model.layers
elif not isinstance(model, functional.Functional):
# We treat subclassed models as a single node.
node = make_node(model, **kwargs)
dot.add_node(node)
return dot
else:
layers = model._operations
# Create graph nodes.
for i, layer in enumerate(layers):
# Process nested functional and sequential models.
if expand_nested and isinstance(
layer, (functional.Functional, sequential.Sequential)
):
submodel = model_to_dot(
layer,
show_shapes,
show_dtype,
show_layer_names,
rankdir,
expand_nested,
subgraph=True,
show_layer_activations=show_layer_activations,
show_trainable=show_trainable,
)
dot.add_subgraph(submodel)
else:
node = make_node(layer, **kwargs)
dot.add_node(node)
# Connect nodes with edges.
if isinstance(model, sequential.Sequential):
if not expand_nested:
# Single Sequential case.
for i in range(len(layers) - 1):
add_edge(dot, layers[i], layers[i + 1])
return dot
else:
# The first layer is connected to the `InputLayer`, which is not
# represented for Sequential models, so we skip it. What will draw
# the incoming edge from outside of the sequential model is the
# edge connecting the Sequential model itself.
layers = model.layers[1:]
# Functional and nested Sequential case.
for layer in layers:
# Go from current layer to input `Node`s.
for inbound_index, inbound_node in enumerate(layer._inbound_nodes):
# `inbound_node` is a `Node`.
if (
isinstance(model, functional.Functional)
and make_node_key(layer, inbound_index) not in model._nodes
):
continue
# Go from input `Node` to `KerasTensor` representing that input.
for input_index, input_tensor in enumerate(
inbound_node.input_tensors
):
# `input_tensor` is a `KerasTensor`.
# `input_history` is a `KerasHistory`.
input_history = input_tensor._keras_history
if input_history.operation is None:
# Operation is `None` for `Input` tensors.
continue
# Go from input `KerasTensor` to the `Operation` that produced
# it as an output.
input_node = input_history.operation._inbound_nodes[
input_history.node_index
]
output_index = input_history.tensor_index
# Tentative source and destination of the edge.
source = input_node.operation
destination = layer
if not expand_nested:
# No nesting, connect directly.
add_edge(dot, source, layer)
continue
# ==== Potentially nested models case ====
# ---- Resolve the source of the edge ----
while isinstance(
source,
(functional.Functional, sequential.Sequential),
):
# When `source` is a `Functional` or `Sequential` model, we
# need to connect to the correct box within that model.
# Functional and sequential models do not have explicit
# "output" boxes, so we need to find the correct layer that
# produces the output we're connecting to, which can be
# nested several levels deep in sub-models. Hence the while
# loop to continue going into nested models until we
# encounter a real layer that's not a `Functional` or
# `Sequential`.
source, _, output_index = source.outputs[
output_index
]._keras_history
# ---- Resolve the destination of the edge ----
while isinstance(
destination,
(functional.Functional, sequential.Sequential),
):
if isinstance(destination, functional.Functional):
# When `destination` is a `Functional`, we point to the
# specific `InputLayer` in the model.
destination = destination.inputs[
input_index
]._keras_history.operation
else:
# When `destination` is a `Sequential`, there is no
# explicit "input" box, so we want to point to the first
# box in the model, but it may itself be another model.
# Hence the while loop to continue going into nested
# models until we encounter a real layer that's not a
# `Functional` or `Sequential`.
destination = destination.layers[0]
add_edge(dot, source, destination)
return dot
@keras_export("keras.utils.plot_model")
def plot_model(
model,
to_file="model.png",
show_shapes=False,
show_dtype=False,
show_layer_names=False,
rankdir="TB",
expand_nested=False,
dpi=200,
show_layer_activations=False,
show_trainable=False,
**kwargs,
):
"""Converts a Keras model to dot format and save to a file.
Example:
```python
inputs = ...
outputs = ...
model = keras.Model(inputs=inputs, outputs=outputs)
dot_img_file = '/tmp/model_1.png'
keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True)
```
Args:
model: A Keras model instance
to_file: File name of the plot image.
show_shapes: whether to display shape information.
show_dtype: whether to display layer dtypes.
show_layer_names: whether to display layer names.
rankdir: `rankdir` argument passed to PyDot,
a string specifying the format of the plot: `"TB"`
creates a vertical plot; `"LR"` creates a horizontal plot.
expand_nested: whether to expand nested Functional models
into clusters.
dpi: Image resolution in dots per inch.
show_layer_activations: Display layer activations (only for layers that
have an `activation` property).
show_trainable: whether to display if a layer is trainable.
Returns:
A Jupyter notebook Image object if Jupyter is installed.
This enables in-line display of the model plots in notebooks.
"""
if not model.built:
raise ValueError(
"This model has not yet been built. "
"Build the model first by calling `build()` or by calling "
"the model on a batch of data."
)
if not check_pydot():
message = (
"You must install pydot (`pip install pydot`) "
"for `plot_model` to work."
)
if "IPython.core.magics.namespace" in sys.modules:
# We don't raise an exception here in order to avoid crashing
# notebook tests where graphviz is not available.
io_utils.print_msg(message)
return
else:
raise ImportError(message)
if not check_graphviz():
message = (
"You must install graphviz "
"(see instructions at https://graphviz.gitlab.io/download/) "
"for `plot_model` to work."
)
if "IPython.core.magics.namespace" in sys.modules:
# We don't raise an exception here in order to avoid crashing
# notebook tests where graphviz is not available.
io_utils.print_msg(message)
return
else:
raise ImportError(message)
if kwargs.pop("layer_range", None) is not None:
raise ValueError("Argument `layer_range` is no longer supported.")
if kwargs:
raise ValueError(f"Unrecognized keyword arguments: {kwargs}")
dot = model_to_dot(
model,
show_shapes=show_shapes,
show_dtype=show_dtype,
show_layer_names=show_layer_names,
rankdir=rankdir,
expand_nested=expand_nested,
dpi=dpi,
show_layer_activations=show_layer_activations,
show_trainable=show_trainable,
)
to_file = str(to_file)
if dot is None:
return
_, extension = os.path.splitext(to_file)
if not extension:
extension = "png"
else:
extension = extension[1:]
# Save image to disk.
dot.write(to_file, format=extension)
# Return the image as a Jupyter Image object, to be displayed in-line.
# Note that we cannot easily detect whether the code is running in a
# notebook, and thus we always return the Image if Jupyter is available.
if extension != "pdf":
try:
from IPython import display
return display.Image(filename=to_file)
except ImportError:
pass