#ifndef CAFFE_ARGMAX_LAYER_HPP_ #define CAFFE_ARGMAX_LAYER_HPP_ #include #include "caffe/blob.hpp" #include "caffe/layer.hpp" #include "caffe/proto/caffe.pb.h" namespace caffe { /** * @brief Compute the index of the @f$ K @f$ max values for each datum across * all dimensions @f$ (C \times H \times W) @f$. * * Intended for use after a classification layer to produce a prediction. * If parameter out_max_val is set to true, output is a vector of pairs * (max_ind, max_val) for each image. The axis parameter specifies an axis * along which to maximise. * * NOTE: does not implement Backwards operation. */ template class ArgMaxLayer : public Layer { public: /** * @param param provides ArgMaxParameter argmax_param, * with ArgMaxLayer options: * - top_k (\b optional uint, default 1). * the number @f$ K @f$ of maximal items to output. * - out_max_val (\b optional bool, default false). * if set, output a vector of pairs (max_ind, max_val) unless axis is set then * output max_val along the specified axis. * - axis (\b optional int). * if set, maximise along the specified axis else maximise the flattened * trailing dimensions for each index of the first / num dimension. */ explicit ArgMaxLayer(const LayerParameter& param) : Layer(param) {} virtual void LayerSetUp(const vector*>& bottom, const vector*>& top); virtual void Reshape(const vector*>& bottom, const vector*>& top); virtual inline const char* type() const { return "ArgMax"; } virtual inline int ExactNumBottomBlobs() const { return 1; } virtual inline int ExactNumTopBlobs() const { return 1; } protected: /** * @param bottom input Blob vector (length 1) * -# @f$ (N \times C \times H \times W) @f$ * the inputs @f$ x @f$ * @param top output Blob vector (length 1) * -# @f$ (N \times 1 \times K) @f$ or, if out_max_val * @f$ (N \times 2 \times K) @f$ unless axis set than e.g. * @f$ (N \times K \times H \times W) @f$ if axis == 1 * the computed outputs @f$ * y_n = \arg\max\limits_i x_{ni} * @f$ (for @f$ K = 1 @f$). */ virtual void Forward_cpu(const vector*>& bottom, const vector*>& top); /// @brief Not implemented (non-differentiable function) virtual void Backward_cpu(const vector*>& top, const vector& propagate_down, const vector*>& bottom) { NOT_IMPLEMENTED; } bool out_max_val_; size_t top_k_; bool has_axis_; int axis_; }; } // namespace caffe #endif // CAFFE_ARGMAX_LAYER_HPP_