# PyTorch 2.10 Tensor Parallelism Fix ## Summary PyTorch 2.10 changed the class hierarchy for `_StridedShard`, breaking our `distributed/utils.py` code that handles DTensor sharding. This document records the root cause, every change made so far, and the one remaining issue. --- ## 1. Root Cause: `_StridedShard` class hierarchy change | Version | MRO | |---------|-----| | PyTorch < 2.10 | `_StridedShard -> Shard -> Placement` | | PyTorch 2.10 | `_StridedShard -> StridedShard -> Placement` | **Consequences:** - `isinstance(strided_shard, Shard)` returns `False` - `strided_shard.is_shard()` returns `False` - Our `construct_shard_mesh()` treated `_StridedShard` as an unsupported placement and raised `AssertionError`. ### When does `_StridedShard` appear? When two parallelism dimensions shard the same tensor dimension. For example, `fsdp+tp` or `hsdp+tp` configurations where both TP and FSDP shard dimension 0 of Q/K/V projection weights: ``` TP : Shard(0) → each TP rank gets 2048/4 = 512 rows FSDP: Shard(0) on top → each FSDP rank further splits those rows ``` PyTorch represents the second sharding as `_StridedShard(dim=0, split_factor=N)` to indicate non-contiguous (interleaved) row ownership. --- ## 2. Completed Fixes ### 2.1 `_is_shard()` helper (`distributed/utils.py`) Added a helper that correctly identifies both `Shard` and `_StridedShard`: ```python def _is_shard(placement: Placement) -> bool: return isinstance(placement, (Shard, _StridedShard)) ``` Used in `construct_shard_mesh()` where the old code called `placement.is_shard()`. ### 2.2 Rewritten `get_slices_of_dtensor()` (`distributed/utils.py`) Old code assumed contiguous slicing (`start = rank * shard_size`), which is wrong for `_StridedShard`. New code uses PyTorch's own offset-computation methods: | Placement type | API used | |----------------|----------| | `Shard` | `Shard.local_shard_size_and_offset(size, chunks, rank)` → `(size, offset)` | | `_StridedShard` | `_StridedShard.local_shard_size_and_offset(instance, size, chunks, rank, return_first_offset=False)` → `(size, offsets_list)` | Return type changed from `tuple[slice, ...]` to `tuple[slice | torch.Tensor, ...]`: - `slice` for contiguous ranges (Shard or contiguous StridedShard result) - `torch.LongTensor` of indices for non-contiguous ranges Composed sharding (multiple placements on the same dim) is handled by indexing: `dim_indices[shard_dim] = dim_indices[shard_dim][new_indices]`. ### 2.3 Updated `numel_for_rank()` (`core.py`) Now handles both `slice` and `torch.Tensor` index types: ```python for idx, dim_size in zip(indices, param.shape): if isinstance(idx, slice): start, stop, step = idx.indices(dim_size) numel *= max(0, (stop - start + (step - 1)) // step) else: numel *= len(idx) ``` ### 2.4 Updated pipeline stages (`pipeline.py`) - **`_gather_grads()`**: Uses `gathered_grads[id(p)][indices] = sg` (index assignment) instead of view-based copy, works for both slice and tensor indices. - **`_scatter_us()`**: `u_full[indices].flatten()` works for both index types. - **`_update_params()` QK clipping**: Applies clipping on `p._local_tensor` directly instead of through DTensor operations, avoiding sharding propagation errors with `_StridedShard`. --- ## 3. Test Results After Fixes | Test configuration | Status | |--------------------|--------| | base | PASS | | fsdp | PASS | | hsdp | PASS | | tp | PASS | | hsdp+tp (QK clip off) | PASS | | hsdp+tp (QK clip on) | PASS | | fsdp+tp (QK clip off) | PASS | | fsdp+tp (QK clip on) | PASS | All 24 tests pass (126s, `--skip-verify`). --- ## 4. Fixed: QK Clipping with Strided Sharding ### Problem With strided (non-contiguous) sharding, local rows are **interleaved across multiple heads**. For example with `fsdp+tp` (`dp_shard=2, tp=4`): - Q/K projection global shape: `(2048, 2048)`, `head_dim=128`, `16 heads` - Each rank owns 256 rows, but they span 4 heads with 64 rows per head - `view(-1, head_dim, cols)` assumes contiguous head blocks → wrong for interleaved rows → shape mismatch error ### Fix Applied For strided sharding, apply scales **per-row** based on each row's global head index instead of using the head-block view: ```python if isinstance(weight_indices[0], slice): # Contiguous case: view-based approach still works ... else: # Strided case: per-row scaling head_per_row = weight_indices[0] // ratio row_scales = scales_full[head_per_row] local_p.mul_(row_scales.view(-1, 1)) ``` --- ## 5. Files Modified | File | Changes | |------|---------| | `torch-ext/optimizer/distributed/utils.py` | Added `_is_shard()`, rewrote `get_slices_of_dtensor()`, fixed `construct_shard_mesh()` | | `torch-ext/optimizer/core.py` | Updated `numel_for_rank()` for `slice | Tensor` indices | | `torch-ext/optimizer/pipeline.py` | Updated `_gather_grads()`, `_scatter_us()`, `_update_params()` QK clipping |