snote / scripts /bm25_index.py
xuanbao01's picture
Upload folder using huggingface_hub
44c5827 verified
import os, json, glob, pickle, logging
from typing import List, Dict, Any
from underthesea import word_tokenize
from rank_bm25 import BM25Okapi
from pathlib import Path
# ---------------------------
# Config & Logging
# ---------------------------
logging.basicConfig(
level=os.getenv("LOG_LEVEL", "INFO"),
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s"
)
logger = logging.getLogger("bm25_local_indexer")
BASE_DIR = Path(__file__).resolve().parent.parent
CHUNKS_DIR = BASE_DIR / "chunks"
INDEX_OUT = BASE_DIR / "bm25_index.pkl"
MAX_TOKENS = 512
# ---------------------------
# Tokenization (Tiếng Việt)
# ---------------------------
def tokenize_vi(text: str) -> List[str]:
return word_tokenize(text, format="text").lower().split()
# ---------------------------
# Load Chunks
# ---------------------------
def load_chunks(chunks_dir: Path) -> List[Dict[str, Any]]:
files = chunks_dir.glob("*.json")
docs = []
for fp in files:
# Skip manifest files
if "manifest" in fp.name.lower():
logger.info("Skipping manifest file: %s", fp)
continue
try:
with open(fp, "r", encoding="utf-8") as f:
ch = json.load(f)
except Exception as e:
logger.warning("Failed to read %s: %s", fp, e)
continue
text = ch.get("chunk_text", "")
docs.append({
"id": ch["id"],
"doc_id": ch.get("doc_id"),
"path": ch.get("path"),
"text": text,
"chunk_for_embedding": ch.get("chunk_for_embedding"),
"token_count": ch.get("token_count")
})
logger.info("Loaded %d chunks", len(docs))
return docs
# ---------------------------
# Build BM25 Index
# ---------------------------
def build_bm25_index(chunks: List[Dict[str, Any]]) -> Dict[str, Any]:
corpus = [tokenize_vi(c["text"]) for c in chunks]
bm25 = BM25Okapi(corpus)
index = {
"bm25": bm25,
"chunks": chunks,
"tokenized_corpus": corpus
}
return index
# ---------------------------
# Save & Load index
# ---------------------------
def save_index(index: Dict[str, Any], out_path: str):
with open(out_path, "wb") as f:
pickle.dump(index, f)
logger.info("Saved BM25 index to %s", out_path)
def load_index(path: str) -> Dict[str, Any]:
with open(path, "rb") as f:
return pickle.load(f)
# ---------------------------
# CLI
# ---------------------------
def main(reindex: bool, check: bool):
if reindex:
chunks = load_chunks(CHUNKS_DIR)
idx = build_bm25_index(chunks)
save_index(idx, INDEX_OUT)
if check:
idx = load_index(INDEX_OUT)
logger.info("Index contains %d chunks", len(idx["chunks"]))
if __name__ == "__main__":
main(reindex=True, check=True)