readctrl / code /vectordb_build /vector_db_build.py
shahidul034's picture
Add files using upload-large-folder tool
c7a6fe6 verified
import os
# Environment Setup
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import chromadb
from chromadb.utils import embedding_functions
from datasets import load_dataset
import argparse
# 1. Setup
parser = argparse.ArgumentParser()
parser.add_argument("--lang", type=str, default="pt", help="language code")
args = parser.parse_args()
lang_code = args.lang
db_path = f"/home/mshahidul/readctrl/data/vector_db/{lang_code}_v2"
# 2. Initialize Client and Embedding Function
client = chromadb.PersistentClient(path=db_path)
# Qwen3-Embedding-4B is heavy; ensure your GPU has ~10GB+ VRAM
ef = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name='Qwen/Qwen3-Embedding-4B',
device="cuda"
)
collection = client.get_or_create_collection(name="wiki_collection", embedding_function=ef)
# 3. Logic to Add New Data
if collection.count() == 0:
print(f"Database empty. Processing Wikipedia ({lang_code})...")
# Use streaming to avoid loading the whole dataset into RAM
ds = load_dataset("wikimedia/wikipedia", f"20231101.{lang_code}", split='train', streaming=True)
batch_docs = []
batch_ids = []
chunk_count = 0
# Process a subset (e.g., 50,000 articles) to avoid massive processing times
# 1,000,000 articles might result in 10,000,000+ chunks.
max_articles = 500000
import tqdm
for i, item in tqdm.tqdm(enumerate(ds.take(max_articles))):
text = item['text']
# Simple paragraph chunking
paragraphs = [p.strip() for p in text.split('\n\n') if len(p.split()) > 20]
for p_idx, para in tqdm.tqdm(enumerate(paragraphs)):
batch_docs.append(para)
batch_ids.append(f"art_{i}_p_{p_idx}")
# 4. Batch Upload to Chroma (Every 100 chunks)
# This prevents memory overflow and allows for incremental saving
if len(batch_docs) >= 100:
collection.add(
documents=batch_docs,
ids=batch_ids
)
chunk_count += len(batch_docs)
batch_docs = []
batch_ids = []
if i % 500 == 0:
print(f"Processed {i} articles... Total chunks in DB: {collection.count()}")
# Add remaining documents
if batch_docs:
collection.add(documents=batch_docs, ids=batch_ids)
print(f"Finished! Total documents in DB: {collection.count()}")
else:
print(f"Database already exists with {collection.count()} documents. Loading...")
# 5. Search
query = "Tell me about history" # Adjust based on your language
results = collection.query(
query_texts=[query],
n_results=3
)
print(f"\nQuery: {query}")
for i, doc in enumerate(results['documents'][0]):
print(f"Result {i+1}: {doc[:200]}...") # Print first 200 chars