| | |
| | |
| |
|
| |
|
| |
|
| |
|
| |
|
| | from caffe2.python import schema |
| | from caffe2.python.layers.layers import ( |
| | ModelLayer, |
| | ) |
| |
|
| |
|
| | class PairwiseSimilarity(ModelLayer): |
| |
|
| | def __init__(self, model, input_record, output_dim, pairwise_similarity_func='dot', |
| | name='pairwise_similarity', **kwargs): |
| | super(PairwiseSimilarity, self).__init__(model, name, input_record, **kwargs) |
| | assert isinstance(input_record, schema.Struct), ( |
| | "Incorrect input type. Expected Struct, but received: {0}". |
| | format(input_record)) |
| | assert ( |
| | ('all_embeddings' in input_record) ^ |
| | ('x_embeddings' in input_record and 'y_embeddings' in input_record) |
| | ), ( |
| | "either (all_embeddings) xor (x_embeddings and y_embeddings) " + |
| | "should be given." |
| | ) |
| | self.pairwise_similarity_func = pairwise_similarity_func |
| | if 'all_embeddings' in input_record: |
| | x_embeddings = input_record['all_embeddings'] |
| | y_embeddings = input_record['all_embeddings'] |
| | else: |
| | x_embeddings = input_record['x_embeddings'] |
| | y_embeddings = input_record['y_embeddings'] |
| |
|
| | assert isinstance(x_embeddings, schema.Scalar), ( |
| | "Incorrect input type for x. Expected Scalar, " + |
| | "but received: {0}".format(x_embeddings)) |
| | assert isinstance(y_embeddings, schema.Scalar), ( |
| | "Incorrect input type for y. Expected Scalar, " + |
| | "but received: {0}".format(y_embeddings) |
| | ) |
| |
|
| | if 'indices_to_gather' in input_record: |
| | indices_to_gather = input_record['indices_to_gather'] |
| | assert isinstance(indices_to_gather, schema.Scalar), ( |
| | "Incorrect type of indices_to_gather. " |
| | "Expected Scalar, but received: {0}".format(indices_to_gather) |
| | ) |
| | self.indices_to_gather = indices_to_gather |
| | else: |
| | self.indices_to_gather = None |
| |
|
| | self.x_embeddings = x_embeddings |
| | self.y_embeddings = y_embeddings |
| |
|
| | dtype = x_embeddings.field_types()[0].base |
| |
|
| | self.output_schema = schema.Scalar( |
| | (dtype, (output_dim,)), |
| | self.get_next_blob_reference('output') |
| | ) |
| |
|
| | def add_ops(self, net): |
| | if self.pairwise_similarity_func == "cosine_similarity": |
| | x_embeddings_norm = net.Normalize(self.x_embeddings(), axis=1) |
| | y_embeddings_norm = net.Normalize(self.y_embeddings(), axis=1) |
| | Y = net.BatchMatMul( |
| | [x_embeddings_norm, y_embeddings_norm], |
| | [self.get_next_blob_reference(x_embeddings_norm + '_matmul')], |
| | trans_b=1, |
| | ) |
| | elif self.pairwise_similarity_func == "dot": |
| | Y = net.BatchMatMul( |
| | [self.x_embeddings(), self.y_embeddings()], |
| | [self.get_next_blob_reference(self.x_embeddings() + '_matmul')], |
| | trans_b=1, |
| | ) |
| | else: |
| | raise NotImplementedError( |
| | "pairwise_similarity_func={} is not valid".format( |
| | self.pairwise_similarity_func |
| | ) |
| | ) |
| |
|
| | if self.indices_to_gather: |
| | flattened = net.Flatten( |
| | Y, Y + '_flatten', |
| | ) |
| | net.BatchGather( |
| | [flattened, self.indices_to_gather()], |
| | self.output_schema(), |
| | ) |
| | else: |
| | net.Flatten(Y, self.output_schema()) |
| |
|