scholar-rag-engine / retrieval_colbert.py
snakeeee's picture
Initial commit - Scholar RAG Engine
1505bbf
import torch
import faiss
import numpy as np
from transformers import AutoTokenizer, AutoModel
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
class ColBERTRetriever:
def __init__(self):
self.chunks = []
self.doc_embeddings = []
self.index = None
# -----------------------------
# EMBED TEXT TOKENS
# -----------------------------
def embed(self, text):
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=256
)
with torch.no_grad():
outputs = model(**inputs)
embeddings = outputs.last_hidden_state.squeeze(0)
return embeddings.numpy()
# -----------------------------
# BUILD INDEX
# -----------------------------
def build_index(self, chunks):
self.chunks = chunks
vectors = []
for c in chunks:
emb = self.embed(c["text"])
vectors.append(emb.mean(axis=0))
vectors = np.array(vectors).astype("float32")
dim = vectors.shape[1]
self.index = faiss.IndexFlatIP(dim)
self.index.add(vectors)
# -----------------------------
# QUERY
# -----------------------------
def query(self, question, k=20):
q_emb = self.embed(question) # token embeddings
scores = []
for chunk in self.chunks:
d_emb = self.embed(chunk["text"])
sim = np.matmul(q_emb, d_emb.T) # token similarity
score = sim.max(axis=1).sum() # MaxSim
scores.append(score)
idx = np.argsort(scores)[::-1][:k]
return [self.chunks[i] for i in idx]