File size: 4,753 Bytes
0c39e68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""Embed all chunks into ChromaDB in memory-safe mega-batches.

Processes chunks in mega-batches (default 10K) to avoid OOM:
  encode batch → store in ChromaDB → free memory → next batch

Usage:
    python scripts/embed_chunks.py                    # defaults
    python scripts/embed_chunks.py --mega-batch 5000  # smaller mega-batches
    python scripts/embed_chunks.py --encode-batch 256 # bigger GPU batches
"""

import argparse
import gc
import logging
import sys
import time
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent))

from src.config import get_config
from src.storage.chroma_store import ChromaStore
from src.storage.sqlite_db import SQLiteDB

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
logger = logging.getLogger(__name__)


def main():
    parser = argparse.ArgumentParser(description="Embed chunks into ChromaDB")
    parser.add_argument("--mega-batch", type=int, default=10000,
                        help="Chunks per mega-batch (encode+store cycle)")
    parser.add_argument("--encode-batch", type=int, default=128,
                        help="GPU encoding batch size")
    parser.add_argument("--chroma-batch", type=int, default=500,
                        help="ChromaDB insertion batch size")
    parser.add_argument("--no-reset", action="store_true",
                        help="Don't reset ChromaDB (resume mode)")
    args = parser.parse_args()

    config = get_config()
    db = SQLiteDB(config.sqlite_db_path)
    chroma = ChromaStore(config.chroma_db_path)

    # Reset ChromaDB unless resuming
    if not args.no_reset:
        logger.info("Resetting ChromaDB collection...")
        chroma.reset()

    # Load all chunks from DB
    logger.info("Loading chunks from SQLite...")
    all_chunks = db.get_all_chunks()
    total = len(all_chunks)
    logger.info("Total chunks: %d", total)

    # Load model
    logger.info("Loading embedding model...")
    from sentence_transformers import SentenceTransformer
    model = SentenceTransformer(config.embedding_model)
    dim = model.get_sentence_embedding_dimension()
    logger.info("Model loaded on %s (dim=%d)", model.device, dim)

    # Process in mega-batches
    start_time = time.time()
    total_stored = 0

    for mega_start in range(0, total, args.mega_batch):
        mega_end = min(mega_start + args.mega_batch, total)
        batch_chunks = all_chunks[mega_start:mega_end]
        batch_size = len(batch_chunks)

        logger.info("=== Mega-batch %d-%d / %d (%d chunks) ===",
                     mega_start, mega_end, total, batch_size)

        # Extract texts
        texts = [c["chunk_text"] for c in batch_chunks]

        # Encode on GPU
        t0 = time.time()
        embeddings = model.encode(
            texts,
            batch_size=args.encode_batch,
            show_progress_bar=True,
            normalize_embeddings=True,
        )
        encode_time = time.time() - t0
        logger.info("Encoded %d chunks in %.1fs (%.0f chunks/sec)",
                     batch_size, encode_time, batch_size / encode_time)

        # Prepare ChromaDB data (no documents — text lives in SQLite only)
        ids = []
        metadatas = []
        emb_list = []

        for i, chunk in enumerate(batch_chunks):
            chunk_id = str(chunk.get("id", f"{chunk['paper_id']}_chunk_{chunk['chunk_index']}"))
            ids.append(chunk_id)
            emb_list.append(embeddings[i].tolist())
            metadatas.append({
                "paper_id": chunk["paper_id"],
                "chunk_type": chunk.get("chunk_type", "unknown"),
                "chunk_index": chunk.get("chunk_index", 0),
                "year": chunk.get("year", 0),
                "venue": chunk.get("venue", ""),
                "title": chunk.get("title", ""),
            })

        # Store in ChromaDB in sub-batches
        t0 = time.time()
        chroma.add_embeddings(
            ids=ids,
            embeddings=emb_list,
            metadatas=metadatas,
            batch_size=args.chroma_batch,
        )
        store_time = time.time() - t0
        total_stored += batch_size
        logger.info("Stored in ChromaDB in %.1fs. Total stored: %d/%d",
                     store_time, total_stored, total)

        # Free memory
        del texts, embeddings, ids, metadatas, emb_list, batch_chunks
        gc.collect()

    elapsed = time.time() - start_time
    final_count = chroma.count()
    logger.info("=== EMBEDDING COMPLETE ===")
    logger.info("Total: %d embeddings in %.1fs (%.1f chunks/sec)",
                final_count, elapsed, total / elapsed)
    logger.info("ChromaDB count: %d", final_count)


if __name__ == "__main__":
    main()