| import faiss |
| import os |
| import pickle |
| from pypdf import PdfReader |
| from sentence_transformers import SentenceTransformer, CrossEncoder |
| from rank_bm25 import BM25Okapi |
|
|
| USE_HNSW = True |
| USE_RERANKER = True |
|
|
| CHUNK_SIZE = 800 |
| CHUNK_OVERLAP = 200 |
|
|
| DB_FILE_INDEX = "vector.index" |
| DB_FILE_META = "metadata.pkl" |
| DB_FILE_BM25 = "bm25.pkl" |
|
|
| index = None |
| documents = [] |
| metadata = [] |
| bm25 = None |
| tokenized_corpus = [] |
|
|
| embedder = SentenceTransformer("all-MiniLM-L6-v2") |
| reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") |
|
|
| def chunk_text(text): |
| import re |
| sentences = re.split(r'(?<=[.!?])\s+', text) |
| chunks, current = [], "" |
| for s in sentences: |
| if len(current) + len(s) > CHUNK_SIZE and current: |
| chunks.append(current.strip()) |
| overlap = max(0, len(current) - CHUNK_OVERLAP) |
| current = current[overlap:] + " " + s |
| else: |
| current += " " + s if current else s |
| if current.strip(): |
| chunks.append(current.strip()) |
| return chunks |
|
|
| def save_db(): |
| if index: |
| faiss.write_index(index, DB_FILE_INDEX) |
| if documents: |
| with open(DB_FILE_META, "wb") as f: |
| pickle.dump({"documents": documents, "metadata": metadata}, f) |
| if bm25: |
| with open(DB_FILE_BM25, "wb") as f: |
| pickle.dump(bm25, f) |
|
|
| def load_db(): |
| global index, documents, metadata, bm25 |
| if os.path.exists(DB_FILE_INDEX) and os.path.exists(DB_FILE_META): |
| index = faiss.read_index(DB_FILE_INDEX) |
| with open(DB_FILE_META, "rb") as f: |
| data = pickle.load(f) |
| documents = data["documents"] |
| metadata = data["metadata"] |
| |
| if os.path.exists(DB_FILE_BM25): |
| with open(DB_FILE_BM25, "rb") as f: |
| bm25 = pickle.load(f) |
| elif documents: |
| |
| print("Backfilling BM25 index on first load...") |
| tokenized_corpus = [doc.split(" ") for doc in documents] |
| bm25 = BM25Okapi(tokenized_corpus) |
| with open(DB_FILE_BM25, "wb") as f: |
| pickle.dump(bm25, f) |
|
|
| load_db() |
|
|
| def clear_database(): |
| global index, documents, metadata, bm25 |
| index = None |
| documents = [] |
| metadata = [] |
| bm25 = None |
| if os.path.exists(DB_FILE_INDEX): |
| os.remove(DB_FILE_INDEX) |
| if os.path.exists(DB_FILE_META): |
| os.remove(DB_FILE_META) |
| if os.path.exists(DB_FILE_BM25): |
| os.remove(DB_FILE_BM25) |
|
|
| def ingest_documents(files): |
| global index, documents, metadata |
| texts, meta = [], [] |
|
|
| for file in files: |
| if file.filename.endswith(".pdf"): |
| |
| import tempfile |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp: |
| tmp.write(file.file.read()) |
| tmp_path = tmp.name |
| |
| try: |
| |
| import pymupdf4llm |
| md_text = pymupdf4llm.to_markdown(tmp_path) |
| |
| for chunk in chunk_text(md_text): |
| texts.append(chunk) |
| meta.append({"source": file.filename, "page": "N/A"}) |
| finally: |
| os.remove(tmp_path) |
|
|
| elif file.filename.endswith(".txt"): |
| content = file.file.read().decode("utf-8", errors="ignore") |
| for chunk in chunk_text(content): |
| texts.append(chunk) |
| meta.append({"source": file.filename, "page": "N/A"}) |
|
|
| if not texts: |
| raise ValueError("No readable text found (OCR needed for scanned PDFs).") |
|
|
| embeddings = embedder.encode(texts, convert_to_numpy=True, normalize_embeddings=True) |
|
|
| if index is None: |
| dim = embeddings.shape[1] |
| index = faiss.IndexHNSWFlat(dim, 32) if USE_HNSW else faiss.IndexFlatIP(dim) |
| index.hnsw.efConstruction = 200 |
| index.hnsw.efSearch = 64 |
|
|
| index.add(embeddings) |
| documents.extend(texts) |
| metadata.extend(meta) |
| |
| |
| tokenized_corpus = [doc.split(" ") for doc in documents] |
| bm25 = BM25Okapi(tokenized_corpus) |
| |
| save_db() |
| return len(documents) |
|
|
| def search_knowledge(query, top_k=8): |
| if index is None: |
| return [] |
|
|
| |
| qvec = embedder.encode([query], convert_to_numpy=True, normalize_embeddings=True) |
| scores, indices = index.search(qvec, top_k) |
| |
| vector_results = {} |
| for i, (idx, score) in enumerate(zip(indices[0], scores[0])): |
| if idx == -1: continue |
| vector_results[idx] = i |
|
|
| |
| bm25_results = {} |
| if bm25: |
| tokenized_query = query.split(" ") |
| bm25_scores = bm25.get_scores(tokenized_query) |
| |
| top_n = sorted(range(len(bm25_scores)), key=lambda i: bm25_scores[i], reverse=True)[:top_k] |
| for i, idx in enumerate(top_n): |
| bm25_results[idx] = i |
|
|
| |
| |
| k = 60 |
| candidates_idx = set(vector_results.keys()) | set(bm25_results.keys()) |
| merged_candidates = [] |
|
|
| for idx in candidates_idx: |
| v_rank = vector_results.get(idx, float('inf')) |
| b_rank = bm25_results.get(idx, float('inf')) |
| |
| rrf_score = (1 / (k + v_rank)) + (1 / (k + b_rank)) |
| |
| merged_candidates.append({ |
| "text": documents[idx], |
| "metadata": metadata[idx], |
| "score": rrf_score, |
| "vector_rank": v_rank if v_rank != float('inf') else None, |
| "bm25_rank": b_rank if b_rank != float('inf') else None |
| }) |
|
|
| |
| merged_candidates.sort(key=lambda x: x["score"], reverse=True) |
| |
| |
| candidates = merged_candidates[:10] |
|
|
| if USE_RERANKER and candidates: |
| pairs = [(query, c["text"]) for c in candidates] |
| rerank_scores = reranker.predict(pairs) |
| for c, rs in zip(candidates, rerank_scores): |
| c["rerank"] = float(rs) |
| candidates.sort(key=lambda x: x["rerank"], reverse=True) |
|
|
| return candidates[:5] |
|
|
| def get_all_chunks(limit=80): |
| return [{"text": t, "metadata": m} for t, m in zip(documents[:limit], metadata[:limit])] |
|
|