Kernels
optimizer / docs /pytorch-2.10-tp-fix.md
wyldecat's picture
Refactor pipeline to async generator pattern (#16)
33929c0 unverified

PyTorch 2.10 Tensor Parallelism Fix

Summary

PyTorch 2.10 changed the class hierarchy for _StridedShard, breaking our distributed/utils.py code that handles DTensor sharding. This document records the root cause, every change made so far, and the one remaining issue.


1. Root Cause: _StridedShard class hierarchy change

Version MRO
PyTorch < 2.10 _StridedShard -> Shard -> Placement
PyTorch 2.10 _StridedShard -> StridedShard -> Placement

Consequences:

  • isinstance(strided_shard, Shard) returns False
  • strided_shard.is_shard() returns False
  • Our construct_shard_mesh() treated _StridedShard as an unsupported placement and raised AssertionError.

When does _StridedShard appear?

When two parallelism dimensions shard the same tensor dimension. For example, fsdp+tp or hsdp+tp configurations where both TP and FSDP shard dimension 0 of Q/K/V projection weights:

TP  : Shard(0)           →  each TP rank gets 2048/4 = 512 rows
FSDP: Shard(0) on top    →  each FSDP rank further splits those rows

PyTorch represents the second sharding as _StridedShard(dim=0, split_factor=N) to indicate non-contiguous (interleaved) row ownership.


2. Completed Fixes

2.1 _is_shard() helper (distributed/utils.py)

Added a helper that correctly identifies both Shard and _StridedShard:

def _is_shard(placement: Placement) -> bool:
    return isinstance(placement, (Shard, _StridedShard))

Used in construct_shard_mesh() where the old code called placement.is_shard().

2.2 Rewritten get_slices_of_dtensor() (distributed/utils.py)

Old code assumed contiguous slicing (start = rank * shard_size), which is wrong for _StridedShard.

New code uses PyTorch's own offset-computation methods:

Placement type API used
Shard Shard.local_shard_size_and_offset(size, chunks, rank)(size, offset)
_StridedShard _StridedShard.local_shard_size_and_offset(instance, size, chunks, rank, return_first_offset=False)(size, offsets_list)

Return type changed from tuple[slice, ...] to tuple[slice | torch.Tensor, ...]:

  • slice for contiguous ranges (Shard or contiguous StridedShard result)
  • torch.LongTensor of indices for non-contiguous ranges

Composed sharding (multiple placements on the same dim) is handled by indexing: dim_indices[shard_dim] = dim_indices[shard_dim][new_indices].

2.3 Updated numel_for_rank() (core.py)

Now handles both slice and torch.Tensor index types:

for idx, dim_size in zip(indices, param.shape):
    if isinstance(idx, slice):
        start, stop, step = idx.indices(dim_size)
        numel *= max(0, (stop - start + (step - 1)) // step)
    else:
        numel *= len(idx)

2.4 Updated pipeline stages (pipeline.py)

  • _gather_grads(): Uses gathered_grads[id(p)][indices] = sg (index assignment) instead of view-based copy, works for both slice and tensor indices.
  • _scatter_us(): u_full[indices].flatten() works for both index types.
  • _update_params() QK clipping: Applies clipping on p._local_tensor directly instead of through DTensor operations, avoiding sharding propagation errors with _StridedShard.

3. Test Results After Fixes

Test configuration Status
base PASS
fsdp PASS
hsdp PASS
tp PASS
hsdp+tp (QK clip off) PASS
hsdp+tp (QK clip on) PASS
fsdp+tp (QK clip off) PASS
fsdp+tp (QK clip on) PASS

All 24 tests pass (126s, --skip-verify).


4. Fixed: QK Clipping with Strided Sharding

Problem

With strided (non-contiguous) sharding, local rows are interleaved across multiple heads. For example with fsdp+tp (dp_shard=2, tp=4):

  • Q/K projection global shape: (2048, 2048), head_dim=128, 16 heads
  • Each rank owns 256 rows, but they span 4 heads with 64 rows per head
  • view(-1, head_dim, cols) assumes contiguous head blocks → wrong for interleaved rows → shape mismatch error

Fix Applied

For strided sharding, apply scales per-row based on each row's global head index instead of using the head-block view:

if isinstance(weight_indices[0], slice):
    # Contiguous case: view-based approach still works
    ...
else:
    # Strided case: per-row scaling
    head_per_row = weight_indices[0] // ratio
    row_scales = scales_full[head_per_row]
    local_p.mul_(row_scales.view(-1, 1))

5. Files Modified

File Changes
torch-ext/optimizer/distributed/utils.py Added _is_shard(), rewrote get_slices_of_dtensor(), fixed construct_shard_mesh()
torch-ext/optimizer/core.py Updated numel_for_rank() for `slice
torch-ext/optimizer/pipeline.py Updated _gather_grads(), _scatter_us(), _update_params() QK clipping