File size: 1,829 Bytes
1505bbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import faiss
import numpy as np

from transformers import AutoTokenizer, AutoModel

MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)

class ColBERTRetriever:

    def __init__(self):

        self.chunks = []
        self.doc_embeddings = []
        self.index = None

    # -----------------------------
    # EMBED TEXT TOKENS
    # -----------------------------

    def embed(self, text):

        inputs = tokenizer(
            text,
            return_tensors="pt",
            truncation=True,
            padding=True,
            max_length=256
        )

        with torch.no_grad():
            outputs = model(**inputs)

        embeddings = outputs.last_hidden_state.squeeze(0)

        return embeddings.numpy()

    # -----------------------------
    # BUILD INDEX
    # -----------------------------

    def build_index(self, chunks):

        self.chunks = chunks
        vectors = []

        for c in chunks:

            emb = self.embed(c["text"])
            vectors.append(emb.mean(axis=0))

        vectors = np.array(vectors).astype("float32")

        dim = vectors.shape[1]

        self.index = faiss.IndexFlatIP(dim)

        self.index.add(vectors)

    # -----------------------------
    # QUERY
    # -----------------------------

    def query(self, question, k=20):

        q_emb = self.embed(question)  # token embeddings
        scores = []

        for chunk in self.chunks:

            d_emb = self.embed(chunk["text"])

            sim = np.matmul(q_emb, d_emb.T)   # token similarity
            score = sim.max(axis=1).sum()     # MaxSim

            scores.append(score)

        idx = np.argsort(scores)[::-1][:k]

        return [self.chunks[i] for i in idx]