| | |
| | |
| |
|
| |
|
| |
|
| |
|
| |
|
| | from caffe2.python import core, schema |
| | from caffe2.python.layers.layers import ModelLayer |
| |
|
| |
|
| | class ReservoirSampling(ModelLayer): |
| | """ |
| | Collect samples from input record w/ reservoir sampling. If you have complex |
| | data, use PackRecords to pack it before using this layer. |
| | |
| | This layer is not thread safe. |
| | """ |
| |
|
| | def __init__(self, model, input_record, num_to_collect, |
| | name='reservoir_sampling', **kwargs): |
| | super(ReservoirSampling, self).__init__( |
| | model, name, input_record, **kwargs) |
| | assert num_to_collect > 0 |
| | self.num_to_collect = num_to_collect |
| |
|
| | self.reservoir = self.create_param( |
| | param_name='reservoir', |
| | shape=[0], |
| | initializer=('ConstantFill',), |
| | optimizer=model.NoOptim, |
| | ) |
| | self.num_visited_blob = self.create_param( |
| | param_name='num_visited', |
| | shape=[], |
| | initializer=('ConstantFill', { |
| | 'value': 0, |
| | 'dtype': core.DataType.INT64, |
| | }), |
| | optimizer=model.NoOptim, |
| | ) |
| | self.mutex = self.create_param( |
| | param_name='mutex', |
| | shape=[], |
| | initializer=('CreateMutex',), |
| | optimizer=model.NoOptim, |
| | ) |
| |
|
| | self.extra_input_blobs = [] |
| | self.extra_output_blobs = [] |
| | if 'object_id' in input_record: |
| | object_to_pos = self.create_param( |
| | param_name='object_to_pos', |
| | shape=None, |
| | initializer=('CreateMap', { |
| | 'key_dtype': core.DataType.INT64, |
| | 'valued_dtype': core.DataType.INT32, |
| | }), |
| | optimizer=model.NoOptim, |
| | ) |
| | pos_to_object = self.create_param( |
| | param_name='pos_to_object', |
| | shape=[0], |
| | initializer=('ConstantFill', { |
| | 'value': 0, |
| | 'dtype': core.DataType.INT64, |
| | }), |
| | optimizer=model.NoOptim, |
| | ) |
| | self.extra_input_blobs.append(input_record.object_id()) |
| | self.extra_input_blobs.extend([object_to_pos, pos_to_object]) |
| | self.extra_output_blobs.extend([object_to_pos, pos_to_object]) |
| |
|
| | self.output_schema = schema.Struct( |
| | ( |
| | 'reservoir', |
| | schema.from_blob_list(input_record.data, [self.reservoir]) |
| | ), |
| | ('num_visited', schema.Scalar(blob=self.num_visited_blob)), |
| | ('mutex', schema.Scalar(blob=self.mutex)), |
| | ) |
| |
|
| | def add_ops(self, net): |
| | net.ReservoirSampling( |
| | [self.reservoir, self.num_visited_blob, self.input_record.data(), |
| | self.mutex] + self.extra_input_blobs, |
| | [self.reservoir, self.num_visited_blob] + self.extra_output_blobs, |
| | num_to_collect=self.num_to_collect, |
| | ) |
| |
|