|
|
import json |
|
|
from pathlib import Path |
|
|
import logging |
|
|
from typing import Any, Dict, List, Optional |
|
|
from sentence_transformers import SentenceTransformer |
|
|
import torch |
|
|
import time |
|
|
|
|
|
BASE = Path(__file__).resolve().parent.parent |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
|
|
logger = logging.getLogger("emb_chroma") |
|
|
|
|
|
|
|
|
try: |
|
|
import chromadb |
|
|
from chromadb.config import Settings |
|
|
except Exception as e: |
|
|
raise RuntimeError("chromadb not installed. pip install chromadb") from e |
|
|
|
|
|
|
|
|
def chunk_files_iter(chunks_dir: Path): |
|
|
for p in sorted(chunks_dir.glob("*.json")): |
|
|
yield p |
|
|
|
|
|
def load_json(path: Path) -> Dict[str, Any]: |
|
|
return json.loads(path.read_text(encoding="utf-8")) |
|
|
|
|
|
def save_json(path: Path, obj: Dict[str, Any]) -> None: |
|
|
path.write_text(json.dumps(obj, ensure_ascii=False, indent=2), encoding="utf-8") |
|
|
|
|
|
def prepare_text_for_embedding(chunk: Dict[str, Any]) -> str: |
|
|
|
|
|
txt = chunk.get("chunk_for_embedding") or chunk.get("chunk_text") or "" |
|
|
|
|
|
return txt.strip() |
|
|
|
|
|
|
|
|
|
|
|
class ChromaIndexer: |
|
|
def __init__(self, persist_dir: str, collection_name: str, embedding_model_name: str, device: str = "cpu"): |
|
|
self.persist_dir = persist_dir |
|
|
self.collection_name = collection_name |
|
|
self.embedding_model_name = embedding_model_name |
|
|
self.device = device |
|
|
|
|
|
|
|
|
settings = Settings(chroma_db_impl="duckdb+parquet", persist_directory=self.persist_dir) |
|
|
self.client = chromadb.Client(settings) |
|
|
|
|
|
|
|
|
try: |
|
|
self.collection = self.client.get_collection(self.collection_name) |
|
|
logger.info("Opened existing Chroma collection '%s' (persist_dir=%s)", self.collection_name, self.persist_dir) |
|
|
except Exception: |
|
|
self.collection = self.client.create_collection(self.collection_name) |
|
|
logger.info("Created new Chroma collection '%s'", self.collection_name) |
|
|
|
|
|
|
|
|
logger.info("Loading embedding model '%s' on device=%s", self.embedding_model_name, self.device) |
|
|
self.model = SentenceTransformer(self.embedding_model_name, device=self.device) |
|
|
|
|
|
def embed_texts(self, texts: List[str]) -> List[List[float]]: |
|
|
|
|
|
embs = self.model.encode(texts, show_progress_bar=True, convert_to_numpy=True) |
|
|
return [list(vec.astype(float)) for vec in embs] |
|
|
|
|
|
def upsert_batch(self, ids: List[str], embeddings: List[List[float]], metadatas: List[Dict[str, Any]], documents: Optional[List[str]] = None): |
|
|
|
|
|
docs = documents if documents is not None else [m.get("preview","") for m in metadatas] |
|
|
self.collection.add(ids=ids, embeddings=embeddings, metadatas=metadatas, documents=docs) |
|
|
|
|
|
|
|
|
def main(chunks_dir: str, persist_dir: str, collection: str, model_name: str, batch_size: int, device: str, force_reembed: bool): |
|
|
chunks_dir_path = BASE / chunks_dir |
|
|
persist_dir_path = BASE / persist_dir |
|
|
|
|
|
indexer = ChromaIndexer(str(persist_dir_path), collection, model_name, device=device) |
|
|
|
|
|
to_process = [] |
|
|
for p in chunk_files_iter(chunks_dir_path): |
|
|
try: |
|
|
chunk = load_json(p) |
|
|
except Exception as e: |
|
|
logger.warning("Skip unreadable chunk %s: %s", p, e); continue |
|
|
|
|
|
if chunk.get("_embedded", False) and not force_reembed: |
|
|
continue |
|
|
text = prepare_text_for_embedding(chunk) |
|
|
if not text: |
|
|
continue |
|
|
to_process.append((p, chunk, text)) |
|
|
|
|
|
logger.info("Found %d chunks to embed", len(to_process)) |
|
|
if not to_process: |
|
|
return |
|
|
|
|
|
|
|
|
for i in range(0, len(to_process), batch_size): |
|
|
batch = to_process[i:i+batch_size] |
|
|
paths = [t[0] for t in batch] |
|
|
chunks = [t[1] for t in batch] |
|
|
texts = [t[2] for t in batch] |
|
|
ids = [c.get("id") or c.get("checksum") or f"chunk-{idx}" for idx,c,idx in zip(paths, chunks, range(i, i+len(batch)))] |
|
|
|
|
|
try: |
|
|
start_time = time.time() |
|
|
embeddings = indexer.embed_texts(texts) |
|
|
logger.info(f"Embedding time: {time.time() - start_time} seconds") |
|
|
except Exception as e: |
|
|
logger.exception("Embedding failed for batch starting %d: %s", i, e) |
|
|
raise |
|
|
|
|
|
|
|
|
metas = [] |
|
|
for c in chunks: |
|
|
meta = { |
|
|
"doc_id": c.get("doc_id"), |
|
|
"source_filename": c.get("source_filename"), |
|
|
"chapter": c.get("chapter"), |
|
|
"article": c.get("article"), |
|
|
"clause": c.get("clause"), |
|
|
"point": c.get("point"), |
|
|
"content_type": c.get("content_type"), |
|
|
"table_id": c.get("table_id"), |
|
|
"checksum": c.get("checksum"), |
|
|
"path": c.get("path"), |
|
|
"preview": (c.get("chunk_text") or "")[:2000], |
|
|
"chunk_for_embedding": c.get("chunk_for_embedding"), |
|
|
"token_count": c.get("token_count") |
|
|
|
|
|
} |
|
|
|
|
|
filtered_meta = {} |
|
|
for k, v in meta.items(): |
|
|
if v is not None: |
|
|
if isinstance(v, list): |
|
|
|
|
|
filtered_meta[k] = " | ".join(str(item) for item in v) |
|
|
else: |
|
|
filtered_meta[k] = v |
|
|
metas.append(filtered_meta) |
|
|
|
|
|
|
|
|
try: |
|
|
indexer.upsert_batch(ids, embeddings, metas, documents=[m["preview"] for m in metas]) |
|
|
except Exception as e: |
|
|
logger.exception("Chroma upsert failed: %s", e) |
|
|
raise |
|
|
|
|
|
|
|
|
for pth, ch in zip(paths, chunks): |
|
|
ch["_embedded"] = True |
|
|
save_json(pth, ch) |
|
|
|
|
|
logger.info("Upserted batch %d -> %d vectors", i//batch_size + 1, len(batch)) |
|
|
|
|
|
logger.info("Done. Chroma persist dir: %s", persist_dir) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import os |
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
parent_dir = os.path.dirname(current_dir) |
|
|
main( |
|
|
chunks_dir="chunks", |
|
|
persist_dir= os.path.join(parent_dir, "chroma_db"), |
|
|
collection="snote", |
|
|
model_name="AITeamVN/Vietnamese_Embedding_v2", |
|
|
batch_size=100, |
|
|
device="cpu", |
|
|
force_reembed=True |
|
|
) |