Kernels
File size: 5,048 Bytes
33929c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# PyTorch 2.10 Tensor Parallelism Fix

## Summary

PyTorch 2.10 changed the class hierarchy for `_StridedShard`, breaking our
`distributed/utils.py` code that handles DTensor sharding.  This document
records the root cause, every change made so far, and the one remaining issue.

---

## 1. Root Cause: `_StridedShard` class hierarchy change

| Version | MRO |
|---------|-----|
| PyTorch < 2.10 | `_StridedShard -> Shard -> Placement` |
| PyTorch 2.10   | `_StridedShard -> StridedShard -> Placement` |

**Consequences:**

- `isinstance(strided_shard, Shard)` returns `False`
- `strided_shard.is_shard()` returns `False`
- Our `construct_shard_mesh()` treated `_StridedShard` as an unsupported
  placement and raised `AssertionError`.

### When does `_StridedShard` appear?

When two parallelism dimensions shard the same tensor dimension.
For example, `fsdp+tp` or `hsdp+tp` configurations where both TP and
FSDP shard dimension 0 of Q/K/V projection weights:

```
TP  : Shard(0)           →  each TP rank gets 2048/4 = 512 rows
FSDP: Shard(0) on top    →  each FSDP rank further splits those rows
```

PyTorch represents the second sharding as `_StridedShard(dim=0, split_factor=N)`
to indicate non-contiguous (interleaved) row ownership.

---

## 2. Completed Fixes

### 2.1 `_is_shard()` helper  (`distributed/utils.py`)

Added a helper that correctly identifies both `Shard` and `_StridedShard`:

```python
def _is_shard(placement: Placement) -> bool:
    return isinstance(placement, (Shard, _StridedShard))
```

Used in `construct_shard_mesh()` where the old code called `placement.is_shard()`.

### 2.2 Rewritten `get_slices_of_dtensor()` (`distributed/utils.py`)

Old code assumed contiguous slicing (`start = rank * shard_size`), which is
wrong for `_StridedShard`.

New code uses PyTorch's own offset-computation methods:

| Placement type | API used |
|----------------|----------|
| `Shard` | `Shard.local_shard_size_and_offset(size, chunks, rank)``(size, offset)` |
| `_StridedShard` | `_StridedShard.local_shard_size_and_offset(instance, size, chunks, rank, return_first_offset=False)``(size, offsets_list)` |

Return type changed from `tuple[slice, ...]` to `tuple[slice | torch.Tensor, ...]`:
- `slice` for contiguous ranges (Shard or contiguous StridedShard result)
- `torch.LongTensor` of indices for non-contiguous ranges

Composed sharding (multiple placements on the same dim) is handled by
indexing: `dim_indices[shard_dim] = dim_indices[shard_dim][new_indices]`.

### 2.3 Updated `numel_for_rank()` (`core.py`)

Now handles both `slice` and `torch.Tensor` index types:

```python
for idx, dim_size in zip(indices, param.shape):
    if isinstance(idx, slice):
        start, stop, step = idx.indices(dim_size)
        numel *= max(0, (stop - start + (step - 1)) // step)
    else:
        numel *= len(idx)
```

### 2.4 Updated pipeline stages (`pipeline.py`)

- **`_gather_grads()`**: Uses `gathered_grads[id(p)][indices] = sg` (index
  assignment) instead of view-based copy, works for both slice and tensor
  indices.
- **`_scatter_us()`**: `u_full[indices].flatten()` works for both index types.
- **`_update_params()` QK clipping**: Applies clipping on `p._local_tensor`
  directly instead of through DTensor operations, avoiding sharding
  propagation errors with `_StridedShard`.

---

## 3. Test Results After Fixes

| Test configuration | Status |
|--------------------|--------|
| base               | PASS   |
| fsdp               | PASS   |
| hsdp               | PASS   |
| tp                  | PASS   |
| hsdp+tp (QK clip off) | PASS |
| hsdp+tp (QK clip on)  | PASS |
| fsdp+tp (QK clip off) | PASS |
| fsdp+tp (QK clip on)  | PASS |

All 24 tests pass (126s, `--skip-verify`).

---

## 4. Fixed: QK Clipping with Strided Sharding

### Problem

With strided (non-contiguous) sharding, local rows are **interleaved across
multiple heads**.  For example with `fsdp+tp` (`dp_shard=2, tp=4`):

- Q/K projection global shape: `(2048, 2048)`, `head_dim=128`, `16 heads`
- Each rank owns 256 rows, but they span 4 heads with 64 rows per head
- `view(-1, head_dim, cols)` assumes contiguous head blocks → wrong for
  interleaved rows → shape mismatch error

### Fix Applied

For strided sharding, apply scales **per-row** based on each row's global
head index instead of using the head-block view:

```python
if isinstance(weight_indices[0], slice):
    # Contiguous case: view-based approach still works
    ...
else:
    # Strided case: per-row scaling
    head_per_row = weight_indices[0] // ratio
    row_scales = scales_full[head_per_row]
    local_p.mul_(row_scales.view(-1, 1))
```

---

## 5. Files Modified

| File | Changes |
|------|---------|
| `torch-ext/optimizer/distributed/utils.py` | Added `_is_shard()`, rewrote `get_slices_of_dtensor()`, fixed `construct_shard_mesh()` |
| `torch-ext/optimizer/core.py` | Updated `numel_for_rank()` for `slice | Tensor` indices |
| `torch-ext/optimizer/pipeline.py` | Updated `_gather_grads()`, `_scatter_us()`, `_update_params()` QK clipping |