from FlagEmbedding import BGEM3FlagModel, FlagReranker import numpy as np import json from pathlib import Path import faiss import soundfile as sf def load_jsonl(path): out = [] with open(path, "r", encoding="utf-8") as f: for ln in f: ln = ln.strip() if ln: out.append(json.loads(ln)) return out def load_vector_db(faiss_index_path, chunks_path, meta_path, uid_list_path): print("Loading vector database...") index = faiss.read_index(faiss_index_path) chunks = load_jsonl(chunks_path) meta = load_jsonl(meta_path) print(f"FAISS ntotal: {index.ntotal}, Chunks: {len(chunks)}, Meta: {len(meta)}") assert index.ntotal == len(chunks) == len(meta), "DB mapping mismatch!" return index, chunks, meta def load_embedding_models(emb_model_dir, rerank_model_dir): print("Loading embedding and reranking models...") embedder = BGEM3FlagModel(emb_model_dir, use_fp16=True) reranker = FlagReranker(rerank_model_dir, use_fp16=True) print("Embedding and reranking models loaded successfully!") return embedder, reranker def l2_normalize(v): denom = np.linalg.norm(v, axis=1, keepdims=True) + 1e-12 return v / denom def embed_query(embedder, q): out = embedder.encode([q], return_dense=True, return_sparse=False, return_colbert_vecs=False) vec = np.array(out["dense_vecs"], dtype=np.float32) return l2_normalize(vec).astype("float32") def retrieve_then_rerank(embedder, reranker, index, chunks, meta, question_bn, top_k=3, faiss_top_n=30): """RAG: Retrieve from FAISS and rerank""" q_for_retrieval = question_bn qvec = embed_query(embedder, q_for_retrieval) D, I = index.search(qvec, faiss_top_n) cand_idxs = I[0].tolist() cand_texts = [] valid_idxs = [] for i in cand_idxs: t = chunks[i].get("text") or chunks[i].get("chunk_text") or "" t = (t or "").strip() if t: cand_texts.append(t) valid_idxs.append(i) if not valid_idxs: return [], "" pairs = [[q_for_retrieval, cand_texts[j]] for j in range(len(valid_idxs))] scores = np.array(reranker.compute_score(pairs), dtype=np.float32) order = np.argsort(-scores)[:top_k] top = [] for j in order: idx = valid_idxs[j] md = meta[idx].get("metadata", {}) if isinstance(meta[idx], dict) else {} top.append({ "idx": idx, "text": cand_texts[j], "rerank_score": float(scores[j]), "metadata": md, }) ctx = "\n\n---\n\n".join([t["text"] for t in top]) return top, ctx