| | import os |
| | |
| | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| | os.environ["CUDA_VISIBLE_DEVICES"] = "2" |
| | import chromadb |
| | from chromadb.utils import embedding_functions |
| | from datasets import load_dataset |
| | import argparse |
| |
|
| | |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--lang", type=str, default="pt", help="language code") |
| | args = parser.parse_args() |
| | lang_code = args.lang |
| |
|
| | db_path = f"/home/mshahidul/readctrl/data/vector_db/{lang_code}_v2" |
| |
|
| | |
| | client = chromadb.PersistentClient(path=db_path) |
| | |
| | ef = embedding_functions.SentenceTransformerEmbeddingFunction( |
| | model_name='Qwen/Qwen3-Embedding-4B', |
| | device="cuda" |
| | ) |
| |
|
| | collection = client.get_or_create_collection(name="wiki_collection", embedding_function=ef) |
| |
|
| | |
| | if collection.count() == 0: |
| | print(f"Database empty. Processing Wikipedia ({lang_code})...") |
| | |
| | |
| | ds = load_dataset("wikimedia/wikipedia", f"20231101.{lang_code}", split='train', streaming=True) |
| | |
| | batch_docs = [] |
| | batch_ids = [] |
| | chunk_count = 0 |
| | |
| | |
| | max_articles = 500000 |
| | import tqdm |
| | for i, item in tqdm.tqdm(enumerate(ds.take(max_articles))): |
| | text = item['text'] |
| | |
| | paragraphs = [p.strip() for p in text.split('\n\n') if len(p.split()) > 20] |
| | |
| | for p_idx, para in tqdm.tqdm(enumerate(paragraphs)): |
| | batch_docs.append(para) |
| | batch_ids.append(f"art_{i}_p_{p_idx}") |
| | |
| | |
| | |
| | if len(batch_docs) >= 100: |
| | collection.add( |
| | documents=batch_docs, |
| | ids=batch_ids |
| | ) |
| | chunk_count += len(batch_docs) |
| | batch_docs = [] |
| | batch_ids = [] |
| | |
| | if i % 500 == 0: |
| | print(f"Processed {i} articles... Total chunks in DB: {collection.count()}") |
| |
|
| | |
| | if batch_docs: |
| | collection.add(documents=batch_docs, ids=batch_ids) |
| | |
| | print(f"Finished! Total documents in DB: {collection.count()}") |
| | else: |
| | print(f"Database already exists with {collection.count()} documents. Loading...") |
| |
|
| | |
| | query = "Tell me about history" |
| | results = collection.query( |
| | query_texts=[query], |
| | n_results=3 |
| | ) |
| |
|
| | print(f"\nQuery: {query}") |
| | for i, doc in enumerate(results['documents'][0]): |
| | print(f"Result {i+1}: {doc[:200]}...") |