File size: 2,818 Bytes
e87a50a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#include "IndexFlat.h"
#include <queue>
#include <vector>
#include<immintrin.h>

void IndexFlatL2::add(int n, const float *x){
    xb.insert(xb.end(), x, x+(n*d));
    ntotal+=n;
}

void IndexFlatL2::search(int n, const float *x, int k, float *distances, int *labels){
    for(int i = 0; i<n; i++){//iterate over the entire query
        //old stuff
            //float min_distance = 1e9;
            //int bestid = -1;
        std::priority_queue<std::pair<float, int>> pq;

        for(int j= 0; j<ntotal; j++){//compare query against every vec in db
            float curr_distance = 0;
            int m = 0;
            
            __m256 sumvec = _mm256_setzero_ps();
            
            const float* current_db_vec = &xb[j * d];
            const float* current_q_vec = &x[i * d];

            for(; m + 7 < d; m += 8){
                __m256 dbvec = _mm256_loadu_ps(&current_db_vec[m]);
                __m256 qvec = _mm256_loadu_ps(&current_q_vec[m]);
                
                __m256 diff = _mm256_sub_ps(dbvec, qvec);
                
                sumvec = _mm256_fmadd_ps(diff, diff, sumvec);
            }

            
            __m128 upper = _mm256_extractf128_ps(sumvec,1);
            __m128 lower = _mm256_castps256_ps128(sumvec); 
            
            __m128 sumbound = _mm_add_ps(upper, lower);
            __m128 shifted = _mm_movehl_ps(sumbound,sumbound);
            __m128 current = _mm_add_ps(sumbound, shifted);
            __m128 shuffled = _mm_shuffle_ps(current, current, 1);
            __m128 finalsum = _mm_add_ps(current, shuffled);
            curr_distance = _mm_cvtss_f32(finalsum);
            
            for(; m < d; m++){
                float dist = current_db_vec[m] - current_q_vec[m];
                curr_distance += (dist * dist);
            }
            /*
            if (curr_distance<min_distance){
                min_distance=curr_distance;
                bestid = j;
            } */
             
            

            if(pq.size()<k){
                pq.push({curr_distance,j});
            }else{
                if(curr_distance<pq.top().first){
                    pq.pop();
                    pq.push({curr_distance,j});
                }
            }
        }
        /*
        distances[i] = min_distance;
        labels[i] = bestid; */


        //standard for loop cannot handle garbage values. 
        //for that we need 2 seperate for loop, one that handles the queue content properly
        int count=pq.size();
        for(int c = count-1; c>=0; c--){
            distances[i*k+c] = pq.top().first;
            labels[i*k+c] = pq.top().second;
            pq.pop();
        }
        for(int step=count; step<k; step++){
            distances[i*k+step] = -1.0;
            labels[i*k+step] = -1;
        }
    }
}