| | |
| | |
| |
|
| |
|
| |
|
| |
|
| |
|
| | from caffe2.python import schema |
| | from caffe2.python.layers.layers import ( |
| | ModelLayer, |
| | ) |
| | from future.utils import viewitems |
| | import numpy as np |
| | from collections import defaultdict |
| |
|
| | import logging |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def get_concatenated_feature_to_index(blobs_to_concat): |
| | concat_feature_to_index = defaultdict(list) |
| | start_pos = 0 |
| | for scalar in blobs_to_concat: |
| | num_dims = scalar.dtype.shape[0] |
| | if hasattr(scalar, 'metadata') \ |
| | and hasattr(scalar.metadata, 'feature_specs') \ |
| | and hasattr(scalar.metadata.feature_specs, 'feature_to_index') \ |
| | and isinstance(scalar.metadata.feature_specs.feature_to_index, dict): |
| | for k, v in scalar.metadata.feature_specs.feature_to_index.items(): |
| | concat_feature_to_index[k].extend([start_pos + vi for vi in v]) |
| | start_pos += num_dims |
| | return dict(concat_feature_to_index) if concat_feature_to_index.keys() else None |
| |
|
| |
|
| | class Concat(ModelLayer): |
| | """ |
| | Construct Concat layer |
| | Assume that first dimension is batch, |
| | |
| | Example: |
| | |
| | embedding_dim = 64 |
| | input_record = self.new_record(schema.Struct( |
| | ('input1', schema.Scalar((np.float32, (embedding_dim, )))), |
| | ('input2', schema.Scalar((np.float32, (embedding_dim, )))), |
| | ('input3', schema.Scalar((np.float32, (embedding_dim, )))), |
| | )) |
| | |
| | output = self.model.Concat(input_record) |
| | self.assertEqual( |
| | schema.Scalar((np.float32, ((len(input_record.fields) * embedding_dim, )))), |
| | output |
| | ) |
| | |
| | # Note that in Concat layer we assume first dimension is batch. |
| | # so input is B * embedding_dim |
| | # add_axis=1 make it B * 1 * embedding_dim |
| | # Concat on axis=1 make it B * N * embedding_dim |
| | |
| | output = self.model.Concat(input_record, axis=1, add_axis=1) |
| | self.assertEqual( |
| | schema.Scalar((np.float32, ((len(input_record.fields), embedding_dim)))), |
| | output |
| | ) |
| | """ |
| |
|
| | def __init__(self, model, input_record, axis=1, add_axis=0, |
| | name='concat', **kwargs): |
| | super(Concat, self).__init__(model, name, input_record, **kwargs) |
| | self.axis = axis |
| | self.add_axis = add_axis |
| | assert not (axis == 0 and add_axis == 1), \ |
| | "It's not allowed to add axis=0" |
| | assert isinstance(input_record, schema.Struct),\ |
| | "Incorrect input type. Expected Struct, but received: {0}".\ |
| | format(input_record) |
| |
|
| | shapes = [] |
| | for field_name, field_type in viewitems(input_record.fields): |
| | assert isinstance(field_type, schema.Scalar),\ |
| | "Incorrect input type for {}. Expected Scalar, but got: {}".\ |
| | format(field_name, field_type) |
| | |
| | |
| | shape = list(field_type.field_type().shape) |
| | if add_axis: |
| | shape.insert(axis - 1, 1) |
| | assert len(shape) >= axis,\ |
| | "Concat expects that limited dimensions of the input tensor" |
| | shapes.append(shape) |
| | logger.info('Concat Layer input shapes: ' + str(shapes)) |
| |
|
| | if axis == 0: |
| | self.output_schema = schema.from_blob_list( |
| | input_record[0], |
| | [self.get_next_blob_reference('output')] |
| | ) |
| | return |
| |
|
| | concat_dim = 0 |
| | for shape in shapes: |
| | concat_dim += shape[axis - 1] |
| | shape[axis - 1] = 0 |
| | assert shape == shapes[0],\ |
| | "Shapes {0} and {1} are not compatible for Concat".\ |
| | format(shape, shapes[0]) |
| | output_dims = shapes[0] |
| | output_dims[axis - 1] = concat_dim |
| |
|
| | logger.info('Concat Layer output_dims: ' + str(output_dims)) |
| | self.output_schema = schema.Scalar( |
| | (np.float32, output_dims), |
| | self.get_next_blob_reference('output')) |
| |
|
| | record_to_concat = input_record.fields.values() |
| | concated_feature_to_index = get_concatenated_feature_to_index( |
| | record_to_concat |
| | ) |
| | if concated_feature_to_index: |
| | metadata = schema.Metadata( |
| | feature_specs=schema.FeatureSpec( |
| | feature_to_index=concated_feature_to_index |
| | ) |
| | ) |
| | self.output_schema.set_metadata(metadata) |
| |
|
| |
|
| | def add_ops(self, net): |
| | net.Concat( |
| | self.input_record.field_blobs(), |
| | [ |
| | self.output_schema.field_blobs()[0], |
| | self.output_schema.field_blobs()[0] + "_concat_dims" |
| | ], |
| | axis=self.axis, |
| | add_axis=self.add_axis, |
| | ) |
| |
|