File size: 2,648 Bytes
912dbd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from FlagEmbedding import BGEM3FlagModel, FlagReranker
import numpy as np
import json
from pathlib import Path
import faiss
import soundfile as sf

def load_jsonl(path):
    out = []
    with open(path, "r", encoding="utf-8") as f:
        for ln in f:
            ln = ln.strip()
            if ln:
                out.append(json.loads(ln))
    return out

def load_vector_db(faiss_index_path, chunks_path, meta_path, uid_list_path):
    print("Loading vector database...")
    index = faiss.read_index(faiss_index_path)
    chunks = load_jsonl(chunks_path)
    meta = load_jsonl(meta_path)

    print(f"FAISS ntotal: {index.ntotal}, Chunks: {len(chunks)}, Meta: {len(meta)}")
    assert index.ntotal == len(chunks) == len(meta), "DB mapping mismatch!"
    return index, chunks, meta

def load_embedding_models(emb_model_dir, rerank_model_dir):
    print("Loading embedding and reranking models...")
    embedder = BGEM3FlagModel(emb_model_dir, use_fp16=True)
    reranker = FlagReranker(rerank_model_dir, use_fp16=True)
    print("Embedding and reranking models loaded successfully!")
    return embedder, reranker

def l2_normalize(v):
    denom = np.linalg.norm(v, axis=1, keepdims=True) + 1e-12
    return v / denom

def embed_query(embedder, q):
    out = embedder.encode([q], return_dense=True, return_sparse=False, return_colbert_vecs=False)
    vec = np.array(out["dense_vecs"], dtype=np.float32)
    return l2_normalize(vec).astype("float32")

def retrieve_then_rerank(embedder, reranker, index, chunks, meta, question_bn, top_k=3, faiss_top_n=30):
    """RAG: Retrieve from FAISS and rerank"""
    q_for_retrieval = question_bn
    
    qvec = embed_query(embedder, q_for_retrieval)
    D, I = index.search(qvec, faiss_top_n)
    cand_idxs = I[0].tolist()
    cand_texts = []
    valid_idxs = []
    
    for i in cand_idxs:
        t = chunks[i].get("text") or chunks[i].get("chunk_text") or ""
        t = (t or "").strip()
        if t:
            cand_texts.append(t)
            valid_idxs.append(i)

    if not valid_idxs:
        return [], ""

    pairs = [[q_for_retrieval, cand_texts[j]] for j in range(len(valid_idxs))]
    scores = np.array(reranker.compute_score(pairs), dtype=np.float32)
    order = np.argsort(-scores)[:top_k]
    
    top = []
    for j in order:
        idx = valid_idxs[j]
        md = meta[idx].get("metadata", {}) if isinstance(meta[idx], dict) else {}
        top.append({
            "idx": idx,
            "text": cand_texts[j],
            "rerank_score": float(scores[j]),
            "metadata": md,
        })
    
    ctx = "\n\n---\n\n".join([t["text"] for t in top])
    return top, ctx