IRIS-FLOWER-CLASSIFICATION-using-machine-learning-models
/
transformers
/examples
/research_projects
/rag
/distributed_pytorch_retriever.py
| import logging | |
| import os | |
| from typing import List, Tuple | |
| import numpy as np | |
| import psutil | |
| import torch | |
| import torch.distributed as dist | |
| from transformers import RagRetriever | |
| logger = logging.getLogger(__name__) | |
| class RagPyTorchDistributedRetriever(RagRetriever): | |
| """ | |
| A distributed retriever built on top of the ``torch.distributed`` communication package. During training all workers | |
| initialize their own instance of the retriever, however, only the main worker loads the index into memory. The index is stored | |
| in cpu memory. The index will also work well in a non-distributed setup. | |
| Args: | |
| config (:class:`~transformers.RagConfig`): | |
| The configuration of the RAG model this Retriever is used with. Contains parameters indicating which ``Index`` to build. | |
| question_encoder_tokenizer (:class:`~transformers.PreTrainedTokenizer`): | |
| The tokenizer that was used to tokenize the question. | |
| It is used to decode the question and then use the generator_tokenizer. | |
| generator_tokenizer (:class:`~transformers.PreTrainedTokenizer`): | |
| The tokenizer used for the generator part of the RagModel. | |
| index (:class:`~transformers.models.rag.retrieval_rag.Index`, optional, defaults to the one defined by the configuration): | |
| If specified, use this index instead of the one built using the configuration | |
| """ | |
| def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None): | |
| super().__init__( | |
| config, | |
| question_encoder_tokenizer=question_encoder_tokenizer, | |
| generator_tokenizer=generator_tokenizer, | |
| index=index, | |
| init_retrieval=False, | |
| ) | |
| self.process_group = None | |
| def init_retrieval(self, distributed_port: int): | |
| """ | |
| Retriever initialization function, needs to be called from the training process. The function sets some common parameters | |
| and environment variables. On top of that, (only) the main process in the process group loads the index into memory. | |
| Args: | |
| distributed_port (:obj:`int`): | |
| The port on which the main communication of the training run is carried out. We set the port for retrieval-related | |
| communication as ``distributed_port + 1``. | |
| """ | |
| logger.info("initializing retrieval") | |
| # initializing a separate process group for retrieval as the default | |
| # nccl backend doesn't support gather/scatter operations while gloo | |
| # is too slow to replace nccl for the core gpu communication | |
| if dist.is_initialized(): | |
| logger.info("dist initialized") | |
| # needs to be set manually | |
| os.environ["GLOO_SOCKET_IFNAME"] = self._infer_socket_ifname() | |
| # avoid clash with the NCCL port | |
| os.environ["MASTER_PORT"] = str(distributed_port + 1) | |
| self.process_group = dist.new_group(ranks=None, backend="gloo") | |
| # initialize retriever only on the main worker | |
| if not dist.is_initialized() or self._is_main(): | |
| logger.info("dist not initialized / main") | |
| self.index.init_index() | |
| # all processes wait untill the retriever is initialized by the main process | |
| if dist.is_initialized(): | |
| torch.distributed.barrier(group=self.process_group) | |
| def _is_main(self): | |
| return dist.get_rank(group=self.process_group) == 0 | |
| def _scattered(self, scatter_list, target_shape, target_type=torch.float32): | |
| target_tensor = torch.empty(target_shape, dtype=target_type) | |
| dist.scatter(target_tensor, src=0, scatter_list=scatter_list, group=self.process_group) | |
| return target_tensor | |
| def _infer_socket_ifname(self): | |
| addrs = psutil.net_if_addrs() | |
| # a hacky way to deal with varying network interface names | |
| ifname = next((addr for addr in addrs if addr.startswith("e")), None) | |
| return ifname | |
| def retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> Tuple[np.ndarray, List[dict]]: | |
| """ | |
| Retrieves documents for specified ``question_hidden_states``. The main process, which has the access to the index stored in memory, gathers queries | |
| from all the processes in the main training process group, performs the retrieval and scatters back the results. | |
| Args: | |
| question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`): | |
| A batch of query vectors to retrieve with. | |
| n_docs (:obj:`int`): | |
| The number of docs retrieved per query. | |
| Output: | |
| retrieved_doc_embeds (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)` | |
| The retrieval embeddings of the retrieved docs per query. | |
| doc_ids (:obj:`np.ndarray` of shape :obj:`batch_size, n_docs`) | |
| The ids of the documents in the index | |
| doc_dicts (:obj:`List[dict]`): | |
| The retrieved_doc_embeds examples per query. | |
| """ | |
| # single GPU training | |
| if not dist.is_initialized(): | |
| doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs) | |
| return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids) | |
| # distributed training | |
| world_size = dist.get_world_size(group=self.process_group) | |
| # gather logic | |
| gather_list = None | |
| if self._is_main(): | |
| gather_list = [torch.empty(question_hidden_states.shape, dtype=torch.float32) for _ in range(world_size)] | |
| dist.gather(torch.tensor(question_hidden_states), dst=0, gather_list=gather_list, group=self.process_group) | |
| # scatter logic | |
| n_queries = question_hidden_states.shape[0] | |
| scatter_ids = [] | |
| scatter_vectors = [] | |
| if self._is_main(): | |
| assert len(gather_list) == world_size | |
| ids, vectors = self._main_retrieve(torch.cat(gather_list).numpy(), n_docs) | |
| ids, vectors = torch.tensor(ids), torch.tensor(vectors) | |
| scatter_ids = self._chunk_tensor(ids, n_queries) | |
| scatter_vectors = self._chunk_tensor(vectors, n_queries) | |
| doc_ids = self._scattered(scatter_ids, [n_queries, n_docs], target_type=torch.int64) | |
| retrieved_doc_embeds = self._scattered(scatter_vectors, [n_queries, n_docs, question_hidden_states.shape[1]]) | |
| return retrieved_doc_embeds.numpy(), doc_ids.numpy(), self.index.get_doc_dicts(doc_ids) | |