| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ |
| | This is the script to build KNN index map from Training dataset to Retrieval dataset. |
| | For example, it maps chunk_id i from training dataset to K chunk ids in the nearest neighbor in the retrieval dataset. |
| | |
| | It requires the training text data to be converted into `bin` and `idx` files by `preprocess_data_for_megatron.py` script. |
| | It also requires the Faiss Index file for the Retrieval dataset built by `build_retrieval_index.py` script. |
| | |
| | Here is an example to using it: |
| | |
| | ```python |
| | python scripts/nlp_language_modeling/build_knn_map_index.py \ |
| | --input_file=PATH_TO_INPUT_TRAINING_DATA \ |
| | --tokenizer-library=sentencepiece \ |
| | --tokenizer-model=tokenizer.model \ |
| | --process_chunk_size=51200 \ |
| | --K_neighbors=16 \ |
| | --faiss_index=PATH_TO_FAISS_INDEX_FILE \ |
| | --devices=0,1,2,3 \ |
| | --batch_size=1280 \ |
| | --remove_duplicate \ |
| | --output_file=knn_map.idx |
| | ``` |
| | Use `--remove_duplicate` flag if the data and retrieval dataset are the same. It will remove the neighbors from the same document. |
| | It creates a knn_map.idx KNNIndex file. |
| | During training of RETRO model, it can look up the KNN chunk ids of the |
| | DB dataset given the input training data chunk id. |
| | |
| | For large dataset, we can build the KNN index in multiple stages |
| | |
| | stage-1: build sharding indexes, each containing a fraction of the dataset. This can be done in parallel on several machines. example, |
| | |
| | ```python |
| | python scripts/nlp_language_modeling/build_knn_map_index.py \ |
| | --input_file=PATH_TO_INPUT_TRAINING_DATA \ |
| | --tokenizer-library=megatron \ |
| | --tokenizer-type=GPT2BPETokenizer \ |
| | --merge-file=/dataset/gpt2-merges.txt \ |
| | --vocab-file=/dataset/gpt2-vocab.json \ |
| | --process_chunk_size=10000 \ |
| | --K_neighbors=16 \ |
| | --remove_duplicate \ |
| | --workers=2 \ |
| | --shard_id=0 \ |
| | --total_shards=2 \ |
| | --devices=0,1,2 \ |
| | --stage=1 \ |
| | --nprobe=10 \ |
| | --output_file=knn_shard0.save \ |
| | --faiss_index=faiss.index |
| | ``` |
| | |
| | stage-2: merge the sharding indexes into one that is written directly to disk, example |
| | |
| | ```python |
| | python scripts/nlp_language_modeling/build_knn_map_index.py \ |
| | --stage=2 \ |
| | --output_file=knn_final.save \ |
| | --shard_index_input=knn_shard |
| | ``` |
| | """ |
| |
|
| | import argparse |
| | import multiprocessing |
| | import pathlib |
| | import sys |
| | import time |
| | from multiprocessing import Pool |
| |
|
| | import faiss |
| | import numpy as np |
| | import torch |
| | from numba import njit, prange |
| | from sentence_transformers import SentenceTransformer |
| |
|
| | from nemo.collections.nlp.data.language_modeling.megatron.indexed_retrieval_dataset import ( |
| | KNNIndex, |
| | MMapRetrievalIndexedDataset, |
| | merge_knn_files, |
| | ) |
| | from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer |
| | from nemo.utils import logging |
| |
|
| | QUEUE_SIZE = 30 |
| |
|
| | queue = multiprocessing.Queue(QUEUE_SIZE) |
| | emb_queue = multiprocessing.Queue(QUEUE_SIZE) |
| |
|
| |
|
| | @njit(parallel=True) |
| | def build_map(chunk_start, result, total_chunks, start_id, end_id): |
| | """ |
| | Build the map from chunk_id to a range of chunk ids that are from the same document. |
| | The chunk_id is in range [start_id, end_id) |
| | """ |
| | size = len(chunk_start) |
| | for i in prange(size): |
| | beg = chunk_start[i] |
| | end = chunk_start[i + 1] if i < size - 1 else total_chunks |
| | if start_id < end and beg < end_id: |
| | result[max(beg - start_id, 0) : (end - start_id), 0] = beg |
| | result[max(beg - start_id, 0) : (end - start_id), 1] = end |
| |
|
| |
|
| | @njit(parallel=True) |
| | def _dedup(chunk_id_to_range, I, tmp_neighbors, chunk_id_start, offset): |
| | for cid in prange(len(I)): |
| | if chunk_id_start + cid - offset >= 0 and chunk_id_start + cid - offset < len(chunk_id_to_range): |
| | beg, end = chunk_id_to_range[chunk_id_start + cid - offset] |
| | position = 0 |
| | for target_chunk_id in I[cid]: |
| | if beg <= target_chunk_id < end: |
| | |
| | continue |
| | tmp_neighbors[cid, position] = target_chunk_id |
| | position += 1 |
| |
|
| |
|
| | def dedup(chunk_id_to_range, I, tmp_neighbors, chunk_id_start, offset): |
| | """ |
| | deduplicate the KNN who are from the same document as the data chunks. |
| | chunk_id_to_range is calculated by build_map function, which maps chunk_id - offset to range of ids of the same document |
| | I is original KNN search result from Faiss. |
| | chunk_id_start is the chunk_id offset. |
| | offset is the map offset |
| | |
| | filtered KNN will be stored in the tmp_neighbors |
| | |
| | """ |
| | if chunk_id_start < offset or chunk_id_start + len(I) - offset > len(chunk_id_to_range): |
| | raise ValueError('chunk_id_start out side the range') |
| | _dedup(chunk_id_to_range, I, tmp_neighbors, chunk_id_start, offset) |
| |
|
| |
|
| | def get_tokenizer(args): |
| | tokenizer = get_nmt_tokenizer( |
| | library=args.tokenizer_library, |
| | model_name=args.tokenizer_type, |
| | tokenizer_model=args.tokenizer_model, |
| | vocab_file=args.vocab_file, |
| | merges_file=args.merge_file, |
| | delimiter=args.delimiter, |
| | ) |
| | if not hasattr(tokenizer, "pad_id"): |
| | tokenizer.add_special_tokens({'pad_token': '<pad>'}) |
| | elif hasattr(tokenizer, "pad_id") and (tokenizer.pad_id is None or tokenizer.pad_id < 0): |
| | tokenizer.add_special_tokens({'pad_token': '<pad>'}) |
| | return tokenizer |
| |
|
| |
|
| | def calculate_start_end(total_chunks, total_shards, shard_id): |
| | shard_size = total_chunks // total_shards |
| | splits = list(range(0, total_chunks, shard_size)) |
| | if shard_id < total_shards - 1: |
| | start = splits[shard_id] |
| | total_chunks = splits[shard_id + 1] |
| | elif shard_id == total_shards - 1: |
| | start = splits[shard_id] |
| | total_chunks = total_chunks |
| | else: |
| | raise ValueError(f'{shard_id} bigger than {total_shards}') |
| | return start, total_chunks |
| |
|
| |
|
| | def process_sentence_chunks( |
| | ds: MMapRetrievalIndexedDataset, |
| | tokenizer, |
| | chunk_size: int, |
| | stage: int, |
| | workers: int, |
| | shard_id: int, |
| | total_shards: int, |
| | ): |
| | """ |
| | This function takes chunked tokens from the retrieval dataset and map it back to text. |
| | In stage 1, it divides the total work into `total_shards`, and process only at the `shard_id`. |
| | If the stage is None, it process all the chunks. |
| | """ |
| | total_chunks = ds.chunks |
| | start = 0 |
| | threshold = 0 |
| |
|
| | if stage == 1: |
| | start, total_chunks = calculate_start_end( |
| | total_chunks=total_chunks, total_shards=total_shards, shard_id=shard_id |
| | ) |
| | logging.info(f'shard_id {shard_id}, create index from chunk {start} to {total_chunks}') |
| |
|
| | with Pool(workers) as p: |
| | while start < total_chunks: |
| | if start / total_chunks > threshold: |
| | logging.info(f"sentence processing {start / total_chunks} is done") |
| | threshold += 0.1 |
| | slice_id = (start, min(start + chunk_size, total_chunks)) |
| | beg = time.time() |
| | id_slices = ds.get_chunk(slice(*slice_id), force_no_cont_ids=True) |
| | end = time.time() |
| | logging.info(f"load {chunk_size} chunks takes {end-beg}") |
| | start = min(start + chunk_size, total_chunks) |
| | sentences = p.map(tokenizer.ids_to_text, id_slices) |
| | end2 = time.time() |
| | logging.info(f"tokenize {chunk_size} chunks takes {end2-end}") |
| | queue.put((sentences, slice_id)) |
| | queue.put((None, None)) |
| |
|
| |
|
| | def get_sentence_chunks(): |
| | return queue.get() |
| |
|
| |
|
| | def calculate_embedding(pool, batch_size): |
| | while True: |
| | sentences, slice_id = get_sentence_chunks() |
| | if sentences is None: |
| | break |
| | beg = time.time() |
| | emb = model.encode_multi_process(sentences=sentences, pool=pool, batch_size=batch_size) |
| | end = time.time() |
| | logging.info(f"one embedding {len(emb)} batch size takes {end-beg}") |
| | emb_queue.put((emb, slice_id)) |
| | emb_queue.put((None, None)) |
| |
|
| |
|
| | def get_emb(): |
| | return emb_queue.get() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description="build Faiss index",) |
| | parser.add_argument( |
| | '--input_file', type=str, required=False, help='Input file', |
| | ) |
| | parser.add_argument("--faiss_index", type=str, required=False, help='faiss index file for retrieval dataset') |
| | parser.add_argument( |
| | '--process_chunk_size', |
| | type=int, |
| | default=10000, |
| | help='The sentences in chunks that is queries to build map index', |
| | ) |
| | parser.add_argument( |
| | '--remove_duplicate', |
| | action='store_true', |
| | help='Remove the knn neighbors that is from the same document as the data.', |
| | ) |
| | parser.add_argument( |
| | '--K_neighbors', type=int, default=16, help='The number of neighbors to query', |
| | ) |
| | parser.add_argument( |
| | '--dedup_margin', |
| | type=int, |
| | default=2, |
| | help='extra neighbors to fill the spaces of the chunks in the duplicated documents', |
| | ) |
| | parser.add_argument( |
| | '--sentence_transformer_model', |
| | type=str, |
| | default='bert-base-nli-mean-tokens', |
| | help='sentence transformer to load', |
| | ) |
| | parser.add_argument('--shard_id', type=int, default=None, help='run the job to create the shard_id index') |
| | parser.add_argument('--total_shards', type=int, default=None, help='total number of knn index shards') |
| | parser.add_argument( |
| | '--output_file', type=str, required=True, help='Output KNN Map index file', |
| | ) |
| | parser.add_argument( |
| | '--devices', type=str, default=None, help='delimited list input with cuda devices. Specify like 0,1,2' |
| | ) |
| | parser.add_argument( |
| | "--batch_size", type=int, default=4000, help="Batch size for encoding. Use max according to GPU MEM" |
| | ) |
| | group = parser.add_argument_group(title='tokenizer') |
| | group.add_argument( |
| | '--tokenizer-library', |
| | type=str, |
| | required=False, |
| | choices=['yttm', 'sentencepiece', 'megatron', 'huggingface', 'tabular'], |
| | help='What tokenizer library to use.', |
| | ) |
| | group.add_argument( |
| | '--tokenizer-type', type=str, default=None, help='What type of tokenizer to use.', |
| | ) |
| | group.add_argument( |
| | '--tokenizer-model', type=str, default=None, help='Path to tokenizer model.', |
| | ) |
| | group.add_argument('--vocab-file', type=str, default=None, help='Path to the vocab file') |
| | group.add_argument('--merge-file', type=str, default=None, help='Path to the BPE merge file (if necessary).') |
| | group.add_argument('--delimiter', type=str, default=None, help='delimiter used for tabular tokenizer') |
| | group.add_argument( |
| | '--stage', |
| | type=int, |
| | default=None, |
| | help='used for building the large knn index in multiple stages', |
| | choices=[1, 2], |
| | ) |
| | group.add_argument('--workers', type=int, default=None, help='number of workers to run tokenizer') |
| | group.add_argument( |
| | '--nprobe', |
| | type=int, |
| | default=10, |
| | help='number of probes, higher number of probes renders better results but runs slower', |
| | ) |
| | group.add_argument( |
| | '--shard_index_input', |
| | type=str, |
| | default=None, |
| | help='the knn sharding index files, which are created at stage 1', |
| | ) |
| |
|
| | args = parser.parse_args() |
| |
|
| | has_gpu = torch.cuda.is_available() and hasattr(faiss, "index_gpu_to_cpu") |
| |
|
| | if not hasattr(faiss, "index_gpu_to_cpu"): |
| | logging.warning( |
| | "faiss doesn't support gpu index. Please check https://github.com/facebookresearch/faiss/blob/main/INSTALL.md" |
| | ) |
| |
|
| | if args.stage == 2: |
| | |
| | input_file = pathlib.Path(args.shard_index_input) |
| | path = input_file.parent |
| | fname = input_file.name |
| | all_files = [str(i) for i in pathlib.Path(path).glob(fname + '*')] |
| | merge_knn_files(all_files, args.output_file) |
| | f = KNNIndex(args.output_file) |
| | logging.info(f'Write to {args.output_file}, Size of Index : {f.len}') |
| | logging.info(f'Index neighbors: {f.K}') |
| | logging.info(f'Index chunk start id: {f.chunk_start_id}') |
| | logging.info(f'Index chunk end id: {f.chunk_end_id}') |
| | sys.exit(0) |
| |
|
| | model = SentenceTransformer(args.sentence_transformer_model) |
| | tokenizer = get_tokenizer(args) |
| | ds = MMapRetrievalIndexedDataset(args.input_file) |
| |
|
| | if args.devices is None or not torch.cuda.is_available(): |
| | device_list = None |
| | else: |
| | device_list = ['cuda:' + str(device) for device in args.devices.split(',')] |
| |
|
| | index = faiss.read_index(args.faiss_index) |
| | if has_gpu: |
| | co = faiss.GpuMultipleClonerOptions() |
| | co.useFloat16 = True |
| | co.usePrecomputed = False |
| | co.shard = True |
| | index = faiss.index_cpu_to_all_gpus(index, co, ngpu=len(device_list)) |
| |
|
| | index.nprobe = args.nprobe |
| |
|
| | start = 0 |
| | total_chunks = ds.chunks |
| | if args.stage == 1: |
| | start, total_chunks = calculate_start_end( |
| | total_chunks=total_chunks, total_shards=args.total_shards, shard_id=args.shard_id |
| | ) |
| |
|
| | process = multiprocessing.Process( |
| | target=process_sentence_chunks, |
| | args=(ds, tokenizer, args.process_chunk_size, args.stage, args.workers, args.shard_id, args.total_shards), |
| | ) |
| | process.start() |
| |
|
| | pool = model.start_multi_process_pool(device_list) |
| |
|
| | emb_process = multiprocessing.Process(target=calculate_embedding, args=(pool, args.batch_size)) |
| | emb_process.start() |
| |
|
| | if ds._index.retrieval_db and args.remove_duplicate: |
| | neighbors = args.K_neighbors + args.dedup_margin |
| | |
| | id_start = np.array(ds._index._chunk_id_start) |
| | chunk_id_to_doc_id_map = np.zeros((total_chunks - start, 2), dtype=np.int64) |
| | build_map(id_start, chunk_id_to_doc_id_map, ds.chunks, start, total_chunks) |
| | else: |
| | neighbors = args.K_neighbors |
| |
|
| | chunk_id_start = start |
| | with KNNIndex.writer(args.output_file, args.K_neighbors, offset=start) as w: |
| | while True: |
| | emb, slice_id = get_emb() |
| | if emb is None: |
| | break |
| | beg = time.time() |
| | D, I = index.search(emb, neighbors) |
| | end = time.time() |
| | logging.info(f'search {slice_id[0]} - {slice_id[1]} takes {end-beg}') |
| | assert chunk_id_start == slice_id[0] |
| | if ds._index.retrieval_db and args.remove_duplicate: |
| | beg = time.time() |
| | tmp_neighbors = np.ones_like(I) * -1 |
| | dedup(chunk_id_to_doc_id_map, I, tmp_neighbors, chunk_id_start, start) |
| | I = tmp_neighbors[:, : args.K_neighbors] |
| | end = time.time() |
| | logging.info(f'dedup {slice_id[0]} - {slice_id[1]} takes {end-beg}') |
| | beg = time.time() |
| | w.write(I) |
| | end = time.time() |
| | logging.info(f'write {slice_id[0]} - {slice_id[1]} takes {end-beg}') |
| | chunk_id_start += len(I) |
| |
|
| | process.join() |
| | emb_process.join() |
| | model.stop_multi_process_pool(pool) |
| |
|