| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| |
|
| |
|
| |
|
| |
|
| | from caffe2.python import core, schema |
| | from caffe2.python.layers.layers import ModelLayer |
| | import numpy as np |
| |
|
| |
|
| | class LabelSmooth(ModelLayer): |
| | def __init__( |
| | self, model, label, smooth_matrix, name='label_smooth', **kwargs |
| | ): |
| | super(LabelSmooth, self).__init__(model, name, label, **kwargs) |
| | self.label = label |
| | |
| | smooth_matrix = np.array(smooth_matrix).astype(np.float32).flatten() |
| | self.set_dim(smooth_matrix) |
| | self.set_smooth_matrix(smooth_matrix) |
| | self.output_schema = schema.Scalar( |
| | (np.float32, (self.dim, )), |
| | self.get_next_blob_reference('smoothed_label') |
| | ) |
| |
|
| | def set_dim(self, smooth_matrix): |
| | num_elements = smooth_matrix.size |
| | self.binary_prob_label = (num_elements == 2) |
| | if self.binary_prob_label: |
| | self.dim = 1 |
| | else: |
| | assert np.sqrt(num_elements)**2 == num_elements |
| | self.dim = int(np.sqrt(num_elements)) |
| |
|
| | def set_smooth_matrix(self, smooth_matrix): |
| | if not self.binary_prob_label: |
| | self.smooth_matrix = self.model.add_global_constant( |
| | '%s_label_smooth_matrix' % self.name, |
| | array=smooth_matrix.reshape((self.dim, self.dim)), |
| | dtype=np.dtype(np.float32), |
| | ) |
| | self.len = self.model.add_global_constant( |
| | '%s_label_dim' % self.name, |
| | array=self.dim, |
| | dtype=np.dtype(np.int64), |
| | ) |
| | else: |
| | self.smooth_matrix = smooth_matrix |
| |
|
| | def add_ops_for_binary_prob_label(self, net): |
| | if self.label.field_type().base != np.float32: |
| | float32_label = net.NextScopedBlob('float32_label') |
| | net.Cast([self.label()], [float32_label], to=core.DataType.FLOAT) |
| | else: |
| | float32_label = self.label() |
| | net.StumpFunc( |
| | float32_label, |
| | self.output_schema(), |
| | threshold=0.5, |
| | low_value=self.smooth_matrix[0], |
| | high_value=self.smooth_matrix[1], |
| | ) |
| |
|
| | def add_ops_for_categorical_label(self, net): |
| | if self.label.field_type().base != np.int64: |
| | int64_label = net.NextScopedBlob('int64_label') |
| | net.Cast([self.label()], [int64_label], to=core.DataType.INT64) |
| | else: |
| | int64_label = self.label() |
| | one_hot_label = net.NextScopedBlob('one_hot_label') |
| | net.OneHot([int64_label, self.len], [one_hot_label]) |
| | net.MatMul([one_hot_label, self.smooth_matrix], self.output_schema()) |
| |
|
| | def add_ops(self, net): |
| | if self.binary_prob_label: |
| | self.add_ops_for_binary_prob_label(net) |
| | else: |
| | self.add_ops_for_categorical_label(net) |
| |
|