| # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/communication_op.py | |
| from typing import Any, Dict, Optional, Union | |
| import torch | |
| import torch.distributed | |
| from .parallel_state import get_tp_group | |
| def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: | |
| """All-reduce the input tensor across model parallel group.""" | |
| return get_tp_group().all_reduce(input_) | |
| def tensor_model_parallel_all_gather( | |
| input_: torch.Tensor, dim: int = -1 | |
| ) -> torch.Tensor: | |
| """All-gather the input tensor across model parallel group.""" | |
| return get_tp_group().all_gather(input_, dim) | |
| def tensor_model_parallel_gather( | |
| input_: torch.Tensor, dst: int = 0, dim: int = -1 | |
| ) -> Optional[torch.Tensor]: | |
| """Gather the input tensor across model parallel group.""" | |
| return get_tp_group().gather(input_, dst, dim) | |
| def broadcast_tensor_dict( | |
| tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0 | |
| ): | |
| if not torch.distributed.is_initialized(): | |
| return tensor_dict | |
| return get_tp_group().broadcast_tensor_dict(tensor_dict, src) | |
Xet Storage Details
- Size:
- 1.13 kB
- Xet hash:
- 7bb0bed86f70208b649e716ebbe7ae039abadf21df354a949be1446ee6aede58
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.