Spaces:
Runtime error
Runtime error
| import os, re, glob | |
| from typing import List, Tuple | |
| import faiss | |
| from sentence_transformers import SentenceTransformer | |
| from pypdf import PdfReader | |
| from .config import cfg | |
| def _load_texts(input_dir: str) -> List[Tuple[str, str]]: | |
| docs = [] | |
| for path in glob.glob(os.path.join(input_dir, "**/*"), recursive=True): | |
| if os.path.isdir(path): | |
| continue | |
| try: | |
| if path.lower().endswith(('.txt', '.md')): | |
| with open(path, 'r', encoding='utf-8', errors='ignore') as f: | |
| docs.append((path, f.read())) | |
| elif path.lower().endswith('.pdf'): | |
| reader = PdfReader(path) | |
| text = "\n".join([p.extract_text() or "" for p in reader.pages]) | |
| docs.append((path, text)) | |
| except Exception: | |
| pass | |
| return docs | |
| def _chunk(text: str, size: int = 800, overlap: int = 120) -> List[str]: | |
| tokens = re.split(r"(\s+)", text) | |
| chunks, buf, length = [], [], 0 | |
| for t in tokens: | |
| buf.append(t) | |
| length += len(t) | |
| if length >= size: | |
| chunks.append("".join(buf)) | |
| buf = buf[-overlap:] | |
| length = sum(len(x) for x in buf) | |
| if buf: | |
| chunks.append("".join(buf)) | |
| return chunks | |
| def build_index(input_dir: str = "data/corpus", index_dir: str = cfg.index_dir, model_name: str = cfg.embedding_model): | |
| os.makedirs(index_dir, exist_ok=True) | |
| model = SentenceTransformer(model_name) | |
| docs = _load_texts(input_dir) | |
| entries = [] | |
| for path, text in docs: | |
| for ch in _chunk(text): | |
| entries.append((path, ch)) | |
| texts = [x[1] for x in entries] | |
| embs = model.encode(texts, convert_to_numpy=True, normalize_embeddings=True, batch_size=64, show_progress_bar=True) | |
| dim = embs.shape[1] | |
| index = faiss.IndexFlatIP(dim) | |
| index.add(embs) | |
| faiss.write_index(index, os.path.join(index_dir, "index.faiss")) | |
| with open(os.path.join(index_dir, "meta.tsv"), "w", encoding="utf-8") as f: | |
| for (path, ch) in entries: | |
| f.write(f"{path}\t{ch.replace('\t',' ')}\n") | |
| return len(entries) | |
| def search(query: str, k: int = 4, index_dir: str = cfg.index_dir, model_name: str = cfg.embedding_model): | |
| model = SentenceTransformer(model_name) | |
| index_path = os.path.join(index_dir, "index.faiss") | |
| meta_path = os.path.join(index_dir, "meta.tsv") | |
| if not os.path.exists(index_path): | |
| return [] | |
| index = faiss.read_index(index_path) | |
| with open(meta_path, "r", encoding="utf-8") as f: | |
| meta = [line.rstrip("\n").split("\t", 1) for line in f] | |
| q = model.encode([query], convert_to_numpy=True, normalize_embeddings=True) | |
| D, I = index.search(q, k) | |
| results = [] | |
| for i in I[0]: | |
| if i < 0 or i >= len(meta): | |
| continue | |
| results.append((meta[i][0], meta[i][1])) | |
| return results | |