Spaces:
Running
Running
File size: 4,753 Bytes
0c39e68 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | """Embed all chunks into ChromaDB in memory-safe mega-batches.
Processes chunks in mega-batches (default 10K) to avoid OOM:
encode batch → store in ChromaDB → free memory → next batch
Usage:
python scripts/embed_chunks.py # defaults
python scripts/embed_chunks.py --mega-batch 5000 # smaller mega-batches
python scripts/embed_chunks.py --encode-batch 256 # bigger GPU batches
"""
import argparse
import gc
import logging
import sys
import time
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.config import get_config
from src.storage.chroma_store import ChromaStore
from src.storage.sqlite_db import SQLiteDB
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
logger = logging.getLogger(__name__)
def main():
parser = argparse.ArgumentParser(description="Embed chunks into ChromaDB")
parser.add_argument("--mega-batch", type=int, default=10000,
help="Chunks per mega-batch (encode+store cycle)")
parser.add_argument("--encode-batch", type=int, default=128,
help="GPU encoding batch size")
parser.add_argument("--chroma-batch", type=int, default=500,
help="ChromaDB insertion batch size")
parser.add_argument("--no-reset", action="store_true",
help="Don't reset ChromaDB (resume mode)")
args = parser.parse_args()
config = get_config()
db = SQLiteDB(config.sqlite_db_path)
chroma = ChromaStore(config.chroma_db_path)
# Reset ChromaDB unless resuming
if not args.no_reset:
logger.info("Resetting ChromaDB collection...")
chroma.reset()
# Load all chunks from DB
logger.info("Loading chunks from SQLite...")
all_chunks = db.get_all_chunks()
total = len(all_chunks)
logger.info("Total chunks: %d", total)
# Load model
logger.info("Loading embedding model...")
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(config.embedding_model)
dim = model.get_sentence_embedding_dimension()
logger.info("Model loaded on %s (dim=%d)", model.device, dim)
# Process in mega-batches
start_time = time.time()
total_stored = 0
for mega_start in range(0, total, args.mega_batch):
mega_end = min(mega_start + args.mega_batch, total)
batch_chunks = all_chunks[mega_start:mega_end]
batch_size = len(batch_chunks)
logger.info("=== Mega-batch %d-%d / %d (%d chunks) ===",
mega_start, mega_end, total, batch_size)
# Extract texts
texts = [c["chunk_text"] for c in batch_chunks]
# Encode on GPU
t0 = time.time()
embeddings = model.encode(
texts,
batch_size=args.encode_batch,
show_progress_bar=True,
normalize_embeddings=True,
)
encode_time = time.time() - t0
logger.info("Encoded %d chunks in %.1fs (%.0f chunks/sec)",
batch_size, encode_time, batch_size / encode_time)
# Prepare ChromaDB data (no documents — text lives in SQLite only)
ids = []
metadatas = []
emb_list = []
for i, chunk in enumerate(batch_chunks):
chunk_id = str(chunk.get("id", f"{chunk['paper_id']}_chunk_{chunk['chunk_index']}"))
ids.append(chunk_id)
emb_list.append(embeddings[i].tolist())
metadatas.append({
"paper_id": chunk["paper_id"],
"chunk_type": chunk.get("chunk_type", "unknown"),
"chunk_index": chunk.get("chunk_index", 0),
"year": chunk.get("year", 0),
"venue": chunk.get("venue", ""),
"title": chunk.get("title", ""),
})
# Store in ChromaDB in sub-batches
t0 = time.time()
chroma.add_embeddings(
ids=ids,
embeddings=emb_list,
metadatas=metadatas,
batch_size=args.chroma_batch,
)
store_time = time.time() - t0
total_stored += batch_size
logger.info("Stored in ChromaDB in %.1fs. Total stored: %d/%d",
store_time, total_stored, total)
# Free memory
del texts, embeddings, ids, metadatas, emb_list, batch_chunks
gc.collect()
elapsed = time.time() - start_time
final_count = chroma.count()
logger.info("=== EMBEDDING COMPLETE ===")
logger.info("Total: %d embeddings in %.1fs (%.1f chunks/sec)",
final_count, elapsed, total / elapsed)
logger.info("ChromaDB count: %d", final_count)
if __name__ == "__main__":
main()
|