| |
|
|
| """ |
| Utility helpers for broadcasting nested dictionaries of tensors across tensor-parallel ranks. |
| |
| """ |
|
|
| from typing import Any, Dict, List, Tuple |
|
|
| import torch |
|
|
| from megatron.core import mpu, tensor_parallel |
|
|
|
|
| def flatten( |
| nested: Dict[str, Any], prefix: Tuple[str, ...] = () |
| ) -> List[Tuple[Tuple[str, ...], torch.Tensor]]: |
| """Recursively flatten nested dict into [(key_path, tensor), …].""" |
| flat = [] |
| for k, v in nested.items(): |
| path = prefix + (k,) |
| if isinstance(v, dict): |
| flat.extend(flatten(v, path)) |
| elif isinstance(v, torch.Tensor): |
| flat.append((path, v)) |
| else: |
| raise ValueError(f"Unsupported value type: {type(v)} for key {k}" |
| f"In nested dictionary,leaf nodes must contain tensors") |
| return flat |
|
|
|
|
| def regroup(flat: List[Tuple[Tuple[str, ...], torch.Tensor]]) -> Dict[str, Any]: |
| """Rebuild the nested dict from [(key_path, tensor), …].""" |
| root = {} |
| for path, tensor in flat: |
| cur = root |
| for k in path[:-1]: |
| cur = cur.setdefault(k, {}) |
| cur[path[-1]] = tensor |
| return root |
|
|
|
|
| def broadcast_nested_data_batch(nested_dict: Dict[str, Any]) -> Dict[str, Any]: |
| """Recursively broadcast nested dictionaries of tensors using each tensor's own dtype.""" |
| |
| tp_group = mpu.get_tensor_model_parallel_group() |
| src = mpu.get_tensor_model_parallel_src_rank() |
|
|
| |
| if mpu.get_tensor_model_parallel_rank() == 0: |
| flat = flatten(nested_dict) |
| paths, tensors = zip(*flat) if flat else ([], []) |
| dtypes = [t.dtype for t in tensors] |
| else: |
| paths, dtypes = [], [] |
| tensors = [] |
|
|
| |
| meta = [paths, dtypes] |
| obj_list = [meta] |
| torch.distributed.broadcast_object_list(obj_list, src=src, group=tp_group) |
| paths, dtypes = obj_list[0] |
|
|
| |
| |
| dtype_to_keys = {} |
| for p, dt in zip(paths, dtypes): |
| dtype_to_keys.setdefault(dt, []).append(".".join(p)) |
|
|
| |
| if mpu.get_tensor_model_parallel_rank() == 0: |
| data_dict = {".".join(p): t.cuda() for p, t in zip(paths, tensors)} |
| else: |
| data_dict = {} |
|
|
| flat_out = [] |
| for dt, keys in dtype_to_keys.items(): |
| out = tensor_parallel.broadcast_data(keys, data_dict, dt) |
| flat_out.extend([(tuple(k.split(".")), out[k]) for k in keys]) |
|
|
| |
| return regroup(flat_out) |