#ifndef CAFFE_BATCHREINDEX_LAYER_HPP_ #define CAFFE_BATCHREINDEX_LAYER_HPP_ #include #include #include "caffe/blob.hpp" #include "caffe/layer.hpp" #include "caffe/proto/caffe.pb.h" namespace caffe { /** * @brief Index into the input blob along its first axis. * * This layer can be used to select, reorder, and even replicate examples in a * batch. The second blob is cast to int and treated as an index into the * first axis of the first blob. */ template class BatchReindexLayer : public Layer { public: explicit BatchReindexLayer(const LayerParameter& param) : Layer(param) {} virtual void Reshape(const vector*>& bottom, const vector*>& top); virtual inline const char* type() const { return "BatchReindex"; } virtual inline int ExactNumBottomBlobs() const { return 2; } virtual inline int ExactNumTopBlobs() const { return 1; } protected: /** * @param bottom input Blob vector (length 2+) * -# @f$ (N \times ...) @f$ * the inputs @f$ x_1 @f$ * -# @f$ (M) @f$ * the inputs @f$ x_2 @f$ * @param top output Blob vector (length 1) * -# @f$ (M \times ...) @f$: * the reindexed array @f$ * y = x_1[x_2] * @f$ */ virtual void Forward_cpu(const vector*>& bottom, const vector*>& top); virtual void Forward_gpu(const vector*>& bottom, const vector*>& top); /** * @brief Computes the error gradient w.r.t. the reordered input. * * @param top output Blob vector (length 1), providing the error gradient * with respect to the outputs * -# @f$ (M \times ...) @f$: * containing error gradients @f$ \frac{\partial E}{\partial y} @f$ * with respect to concatenated outputs @f$ y @f$ * @param propagate_down see Layer::Backward. * @param bottom input Blob vector (length 2): * - @f$ \frac{\partial E}{\partial y} @f$ is de-indexed (summing where * required) back to the input x_1 * - This layer cannot backprop to x_2, i.e. propagate_down[1] must be * false. */ virtual void Backward_cpu(const vector*>& top, const vector& propagate_down, const vector*>& bottom); virtual void Backward_gpu(const vector*>& top, const vector& propagate_down, const vector*>& bottom); private: struct pair_sort_first { bool operator()(const std::pair &left, const std::pair &right) { return left.first < right.first; } }; void check_batch_reindex(int initial_num, int final_num, const Dtype* ridx_data); }; } // namespace caffe #endif // CAFFE_BATCHREINDEX_LAYER_HPP_