snote / scripts /embedding_index.py
xuanbao01's picture
Upload folder using huggingface_hub
44c5827 verified
import json
from pathlib import Path
import logging
from typing import Any, Dict, List, Optional
from sentence_transformers import SentenceTransformer
import torch
import time
BASE = Path(__file__).resolve().parent.parent
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("emb_chroma")
try:
import chromadb
from chromadb.config import Settings
except Exception as e:
raise RuntimeError("chromadb not installed. pip install chromadb") from e
# --- Helpers ---------------------------------------------------------------
def chunk_files_iter(chunks_dir: Path):
for p in sorted(chunks_dir.glob("*.json")):
yield p
def load_json(path: Path) -> Dict[str, Any]:
return json.loads(path.read_text(encoding="utf-8"))
def save_json(path: Path, obj: Dict[str, Any]) -> None:
path.write_text(json.dumps(obj, ensure_ascii=False, indent=2), encoding="utf-8")
def prepare_text_for_embedding(chunk: Dict[str, Any]) -> str:
# prefer chunk_for_embedding; fallback to chunk_text
txt = chunk.get("chunk_for_embedding") or chunk.get("chunk_text") or ""
# ensure not empty; also optionally trim extremely long previews
return txt.strip()
# --- Main ------------------------------------------------------------------
class ChromaIndexer:
def __init__(self, persist_dir: str, collection_name: str, embedding_model_name: str, device: str = "cpu"):
self.persist_dir = persist_dir
self.collection_name = collection_name
self.embedding_model_name = embedding_model_name
self.device = device
# init chroma client
settings = Settings(chroma_db_impl="duckdb+parquet", persist_directory=self.persist_dir)
self.client = chromadb.Client(settings)
# create or get collection
try:
self.collection = self.client.get_collection(self.collection_name)
logger.info("Opened existing Chroma collection '%s' (persist_dir=%s)", self.collection_name, self.persist_dir)
except Exception:
self.collection = self.client.create_collection(self.collection_name)
logger.info("Created new Chroma collection '%s'", self.collection_name)
# Load embedding model
logger.info("Loading embedding model '%s' on device=%s", self.embedding_model_name, self.device)
self.model = SentenceTransformer(self.embedding_model_name, device=self.device)
def embed_texts(self, texts: List[str]) -> List[List[float]]:
# SentenceTransformer encode returns numpy arrays; convert to lists
embs = self.model.encode(texts, show_progress_bar=True, convert_to_numpy=True)
return [list(vec.astype(float)) for vec in embs]
def upsert_batch(self, ids: List[str], embeddings: List[List[float]], metadatas: List[Dict[str, Any]], documents: Optional[List[str]] = None):
# chroma collection.add expects lists
docs = documents if documents is not None else [m.get("preview","") for m in metadatas]
self.collection.add(ids=ids, embeddings=embeddings, metadatas=metadatas, documents=docs)
def main(chunks_dir: str, persist_dir: str, collection: str, model_name: str, batch_size: int, device: str, force_reembed: bool):
chunks_dir_path = BASE / chunks_dir
persist_dir_path = BASE / persist_dir
# persist_dir_path.mkdir(parents=True, exist_ok=True)
indexer = ChromaIndexer(str(persist_dir_path), collection, model_name, device=device)
to_process = []
for p in chunk_files_iter(chunks_dir_path):
try:
chunk = load_json(p)
except Exception as e:
logger.warning("Skip unreadable chunk %s: %s", p, e); continue
# skip embedded marker unless force
if chunk.get("_embedded", False) and not force_reembed:
continue
text = prepare_text_for_embedding(chunk)
if not text:
continue
to_process.append((p, chunk, text))
logger.info("Found %d chunks to embed", len(to_process))
if not to_process:
return
# process in batches
for i in range(0, len(to_process), batch_size):
batch = to_process[i:i+batch_size]
paths = [t[0] for t in batch]
chunks = [t[1] for t in batch]
texts = [t[2] for t in batch]
ids = [c.get("id") or c.get("checksum") or f"chunk-{idx}" for idx,c,idx in zip(paths, chunks, range(i, i+len(batch)))]
# compute embeddings
try:
start_time = time.time()
embeddings = indexer.embed_texts(texts)
logger.info(f"Embedding time: {time.time() - start_time} seconds")
except Exception as e:
logger.exception("Embedding failed for batch starting %d: %s", i, e)
raise
# prepare metadatas
metas = []
for c in chunks:
meta = {
"doc_id": c.get("doc_id"),
"source_filename": c.get("source_filename"),
"chapter": c.get("chapter"),
"article": c.get("article"),
"clause": c.get("clause"),
"point": c.get("point"),
"content_type": c.get("content_type"),
"table_id": c.get("table_id"),
"checksum": c.get("checksum"),
"path": c.get("path"),
"preview": (c.get("chunk_text") or "")[:2000],
"chunk_for_embedding": c.get("chunk_for_embedding"),
"token_count": c.get("token_count")
}
# Filter out None values and convert lists to strings as ChromaDB only accepts str, int, or float
filtered_meta = {}
for k, v in meta.items():
if v is not None:
if isinstance(v, list):
# Convert list to string representation
filtered_meta[k] = " | ".join(str(item) for item in v)
else:
filtered_meta[k] = v
metas.append(filtered_meta)
# upsert to chroma
try:
indexer.upsert_batch(ids, embeddings, metas, documents=[m["preview"] for m in metas])
except Exception as e:
logger.exception("Chroma upsert failed: %s", e)
raise
# mark chunks as embedded
for pth, ch in zip(paths, chunks):
ch["_embedded"] = True
save_json(pth, ch)
logger.info("Upserted batch %d -> %d vectors", i//batch_size + 1, len(batch))
logger.info("Done. Chroma persist dir: %s", persist_dir)
if __name__ == "__main__":
import os
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
main(
chunks_dir="chunks",
persist_dir= os.path.join(parent_dir, "chroma_db"),
collection="snote",
model_name="AITeamVN/Vietnamese_Embedding_v2",
batch_size=100,
device="cpu",
force_reembed=True
)