Aravindhan11's picture
Upload 33 files
9776024 verified
import os, re, glob
from typing import List, Tuple
import faiss
from sentence_transformers import SentenceTransformer
from pypdf import PdfReader
from .config import cfg
def _load_texts(input_dir: str) -> List[Tuple[str, str]]:
docs = []
for path in glob.glob(os.path.join(input_dir, "**/*"), recursive=True):
if os.path.isdir(path):
continue
try:
if path.lower().endswith(('.txt', '.md')):
with open(path, 'r', encoding='utf-8', errors='ignore') as f:
docs.append((path, f.read()))
elif path.lower().endswith('.pdf'):
reader = PdfReader(path)
text = "\n".join([p.extract_text() or "" for p in reader.pages])
docs.append((path, text))
except Exception:
pass
return docs
def _chunk(text: str, size: int = 800, overlap: int = 120) -> List[str]:
tokens = re.split(r"(\s+)", text)
chunks, buf, length = [], [], 0
for t in tokens:
buf.append(t)
length += len(t)
if length >= size:
chunks.append("".join(buf))
buf = buf[-overlap:]
length = sum(len(x) for x in buf)
if buf:
chunks.append("".join(buf))
return chunks
def build_index(input_dir: str = "data/corpus", index_dir: str = cfg.index_dir, model_name: str = cfg.embedding_model):
os.makedirs(index_dir, exist_ok=True)
model = SentenceTransformer(model_name)
docs = _load_texts(input_dir)
entries = []
for path, text in docs:
for ch in _chunk(text):
entries.append((path, ch))
texts = [x[1] for x in entries]
embs = model.encode(texts, convert_to_numpy=True, normalize_embeddings=True, batch_size=64, show_progress_bar=True)
dim = embs.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(embs)
faiss.write_index(index, os.path.join(index_dir, "index.faiss"))
with open(os.path.join(index_dir, "meta.tsv"), "w", encoding="utf-8") as f:
for (path, ch) in entries:
f.write(f"{path}\t{ch.replace('\t',' ')}\n")
return len(entries)
def search(query: str, k: int = 4, index_dir: str = cfg.index_dir, model_name: str = cfg.embedding_model):
model = SentenceTransformer(model_name)
index_path = os.path.join(index_dir, "index.faiss")
meta_path = os.path.join(index_dir, "meta.tsv")
if not os.path.exists(index_path):
return []
index = faiss.read_index(index_path)
with open(meta_path, "r", encoding="utf-8") as f:
meta = [line.rstrip("\n").split("\t", 1) for line in f]
q = model.encode([query], convert_to_numpy=True, normalize_embeddings=True)
D, I = index.search(q, k)
results = []
for i in I[0]:
if i < 0 or i >= len(meta):
continue
results.append((meta[i][0], meta[i][1]))
return results