Spaces:
Running
Running
File size: 2,818 Bytes
e87a50a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | #include "IndexFlat.h"
#include <queue>
#include <vector>
#include<immintrin.h>
void IndexFlatL2::add(int n, const float *x){
xb.insert(xb.end(), x, x+(n*d));
ntotal+=n;
}
void IndexFlatL2::search(int n, const float *x, int k, float *distances, int *labels){
for(int i = 0; i<n; i++){//iterate over the entire query
//old stuff
//float min_distance = 1e9;
//int bestid = -1;
std::priority_queue<std::pair<float, int>> pq;
for(int j= 0; j<ntotal; j++){//compare query against every vec in db
float curr_distance = 0;
int m = 0;
__m256 sumvec = _mm256_setzero_ps();
const float* current_db_vec = &xb[j * d];
const float* current_q_vec = &x[i * d];
for(; m + 7 < d; m += 8){
__m256 dbvec = _mm256_loadu_ps(¤t_db_vec[m]);
__m256 qvec = _mm256_loadu_ps(¤t_q_vec[m]);
__m256 diff = _mm256_sub_ps(dbvec, qvec);
sumvec = _mm256_fmadd_ps(diff, diff, sumvec);
}
__m128 upper = _mm256_extractf128_ps(sumvec,1);
__m128 lower = _mm256_castps256_ps128(sumvec);
__m128 sumbound = _mm_add_ps(upper, lower);
__m128 shifted = _mm_movehl_ps(sumbound,sumbound);
__m128 current = _mm_add_ps(sumbound, shifted);
__m128 shuffled = _mm_shuffle_ps(current, current, 1);
__m128 finalsum = _mm_add_ps(current, shuffled);
curr_distance = _mm_cvtss_f32(finalsum);
for(; m < d; m++){
float dist = current_db_vec[m] - current_q_vec[m];
curr_distance += (dist * dist);
}
/*
if (curr_distance<min_distance){
min_distance=curr_distance;
bestid = j;
} */
if(pq.size()<k){
pq.push({curr_distance,j});
}else{
if(curr_distance<pq.top().first){
pq.pop();
pq.push({curr_distance,j});
}
}
}
/*
distances[i] = min_distance;
labels[i] = bestid; */
//standard for loop cannot handle garbage values.
//for that we need 2 seperate for loop, one that handles the queue content properly
int count=pq.size();
for(int c = count-1; c>=0; c--){
distances[i*k+c] = pq.top().first;
labels[i*k+c] = pq.top().second;
pq.pop();
}
for(int step=count; step<k; step++){
distances[i*k+step] = -1.0;
labels[i*k+step] = -1;
}
}
} |