import json from pathlib import Path import logging from typing import Any, Dict, List, Optional from sentence_transformers import SentenceTransformer import torch import time BASE = Path(__file__).resolve().parent.parent logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger("emb_chroma") try: import chromadb from chromadb.config import Settings except Exception as e: raise RuntimeError("chromadb not installed. pip install chromadb") from e # --- Helpers --------------------------------------------------------------- def chunk_files_iter(chunks_dir: Path): for p in sorted(chunks_dir.glob("*.json")): yield p def load_json(path: Path) -> Dict[str, Any]: return json.loads(path.read_text(encoding="utf-8")) def save_json(path: Path, obj: Dict[str, Any]) -> None: path.write_text(json.dumps(obj, ensure_ascii=False, indent=2), encoding="utf-8") def prepare_text_for_embedding(chunk: Dict[str, Any]) -> str: # prefer chunk_for_embedding; fallback to chunk_text txt = chunk.get("chunk_for_embedding") or chunk.get("chunk_text") or "" # ensure not empty; also optionally trim extremely long previews return txt.strip() # --- Main ------------------------------------------------------------------ class ChromaIndexer: def __init__(self, persist_dir: str, collection_name: str, embedding_model_name: str, device: str = "cpu"): self.persist_dir = persist_dir self.collection_name = collection_name self.embedding_model_name = embedding_model_name self.device = device # init chroma client settings = Settings(chroma_db_impl="duckdb+parquet", persist_directory=self.persist_dir) self.client = chromadb.Client(settings) # create or get collection try: self.collection = self.client.get_collection(self.collection_name) logger.info("Opened existing Chroma collection '%s' (persist_dir=%s)", self.collection_name, self.persist_dir) except Exception: self.collection = self.client.create_collection(self.collection_name) logger.info("Created new Chroma collection '%s'", self.collection_name) # Load embedding model logger.info("Loading embedding model '%s' on device=%s", self.embedding_model_name, self.device) self.model = SentenceTransformer(self.embedding_model_name, device=self.device) def embed_texts(self, texts: List[str]) -> List[List[float]]: # SentenceTransformer encode returns numpy arrays; convert to lists embs = self.model.encode(texts, show_progress_bar=True, convert_to_numpy=True) return [list(vec.astype(float)) for vec in embs] def upsert_batch(self, ids: List[str], embeddings: List[List[float]], metadatas: List[Dict[str, Any]], documents: Optional[List[str]] = None): # chroma collection.add expects lists docs = documents if documents is not None else [m.get("preview","") for m in metadatas] self.collection.add(ids=ids, embeddings=embeddings, metadatas=metadatas, documents=docs) def main(chunks_dir: str, persist_dir: str, collection: str, model_name: str, batch_size: int, device: str, force_reembed: bool): chunks_dir_path = BASE / chunks_dir persist_dir_path = BASE / persist_dir # persist_dir_path.mkdir(parents=True, exist_ok=True) indexer = ChromaIndexer(str(persist_dir_path), collection, model_name, device=device) to_process = [] for p in chunk_files_iter(chunks_dir_path): try: chunk = load_json(p) except Exception as e: logger.warning("Skip unreadable chunk %s: %s", p, e); continue # skip embedded marker unless force if chunk.get("_embedded", False) and not force_reembed: continue text = prepare_text_for_embedding(chunk) if not text: continue to_process.append((p, chunk, text)) logger.info("Found %d chunks to embed", len(to_process)) if not to_process: return # process in batches for i in range(0, len(to_process), batch_size): batch = to_process[i:i+batch_size] paths = [t[0] for t in batch] chunks = [t[1] for t in batch] texts = [t[2] for t in batch] ids = [c.get("id") or c.get("checksum") or f"chunk-{idx}" for idx,c,idx in zip(paths, chunks, range(i, i+len(batch)))] # compute embeddings try: start_time = time.time() embeddings = indexer.embed_texts(texts) logger.info(f"Embedding time: {time.time() - start_time} seconds") except Exception as e: logger.exception("Embedding failed for batch starting %d: %s", i, e) raise # prepare metadatas metas = [] for c in chunks: meta = { "doc_id": c.get("doc_id"), "source_filename": c.get("source_filename"), "chapter": c.get("chapter"), "article": c.get("article"), "clause": c.get("clause"), "point": c.get("point"), "content_type": c.get("content_type"), "table_id": c.get("table_id"), "checksum": c.get("checksum"), "path": c.get("path"), "preview": (c.get("chunk_text") or "")[:2000], "chunk_for_embedding": c.get("chunk_for_embedding"), "token_count": c.get("token_count") } # Filter out None values and convert lists to strings as ChromaDB only accepts str, int, or float filtered_meta = {} for k, v in meta.items(): if v is not None: if isinstance(v, list): # Convert list to string representation filtered_meta[k] = " | ".join(str(item) for item in v) else: filtered_meta[k] = v metas.append(filtered_meta) # upsert to chroma try: indexer.upsert_batch(ids, embeddings, metas, documents=[m["preview"] for m in metas]) except Exception as e: logger.exception("Chroma upsert failed: %s", e) raise # mark chunks as embedded for pth, ch in zip(paths, chunks): ch["_embedded"] = True save_json(pth, ch) logger.info("Upserted batch %d -> %d vectors", i//batch_size + 1, len(batch)) logger.info("Done. Chroma persist dir: %s", persist_dir) if __name__ == "__main__": import os current_dir = os.path.dirname(os.path.abspath(__file__)) parent_dir = os.path.dirname(current_dir) main( chunks_dir="chunks", persist_dir= os.path.join(parent_dir, "chroma_db"), collection="snote", model_name="AITeamVN/Vietnamese_Embedding_v2", batch_size=100, device="cpu", force_reembed=True )