| import os |
| |
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| os.environ["CUDA_VISIBLE_DEVICES"] = "2" |
| import chromadb |
| from chromadb.utils import embedding_functions |
| from datasets import load_dataset |
| import argparse |
|
|
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--lang", type=str, default="pt", help="language code") |
| args = parser.parse_args() |
| lang_code = args.lang |
|
|
| db_path = f"/home/mshahidul/readctrl/data/vector_db/{lang_code}_v2" |
|
|
| |
| client = chromadb.PersistentClient(path=db_path) |
| |
| ef = embedding_functions.SentenceTransformerEmbeddingFunction( |
| model_name='Qwen/Qwen3-Embedding-4B', |
| device="cuda" |
| ) |
|
|
| collection = client.get_or_create_collection(name="wiki_collection", embedding_function=ef) |
|
|
| |
| if collection.count() == 0: |
| print(f"Database empty. Processing Wikipedia ({lang_code})...") |
| |
| |
| ds = load_dataset("wikimedia/wikipedia", f"20231101.{lang_code}", split='train', streaming=True) |
| |
| batch_docs = [] |
| batch_ids = [] |
| chunk_count = 0 |
| |
| |
| max_articles = 500000 |
| import tqdm |
| for i, item in tqdm.tqdm(enumerate(ds.take(max_articles))): |
| text = item['text'] |
| |
| paragraphs = [p.strip() for p in text.split('\n\n') if len(p.split()) > 20] |
| |
| for p_idx, para in tqdm.tqdm(enumerate(paragraphs)): |
| batch_docs.append(para) |
| batch_ids.append(f"art_{i}_p_{p_idx}") |
| |
| |
| |
| if len(batch_docs) >= 100: |
| collection.add( |
| documents=batch_docs, |
| ids=batch_ids |
| ) |
| chunk_count += len(batch_docs) |
| batch_docs = [] |
| batch_ids = [] |
| |
| if i % 500 == 0: |
| print(f"Processed {i} articles... Total chunks in DB: {collection.count()}") |
|
|
| |
| if batch_docs: |
| collection.add(documents=batch_docs, ids=batch_ids) |
| |
| print(f"Finished! Total documents in DB: {collection.count()}") |
| else: |
| print(f"Database already exists with {collection.count()} documents. Loading...") |
|
|
| |
| query = "Tell me about history" |
| results = collection.query( |
| query_texts=[query], |
| n_results=3 |
| ) |
|
|
| print(f"\nQuery: {query}") |
| for i, doc in enumerate(results['documents'][0]): |
| print(f"Result {i+1}: {doc[:200]}...") |