File size: 2,872 Bytes
9776024
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os, re, glob
from typing import List, Tuple
import faiss
from sentence_transformers import SentenceTransformer
from pypdf import PdfReader

from .config import cfg

def _load_texts(input_dir: str) -> List[Tuple[str, str]]:
    docs = []
    for path in glob.glob(os.path.join(input_dir, "**/*"), recursive=True):
        if os.path.isdir(path):
            continue
        try:
            if path.lower().endswith(('.txt', '.md')):
                with open(path, 'r', encoding='utf-8', errors='ignore') as f:
                    docs.append((path, f.read()))
            elif path.lower().endswith('.pdf'):
                reader = PdfReader(path)
                text = "\n".join([p.extract_text() or "" for p in reader.pages])
                docs.append((path, text))
        except Exception:
            pass
    return docs

def _chunk(text: str, size: int = 800, overlap: int = 120) -> List[str]:
    tokens = re.split(r"(\s+)", text)
    chunks, buf, length = [], [], 0
    for t in tokens:
        buf.append(t)
        length += len(t)
        if length >= size:
            chunks.append("".join(buf))
            buf = buf[-overlap:]
            length = sum(len(x) for x in buf)
    if buf:
        chunks.append("".join(buf))
    return chunks

def build_index(input_dir: str = "data/corpus", index_dir: str = cfg.index_dir, model_name: str = cfg.embedding_model):
    os.makedirs(index_dir, exist_ok=True)
    model = SentenceTransformer(model_name)
    docs = _load_texts(input_dir)
    entries = []
    for path, text in docs:
        for ch in _chunk(text):
            entries.append((path, ch))
    texts = [x[1] for x in entries]
    embs = model.encode(texts, convert_to_numpy=True, normalize_embeddings=True, batch_size=64, show_progress_bar=True)
    dim = embs.shape[1]
    index = faiss.IndexFlatIP(dim)
    index.add(embs)
    faiss.write_index(index, os.path.join(index_dir, "index.faiss"))
    with open(os.path.join(index_dir, "meta.tsv"), "w", encoding="utf-8") as f:
        for (path, ch) in entries:
            f.write(f"{path}\t{ch.replace('\t',' ')}\n")
    return len(entries)

def search(query: str, k: int = 4, index_dir: str = cfg.index_dir, model_name: str = cfg.embedding_model):
    model = SentenceTransformer(model_name)
    index_path = os.path.join(index_dir, "index.faiss")
    meta_path = os.path.join(index_dir, "meta.tsv")
    if not os.path.exists(index_path):
        return []
    index = faiss.read_index(index_path)
    with open(meta_path, "r", encoding="utf-8") as f:
        meta = [line.rstrip("\n").split("\t", 1) for line in f]
    q = model.encode([query], convert_to_numpy=True, normalize_embeddings=True)
    D, I = index.search(q, k)
    results = []
    for i in I[0]:
        if i < 0 or i >= len(meta):
            continue
        results.append((meta[i][0], meta[i][1]))
    return results