import os # Environment Setup os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "2" import chromadb from chromadb.utils import embedding_functions from datasets import load_dataset import argparse # 1. Setup parser = argparse.ArgumentParser() parser.add_argument("--lang", type=str, default="pt", help="language code") args = parser.parse_args() lang_code = args.lang db_path = f"/home/mshahidul/readctrl/data/vector_db/{lang_code}_v2" # 2. Initialize Client and Embedding Function client = chromadb.PersistentClient(path=db_path) # Qwen3-Embedding-4B is heavy; ensure your GPU has ~10GB+ VRAM ef = embedding_functions.SentenceTransformerEmbeddingFunction( model_name='Qwen/Qwen3-Embedding-4B', device="cuda" ) collection = client.get_or_create_collection(name="wiki_collection", embedding_function=ef) # 3. Logic to Add New Data if collection.count() == 0: print(f"Database empty. Processing Wikipedia ({lang_code})...") # Use streaming to avoid loading the whole dataset into RAM ds = load_dataset("wikimedia/wikipedia", f"20231101.{lang_code}", split='train', streaming=True) batch_docs = [] batch_ids = [] chunk_count = 0 # Process a subset (e.g., 50,000 articles) to avoid massive processing times # 1,000,000 articles might result in 10,000,000+ chunks. max_articles = 500000 import tqdm for i, item in tqdm.tqdm(enumerate(ds.take(max_articles))): text = item['text'] # Simple paragraph chunking paragraphs = [p.strip() for p in text.split('\n\n') if len(p.split()) > 20] for p_idx, para in tqdm.tqdm(enumerate(paragraphs)): batch_docs.append(para) batch_ids.append(f"art_{i}_p_{p_idx}") # 4. Batch Upload to Chroma (Every 100 chunks) # This prevents memory overflow and allows for incremental saving if len(batch_docs) >= 100: collection.add( documents=batch_docs, ids=batch_ids ) chunk_count += len(batch_docs) batch_docs = [] batch_ids = [] if i % 500 == 0: print(f"Processed {i} articles... Total chunks in DB: {collection.count()}") # Add remaining documents if batch_docs: collection.add(documents=batch_docs, ids=batch_ids) print(f"Finished! Total documents in DB: {collection.count()}") else: print(f"Database already exists with {collection.count()} documents. Loading...") # 5. Search query = "Tell me about history" # Adjust based on your language results = collection.query( query_texts=[query], n_results=3 ) print(f"\nQuery: {query}") for i, doc in enumerate(results['documents'][0]): print(f"Result {i+1}: {doc[:200]}...") # Print first 200 chars