File size: 7,086 Bytes
44c5827 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import json
from pathlib import Path
import logging
from typing import Any, Dict, List, Optional
from sentence_transformers import SentenceTransformer
import torch
import time
BASE = Path(__file__).resolve().parent.parent
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("emb_chroma")
try:
import chromadb
from chromadb.config import Settings
except Exception as e:
raise RuntimeError("chromadb not installed. pip install chromadb") from e
# --- Helpers ---------------------------------------------------------------
def chunk_files_iter(chunks_dir: Path):
for p in sorted(chunks_dir.glob("*.json")):
yield p
def load_json(path: Path) -> Dict[str, Any]:
return json.loads(path.read_text(encoding="utf-8"))
def save_json(path: Path, obj: Dict[str, Any]) -> None:
path.write_text(json.dumps(obj, ensure_ascii=False, indent=2), encoding="utf-8")
def prepare_text_for_embedding(chunk: Dict[str, Any]) -> str:
# prefer chunk_for_embedding; fallback to chunk_text
txt = chunk.get("chunk_for_embedding") or chunk.get("chunk_text") or ""
# ensure not empty; also optionally trim extremely long previews
return txt.strip()
# --- Main ------------------------------------------------------------------
class ChromaIndexer:
def __init__(self, persist_dir: str, collection_name: str, embedding_model_name: str, device: str = "cpu"):
self.persist_dir = persist_dir
self.collection_name = collection_name
self.embedding_model_name = embedding_model_name
self.device = device
# init chroma client
settings = Settings(chroma_db_impl="duckdb+parquet", persist_directory=self.persist_dir)
self.client = chromadb.Client(settings)
# create or get collection
try:
self.collection = self.client.get_collection(self.collection_name)
logger.info("Opened existing Chroma collection '%s' (persist_dir=%s)", self.collection_name, self.persist_dir)
except Exception:
self.collection = self.client.create_collection(self.collection_name)
logger.info("Created new Chroma collection '%s'", self.collection_name)
# Load embedding model
logger.info("Loading embedding model '%s' on device=%s", self.embedding_model_name, self.device)
self.model = SentenceTransformer(self.embedding_model_name, device=self.device)
def embed_texts(self, texts: List[str]) -> List[List[float]]:
# SentenceTransformer encode returns numpy arrays; convert to lists
embs = self.model.encode(texts, show_progress_bar=True, convert_to_numpy=True)
return [list(vec.astype(float)) for vec in embs]
def upsert_batch(self, ids: List[str], embeddings: List[List[float]], metadatas: List[Dict[str, Any]], documents: Optional[List[str]] = None):
# chroma collection.add expects lists
docs = documents if documents is not None else [m.get("preview","") for m in metadatas]
self.collection.add(ids=ids, embeddings=embeddings, metadatas=metadatas, documents=docs)
def main(chunks_dir: str, persist_dir: str, collection: str, model_name: str, batch_size: int, device: str, force_reembed: bool):
chunks_dir_path = BASE / chunks_dir
persist_dir_path = BASE / persist_dir
# persist_dir_path.mkdir(parents=True, exist_ok=True)
indexer = ChromaIndexer(str(persist_dir_path), collection, model_name, device=device)
to_process = []
for p in chunk_files_iter(chunks_dir_path):
try:
chunk = load_json(p)
except Exception as e:
logger.warning("Skip unreadable chunk %s: %s", p, e); continue
# skip embedded marker unless force
if chunk.get("_embedded", False) and not force_reembed:
continue
text = prepare_text_for_embedding(chunk)
if not text:
continue
to_process.append((p, chunk, text))
logger.info("Found %d chunks to embed", len(to_process))
if not to_process:
return
# process in batches
for i in range(0, len(to_process), batch_size):
batch = to_process[i:i+batch_size]
paths = [t[0] for t in batch]
chunks = [t[1] for t in batch]
texts = [t[2] for t in batch]
ids = [c.get("id") or c.get("checksum") or f"chunk-{idx}" for idx,c,idx in zip(paths, chunks, range(i, i+len(batch)))]
# compute embeddings
try:
start_time = time.time()
embeddings = indexer.embed_texts(texts)
logger.info(f"Embedding time: {time.time() - start_time} seconds")
except Exception as e:
logger.exception("Embedding failed for batch starting %d: %s", i, e)
raise
# prepare metadatas
metas = []
for c in chunks:
meta = {
"doc_id": c.get("doc_id"),
"source_filename": c.get("source_filename"),
"chapter": c.get("chapter"),
"article": c.get("article"),
"clause": c.get("clause"),
"point": c.get("point"),
"content_type": c.get("content_type"),
"table_id": c.get("table_id"),
"checksum": c.get("checksum"),
"path": c.get("path"),
"preview": (c.get("chunk_text") or "")[:2000],
"chunk_for_embedding": c.get("chunk_for_embedding"),
"token_count": c.get("token_count")
}
# Filter out None values and convert lists to strings as ChromaDB only accepts str, int, or float
filtered_meta = {}
for k, v in meta.items():
if v is not None:
if isinstance(v, list):
# Convert list to string representation
filtered_meta[k] = " | ".join(str(item) for item in v)
else:
filtered_meta[k] = v
metas.append(filtered_meta)
# upsert to chroma
try:
indexer.upsert_batch(ids, embeddings, metas, documents=[m["preview"] for m in metas])
except Exception as e:
logger.exception("Chroma upsert failed: %s", e)
raise
# mark chunks as embedded
for pth, ch in zip(paths, chunks):
ch["_embedded"] = True
save_json(pth, ch)
logger.info("Upserted batch %d -> %d vectors", i//batch_size + 1, len(batch))
logger.info("Done. Chroma persist dir: %s", persist_dir)
if __name__ == "__main__":
import os
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
main(
chunks_dir="chunks",
persist_dir= os.path.join(parent_dir, "chroma_db"),
collection="snote",
model_name="AITeamVN/Vietnamese_Embedding_v2",
batch_size=100,
device="cpu",
force_reembed=True
) |