| import logging | |
| from dataclasses import dataclass | |
| from multiprocessing import shared_memory | |
| from pathlib import Path | |
| from typing import List, Optional | |
| import numpy as np | |
| import torch | |
| from sglang.srt.distributed.naive_distributed import get_naive_distributed | |
| from sglang.srt.utils import check_cuda_result | |
| logger = logging.getLogger(__name__) | |
| class HostSharedMemoryManager: | |
| def __init__(self, base_name: str): | |
| self._base_name = Path(base_name) | |
| self._operation_index = 0 | |
| self._records: List[_Record] = [] | |
| def malloc(self, *, shape, dtype): | |
| meta_tensor = torch.empty(size=shape, dtype=dtype, device="meta") | |
| raw = self._malloc_raw(num_bytes=meta_tensor.nbytes) | |
| return raw.view(dtype).view(*shape) | |
| def _malloc_raw(self, *, num_bytes: int) -> torch.Tensor: | |
| import cuda.bindings.runtime as cuda_rt | |
| self._operation_index += 1 | |
| shm_name = f"{self._base_name}_op{self._operation_index}" | |
| # TODO handle dispose | |
| if get_naive_distributed().get_rank() == 0: | |
| shm = shared_memory.SharedMemory(name=shm_name, create=True, size=num_bytes) | |
| get_naive_distributed().barrier() | |
| if get_naive_distributed().get_rank() != 0: | |
| shm = shared_memory.SharedMemory(name=shm_name) | |
| np_array = np.ndarray((num_bytes,), dtype=np.uint8, buffer=shm.buf) | |
| tensor = torch.from_numpy(np_array) | |
| check_cuda_result( | |
| cuda_rt.cudaHostRegister( | |
| tensor.data_ptr(), num_bytes, cuda_rt.cudaHostRegisterPortable | |
| ) | |
| ) | |
| get_naive_distributed().barrier() | |
| self._records.append( | |
| _Record( | |
| shm=shm, | |
| np_array=np_array, | |
| tensor=tensor, | |
| ) | |
| ) | |
| return tensor | |
| class _Record: | |
| shm: shared_memory.SharedMemory | |
| np_array: np.ndarray | |
| tensor: torch.Tensor | |
| # Can have multi instances if needed | |
| _instance: Optional[HostSharedMemoryManager] = None | |
| def get_host_shared_memory_manager(): | |
| assert _instance is not None | |
| return _instance | |
| def set_host_shared_memory_manager(instance: HostSharedMemoryManager): | |
| global _instance | |
| assert _instance is None | |
| _instance = instance | |
Xet Storage Details
- Size:
- 2.26 kB
- Xet hash:
- 8d74ead88668d10d494687f43d685ac0831439d657e3732b414ef9e392b6d017
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.