Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import collections | |
| import io | |
| import pickle | |
| from typing import Any | |
| import torch | |
| import torch.distributed as dist | |
| # https://github.com/pytorch/pytorch/blob/main/torch/distributed/optim/zero_redundancy_optimizer.py#L29 | |
| def broadcast_object( | |
| obj: Any, | |
| src_rank: int, | |
| group: object = dist.group.WORLD, | |
| device: torch.device = torch.device("cpu"), | |
| ) -> Any: | |
| r""" | |
| Broadcasts an object to the given group. | |
| It will be sending the object if called from the source rank and receiving | |
| the object otherwise. | |
| Arguments: | |
| obj: object to broadcast; only used if called on the source rank. | |
| src_rank (int): source rank. | |
| group (``ProcessGroup``, optional): group used for the broadcast | |
| (default: ``dist.group.WORLD``). | |
| device (``torch.device``, optional): device to send from or receive | |
| to (default: ``torch.device("cpu")``). | |
| Returns: | |
| The broadcasted object. | |
| """ | |
| if dist.get_rank() == src_rank: | |
| # Send the object | |
| buffer = io.BytesIO() | |
| torch.save(obj, buffer, pickle_protocol=pickle.HIGHEST_PROTOCOL) | |
| data = bytearray(buffer.getbuffer()) | |
| length_tensor = torch.LongTensor([len(data)]).to(device) | |
| data_send_tensor = torch.ByteTensor(data).to(device) | |
| dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) | |
| dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False) | |
| else: | |
| # Receive the object | |
| length_tensor = torch.LongTensor([0]).to(device) | |
| dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) | |
| data_recv_tensor = torch.empty([int(length_tensor.item())], dtype=torch.uint8, device=device) | |
| dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False) | |
| buffer = io.BytesIO(data_recv_tensor.cpu().numpy()) | |
| obj = torch.load(buffer, map_location=device, weights_only=False) | |
| return obj | |
| def _recursive_copy_to_device( | |
| value: Any, | |
| non_blocking: bool, | |
| device: torch.device, | |
| ) -> Any: | |
| r""" | |
| Recursively searches lists, tuples, dicts and copies tensors to device if possible. | |
| Non-tensor values are passed as-is in the result. | |
| .. note: These are all copies, so if there are two objects that reference | |
| the same object, then after this call, there will be two different objects | |
| referenced on the device. | |
| """ | |
| if isinstance(value, torch.Tensor): | |
| return value.to(device, non_blocking=non_blocking) | |
| if isinstance(value, (list, tuple)): | |
| values = [_recursive_copy_to_device(val, non_blocking=non_blocking, device=device) for val in value] | |
| return values if isinstance(value, list) else tuple(values) | |
| if isinstance(value, collections.abc.Mapping): | |
| return { | |
| key: _recursive_copy_to_device(val, non_blocking=non_blocking, device=device) for key, val in value.items() | |
| } | |
| return value | |