| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | #ifndef FLANN_KDTREE_INDEX_H_ |
| | #define FLANN_KDTREE_INDEX_H_ |
| |
|
| | #include <algorithm> |
| | #include <map> |
| | #include <cassert> |
| | #include <cstring> |
| | #include <stdarg.h> |
| | #include <cmath> |
| | #include <random> |
| |
|
| | #include "FLANN/general.h" |
| | #include "FLANN/algorithms/nn_index.h" |
| | #include "FLANN/util/dynamic_bitset.h" |
| | #include "FLANN/util/matrix.h" |
| | #include "FLANN/util/result_set.h" |
| | #include "FLANN/util/heap.h" |
| | #include "FLANN/util/allocator.h" |
| | #include "FLANN/util/random.h" |
| | #include "FLANN/util/saving.h" |
| |
|
| |
|
| | namespace flann |
| | { |
| |
|
| | struct KDTreeIndexParams : public IndexParams |
| | { |
| | KDTreeIndexParams(int trees = 4) |
| | { |
| | (*this)["algorithm"] = FLANN_INDEX_KDTREE; |
| | (*this)["trees"] = trees; |
| | } |
| | }; |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | template <typename Distance> |
| | class KDTreeIndex : public NNIndex<Distance> |
| | { |
| | public: |
| | typedef typename Distance::ElementType ElementType; |
| | typedef typename Distance::ResultType DistanceType; |
| |
|
| | typedef NNIndex<Distance> BaseClass; |
| |
|
| | typedef bool needs_kdtree_distance; |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | KDTreeIndex(const IndexParams& params = KDTreeIndexParams(), Distance d = Distance() ) : |
| | BaseClass(params, d), mean_(NULL), var_(NULL) |
| | { |
| | trees_ = get_param(index_params_,"trees",4); |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | KDTreeIndex(const Matrix<ElementType>& dataset, const IndexParams& params = KDTreeIndexParams(), |
| | Distance d = Distance() ) : BaseClass(params,d ), mean_(NULL), var_(NULL) |
| | { |
| | trees_ = get_param(index_params_,"trees",4); |
| |
|
| | setDataset(dataset); |
| | } |
| |
|
| | KDTreeIndex(const KDTreeIndex& other) : BaseClass(other), |
| | trees_(other.trees_) |
| | { |
| | tree_roots_.resize(other.tree_roots_.size()); |
| | for (size_t i=0;i<tree_roots_.size();++i) { |
| | copyTree(tree_roots_[i], other.tree_roots_[i]); |
| | } |
| | } |
| |
|
| | KDTreeIndex& operator=(KDTreeIndex other) |
| | { |
| | this->swap(other); |
| | return *this; |
| | } |
| |
|
| | |
| | |
| | |
| | virtual ~KDTreeIndex() |
| | { |
| | freeIndex(); |
| | } |
| |
|
| | BaseClass* clone() const |
| | { |
| | return new KDTreeIndex(*this); |
| | } |
| |
|
| | using BaseClass::buildIndex; |
| |
|
| | void addPoints(const Matrix<ElementType>& points, float rebuild_threshold = 2) |
| | { |
| | assert(points.cols==veclen_); |
| |
|
| | size_t old_size = size_; |
| | extendDataset(points); |
| |
|
| | if (rebuild_threshold>1 && size_at_build_*rebuild_threshold<size_) { |
| | buildIndex(); |
| | } |
| | else { |
| | for (size_t i=old_size;i<size_;++i) { |
| | for (int j = 0; j < trees_; j++) { |
| | addPointToTree(tree_roots_[j], i); |
| | } |
| | } |
| | } |
| | } |
| |
|
| | flann_algorithm_t getType() const |
| | { |
| | return FLANN_INDEX_KDTREE; |
| | } |
| |
|
| |
|
| | template<typename Archive> |
| | void serialize(Archive& ar) |
| | { |
| | ar.setObject(this); |
| |
|
| | ar & *static_cast<NNIndex<Distance>*>(this); |
| |
|
| | ar & trees_; |
| |
|
| | if (Archive::is_loading::value) { |
| | tree_roots_.resize(trees_); |
| | } |
| | for (size_t i=0;i<tree_roots_.size();++i) { |
| | if (Archive::is_loading::value) { |
| | tree_roots_[i] = new(pool_) Node(); |
| | } |
| | ar & *tree_roots_[i]; |
| | } |
| |
|
| | if (Archive::is_loading::value) { |
| | index_params_["algorithm"] = getType(); |
| | index_params_["trees"] = trees_; |
| | } |
| | } |
| |
|
| |
|
| | void saveIndex(FILE* stream) |
| | { |
| | serialization::SaveArchive sa(stream); |
| | sa & *this; |
| | } |
| |
|
| |
|
| | void loadIndex(FILE* stream) |
| | { |
| | freeIndex(); |
| | serialization::LoadArchive la(stream); |
| | la & *this; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | int usedMemory() const |
| | { |
| | return int(pool_.usedMemory+pool_.wastedMemory+size_*sizeof(int)); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) const |
| | { |
| | int maxChecks = searchParams.checks; |
| | float epsError = 1+searchParams.eps; |
| |
|
| | if (maxChecks==FLANN_CHECKS_UNLIMITED) { |
| | if (removed_) { |
| | getExactNeighbors<true>(result, vec, epsError); |
| | } |
| | else { |
| | getExactNeighbors<false>(result, vec, epsError); |
| | } |
| | } |
| | else { |
| | if (removed_) { |
| | getNeighbors<true>(result, vec, maxChecks, epsError); |
| | } |
| | else { |
| | getNeighbors<false>(result, vec, maxChecks, epsError); |
| | } |
| | } |
| | } |
| |
|
| | protected: |
| |
|
| | |
| | |
| | |
| | void buildIndexImpl() |
| | { |
| | |
| | std::vector<int> ind(size_); |
| | for (size_t i = 0; i < size_; ++i) { |
| | ind[i] = int(i); |
| | } |
| |
|
| | mean_ = new DistanceType[veclen_]; |
| | var_ = new DistanceType[veclen_]; |
| |
|
| | std::default_random_engine generator; |
| |
|
| | tree_roots_.resize(trees_); |
| | |
| | for (int i = 0; i < trees_; i++) { |
| | |
| | std::shuffle(ind.begin(), ind.end(), generator); |
| | tree_roots_[i] = divideTree(&ind[0], int(size_) ); |
| | } |
| | delete[] mean_; |
| | delete[] var_; |
| | } |
| |
|
| | void freeIndex() |
| | { |
| | for (size_t i=0;i<tree_roots_.size();++i) { |
| | |
| | if (tree_roots_[i]!=NULL) tree_roots_[i]->~Node(); |
| | } |
| | pool_.free(); |
| | } |
| |
|
| |
|
| | private: |
| |
|
| | |
| | struct Node |
| | { |
| | |
| | |
| | |
| | int divfeat; |
| | |
| | |
| | |
| | DistanceType divval; |
| | |
| | |
| | |
| | ElementType* point; |
| | |
| | |
| | |
| | Node* child1, *child2; |
| | Node(){ |
| | child1 = NULL; |
| | child2 = NULL; |
| | } |
| | ~Node() { |
| | if (child1 != NULL) { child1->~Node(); child1 = NULL; } |
| |
|
| | if (child2 != NULL) { child2->~Node(); child2 = NULL; } |
| | } |
| |
|
| | private: |
| | template<typename Archive> |
| | void serialize(Archive& ar) |
| | { |
| | typedef KDTreeIndex<Distance> Index; |
| | Index* obj = static_cast<Index*>(ar.getObject()); |
| |
|
| | ar & divfeat; |
| | ar & divval; |
| |
|
| | bool leaf_node = false; |
| | if (Archive::is_saving::value) { |
| | leaf_node = ((child1==NULL) && (child2==NULL)); |
| | } |
| | ar & leaf_node; |
| |
|
| | if (leaf_node) { |
| | if (Archive::is_loading::value) { |
| | point = obj->points_[divfeat]; |
| | } |
| | } |
| |
|
| | if (!leaf_node) { |
| | if (Archive::is_loading::value) { |
| | child1 = new(obj->pool_) Node(); |
| | child2 = new(obj->pool_) Node(); |
| | } |
| | ar & *child1; |
| | ar & *child2; |
| | } |
| | } |
| | friend struct serialization::access; |
| | }; |
| | typedef Node* NodePtr; |
| | typedef BranchStruct<NodePtr, DistanceType> BranchSt; |
| | typedef BranchSt* Branch; |
| |
|
| |
|
| | void copyTree(NodePtr& dst, const NodePtr& src) |
| | { |
| | dst = new(pool_) Node(); |
| | dst->divfeat = src->divfeat; |
| | dst->divval = src->divval; |
| | if (src->child1==NULL && src->child2==NULL) { |
| | dst->point = points_[dst->divfeat]; |
| | dst->child1 = NULL; |
| | dst->child2 = NULL; |
| | } |
| | else { |
| | copyTree(dst->child1, src->child1); |
| | copyTree(dst->child2, src->child2); |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | NodePtr divideTree(int* ind, int count) |
| | { |
| | NodePtr node = new(pool_) Node(); |
| |
|
| | |
| | if (count == 1) { |
| | node->child1 = node->child2 = NULL; |
| | node->divfeat = *ind; |
| | node->point = points_[*ind]; |
| | } |
| | else { |
| | int idx; |
| | int cutfeat; |
| | DistanceType cutval; |
| | meanSplit(ind, count, idx, cutfeat, cutval); |
| |
|
| | node->divfeat = cutfeat; |
| | node->divval = cutval; |
| | node->child1 = divideTree(ind, idx); |
| | node->child2 = divideTree(ind+idx, count-idx); |
| | } |
| |
|
| | return node; |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | void meanSplit(int* ind, int count, int& index, int& cutfeat, DistanceType& cutval) |
| | { |
| | memset(mean_,0,veclen_*sizeof(DistanceType)); |
| | memset(var_,0,veclen_*sizeof(DistanceType)); |
| |
|
| | |
| | |
| | |
| | int cnt = std::min((int)SAMPLE_MEAN+1, count); |
| | for (int j = 0; j < cnt; ++j) { |
| | ElementType* v = points_[ind[j]]; |
| | for (size_t k=0; k<veclen_; ++k) { |
| | mean_[k] += v[k]; |
| | } |
| | } |
| | DistanceType div_factor = DistanceType(1)/cnt; |
| | for (size_t k=0; k<veclen_; ++k) { |
| | mean_[k] *= div_factor; |
| | } |
| |
|
| | |
| | for (int j = 0; j < cnt; ++j) { |
| | ElementType* v = points_[ind[j]]; |
| | for (size_t k=0; k<veclen_; ++k) { |
| | DistanceType dist = v[k] - mean_[k]; |
| | var_[k] += dist * dist; |
| | } |
| | } |
| | |
| | cutfeat = selectDivision(var_); |
| | cutval = mean_[cutfeat]; |
| |
|
| | int lim1, lim2; |
| | planeSplit(ind, count, cutfeat, cutval, lim1, lim2); |
| |
|
| | if (lim1>count/2) index = lim1; |
| | else if (lim2<count/2) index = lim2; |
| | else index = count/2; |
| |
|
| | |
| | |
| | |
| | if ((lim1==count)||(lim2==0)) index = count/2; |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | int selectDivision(DistanceType* v) |
| | { |
| | int num = 0; |
| | size_t topind[RAND_DIM]; |
| |
|
| | |
| | for (size_t i = 0; i < veclen_; ++i) { |
| | if ((num < RAND_DIM)||(v[i] > v[topind[num-1]])) { |
| | |
| | if (num < RAND_DIM) { |
| | topind[num++] = i; |
| | } |
| | else { |
| | topind[num-1] = i; |
| | } |
| | |
| | int j = num - 1; |
| | while (j > 0 && v[topind[j]] > v[topind[j-1]]) { |
| | std::swap(topind[j], topind[j-1]); |
| | --j; |
| | } |
| | } |
| | } |
| | |
| | int rnd = rand_int(num); |
| | return (int)topind[rnd]; |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | void planeSplit(int* ind, int count, int cutfeat, DistanceType cutval, int& lim1, int& lim2) |
| | { |
| | |
| | int left = 0; |
| | int right = count-1; |
| | for (;; ) { |
| | while (left<=right && points_[ind[left]][cutfeat]<cutval) ++left; |
| | while (left<=right && points_[ind[right]][cutfeat]>=cutval) --right; |
| | if (left>right) break; |
| | std::swap(ind[left], ind[right]); ++left; --right; |
| | } |
| | lim1 = left; |
| | right = count-1; |
| | for (;; ) { |
| | while (left<=right && points_[ind[left]][cutfeat]<=cutval) ++left; |
| | while (left<=right && points_[ind[right]][cutfeat]>cutval) --right; |
| | if (left>right) break; |
| | std::swap(ind[left], ind[right]); ++left; --right; |
| | } |
| | lim2 = left; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | template<bool with_removed> |
| | void getExactNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, float epsError) const |
| | { |
| | |
| |
|
| | if (trees_ > 1) { |
| | fprintf(stderr,"It doesn't make any sense to use more than one tree for exact search"); |
| | } |
| | if (trees_>0) { |
| | searchLevelExact<with_removed>(result, vec, tree_roots_[0], 0.0, epsError); |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | template<bool with_removed> |
| | void getNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, int maxCheck, float epsError) const |
| | { |
| | int i; |
| | BranchSt branch; |
| |
|
| | int checkCount = 0; |
| | Heap<BranchSt>* heap = new Heap<BranchSt>((int)size_); |
| | DynamicBitset checked(size_); |
| |
|
| | |
| | for (i = 0; i < trees_; ++i) { |
| | searchLevel<with_removed>(result, vec, tree_roots_[i], 0, checkCount, maxCheck, epsError, heap, checked); |
| | } |
| |
|
| | |
| | while ( heap->popMin(branch) && (checkCount < maxCheck || !result.full() )) { |
| | searchLevel<with_removed>(result, vec, branch.node, branch.mindist, checkCount, maxCheck, epsError, heap, checked); |
| | } |
| |
|
| | delete heap; |
| |
|
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | template<bool with_removed> |
| | void searchLevel(ResultSet<DistanceType>& result_set, const ElementType* vec, NodePtr node, DistanceType mindist, int& checkCount, int maxCheck, |
| | float epsError, Heap<BranchSt>* heap, DynamicBitset& checked) const |
| | { |
| | if (result_set.worstDist()<mindist) { |
| | |
| | return; |
| | } |
| |
|
| | |
| | if ((node->child1 == NULL)&&(node->child2 == NULL)) { |
| | int index = node->divfeat; |
| | if (with_removed) { |
| | if (removed_points_.test(index)) return; |
| | } |
| | |
| | if ( checked.test(index) || ((checkCount>=maxCheck)&& result_set.full()) ) return; |
| | checked.set(index); |
| | checkCount++; |
| |
|
| | DistanceType dist = distance_(node->point, vec, veclen_); |
| | result_set.addPoint(dist,index); |
| | return; |
| | } |
| |
|
| | |
| | ElementType val = vec[node->divfeat]; |
| | DistanceType diff = val - node->divval; |
| | NodePtr bestChild = (diff < 0) ? node->child1 : node->child2; |
| | NodePtr otherChild = (diff < 0) ? node->child2 : node->child1; |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | DistanceType new_distsq = mindist + distance_.accum_dist(val, node->divval, node->divfeat); |
| | |
| | if ((new_distsq*epsError < result_set.worstDist())|| !result_set.full()) { |
| | heap->insert( BranchSt(otherChild, new_distsq) ); |
| | } |
| |
|
| | |
| | searchLevel<with_removed>(result_set, vec, bestChild, mindist, checkCount, maxCheck, epsError, heap, checked); |
| | } |
| |
|
| | |
| | |
| | |
| | template<bool with_removed> |
| | void searchLevelExact(ResultSet<DistanceType>& result_set, const ElementType* vec, const NodePtr node, DistanceType mindist, const float epsError) const |
| | { |
| | |
| | if ((node->child1 == NULL)&&(node->child2 == NULL)) { |
| | int index = node->divfeat; |
| | if (with_removed) { |
| | if (removed_points_.test(index)) return; |
| | } |
| | DistanceType dist = distance_(node->point, vec, veclen_); |
| | result_set.addPoint(dist,index); |
| |
|
| | return; |
| | } |
| |
|
| | |
| | ElementType val = vec[node->divfeat]; |
| | DistanceType diff = val - node->divval; |
| | NodePtr bestChild = (diff < 0) ? node->child1 : node->child2; |
| | NodePtr otherChild = (diff < 0) ? node->child2 : node->child1; |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | DistanceType new_distsq = mindist + distance_.accum_dist(val, node->divval, node->divfeat); |
| |
|
| | |
| | searchLevelExact<with_removed>(result_set, vec, bestChild, mindist, epsError); |
| |
|
| | if (mindist*epsError<=result_set.worstDist()) { |
| | searchLevelExact<with_removed>(result_set, vec, otherChild, new_distsq, epsError); |
| | } |
| | } |
| |
|
| | void addPointToTree(NodePtr node, int ind) |
| | { |
| | ElementType* point = points_[ind]; |
| |
|
| | if ((node->child1==NULL) && (node->child2==NULL)) { |
| | ElementType* leaf_point = node->point; |
| | ElementType max_span = 0; |
| | size_t div_feat = 0; |
| | for (size_t i=0;i<veclen_;++i) { |
| | ElementType span = std::abs(point[i]-leaf_point[i]); |
| | if (span > max_span) { |
| | max_span = span; |
| | div_feat = i; |
| | } |
| | } |
| | NodePtr left = new(pool_) Node(); |
| | left->child1 = left->child2 = NULL; |
| | NodePtr right = new(pool_) Node(); |
| | right->child1 = right->child2 = NULL; |
| |
|
| | if (point[div_feat]<leaf_point[div_feat]) { |
| | left->divfeat = ind; |
| | left->point = point; |
| | right->divfeat = node->divfeat; |
| | right->point = node->point; |
| | } |
| | else { |
| | left->divfeat = node->divfeat; |
| | left->point = node->point; |
| | right->divfeat = ind; |
| | right->point = point; |
| | } |
| | node->divfeat = div_feat; |
| | node->divval = (point[div_feat]+leaf_point[div_feat])/2; |
| | node->child1 = left; |
| | node->child2 = right; |
| | } |
| | else { |
| | if (point[node->divfeat]<node->divval) { |
| | addPointToTree(node->child1,ind); |
| | } |
| | else { |
| | addPointToTree(node->child2,ind); |
| | } |
| | } |
| | } |
| | private: |
| | void swap(KDTreeIndex& other) |
| | { |
| | BaseClass::swap(other); |
| | std::swap(trees_, other.trees_); |
| | std::swap(tree_roots_, other.tree_roots_); |
| | std::swap(pool_, other.pool_); |
| | } |
| |
|
| | private: |
| |
|
| | enum |
| | { |
| | |
| | |
| | |
| | |
| | |
| | SAMPLE_MEAN = 100, |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | RAND_DIM=5 |
| | }; |
| |
|
| |
|
| | |
| | |
| | |
| | int trees_; |
| |
|
| | DistanceType* mean_; |
| | DistanceType* var_; |
| |
|
| | |
| | |
| | |
| | std::vector<NodePtr> tree_roots_; |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | PooledAllocator pool_; |
| |
|
| | USING_BASECLASS_SYMBOLS |
| | }; |
| |
|
| | } |
| |
|
| | #endif |
| |
|