| from typing import Optional | |
| from tensorboard.compat.proto.node_def_pb2 import NodeDef | |
| from tensorboard.compat.proto.attr_value_pb2 import AttrValue | |
| from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto | |
| def attr_value_proto(dtype, shape, s): | |
| """Creates a dict of objects matching | |
| https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/attr_value.proto | |
| specifically designed for a NodeDef. The values have been | |
| reverse engineered from standard TensorBoard logged data. | |
| """ | |
| attr = {} | |
| if s is not None: | |
| attr["attr"] = AttrValue(s=s.encode(encoding="utf_8")) | |
| if shape is not None: | |
| shapeproto = tensor_shape_proto(shape) | |
| attr["_output_shapes"] = AttrValue(list=AttrValue.ListValue(shape=[shapeproto])) | |
| return attr | |
| def tensor_shape_proto(outputsize): | |
| """Creates an object matching | |
| https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/tensor_shape.proto | |
| """ | |
| return TensorShapeProto(dim=[TensorShapeProto.Dim(size=d) for d in outputsize]) | |
| def node_proto( | |
| name, | |
| op="UnSpecified", | |
| input=None, | |
| dtype=None, | |
| shape: Optional[tuple] = None, | |
| outputsize=None, | |
| attributes="", | |
| ): | |
| """Creates an object matching | |
| https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/node_def.proto | |
| """ | |
| if input is None: | |
| input = [] | |
| if not isinstance(input, list): | |
| input = [input] | |
| return NodeDef( | |
| name=name.encode(encoding="utf_8"), | |
| op=op, | |
| input=input, | |
| attr=attr_value_proto(dtype, outputsize, attributes), | |
| ) | |