PyTorch 2.10 Tensor Parallelism Fix
Summary
PyTorch 2.10 changed the class hierarchy for _StridedShard, breaking our
distributed/utils.py code that handles DTensor sharding. This document
records the root cause, every change made so far, and the one remaining issue.
1. Root Cause: _StridedShard class hierarchy change
| Version | MRO |
|---|---|
| PyTorch < 2.10 | _StridedShard -> Shard -> Placement |
| PyTorch 2.10 | _StridedShard -> StridedShard -> Placement |
Consequences:
isinstance(strided_shard, Shard)returnsFalsestrided_shard.is_shard()returnsFalse- Our
construct_shard_mesh()treated_StridedShardas an unsupported placement and raisedAssertionError.
When does _StridedShard appear?
When two parallelism dimensions shard the same tensor dimension.
For example, fsdp+tp or hsdp+tp configurations where both TP and
FSDP shard dimension 0 of Q/K/V projection weights:
TP : Shard(0) → each TP rank gets 2048/4 = 512 rows
FSDP: Shard(0) on top → each FSDP rank further splits those rows
PyTorch represents the second sharding as _StridedShard(dim=0, split_factor=N)
to indicate non-contiguous (interleaved) row ownership.
2. Completed Fixes
2.1 _is_shard() helper (distributed/utils.py)
Added a helper that correctly identifies both Shard and _StridedShard:
def _is_shard(placement: Placement) -> bool:
return isinstance(placement, (Shard, _StridedShard))
Used in construct_shard_mesh() where the old code called placement.is_shard().
2.2 Rewritten get_slices_of_dtensor() (distributed/utils.py)
Old code assumed contiguous slicing (start = rank * shard_size), which is
wrong for _StridedShard.
New code uses PyTorch's own offset-computation methods:
| Placement type | API used |
|---|---|
Shard |
Shard.local_shard_size_and_offset(size, chunks, rank) → (size, offset) |
_StridedShard |
_StridedShard.local_shard_size_and_offset(instance, size, chunks, rank, return_first_offset=False) → (size, offsets_list) |
Return type changed from tuple[slice, ...] to tuple[slice | torch.Tensor, ...]:
slicefor contiguous ranges (Shard or contiguous StridedShard result)torch.LongTensorof indices for non-contiguous ranges
Composed sharding (multiple placements on the same dim) is handled by
indexing: dim_indices[shard_dim] = dim_indices[shard_dim][new_indices].
2.3 Updated numel_for_rank() (core.py)
Now handles both slice and torch.Tensor index types:
for idx, dim_size in zip(indices, param.shape):
if isinstance(idx, slice):
start, stop, step = idx.indices(dim_size)
numel *= max(0, (stop - start + (step - 1)) // step)
else:
numel *= len(idx)
2.4 Updated pipeline stages (pipeline.py)
_gather_grads(): Usesgathered_grads[id(p)][indices] = sg(index assignment) instead of view-based copy, works for both slice and tensor indices._scatter_us():u_full[indices].flatten()works for both index types._update_params()QK clipping: Applies clipping onp._local_tensordirectly instead of through DTensor operations, avoiding sharding propagation errors with_StridedShard.
3. Test Results After Fixes
| Test configuration | Status |
|---|---|
| base | PASS |
| fsdp | PASS |
| hsdp | PASS |
| tp | PASS |
| hsdp+tp (QK clip off) | PASS |
| hsdp+tp (QK clip on) | PASS |
| fsdp+tp (QK clip off) | PASS |
| fsdp+tp (QK clip on) | PASS |
All 24 tests pass (126s, --skip-verify).
4. Fixed: QK Clipping with Strided Sharding
Problem
With strided (non-contiguous) sharding, local rows are interleaved across
multiple heads. For example with fsdp+tp (dp_shard=2, tp=4):
- Q/K projection global shape:
(2048, 2048),head_dim=128,16 heads - Each rank owns 256 rows, but they span 4 heads with 64 rows per head
view(-1, head_dim, cols)assumes contiguous head blocks → wrong for interleaved rows → shape mismatch error
Fix Applied
For strided sharding, apply scales per-row based on each row's global head index instead of using the head-block view:
if isinstance(weight_indices[0], slice):
# Contiguous case: view-based approach still works
...
else:
# Strided case: per-row scaling
head_per_row = weight_indices[0] // ratio
row_scales = scales_full[head_per_row]
local_p.mul_(row_scales.view(-1, 1))
5. Files Modified
| File | Changes |
|---|---|
torch-ext/optimizer/distributed/utils.py |
Added _is_shard(), rewrote get_slices_of_dtensor(), fixed construct_shard_mesh() |
torch-ext/optimizer/core.py |
Updated numel_for_rank() for `slice |
torch-ext/optimizer/pipeline.py |
Updated _gather_grads(), _scatter_us(), _update_params() QK clipping |