Spaces:
Running
Running
Add support for DualPipe.
Browse files- .gitignore +1 -0
- README.md +21 -6
- assets/dualpipe.png +3 -0
- conf/config.yaml +3 -0
- main.py +23 -0
- src/execution_model.py +81 -19
- src/strategies.py +227 -2
- src/visualizer.py +2 -12
.gitignore
CHANGED
|
@@ -3,6 +3,7 @@
|
|
| 3 |
uv.lock
|
| 4 |
outputs/
|
| 5 |
.cursor/*
|
|
|
|
| 6 |
|
| 7 |
# Uncomment below if you want to include these files
|
| 8 |
# !assets/*.png
|
|
|
|
| 3 |
uv.lock
|
| 4 |
outputs/
|
| 5 |
.cursor/*
|
| 6 |
+
*.json
|
| 7 |
|
| 8 |
# Uncomment below if you want to include these files
|
| 9 |
# !assets/*.png
|
README.md
CHANGED
|
@@ -18,6 +18,7 @@ Pipeline parallelism is a technique used to train large models by partitioning t
|
|
| 18 |
- Zero-Bubble 1F1B (ZB-1P)
|
| 19 |
- 1F1B with computation-communication overlap
|
| 20 |
- Interleaved 1F1B with computation-communication overlap
|
|
|
|
| 21 |
|
| 22 |
- **Visualization**:
|
| 23 |
- Interactive visualization dashboard using Plotly/Dash
|
|
@@ -56,6 +57,12 @@ uv run python main.py strategy=zb1p num_devices=4 num_stages=4 num_batches=8
|
|
| 56 |
```
|
| 57 |

|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
### Running for 1F1B-batch-overlap strategy:
|
| 60 |
```bash
|
| 61 |
uv run python main.py strategy=1f1b_overlap num_devices=4 num_stages=4 num_batches=8
|
|
@@ -68,10 +75,24 @@ uv run python main.py strategy=1f1b_interleave_overlap num_devices=4 num_stages=
|
|
| 68 |
```
|
| 69 |

|
| 70 |
|
|
|
|
| 71 |
## Configuration
|
| 72 |
|
| 73 |
The default configuration is in `conf/config.yaml`. You can override any parameter on the command line or create configuration groups for different scenarios.
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
### Using Different Configuration Files
|
| 76 |
|
| 77 |
You can use different configuration files with Hydra in several ways:
|
|
@@ -90,12 +111,6 @@ You can use different configuration files with Hydra in several ways:
|
|
| 90 |
uv run python main.py --config-name=model_A
|
| 91 |
```
|
| 92 |
|
| 93 |
-
#### Override Specific Parameters
|
| 94 |
-
|
| 95 |
-
You can also override specific parameters at runtime:
|
| 96 |
-
```bash
|
| 97 |
-
uv run python main.py op_times.forward=0.5 op_times.backward=1.0 num_batches=6
|
| 98 |
-
```
|
| 99 |
|
| 100 |
## Project Structure
|
| 101 |
|
|
|
|
| 18 |
- Zero-Bubble 1F1B (ZB-1P)
|
| 19 |
- 1F1B with computation-communication overlap
|
| 20 |
- Interleaved 1F1B with computation-communication overlap
|
| 21 |
+
- DualPipe (Bidirectional pipeline parallelism with full forward-backward overlap)
|
| 22 |
|
| 23 |
- **Visualization**:
|
| 24 |
- Interactive visualization dashboard using Plotly/Dash
|
|
|
|
| 57 |
```
|
| 58 |

|
| 59 |
|
| 60 |
+
### Running for DualPipe strategy:
|
| 61 |
+
```bash
|
| 62 |
+
uv run python main.py strategy=dualpipe num_devices=8 num_stages=8 num_batches=20
|
| 63 |
+
```
|
| 64 |
+

|
| 65 |
+
|
| 66 |
### Running for 1F1B-batch-overlap strategy:
|
| 67 |
```bash
|
| 68 |
uv run python main.py strategy=1f1b_overlap num_devices=4 num_stages=4 num_batches=8
|
|
|
|
| 75 |
```
|
| 76 |

|
| 77 |
|
| 78 |
+
|
| 79 |
## Configuration
|
| 80 |
|
| 81 |
The default configuration is in `conf/config.yaml`. You can override any parameter on the command line or create configuration groups for different scenarios.
|
| 82 |
|
| 83 |
+
#### Override Specific Parameters
|
| 84 |
+
|
| 85 |
+
You can override specific parameters at runtime:
|
| 86 |
+
```bash
|
| 87 |
+
uv run python main.py op_times.forward=0.5 op_times.backward=1.0 num_batches=6
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
Use DualPipe as an example, you can manually set different time for forward/backward/backward_D/backward_W/overlapped_forward_backward:
|
| 91 |
+
```bash
|
| 92 |
+
uv run python main.py strategy=dualpipe num_devices=8 num_stages=8 num_batches=32 op_times.forward=1.0 op_times.backward=2.0 op_times.backward_D=1.0 op_times.backward_W=1.0 op_times.overlapped_forward_backward=2.5
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
|
| 96 |
### Using Different Configuration Files
|
| 97 |
|
| 98 |
You can use different configuration files with Hydra in several ways:
|
|
|
|
| 111 |
uv run python main.py --config-name=model_A
|
| 112 |
```
|
| 113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
## Project Structure
|
| 116 |
|
assets/dualpipe.png
ADDED
|
Git LFS Details
|
conf/config.yaml
CHANGED
|
@@ -11,6 +11,9 @@ op_times:
|
|
| 11 |
# Option 1: Simple configuration (same time for all stages)
|
| 12 |
forward: 1.0
|
| 13 |
backward: 2.0
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
# Option 2: Commented example of stage-specific configuration
|
| 16 |
# forward:
|
|
|
|
| 11 |
# Option 1: Simple configuration (same time for all stages)
|
| 12 |
forward: 1.0
|
| 13 |
backward: 2.0
|
| 14 |
+
backward_D: 1.0
|
| 15 |
+
backward_W: 1.0
|
| 16 |
+
overlapped_forward_backward: 2.0
|
| 17 |
|
| 18 |
# Option 2: Commented example of stage-specific configuration
|
| 19 |
# forward:
|
main.py
CHANGED
|
@@ -5,6 +5,7 @@ from src.strategies import (
|
|
| 5 |
generate_1f1b_overlap_schedule,
|
| 6 |
generate_1f1b_schedule,
|
| 7 |
generate_zero_bubble_1p_schedule,
|
|
|
|
| 8 |
)
|
| 9 |
from src.visualizer import visualize_pipeline_parallelism_dash
|
| 10 |
import hydra
|
|
@@ -26,6 +27,8 @@ def main(cfg: DictConfig) -> None:
|
|
| 26 |
run_1f1b_overlap(cfg)
|
| 27 |
elif cfg.strategy == "1f1b_interleave_overlap":
|
| 28 |
run_1f1b_interleave_overlap(cfg)
|
|
|
|
|
|
|
| 29 |
else:
|
| 30 |
raise ValueError(f"Unknown strategy: {cfg.strategy}")
|
| 31 |
|
|
@@ -129,5 +132,25 @@ def run_1f1b_interleave_overlap(cfg: DictConfig) -> None:
|
|
| 129 |
schedule.execute()
|
| 130 |
visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
if __name__ == "__main__":
|
| 133 |
main()
|
|
|
|
| 5 |
generate_1f1b_overlap_schedule,
|
| 6 |
generate_1f1b_schedule,
|
| 7 |
generate_zero_bubble_1p_schedule,
|
| 8 |
+
generate_dualpipe_schedule,
|
| 9 |
)
|
| 10 |
from src.visualizer import visualize_pipeline_parallelism_dash
|
| 11 |
import hydra
|
|
|
|
| 27 |
run_1f1b_overlap(cfg)
|
| 28 |
elif cfg.strategy == "1f1b_interleave_overlap":
|
| 29 |
run_1f1b_interleave_overlap(cfg)
|
| 30 |
+
elif cfg.strategy == "dualpipe":
|
| 31 |
+
run_dualpipe(cfg)
|
| 32 |
else:
|
| 33 |
raise ValueError(f"Unknown strategy: {cfg.strategy}")
|
| 34 |
|
|
|
|
| 132 |
schedule.execute()
|
| 133 |
visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
| 134 |
|
| 135 |
+
def run_dualpipe(cfg: DictConfig) -> None:
|
| 136 |
+
"""Run DualPipe pipeline parallelism simulation."""
|
| 137 |
+
# Convert OmegaConf to dict for op_times if it exists
|
| 138 |
+
op_times = (
|
| 139 |
+
OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
schedule_config = ScheduleConfig(
|
| 143 |
+
num_devices=cfg.num_devices,
|
| 144 |
+
num_stages=cfg.num_stages,
|
| 145 |
+
num_batches=cfg.num_batches,
|
| 146 |
+
p2p_latency=cfg.p2p_latency,
|
| 147 |
+
op_times=op_times,
|
| 148 |
+
split_backward=True,
|
| 149 |
+
placement_strategy="dualpipe",
|
| 150 |
+
)
|
| 151 |
+
schedule = generate_dualpipe_schedule(schedule_config)
|
| 152 |
+
schedule.execute()
|
| 153 |
+
visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
| 154 |
+
|
| 155 |
if __name__ == "__main__":
|
| 156 |
main()
|
src/execution_model.py
CHANGED
|
@@ -69,7 +69,7 @@ class DeviceQueue:
|
|
| 69 |
def add_operation(self, op: Operation):
|
| 70 |
assert op.stage_id in self.stages
|
| 71 |
self.ops.append(op)
|
| 72 |
-
assert op.device_id is None
|
| 73 |
op.device_id = self.device_id
|
| 74 |
|
| 75 |
|
|
@@ -97,6 +97,7 @@ class ScheduleConfig:
|
|
| 97 |
"forward": 1.0,
|
| 98 |
"backward_D": 1.0,
|
| 99 |
"backward_W": 1.0,
|
|
|
|
| 100 |
}
|
| 101 |
else:
|
| 102 |
self.op_times = {
|
|
@@ -128,9 +129,14 @@ class ScheduleConfig:
|
|
| 128 |
self.num_stages_per_device = num_stages // num_devices
|
| 129 |
|
| 130 |
self.init_device_to_stages()
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
def init_device_to_stages(self):
|
| 136 |
if self.placement_strategy == "standard":
|
|
@@ -145,14 +151,27 @@ class ScheduleConfig:
|
|
| 145 |
for i in range(self.num_stages):
|
| 146 |
device_to_put = i % self.num_devices
|
| 147 |
self.device_to_stages[device_to_put].append(i)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
else:
|
| 149 |
raise ValueError(f"Invalid placement strategy: {self.placement_strategy}")
|
| 150 |
|
| 151 |
def get_op_time(self, op_type: str, stage_id: int):
|
| 152 |
# For overlapped operations, extract the original operation types
|
| 153 |
if op_type.startswith("overlapped_"):
|
| 154 |
-
if op_type in self.op_times
|
| 155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
else:
|
| 157 |
op_parts = op_type.split("_")[1:]
|
| 158 |
if len(op_parts) >= 2:
|
|
@@ -173,20 +192,25 @@ class ScheduleConfig:
|
|
| 173 |
|
| 174 |
|
| 175 |
class Schedule:
|
| 176 |
-
def __init__(self, config: ScheduleConfig):
|
| 177 |
self.ops = {} # (batch_id, stage_id, op_type) -> Operation
|
| 178 |
self.device_queues: List[DeviceQueue] = []
|
| 179 |
for dev_id in range(config.num_devices):
|
| 180 |
self.device_queues.append(DeviceQueue(config.device_to_stages[dev_id], dev_id))
|
| 181 |
self.config = config
|
| 182 |
|
| 183 |
-
|
|
|
|
| 184 |
self.op_to_overlapped = {}
|
| 185 |
|
| 186 |
def register_overlapped_operation(self, overlapped_op: OverlappedOperation):
|
| 187 |
for op in overlapped_op.operations:
|
| 188 |
self.op_to_overlapped[(op.batch_id, op.stage_id, op.op_type)] = overlapped_op
|
| 189 |
self.ops[(op.batch_id, op.stage_id, op.op_type)] = overlapped_op
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
|
| 191 |
def init_operations(self):
|
| 192 |
op_types = ["forward", "backward"]
|
|
@@ -199,9 +223,12 @@ class Schedule:
|
|
| 199 |
batch_id, stage_id, op_type
|
| 200 |
)
|
| 201 |
|
| 202 |
-
def get_op(self, batch_id: int, stage_id: int, op_type: str):
|
| 203 |
if (batch_id, stage_id, op_type) in self.op_to_overlapped:
|
| 204 |
return self.op_to_overlapped[(batch_id, stage_id, op_type)]
|
|
|
|
|
|
|
|
|
|
| 205 |
return self.ops[(batch_id, stage_id, op_type)]
|
| 206 |
|
| 207 |
def get_dependencies(self, op: Operation, include_device_dependency=True):
|
|
@@ -226,20 +253,55 @@ class Schedule:
|
|
| 226 |
if self.config.split_backward:
|
| 227 |
if op.op_type == "backward_D":
|
| 228 |
if op.stage_id < self.config.num_stages - 1:
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
)
|
| 234 |
-
)
|
| 235 |
elif op.op_type == "backward_W":
|
| 236 |
if op.stage_id < self.config.num_stages - 1:
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
)
|
| 242 |
-
)
|
| 243 |
else:
|
| 244 |
if op.op_type == "backward":
|
| 245 |
if op.stage_id < self.config.num_stages - 1:
|
|
|
|
| 69 |
def add_operation(self, op: Operation):
|
| 70 |
assert op.stage_id in self.stages
|
| 71 |
self.ops.append(op)
|
| 72 |
+
assert op.device_id is None, f"Operation {op.batch_id}, {op.stage_id}, {op.op_type} already has a device id on {op.device_id}"
|
| 73 |
op.device_id = self.device_id
|
| 74 |
|
| 75 |
|
|
|
|
| 97 |
"forward": 1.0,
|
| 98 |
"backward_D": 1.0,
|
| 99 |
"backward_W": 1.0,
|
| 100 |
+
"backward": 2.0,
|
| 101 |
}
|
| 102 |
else:
|
| 103 |
self.op_times = {
|
|
|
|
| 129 |
self.num_stages_per_device = num_stages // num_devices
|
| 130 |
|
| 131 |
self.init_device_to_stages()
|
| 132 |
+
if self.placement_strategy == "dualpipe":
|
| 133 |
+
assert (
|
| 134 |
+
sum(len(stages) for stages in self.device_to_stages.values()) == num_stages * 2
|
| 135 |
+
)
|
| 136 |
+
else:
|
| 137 |
+
assert (
|
| 138 |
+
sum(len(stages) for stages in self.device_to_stages.values()) == num_stages
|
| 139 |
+
)
|
| 140 |
|
| 141 |
def init_device_to_stages(self):
|
| 142 |
if self.placement_strategy == "standard":
|
|
|
|
| 151 |
for i in range(self.num_stages):
|
| 152 |
device_to_put = i % self.num_devices
|
| 153 |
self.device_to_stages[device_to_put].append(i)
|
| 154 |
+
elif self.placement_strategy == "dualpipe":
|
| 155 |
+
# For DualPipe, each device has two stages
|
| 156 |
+
assert self.num_devices == self.num_stages, "DualPipe requires num_devices == num_stages"
|
| 157 |
+
assert self.num_devices % 2 == 0, "DualPipe requires an even number of devices"
|
| 158 |
+
self.device_to_stages = defaultdict(list)
|
| 159 |
+
for i in range(self.num_stages):
|
| 160 |
+
self.device_to_stages[i] = [i, self.num_stages - i - 1]
|
| 161 |
else:
|
| 162 |
raise ValueError(f"Invalid placement strategy: {self.placement_strategy}")
|
| 163 |
|
| 164 |
def get_op_time(self, op_type: str, stage_id: int):
|
| 165 |
# For overlapped operations, extract the original operation types
|
| 166 |
if op_type.startswith("overlapped_"):
|
| 167 |
+
if op_type in self.op_times:
|
| 168 |
+
if isinstance(self.op_times[op_type], dict):
|
| 169 |
+
if stage_id in self.op_times[op_type]:
|
| 170 |
+
return self.op_times[op_type][stage_id]
|
| 171 |
+
else:
|
| 172 |
+
raise ValueError(f"No time specified for operation {op_type} at stage {stage_id}")
|
| 173 |
+
else:
|
| 174 |
+
return self.op_times[op_type]
|
| 175 |
else:
|
| 176 |
op_parts = op_type.split("_")[1:]
|
| 177 |
if len(op_parts) >= 2:
|
|
|
|
| 192 |
|
| 193 |
|
| 194 |
class Schedule:
|
| 195 |
+
def __init__(self, config: ScheduleConfig, init_ops: bool = True):
|
| 196 |
self.ops = {} # (batch_id, stage_id, op_type) -> Operation
|
| 197 |
self.device_queues: List[DeviceQueue] = []
|
| 198 |
for dev_id in range(config.num_devices):
|
| 199 |
self.device_queues.append(DeviceQueue(config.device_to_stages[dev_id], dev_id))
|
| 200 |
self.config = config
|
| 201 |
|
| 202 |
+
if init_ops:
|
| 203 |
+
self.init_operations()
|
| 204 |
self.op_to_overlapped = {}
|
| 205 |
|
| 206 |
def register_overlapped_operation(self, overlapped_op: OverlappedOperation):
|
| 207 |
for op in overlapped_op.operations:
|
| 208 |
self.op_to_overlapped[(op.batch_id, op.stage_id, op.op_type)] = overlapped_op
|
| 209 |
self.ops[(op.batch_id, op.stage_id, op.op_type)] = overlapped_op
|
| 210 |
+
|
| 211 |
+
def register_operation(self, op: Operation):
|
| 212 |
+
assert (op.batch_id, op.stage_id, op.op_type) not in self.ops, f"Operation {op.batch_id}, {op.stage_id}, {op.op_type} already registered"
|
| 213 |
+
self.ops[(op.batch_id, op.stage_id, op.op_type)] = op
|
| 214 |
|
| 215 |
def init_operations(self):
|
| 216 |
op_types = ["forward", "backward"]
|
|
|
|
| 223 |
batch_id, stage_id, op_type
|
| 224 |
)
|
| 225 |
|
| 226 |
+
def get_op(self, batch_id: int, stage_id: int, op_type: str, allow_none=False):
|
| 227 |
if (batch_id, stage_id, op_type) in self.op_to_overlapped:
|
| 228 |
return self.op_to_overlapped[(batch_id, stage_id, op_type)]
|
| 229 |
+
if allow_none:
|
| 230 |
+
if (batch_id, stage_id, op_type) not in self.ops:
|
| 231 |
+
return None
|
| 232 |
return self.ops[(batch_id, stage_id, op_type)]
|
| 233 |
|
| 234 |
def get_dependencies(self, op: Operation, include_device_dependency=True):
|
|
|
|
| 253 |
if self.config.split_backward:
|
| 254 |
if op.op_type == "backward_D":
|
| 255 |
if op.stage_id < self.config.num_stages - 1:
|
| 256 |
+
op_bwd_d = self.get_op(op.batch_id, op.stage_id + 1, "backward_D", allow_none=True)
|
| 257 |
+
if op_bwd_d is not None:
|
| 258 |
+
deps.append(
|
| 259 |
+
(
|
| 260 |
+
op_bwd_d,
|
| 261 |
+
self.config.p2p_latency,
|
| 262 |
+
)
|
| 263 |
+
)
|
| 264 |
+
else:
|
| 265 |
+
deps.append(
|
| 266 |
+
(
|
| 267 |
+
self.get_op(op.batch_id, op.stage_id + 1, "backward"),
|
| 268 |
+
self.config.p2p_latency,
|
| 269 |
+
)
|
| 270 |
)
|
|
|
|
| 271 |
elif op.op_type == "backward_W":
|
| 272 |
if op.stage_id < self.config.num_stages - 1:
|
| 273 |
+
op_bwd_d = self.get_op(op.batch_id, op.stage_id, "backward_D", allow_none=True)
|
| 274 |
+
if op_bwd_d is not None:
|
| 275 |
+
deps.append(
|
| 276 |
+
(
|
| 277 |
+
op_bwd_d,
|
| 278 |
+
self.config.p2p_latency,
|
| 279 |
+
)
|
| 280 |
+
)
|
| 281 |
+
else:
|
| 282 |
+
deps.append(
|
| 283 |
+
(
|
| 284 |
+
self.get_op(op.batch_id, op.stage_id, "backward"),
|
| 285 |
+
self.config.p2p_latency,
|
| 286 |
+
)
|
| 287 |
+
)
|
| 288 |
+
elif op.op_type == "backward":
|
| 289 |
+
if op.stage_id < self.config.num_stages - 1:
|
| 290 |
+
op_bwd = self.get_op(op.batch_id, op.stage_id + 1, "backward", allow_none=True)
|
| 291 |
+
if op_bwd is not None:
|
| 292 |
+
deps.append(
|
| 293 |
+
(
|
| 294 |
+
op_bwd,
|
| 295 |
+
self.config.p2p_latency,
|
| 296 |
+
)
|
| 297 |
+
)
|
| 298 |
+
else:
|
| 299 |
+
deps.append(
|
| 300 |
+
(
|
| 301 |
+
self.get_op(op.batch_id, op.stage_id + 1, "backward_D"),
|
| 302 |
+
self.config.p2p_latency,
|
| 303 |
+
)
|
| 304 |
)
|
|
|
|
| 305 |
else:
|
| 306 |
if op.op_type == "backward":
|
| 307 |
if op.stage_id < self.config.num_stages - 1:
|
src/strategies.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
-
from collections import defaultdict
|
| 2 |
-
from src.execution_model import OverlappedOperation, Schedule, ScheduleConfig
|
| 3 |
|
| 4 |
|
| 5 |
def generate_1f1b_schedule(config: ScheduleConfig):
|
|
@@ -43,6 +43,7 @@ def generate_zero_bubble_1p_schedule(config: ScheduleConfig):
|
|
| 43 |
schedule = Schedule(config)
|
| 44 |
total_batches = config.num_batches
|
| 45 |
assert config.num_devices == config.num_stages, "num_devices must be equal to num_stages for ZB-1P"
|
|
|
|
| 46 |
|
| 47 |
for i in range(config.num_devices):
|
| 48 |
fwd_batch_id = 0
|
|
@@ -354,3 +355,227 @@ def generate_1f1b_interleave_overlap_schedule(config: ScheduleConfig):
|
|
| 354 |
|
| 355 |
|
| 356 |
return schedule
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict, deque
|
| 2 |
+
from src.execution_model import OverlappedOperation, Operation, Schedule, ScheduleConfig
|
| 3 |
|
| 4 |
|
| 5 |
def generate_1f1b_schedule(config: ScheduleConfig):
|
|
|
|
| 43 |
schedule = Schedule(config)
|
| 44 |
total_batches = config.num_batches
|
| 45 |
assert config.num_devices == config.num_stages, "num_devices must be equal to num_stages for ZB-1P"
|
| 46 |
+
assert config.split_backward, "ZB-1P requires split_backward=True"
|
| 47 |
|
| 48 |
for i in range(config.num_devices):
|
| 49 |
fwd_batch_id = 0
|
|
|
|
| 355 |
|
| 356 |
|
| 357 |
return schedule
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def create_overlapped_ops(schedule, batch_id1, batch_id2, stage_id, type1, type2):
|
| 361 |
+
"""
|
| 362 |
+
Helper function to create overlapped operations correctly.
|
| 363 |
+
This handles the underlying operation creation and registration to avoid device_id issues.
|
| 364 |
+
"""
|
| 365 |
+
# Get the operations from the schedule
|
| 366 |
+
op1 = schedule.ops[(batch_id1, stage_id, type1)]
|
| 367 |
+
op2 = schedule.ops[(batch_id2, stage_id, type2)]
|
| 368 |
+
|
| 369 |
+
# Create the overlapped operation
|
| 370 |
+
overlapped_op = OverlappedOperation([op1, op2])
|
| 371 |
+
|
| 372 |
+
# Register in the schedule to ensure proper tracking
|
| 373 |
+
schedule.register_overlapped_operation(overlapped_op)
|
| 374 |
+
|
| 375 |
+
return overlapped_op
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def generate_dualpipe_schedule(config: ScheduleConfig):
|
| 379 |
+
"""
|
| 380 |
+
Implements the DualPipe scheduling strategy.
|
| 381 |
+
|
| 382 |
+
DualPipe is a bidirectional pipeline parallelism algorithm that achieves full overlap of forward
|
| 383 |
+
and backward computation-communication phases and reduces pipeline bubbles.
|
| 384 |
+
|
| 385 |
+
The DualPipe strategy has the following characteristics:
|
| 386 |
+
1. Requires placement_strategy="dualpipe" in ScheduleConfig (set automatically)
|
| 387 |
+
2. Each device handles both a forward stage and a reverse stage
|
| 388 |
+
3. Overlaps forward and backward operations to reduce bubble size
|
| 389 |
+
4. Assumes config.num_batches corresponds to half the total microbatches in original paper (M).
|
| 390 |
+
5. Currently only supports split_backward=True.
|
| 391 |
+
|
| 392 |
+
Args:
|
| 393 |
+
config: The scheduling configuration
|
| 394 |
+
|
| 395 |
+
Returns:
|
| 396 |
+
A Schedule object with the DualPipe scheduling
|
| 397 |
+
"""
|
| 398 |
+
# Ensure placement strategy is set for Schedule initialization
|
| 399 |
+
assert config.placement_strategy == "dualpipe", "DualPipe schedule currently only supports placement_strategy='dualpipe'"
|
| 400 |
+
# Assertions based on DualPipe requirements
|
| 401 |
+
assert config.num_stages % 2 == 0, "DualPipe requires an even number of stages (and devices)"
|
| 402 |
+
assert config.num_devices == config.num_stages, "DualPipe requires num_devices == num_stages"
|
| 403 |
+
assert config.num_batches % 2 == 0, "DualPipe requires an even number of microbatches (config.num_batches)"
|
| 404 |
+
# Assertion based on original implementation: num_chunks >= num_ranks * 2
|
| 405 |
+
# Here, M (config.num_batches) corresponds to half_num_chunks
|
| 406 |
+
assert config.num_batches >= config.num_devices, "DualPipe requires config.num_batches >= config.num_devices"
|
| 407 |
+
assert config.split_backward, "DualPipe schedule currently only supports split_backward=True"
|
| 408 |
+
|
| 409 |
+
schedule = Schedule(config, init_ops=False)
|
| 410 |
+
|
| 411 |
+
num_stages = config.num_stages
|
| 412 |
+
num_devices = config.num_devices
|
| 413 |
+
# config.num_batches is M in the original paper, which corresponds to half_num_chunks
|
| 414 |
+
half_num_chunks = config.num_batches // 2
|
| 415 |
+
num_half_ranks = num_devices // 2
|
| 416 |
+
|
| 417 |
+
fwd_batch_ids = defaultdict(int) # (device_id, phase) -> batch_id
|
| 418 |
+
bwd_d_batch_ids = defaultdict(int) # (device_id, phase) -> batch_id
|
| 419 |
+
|
| 420 |
+
waited_weight_grad = [deque() for _ in range(num_devices)] # (device_id, ) -> List[(stage_id, batch_id)]
|
| 421 |
+
|
| 422 |
+
for device_id in range(num_devices):
|
| 423 |
+
is_in_second_half = device_id >= num_half_ranks
|
| 424 |
+
if is_in_second_half:
|
| 425 |
+
fwd_batch_ids[device_id, 1] = 0
|
| 426 |
+
fwd_batch_ids[device_id, 0] = config.num_batches // 2
|
| 427 |
+
bwd_d_batch_ids[device_id, 1] = 0
|
| 428 |
+
bwd_d_batch_ids[device_id, 0] = config.num_batches // 2
|
| 429 |
+
else:
|
| 430 |
+
fwd_batch_ids[device_id, 0] = 0
|
| 431 |
+
fwd_batch_ids[device_id, 1] = config.num_batches // 2
|
| 432 |
+
bwd_d_batch_ids[device_id, 0] = 0
|
| 433 |
+
bwd_d_batch_ids[device_id, 1] = config.num_batches // 2
|
| 434 |
+
def get_stage_for_phase(device_id, phase, num_stages, is_in_second_half):
|
| 435 |
+
stage_fwd_dir = device_id # Stage handled when moving forward (0 to N-1)
|
| 436 |
+
stage_rev_dir = num_stages - 1 - device_id # Stage handled when moving backward (N-1 to 0)
|
| 437 |
+
if not is_in_second_half:
|
| 438 |
+
# First half: phase 0 -> fwd_dir, phase 1 -> rev_dir
|
| 439 |
+
return stage_fwd_dir if phase == 0 else stage_rev_dir
|
| 440 |
+
else:
|
| 441 |
+
# Second half: phase 0 -> rev_dir, phase 1 -> fwd_dir
|
| 442 |
+
return stage_rev_dir if phase == 0 else stage_fwd_dir
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def add_op_to_queue(device_id, stage_id, op_type, batch_id):
|
| 446 |
+
# Retrieve the correct pre-initialized Operation object
|
| 447 |
+
op = Operation(batch_id, stage_id, op_type)
|
| 448 |
+
schedule.register_operation(op)
|
| 449 |
+
# Add to the device queue
|
| 450 |
+
schedule.device_queues[device_id].add_operation(op)
|
| 451 |
+
|
| 452 |
+
def _schedule_forward_chunk(device_id, phase, is_in_second_half):
|
| 453 |
+
"""Schedules a forward compute operation."""
|
| 454 |
+
stage_id = get_stage_for_phase(device_id, phase, num_stages, is_in_second_half)
|
| 455 |
+
batch_id = fwd_batch_ids[device_id, phase]
|
| 456 |
+
add_op_to_queue(device_id, stage_id, "forward", batch_id)
|
| 457 |
+
fwd_batch_ids[device_id, phase] += 1
|
| 458 |
+
|
| 459 |
+
def _schedule_backward_chunk(device_id, phase, is_in_second_half):
|
| 460 |
+
"""Schedules a backward_D with backward_W compute operation."""
|
| 461 |
+
stage_id = get_stage_for_phase(device_id, phase, num_stages, is_in_second_half)
|
| 462 |
+
batch_id = bwd_d_batch_ids[device_id, phase]
|
| 463 |
+
add_op_to_queue(device_id, stage_id, "backward", batch_id)
|
| 464 |
+
bwd_d_batch_ids[device_id, phase] += 1
|
| 465 |
+
|
| 466 |
+
def _schedule_backward_input_chunk(device_id, phase, is_in_second_half):
|
| 467 |
+
"""Schedules a backward_D compute operation."""
|
| 468 |
+
stage_id = get_stage_for_phase(device_id, phase, num_stages, is_in_second_half)
|
| 469 |
+
batch_id = bwd_d_batch_ids[device_id, phase]
|
| 470 |
+
add_op_to_queue(device_id, stage_id, "backward_D", batch_id)
|
| 471 |
+
bwd_d_batch_ids[device_id, phase] += 1
|
| 472 |
+
waited_weight_grad[device_id].append((stage_id, batch_id))
|
| 473 |
+
|
| 474 |
+
def _schedule_backward_weight_chunk(device_id):
|
| 475 |
+
"""Schedules a backward_W compute operation."""
|
| 476 |
+
stage_id, batch_id = waited_weight_grad[device_id].popleft()
|
| 477 |
+
add_op_to_queue(device_id, stage_id, "backward_W", batch_id)
|
| 478 |
+
|
| 479 |
+
def _schedule_forward_backward_chunk(device_id, fwd_phase, bwd_phase, is_in_second_half):
|
| 480 |
+
"""Schedules an overlapped forward and backward_D compute operation."""
|
| 481 |
+
fwd_stage_id = get_stage_for_phase(device_id, fwd_phase, num_stages, is_in_second_half)
|
| 482 |
+
bwd_stage_id = get_stage_for_phase(device_id, bwd_phase, num_stages, is_in_second_half)
|
| 483 |
+
|
| 484 |
+
fwd_batch_id = fwd_batch_ids[device_id, fwd_phase]
|
| 485 |
+
|
| 486 |
+
fwd_op = Operation(fwd_batch_id, fwd_stage_id, "forward")
|
| 487 |
+
schedule.register_operation(fwd_op)
|
| 488 |
+
fwd_batch_ids[device_id, fwd_phase] += 1
|
| 489 |
+
|
| 490 |
+
bwd_batch_id_d = bwd_d_batch_ids[device_id, bwd_phase]
|
| 491 |
+
bwd_op = Operation(bwd_batch_id_d, bwd_stage_id, "backward")
|
| 492 |
+
schedule.register_operation(bwd_op)
|
| 493 |
+
bwd_d_batch_ids[device_id, bwd_phase] += 1
|
| 494 |
+
|
| 495 |
+
# Create and register the overlapped operation
|
| 496 |
+
overlapped_op = OverlappedOperation([fwd_op, bwd_op])
|
| 497 |
+
schedule.register_overlapped_operation(overlapped_op)
|
| 498 |
+
|
| 499 |
+
# Add the overlapped operation to the queue
|
| 500 |
+
schedule.device_queues[device_id].add_operation(overlapped_op)
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
# Process each device (rank in original code)
|
| 504 |
+
for device_id in range(num_devices):
|
| 505 |
+
half_rank = min(device_id, num_devices - 1 - device_id)
|
| 506 |
+
is_in_second_half = device_id >= num_half_ranks
|
| 507 |
+
is_middle_rank = (device_id == num_half_ranks - 1) or (device_id == num_half_ranks)
|
| 508 |
+
|
| 509 |
+
# Map original steps to operation additions
|
| 510 |
+
# Step 1: nF0
|
| 511 |
+
step_1_count = (num_half_ranks - half_rank - 1) * 2
|
| 512 |
+
for _ in range(step_1_count):
|
| 513 |
+
_schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
|
| 514 |
+
|
| 515 |
+
# Step 2: nF0F1
|
| 516 |
+
step_2_count = half_rank + 1
|
| 517 |
+
for i in range(step_2_count):
|
| 518 |
+
_schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
|
| 519 |
+
_schedule_forward_chunk(device_id, 1, is_in_second_half) # F1
|
| 520 |
+
|
| 521 |
+
# Step 3: nB1W1F1
|
| 522 |
+
step_3_count = num_half_ranks - half_rank - 1
|
| 523 |
+
for _ in range(step_3_count):
|
| 524 |
+
_schedule_backward_input_chunk(device_id, 1, is_in_second_half) # B1_D
|
| 525 |
+
_schedule_backward_weight_chunk(device_id,) # W1
|
| 526 |
+
_schedule_forward_chunk(device_id, 1, is_in_second_half) # F1
|
| 527 |
+
|
| 528 |
+
# Step 4 (Main step): nF0B1F1B0
|
| 529 |
+
step_4_count = half_num_chunks - num_devices + half_rank + 1
|
| 530 |
+
for i in range(step_4_count):
|
| 531 |
+
# if i == 0 and is_middle_rank:
|
| 532 |
+
# Schedule F0, B1_D, W1 sequentially for middle ranks on first iteration
|
| 533 |
+
# _schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
|
| 534 |
+
# _schedule_backward_chunk(device_id, 1, is_in_second_half)# B1
|
| 535 |
+
# _schedule_backward_weight_chunk(device_id, 1, is_in_second_half) # W1
|
| 536 |
+
# else:
|
| 537 |
+
# Overlap F0 and B1_D, then schedule W1
|
| 538 |
+
_schedule_forward_backward_chunk(device_id, 0, 1, is_in_second_half) # F0+B1
|
| 539 |
+
|
| 540 |
+
# Overlap F1 and B0_D, then schedule W0
|
| 541 |
+
_schedule_forward_backward_chunk(device_id, 1, 0, is_in_second_half) # F1+B0
|
| 542 |
+
|
| 543 |
+
# Step 5: nB1F1B0
|
| 544 |
+
step_5_count = num_half_ranks - half_rank - 1
|
| 545 |
+
for _ in range(step_5_count):
|
| 546 |
+
_schedule_backward_chunk(device_id, 1, is_in_second_half) # B1_D + B1_W
|
| 547 |
+
_schedule_forward_backward_chunk(device_id, 1, 0, is_in_second_half) # F1+B0
|
| 548 |
+
|
| 549 |
+
# Step 6: nB1B0
|
| 550 |
+
step_6_count = half_rank + 1
|
| 551 |
+
enable_zb = False
|
| 552 |
+
for i in range(step_6_count):
|
| 553 |
+
if i == step_6_count // 2 and half_rank % 2 == 1:
|
| 554 |
+
enable_zb = True
|
| 555 |
+
if enable_zb:
|
| 556 |
+
_schedule_backward_input_chunk(device_id, 1, is_in_second_half)
|
| 557 |
+
else:
|
| 558 |
+
_schedule_backward_chunk(device_id, 1, is_in_second_half)
|
| 559 |
+
if i == step_6_count // 2 and half_rank % 2 == 0:
|
| 560 |
+
enable_zb = True
|
| 561 |
+
if enable_zb:
|
| 562 |
+
_schedule_backward_input_chunk(device_id, 0, is_in_second_half)
|
| 563 |
+
else:
|
| 564 |
+
_schedule_backward_chunk(device_id, 0, is_in_second_half)
|
| 565 |
+
|
| 566 |
+
# Step 7: nWB0
|
| 567 |
+
step_7_count = num_half_ranks - half_rank - 1
|
| 568 |
+
for _ in range(step_7_count):
|
| 569 |
+
_schedule_backward_weight_chunk(device_id) # W1 (use gradient from B1_D scheduled previously)
|
| 570 |
+
_schedule_backward_input_chunk(device_id, 0, is_in_second_half) # B0_D
|
| 571 |
+
|
| 572 |
+
# Step 8: nW
|
| 573 |
+
step_8_count = half_rank + 1
|
| 574 |
+
for _ in range(step_8_count):
|
| 575 |
+
# W0 uses gradients from B0_D scheduled in steps 4, 5, 6.
|
| 576 |
+
# W1 uses gradients from B1_D scheduled in steps 3, 4, 5, 6.
|
| 577 |
+
# The last W0 gradients correspond to B0_D from step 6 or 7.
|
| 578 |
+
_schedule_backward_weight_chunk(device_id) # W0 (use gradient from B0_D scheduled previously)
|
| 579 |
+
|
| 580 |
+
return schedule
|
| 581 |
+
|
src/visualizer.py
CHANGED
|
@@ -89,11 +89,6 @@ def get_color(op_type: str, stage_id: int, num_devices: int):
|
|
| 89 |
|
| 90 |
# Improved teal/turquoise palette with low saturation and high brightness
|
| 91 |
backward_d_colors = [
|
| 92 |
-
"#ccffff", # Very light cyan
|
| 93 |
-
"#b3ffff", # Pale cyan
|
| 94 |
-
"#99ffff", # Light cyan
|
| 95 |
-
"#80ffff", # Cyan
|
| 96 |
-
"#66e6e6", # Soft teal
|
| 97 |
"#4dcccc", # Light teal
|
| 98 |
"#33b3b3", # Teal
|
| 99 |
"#009999", # Medium teal
|
|
@@ -102,12 +97,6 @@ def get_color(op_type: str, stage_id: int, num_devices: int):
|
|
| 102 |
|
| 103 |
# Improved green palette with low saturation and high brightness
|
| 104 |
backward_w_colors = [
|
| 105 |
-
"#ccffe6", # Very light mint
|
| 106 |
-
"#b3ffd9", # Pale mint
|
| 107 |
-
"#99ffcc", # Light mint
|
| 108 |
-
"#80ffbf", # Mint green
|
| 109 |
-
"#66e6a6", # Soft green
|
| 110 |
-
"#4dcc8c", # Light green
|
| 111 |
"#33b373", # Medium green
|
| 112 |
"#009959", # Forest green
|
| 113 |
"#008040", # Dark green
|
|
@@ -162,7 +151,8 @@ def create_pipeline_figure(
|
|
| 162 |
max_batch = max(max_batch, task["batch"])
|
| 163 |
|
| 164 |
# Flag to determine whether to show text labels
|
| 165 |
-
|
|
|
|
| 166 |
|
| 167 |
# Create a figure
|
| 168 |
fig = go.Figure()
|
|
|
|
| 89 |
|
| 90 |
# Improved teal/turquoise palette with low saturation and high brightness
|
| 91 |
backward_d_colors = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
"#4dcccc", # Light teal
|
| 93 |
"#33b3b3", # Teal
|
| 94 |
"#009999", # Medium teal
|
|
|
|
| 97 |
|
| 98 |
# Improved green palette with low saturation and high brightness
|
| 99 |
backward_w_colors = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
"#33b373", # Medium green
|
| 101 |
"#009959", # Forest green
|
| 102 |
"#008040", # Dark green
|
|
|
|
| 151 |
max_batch = max(max_batch, task["batch"])
|
| 152 |
|
| 153 |
# Flag to determine whether to show text labels
|
| 154 |
+
num_operations_per_device = len(schedule_data[0])
|
| 155 |
+
show_text_labels = num_operations_per_device <= 64
|
| 156 |
|
| 157 |
# Create a figure
|
| 158 |
fig = go.Figure()
|