File size: 4,650 Bytes
48e85cb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | # build_vector_store.py
import os
import json
import math
from pathlib import Path
from tqdm import tqdm
import numpy as np
import pdfplumber
from sentence_transformers import SentenceTransformer
import faiss
# --------- CONFIG ----------
DOCS_DIR = Path("docs")
DATA_DIR = Path("data")
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
CHUNK_CHAR_SIZE = 1000 # ~400-500 tokens approx (tweak if you want)
CHUNK_OVERLAP = 200
EMBED_DIM = 384 # embedding dimension of all-MiniLM-L6-v2
BATCH_SIZE = 32
TOP_K = 5
# ---------------------------
DATA_DIR.mkdir(exist_ok=True)
def extract_text_from_pdf(pdf_path: Path):
pages = []
with pdfplumber.open(pdf_path) as pdf:
for i, page in enumerate(pdf.pages):
text = page.extract_text() or ""
pages.append({"page_number": i+1, "text": text})
return pages
def split_text_into_chunks(text, chunk_size=CHUNK_CHAR_SIZE, overlap=CHUNK_OVERLAP):
text = text.strip()
if not text:
return []
chunks = []
start = 0
text_len = len(text)
while start < text_len:
end = start + chunk_size
# try to avoid breaking mid-sentence: find last newline or period inside chunk
if end < text_len:
snippet = text[start:end]
# prefer last sentence boundary
cut = max(snippet.rfind('\n'), snippet.rfind('. '), snippet.rfind('? '), snippet.rfind('! '))
if cut != -1 and cut > int(chunk_size * 0.5):
end = start + cut + 1
chunk_text = text[start:end].strip()
if chunk_text:
chunks.append(chunk_text)
start = end - overlap
if start < 0:
start = 0
if end >= text_len:
break
return chunks
def build_embeddings(model, texts):
embeddings = []
for i in range(0, len(texts), BATCH_SIZE):
batch = texts[i:i+BATCH_SIZE]
embs = model.encode(batch, show_progress_bar=False, convert_to_numpy=True)
embeddings.append(embs)
if embeddings:
return np.vstack(embeddings)
return np.empty((0, model.get_sentence_embedding_dimension()))
def normalize_embeddings(embeddings: np.ndarray):
# normalize in-place to unit vectors for cosine via inner product index
faiss.normalize_L2(embeddings)
return embeddings
def main():
model = SentenceTransformer(EMBED_MODEL)
EMBED_DIM_LOCAL = model.get_sentence_embedding_dimension()
print(f"Loaded embed model '{EMBED_MODEL}' with dim={EMBED_DIM_LOCAL}")
all_text_chunks = []
metadata = []
chunk_id = 0
pdf_files = list(DOCS_DIR.glob("*.pdf"))
if not pdf_files:
print("No PDF files found in docs/ — put your PDFs there and re-run.")
return
for pdf_path in pdf_files:
print(f"Processing: {pdf_path.name}")
pages = extract_text_from_pdf(pdf_path)
for page in pages:
page_text = page["text"]
if not page_text:
continue
chunks = split_text_into_chunks(page_text)
for i, c in enumerate(chunks):
doc_meta = {
"chunk_id": chunk_id,
"source_file": pdf_path.name,
"page": page["page_number"],
"chunk_index_in_page": i,
"text": c[:1000] # store a preview (or store full text if you want)
}
metadata.append(doc_meta)
all_text_chunks.append(c)
chunk_id += 1
if not all_text_chunks:
print("No text extracted from PDFs.")
return
print(f"Total chunks: {len(all_text_chunks)}")
# compute embeddings
embeddings = build_embeddings(model, all_text_chunks)
print("Embeddings shape:", embeddings.shape)
# normalize
embeddings = normalize_embeddings(embeddings)
# build FAISS index (inner-product on normalized vectors == cosine sim)
index = faiss.IndexFlatIP(EMBED_DIM_LOCAL)
index.add(embeddings.astype('float32'))
print("FAISS index built. n_total:", index.ntotal)
# save index and metadata
index_path = DATA_DIR / "vector_store.index"
faiss.write_index(index, str(index_path))
meta_path = DATA_DIR / "metadata.json"
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(metadata, f, ensure_ascii=False, indent=2)
print(f"Saved FAISS index -> {index_path}")
print(f"Saved metadata -> {meta_path}")
if __name__ == "__main__":
main()
|