File size: 5,048 Bytes
33929c0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | # PyTorch 2.10 Tensor Parallelism Fix
## Summary
PyTorch 2.10 changed the class hierarchy for `_StridedShard`, breaking our
`distributed/utils.py` code that handles DTensor sharding. This document
records the root cause, every change made so far, and the one remaining issue.
---
## 1. Root Cause: `_StridedShard` class hierarchy change
| Version | MRO |
|---------|-----|
| PyTorch < 2.10 | `_StridedShard -> Shard -> Placement` |
| PyTorch 2.10 | `_StridedShard -> StridedShard -> Placement` |
**Consequences:**
- `isinstance(strided_shard, Shard)` returns `False`
- `strided_shard.is_shard()` returns `False`
- Our `construct_shard_mesh()` treated `_StridedShard` as an unsupported
placement and raised `AssertionError`.
### When does `_StridedShard` appear?
When two parallelism dimensions shard the same tensor dimension.
For example, `fsdp+tp` or `hsdp+tp` configurations where both TP and
FSDP shard dimension 0 of Q/K/V projection weights:
```
TP : Shard(0) → each TP rank gets 2048/4 = 512 rows
FSDP: Shard(0) on top → each FSDP rank further splits those rows
```
PyTorch represents the second sharding as `_StridedShard(dim=0, split_factor=N)`
to indicate non-contiguous (interleaved) row ownership.
---
## 2. Completed Fixes
### 2.1 `_is_shard()` helper (`distributed/utils.py`)
Added a helper that correctly identifies both `Shard` and `_StridedShard`:
```python
def _is_shard(placement: Placement) -> bool:
return isinstance(placement, (Shard, _StridedShard))
```
Used in `construct_shard_mesh()` where the old code called `placement.is_shard()`.
### 2.2 Rewritten `get_slices_of_dtensor()` (`distributed/utils.py`)
Old code assumed contiguous slicing (`start = rank * shard_size`), which is
wrong for `_StridedShard`.
New code uses PyTorch's own offset-computation methods:
| Placement type | API used |
|----------------|----------|
| `Shard` | `Shard.local_shard_size_and_offset(size, chunks, rank)` → `(size, offset)` |
| `_StridedShard` | `_StridedShard.local_shard_size_and_offset(instance, size, chunks, rank, return_first_offset=False)` → `(size, offsets_list)` |
Return type changed from `tuple[slice, ...]` to `tuple[slice | torch.Tensor, ...]`:
- `slice` for contiguous ranges (Shard or contiguous StridedShard result)
- `torch.LongTensor` of indices for non-contiguous ranges
Composed sharding (multiple placements on the same dim) is handled by
indexing: `dim_indices[shard_dim] = dim_indices[shard_dim][new_indices]`.
### 2.3 Updated `numel_for_rank()` (`core.py`)
Now handles both `slice` and `torch.Tensor` index types:
```python
for idx, dim_size in zip(indices, param.shape):
if isinstance(idx, slice):
start, stop, step = idx.indices(dim_size)
numel *= max(0, (stop - start + (step - 1)) // step)
else:
numel *= len(idx)
```
### 2.4 Updated pipeline stages (`pipeline.py`)
- **`_gather_grads()`**: Uses `gathered_grads[id(p)][indices] = sg` (index
assignment) instead of view-based copy, works for both slice and tensor
indices.
- **`_scatter_us()`**: `u_full[indices].flatten()` works for both index types.
- **`_update_params()` QK clipping**: Applies clipping on `p._local_tensor`
directly instead of through DTensor operations, avoiding sharding
propagation errors with `_StridedShard`.
---
## 3. Test Results After Fixes
| Test configuration | Status |
|--------------------|--------|
| base | PASS |
| fsdp | PASS |
| hsdp | PASS |
| tp | PASS |
| hsdp+tp (QK clip off) | PASS |
| hsdp+tp (QK clip on) | PASS |
| fsdp+tp (QK clip off) | PASS |
| fsdp+tp (QK clip on) | PASS |
All 24 tests pass (126s, `--skip-verify`).
---
## 4. Fixed: QK Clipping with Strided Sharding
### Problem
With strided (non-contiguous) sharding, local rows are **interleaved across
multiple heads**. For example with `fsdp+tp` (`dp_shard=2, tp=4`):
- Q/K projection global shape: `(2048, 2048)`, `head_dim=128`, `16 heads`
- Each rank owns 256 rows, but they span 4 heads with 64 rows per head
- `view(-1, head_dim, cols)` assumes contiguous head blocks → wrong for
interleaved rows → shape mismatch error
### Fix Applied
For strided sharding, apply scales **per-row** based on each row's global
head index instead of using the head-block view:
```python
if isinstance(weight_indices[0], slice):
# Contiguous case: view-based approach still works
...
else:
# Strided case: per-row scaling
head_per_row = weight_indices[0] // ratio
row_scales = scales_full[head_per_row]
local_p.mul_(row_scales.view(-1, 1))
```
---
## 5. Files Modified
| File | Changes |
|------|---------|
| `torch-ext/optimizer/distributed/utils.py` | Added `_is_shard()`, rewrote `get_slices_of_dtensor()`, fixed `construct_shard_mesh()` |
| `torch-ext/optimizer/core.py` | Updated `numel_for_rank()` for `slice | Tensor` indices |
| `torch-ext/optimizer/pipeline.py` | Updated `_gather_grads()`, `_scatter_us()`, `_update_params()` QK clipping |
|