import os, re, glob from typing import List, Tuple import faiss from sentence_transformers import SentenceTransformer from pypdf import PdfReader from .config import cfg def _load_texts(input_dir: str) -> List[Tuple[str, str]]: docs = [] for path in glob.glob(os.path.join(input_dir, "**/*"), recursive=True): if os.path.isdir(path): continue try: if path.lower().endswith(('.txt', '.md')): with open(path, 'r', encoding='utf-8', errors='ignore') as f: docs.append((path, f.read())) elif path.lower().endswith('.pdf'): reader = PdfReader(path) text = "\n".join([p.extract_text() or "" for p in reader.pages]) docs.append((path, text)) except Exception: pass return docs def _chunk(text: str, size: int = 800, overlap: int = 120) -> List[str]: tokens = re.split(r"(\s+)", text) chunks, buf, length = [], [], 0 for t in tokens: buf.append(t) length += len(t) if length >= size: chunks.append("".join(buf)) buf = buf[-overlap:] length = sum(len(x) for x in buf) if buf: chunks.append("".join(buf)) return chunks def build_index(input_dir: str = "data/corpus", index_dir: str = cfg.index_dir, model_name: str = cfg.embedding_model): os.makedirs(index_dir, exist_ok=True) model = SentenceTransformer(model_name) docs = _load_texts(input_dir) entries = [] for path, text in docs: for ch in _chunk(text): entries.append((path, ch)) texts = [x[1] for x in entries] embs = model.encode(texts, convert_to_numpy=True, normalize_embeddings=True, batch_size=64, show_progress_bar=True) dim = embs.shape[1] index = faiss.IndexFlatIP(dim) index.add(embs) faiss.write_index(index, os.path.join(index_dir, "index.faiss")) with open(os.path.join(index_dir, "meta.tsv"), "w", encoding="utf-8") as f: for (path, ch) in entries: f.write(f"{path}\t{ch.replace('\t',' ')}\n") return len(entries) def search(query: str, k: int = 4, index_dir: str = cfg.index_dir, model_name: str = cfg.embedding_model): model = SentenceTransformer(model_name) index_path = os.path.join(index_dir, "index.faiss") meta_path = os.path.join(index_dir, "meta.tsv") if not os.path.exists(index_path): return [] index = faiss.read_index(index_path) with open(meta_path, "r", encoding="utf-8") as f: meta = [line.rstrip("\n").split("\t", 1) for line in f] q = model.encode([query], convert_to_numpy=True, normalize_embeddings=True) D, I = index.search(q, k) results = [] for i in I[0]: if i < 0 or i >= len(meta): continue results.append((meta[i][0], meta[i][1])) return results