File size: 7,086 Bytes
44c5827
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import json
from pathlib import Path
import logging
from typing import Any, Dict, List, Optional
from sentence_transformers import SentenceTransformer
import torch
import time

BASE = Path(__file__).resolve().parent.parent

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("emb_chroma")


try:
    import chromadb
    from chromadb.config import Settings
except Exception as e:
    raise RuntimeError("chromadb not installed. pip install chromadb") from e

# --- Helpers ---------------------------------------------------------------
def chunk_files_iter(chunks_dir: Path):
    for p in sorted(chunks_dir.glob("*.json")):
        yield p

def load_json(path: Path) -> Dict[str, Any]:
    return json.loads(path.read_text(encoding="utf-8"))

def save_json(path: Path, obj: Dict[str, Any]) -> None:
    path.write_text(json.dumps(obj, ensure_ascii=False, indent=2), encoding="utf-8")

def prepare_text_for_embedding(chunk: Dict[str, Any]) -> str:
    # prefer chunk_for_embedding; fallback to chunk_text
    txt = chunk.get("chunk_for_embedding") or chunk.get("chunk_text") or ""
    # ensure not empty; also optionally trim extremely long previews
    return txt.strip()

# --- Main ------------------------------------------------------------------

class ChromaIndexer:
    def __init__(self, persist_dir: str, collection_name: str, embedding_model_name: str, device: str = "cpu"):
        self.persist_dir = persist_dir
        self.collection_name = collection_name
        self.embedding_model_name = embedding_model_name
        self.device = device
        
        # init chroma client
        settings = Settings(chroma_db_impl="duckdb+parquet", persist_directory=self.persist_dir)
        self.client = chromadb.Client(settings)

        # create or get collection
        try:
            self.collection = self.client.get_collection(self.collection_name)
            logger.info("Opened existing Chroma collection '%s' (persist_dir=%s)", self.collection_name, self.persist_dir)
        except Exception:
            self.collection = self.client.create_collection(self.collection_name)
            logger.info("Created new Chroma collection '%s'", self.collection_name)
        
        # Load embedding model
        logger.info("Loading embedding model '%s' on device=%s", self.embedding_model_name, self.device)
        self.model = SentenceTransformer(self.embedding_model_name, device=self.device)
    
    def embed_texts(self, texts: List[str]) -> List[List[float]]:
        # SentenceTransformer encode returns numpy arrays; convert to lists
        embs = self.model.encode(texts, show_progress_bar=True, convert_to_numpy=True)
        return [list(vec.astype(float)) for vec in embs]
    
    def upsert_batch(self, ids: List[str], embeddings: List[List[float]], metadatas: List[Dict[str, Any]], documents: Optional[List[str]] = None):
        # chroma collection.add expects lists
        docs = documents if documents is not None else [m.get("preview","") for m in metadatas]
        self.collection.add(ids=ids, embeddings=embeddings, metadatas=metadatas, documents=docs)
        

def main(chunks_dir: str, persist_dir: str, collection: str, model_name: str, batch_size: int, device: str, force_reembed: bool):
    chunks_dir_path = BASE / chunks_dir
    persist_dir_path = BASE / persist_dir
    # persist_dir_path.mkdir(parents=True, exist_ok=True)
    indexer = ChromaIndexer(str(persist_dir_path), collection, model_name, device=device)

    to_process = []
    for p in chunk_files_iter(chunks_dir_path):
        try:
            chunk = load_json(p)
        except Exception as e:
            logger.warning("Skip unreadable chunk %s: %s", p, e); continue
        # skip embedded marker unless force
        if chunk.get("_embedded", False) and not force_reembed:
            continue
        text = prepare_text_for_embedding(chunk)
        if not text:
            continue
        to_process.append((p, chunk, text))

    logger.info("Found %d chunks to embed", len(to_process))
    if not to_process:
        return
    
    # process in batches
    for i in range(0, len(to_process), batch_size):
        batch = to_process[i:i+batch_size]
        paths = [t[0] for t in batch]
        chunks = [t[1] for t in batch]
        texts = [t[2] for t in batch]
        ids = [c.get("id") or c.get("checksum") or f"chunk-{idx}" for idx,c,idx in zip(paths, chunks, range(i, i+len(batch)))]
        # compute embeddings
        try:
            start_time = time.time()
            embeddings = indexer.embed_texts(texts)
            logger.info(f"Embedding time: {time.time() - start_time} seconds")
        except Exception as e:
            logger.exception("Embedding failed for batch starting %d: %s", i, e)
            raise

        # prepare metadatas
        metas = []
        for c in chunks:
            meta = {
                "doc_id": c.get("doc_id"),
                "source_filename": c.get("source_filename"),
                "chapter": c.get("chapter"),
                "article": c.get("article"),
                "clause": c.get("clause"),
                "point": c.get("point"),
                "content_type": c.get("content_type"),
                "table_id": c.get("table_id"),
                "checksum": c.get("checksum"),
                "path": c.get("path"),
                "preview": (c.get("chunk_text") or "")[:2000],
                "chunk_for_embedding": c.get("chunk_for_embedding"),
                "token_count": c.get("token_count")
                
            }
            # Filter out None values and convert lists to strings as ChromaDB only accepts str, int, or float
            filtered_meta = {}
            for k, v in meta.items():
                if v is not None:
                    if isinstance(v, list):
                        # Convert list to string representation
                        filtered_meta[k] = " | ".join(str(item) for item in v)
                    else:
                        filtered_meta[k] = v
            metas.append(filtered_meta)
        
        # upsert to chroma
        try:
            indexer.upsert_batch(ids, embeddings, metas, documents=[m["preview"] for m in metas])
        except Exception as e:
            logger.exception("Chroma upsert failed: %s", e)
            raise
        
        # mark chunks as embedded
        for pth, ch in zip(paths, chunks):
            ch["_embedded"] = True
            save_json(pth, ch)

        logger.info("Upserted batch %d -> %d vectors", i//batch_size + 1, len(batch))
        
    logger.info("Done. Chroma persist dir: %s", persist_dir)


if __name__ == "__main__":
    import os
    current_dir = os.path.dirname(os.path.abspath(__file__))
    parent_dir = os.path.dirname(current_dir)
    main(
        chunks_dir="chunks",
        persist_dir= os.path.join(parent_dir, "chroma_db"),
        collection="snote",
        model_name="AITeamVN/Vietnamese_Embedding_v2",
        batch_size=100,
        device="cpu",
        force_reembed=True
    )