Kernels
optimizer / docs /pytorch-2.10-tp-fix.md
wyldecat's picture
Refactor pipeline to async generator pattern (#16)
33929c0 unverified
# PyTorch 2.10 Tensor Parallelism Fix
## Summary
PyTorch 2.10 changed the class hierarchy for `_StridedShard`, breaking our
`distributed/utils.py` code that handles DTensor sharding. This document
records the root cause, every change made so far, and the one remaining issue.
---
## 1. Root Cause: `_StridedShard` class hierarchy change
| Version | MRO |
|---------|-----|
| PyTorch < 2.10 | `_StridedShard -> Shard -> Placement` |
| PyTorch 2.10 | `_StridedShard -> StridedShard -> Placement` |
**Consequences:**
- `isinstance(strided_shard, Shard)` returns `False`
- `strided_shard.is_shard()` returns `False`
- Our `construct_shard_mesh()` treated `_StridedShard` as an unsupported
placement and raised `AssertionError`.
### When does `_StridedShard` appear?
When two parallelism dimensions shard the same tensor dimension.
For example, `fsdp+tp` or `hsdp+tp` configurations where both TP and
FSDP shard dimension 0 of Q/K/V projection weights:
```
TP : Shard(0) β†’ each TP rank gets 2048/4 = 512 rows
FSDP: Shard(0) on top β†’ each FSDP rank further splits those rows
```
PyTorch represents the second sharding as `_StridedShard(dim=0, split_factor=N)`
to indicate non-contiguous (interleaved) row ownership.
---
## 2. Completed Fixes
### 2.1 `_is_shard()` helper (`distributed/utils.py`)
Added a helper that correctly identifies both `Shard` and `_StridedShard`:
```python
def _is_shard(placement: Placement) -> bool:
return isinstance(placement, (Shard, _StridedShard))
```
Used in `construct_shard_mesh()` where the old code called `placement.is_shard()`.
### 2.2 Rewritten `get_slices_of_dtensor()` (`distributed/utils.py`)
Old code assumed contiguous slicing (`start = rank * shard_size`), which is
wrong for `_StridedShard`.
New code uses PyTorch's own offset-computation methods:
| Placement type | API used |
|----------------|----------|
| `Shard` | `Shard.local_shard_size_and_offset(size, chunks, rank)` β†’ `(size, offset)` |
| `_StridedShard` | `_StridedShard.local_shard_size_and_offset(instance, size, chunks, rank, return_first_offset=False)` β†’ `(size, offsets_list)` |
Return type changed from `tuple[slice, ...]` to `tuple[slice | torch.Tensor, ...]`:
- `slice` for contiguous ranges (Shard or contiguous StridedShard result)
- `torch.LongTensor` of indices for non-contiguous ranges
Composed sharding (multiple placements on the same dim) is handled by
indexing: `dim_indices[shard_dim] = dim_indices[shard_dim][new_indices]`.
### 2.3 Updated `numel_for_rank()` (`core.py`)
Now handles both `slice` and `torch.Tensor` index types:
```python
for idx, dim_size in zip(indices, param.shape):
if isinstance(idx, slice):
start, stop, step = idx.indices(dim_size)
numel *= max(0, (stop - start + (step - 1)) // step)
else:
numel *= len(idx)
```
### 2.4 Updated pipeline stages (`pipeline.py`)
- **`_gather_grads()`**: Uses `gathered_grads[id(p)][indices] = sg` (index
assignment) instead of view-based copy, works for both slice and tensor
indices.
- **`_scatter_us()`**: `u_full[indices].flatten()` works for both index types.
- **`_update_params()` QK clipping**: Applies clipping on `p._local_tensor`
directly instead of through DTensor operations, avoiding sharding
propagation errors with `_StridedShard`.
---
## 3. Test Results After Fixes
| Test configuration | Status |
|--------------------|--------|
| base | PASS |
| fsdp | PASS |
| hsdp | PASS |
| tp | PASS |
| hsdp+tp (QK clip off) | PASS |
| hsdp+tp (QK clip on) | PASS |
| fsdp+tp (QK clip off) | PASS |
| fsdp+tp (QK clip on) | PASS |
All 24 tests pass (126s, `--skip-verify`).
---
## 4. Fixed: QK Clipping with Strided Sharding
### Problem
With strided (non-contiguous) sharding, local rows are **interleaved across
multiple heads**. For example with `fsdp+tp` (`dp_shard=2, tp=4`):
- Q/K projection global shape: `(2048, 2048)`, `head_dim=128`, `16 heads`
- Each rank owns 256 rows, but they span 4 heads with 64 rows per head
- `view(-1, head_dim, cols)` assumes contiguous head blocks β†’ wrong for
interleaved rows β†’ shape mismatch error
### Fix Applied
For strided sharding, apply scales **per-row** based on each row's global
head index instead of using the head-block view:
```python
if isinstance(weight_indices[0], slice):
# Contiguous case: view-based approach still works
...
else:
# Strided case: per-row scaling
head_per_row = weight_indices[0] // ratio
row_scales = scales_full[head_per_row]
local_p.mul_(row_scales.view(-1, 1))
```
---
## 5. Files Modified
| File | Changes |
|------|---------|
| `torch-ext/optimizer/distributed/utils.py` | Added `_is_shard()`, rewrote `get_slices_of_dtensor()`, fixed `construct_shard_mesh()` |
| `torch-ext/optimizer/core.py` | Updated `numel_for_rank()` for `slice | Tensor` indices |
| `torch-ext/optimizer/pipeline.py` | Updated `_gather_grads()`, `_scatter_us()`, `_update_params()` QK clipping |