| # PyTorch 2.10 Tensor Parallelism Fix | |
| ## Summary | |
| PyTorch 2.10 changed the class hierarchy for `_StridedShard`, breaking our | |
| `distributed/utils.py` code that handles DTensor sharding. This document | |
| records the root cause, every change made so far, and the one remaining issue. | |
| --- | |
| ## 1. Root Cause: `_StridedShard` class hierarchy change | |
| | Version | MRO | | |
| |---------|-----| | |
| | PyTorch < 2.10 | `_StridedShard -> Shard -> Placement` | | |
| | PyTorch 2.10 | `_StridedShard -> StridedShard -> Placement` | | |
| **Consequences:** | |
| - `isinstance(strided_shard, Shard)` returns `False` | |
| - `strided_shard.is_shard()` returns `False` | |
| - Our `construct_shard_mesh()` treated `_StridedShard` as an unsupported | |
| placement and raised `AssertionError`. | |
| ### When does `_StridedShard` appear? | |
| When two parallelism dimensions shard the same tensor dimension. | |
| For example, `fsdp+tp` or `hsdp+tp` configurations where both TP and | |
| FSDP shard dimension 0 of Q/K/V projection weights: | |
| ``` | |
| TP : Shard(0) β each TP rank gets 2048/4 = 512 rows | |
| FSDP: Shard(0) on top β each FSDP rank further splits those rows | |
| ``` | |
| PyTorch represents the second sharding as `_StridedShard(dim=0, split_factor=N)` | |
| to indicate non-contiguous (interleaved) row ownership. | |
| --- | |
| ## 2. Completed Fixes | |
| ### 2.1 `_is_shard()` helper (`distributed/utils.py`) | |
| Added a helper that correctly identifies both `Shard` and `_StridedShard`: | |
| ```python | |
| def _is_shard(placement: Placement) -> bool: | |
| return isinstance(placement, (Shard, _StridedShard)) | |
| ``` | |
| Used in `construct_shard_mesh()` where the old code called `placement.is_shard()`. | |
| ### 2.2 Rewritten `get_slices_of_dtensor()` (`distributed/utils.py`) | |
| Old code assumed contiguous slicing (`start = rank * shard_size`), which is | |
| wrong for `_StridedShard`. | |
| New code uses PyTorch's own offset-computation methods: | |
| | Placement type | API used | | |
| |----------------|----------| | |
| | `Shard` | `Shard.local_shard_size_and_offset(size, chunks, rank)` β `(size, offset)` | | |
| | `_StridedShard` | `_StridedShard.local_shard_size_and_offset(instance, size, chunks, rank, return_first_offset=False)` β `(size, offsets_list)` | | |
| Return type changed from `tuple[slice, ...]` to `tuple[slice | torch.Tensor, ...]`: | |
| - `slice` for contiguous ranges (Shard or contiguous StridedShard result) | |
| - `torch.LongTensor` of indices for non-contiguous ranges | |
| Composed sharding (multiple placements on the same dim) is handled by | |
| indexing: `dim_indices[shard_dim] = dim_indices[shard_dim][new_indices]`. | |
| ### 2.3 Updated `numel_for_rank()` (`core.py`) | |
| Now handles both `slice` and `torch.Tensor` index types: | |
| ```python | |
| for idx, dim_size in zip(indices, param.shape): | |
| if isinstance(idx, slice): | |
| start, stop, step = idx.indices(dim_size) | |
| numel *= max(0, (stop - start + (step - 1)) // step) | |
| else: | |
| numel *= len(idx) | |
| ``` | |
| ### 2.4 Updated pipeline stages (`pipeline.py`) | |
| - **`_gather_grads()`**: Uses `gathered_grads[id(p)][indices] = sg` (index | |
| assignment) instead of view-based copy, works for both slice and tensor | |
| indices. | |
| - **`_scatter_us()`**: `u_full[indices].flatten()` works for both index types. | |
| - **`_update_params()` QK clipping**: Applies clipping on `p._local_tensor` | |
| directly instead of through DTensor operations, avoiding sharding | |
| propagation errors with `_StridedShard`. | |
| --- | |
| ## 3. Test Results After Fixes | |
| | Test configuration | Status | | |
| |--------------------|--------| | |
| | base | PASS | | |
| | fsdp | PASS | | |
| | hsdp | PASS | | |
| | tp | PASS | | |
| | hsdp+tp (QK clip off) | PASS | | |
| | hsdp+tp (QK clip on) | PASS | | |
| | fsdp+tp (QK clip off) | PASS | | |
| | fsdp+tp (QK clip on) | PASS | | |
| All 24 tests pass (126s, `--skip-verify`). | |
| --- | |
| ## 4. Fixed: QK Clipping with Strided Sharding | |
| ### Problem | |
| With strided (non-contiguous) sharding, local rows are **interleaved across | |
| multiple heads**. For example with `fsdp+tp` (`dp_shard=2, tp=4`): | |
| - Q/K projection global shape: `(2048, 2048)`, `head_dim=128`, `16 heads` | |
| - Each rank owns 256 rows, but they span 4 heads with 64 rows per head | |
| - `view(-1, head_dim, cols)` assumes contiguous head blocks β wrong for | |
| interleaved rows β shape mismatch error | |
| ### Fix Applied | |
| For strided sharding, apply scales **per-row** based on each row's global | |
| head index instead of using the head-block view: | |
| ```python | |
| if isinstance(weight_indices[0], slice): | |
| # Contiguous case: view-based approach still works | |
| ... | |
| else: | |
| # Strided case: per-row scaling | |
| head_per_row = weight_indices[0] // ratio | |
| row_scales = scales_full[head_per_row] | |
| local_p.mul_(row_scales.view(-1, 1)) | |
| ``` | |
| --- | |
| ## 5. Files Modified | |
| | File | Changes | | |
| |------|---------| | |
| | `torch-ext/optimizer/distributed/utils.py` | Added `_is_shard()`, rewrote `get_slices_of_dtensor()`, fixed `construct_shard_mesh()` | | |
| | `torch-ext/optimizer/core.py` | Updated `numel_for_rank()` for `slice | Tensor` indices | | |
| | `torch-ext/optimizer/pipeline.py` | Updated `_gather_grads()`, `_scatter_us()`, `_update_params()` QK clipping | | |