| namespace caffe { | |
| /** | |
| * @brief Compute the index of the @f$ K @f$ max values for each datum across | |
| * all dimensions @f$ (C \times H \times W) @f$. | |
| * | |
| * Intended for use after a classification layer to produce a prediction. | |
| * If parameter out_max_val is set to true, output is a vector of pairs | |
| * (max_ind, max_val) for each image. The axis parameter specifies an axis | |
| * along which to maximise. | |
| * | |
| * NOTE: does not implement Backwards operation. | |
| */ | |
| template <typename Dtype> | |
| class ArgMaxLayer : public Layer<Dtype> { | |
| public: | |
| /** | |
| * @param param provides ArgMaxParameter argmax_param, | |
| * with ArgMaxLayer options: | |
| * - top_k (\b optional uint, default 1). | |
| * the number @f$ K @f$ of maximal items to output. | |
| * - out_max_val (\b optional bool, default false). | |
| * if set, output a vector of pairs (max_ind, max_val) unless axis is set then | |
| * output max_val along the specified axis. | |
| * - axis (\b optional int). | |
| * if set, maximise along the specified axis else maximise the flattened | |
| * trailing dimensions for each index of the first / num dimension. | |
| */ | |
| explicit ArgMaxLayer(const LayerParameter& param) | |
| : Layer<Dtype>(param) {} | |
| virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom, | |
| const vector<Blob<Dtype>*>& top); | |
| virtual void Reshape(const vector<Blob<Dtype>*>& bottom, | |
| const vector<Blob<Dtype>*>& top); | |
| virtual inline const char* type() const { return "ArgMax"; } | |
| virtual inline int ExactNumBottomBlobs() const { return 1; } | |
| virtual inline int ExactNumTopBlobs() const { return 1; } | |
| protected: | |
| /** | |
| * @param bottom input Blob vector (length 1) | |
| * -# @f$ (N \times C \times H \times W) @f$ | |
| * the inputs @f$ x @f$ | |
| * @param top output Blob vector (length 1) | |
| * -# @f$ (N \times 1 \times K) @f$ or, if out_max_val | |
| * @f$ (N \times 2 \times K) @f$ unless axis set than e.g. | |
| * @f$ (N \times K \times H \times W) @f$ if axis == 1 | |
| * the computed outputs @f$ | |
| * y_n = \arg\max\limits_i x_{ni} | |
| * @f$ (for @f$ K = 1 @f$). | |
| */ | |
| virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom, | |
| const vector<Blob<Dtype>*>& top); | |
| /// @brief Not implemented (non-differentiable function) | |
| virtual void Backward_cpu(const vector<Blob<Dtype>*>& top, | |
| const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) { | |
| NOT_IMPLEMENTED; | |
| } | |
| bool out_max_val_; | |
| size_t top_k_; | |
| bool has_axis_; | |
| int axis_; | |
| }; | |
| } // namespace caffe | |