Victarry commited on
Commit
592da35
·
1 Parent(s): efe627a

Add visualization script and README for Pipeline Parallelism in Megatron-LM

Browse files
assets/dumped_example.jpg ADDED

Git LFS Details

  • SHA256: a12aecf1e52d57ef7f13c2b6bfd0a7320dd8ed389c8722a86fd999953b301a3b
  • Pointer size: 131 Bytes
  • Size of remote file: 796 kB
examples/megatron-lm/README.md ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pipeline Parallelism Visualization for Megatron-LM
2
+
3
+ This tool provides visualization capabilities for Pipeline Parallelism (PP) scheduling in Megatron-LM training, helping you analyze load balancing issues and debug abnormal PP bubble problems that are difficult to inspect directly from Nsight Systems profiling.
4
+
5
+ ## Overview
6
+
7
+ The visualization tool offers intuitive visual representation of PP scheduling, making it easier to:
8
+ - Identify load balancing issues across pipeline stages
9
+ - Debug PP bubble problems
10
+ - Analyze pipeline efficiency and bottlenecks
11
+ - Optimize pipeline parallelism configurations
12
+
13
+ ## Prerequisites
14
+
15
+ - Megatron-LM with PP timer support
16
+ - Python environment with required dependencies
17
+ - UV package manager (recommended)
18
+
19
+ ## Usage
20
+
21
+ ### Step 1: Enable PP Timer in Megatron-LM
22
+
23
+ First, you need to apply the PP timer patch to your Megatron-LM installation:
24
+
25
+ 1. Cherry-pick the commit from the modified Megatron-LM repository:
26
+ ```bash
27
+ # Navigate to your Megatron-LM directory
28
+ cd /path/to/Megatron-LM
29
+
30
+ # Cherry-pick the PP timer commit
31
+ git remote add victarry https://github.com/Victarry/PP-Schedule-Visualization.git
32
+ git fetch victarry
33
+ git cherry-pick ad3bc3a22adc79827dc1b35619ad6813078e621b
34
+ ```
35
+
36
+ **Note**: The commit can be viewed at: https://github.com/Victarry/Megatron-LM/commit/ad3bc3a22adc79827dc1b35619ad6813078e621b
37
+
38
+ ### Step 2: Configure Environment Variables
39
+
40
+ Set the following environment variables before running your training script:
41
+
42
+ ```bash
43
+ # Enable PP timer functionality
44
+ export ENABLE_PP_TIMER=1
45
+
46
+ # Specify which iteration to dump (e.g., iteration 1)
47
+ export ENABLE_PP_TIMER_ITER=1
48
+
49
+ # Set directory to save the dumped timer results
50
+ export PP_TIMER_LOG_DIR=/path/to/save/timer/logs
51
+
52
+ # Run your training script
53
+ bash your_training_script.sh
54
+ ```
55
+
56
+ ### Step 3: Generate Visualization
57
+
58
+ Once you have collected the timer data, use the visualization script:
59
+
60
+ ```bash
61
+ # Navigate to the PP-Schedule-Visualization directory
62
+ cd /path/to/PP-Schedule-Visualization
63
+
64
+ # Set your configuration parameters
65
+ PP_SIZE=4 # Number of pipeline parallel stages
66
+ VPP_SIZE=1 # Virtual pipeline parallel size (usually 1)
67
+ DATA_DIR=/path/to/timer/logs # Directory containing the dumped timer data
68
+
69
+ # Run the visualization script
70
+ uv run examples/megatron-lm/plot.py --data-dir $DATA_DIR --pp-size $PP_SIZE --vpp-size $VPP_SIZE
71
+ ```
72
+
73
+ **Parameters:**
74
+ - `--data-dir`: Path to the directory containing PP timer log files
75
+ - `--pp-size`: Number of pipeline parallel stages in your training setup
76
+ - `--vpp-size`: Virtual pipeline parallel size (typically 1 unless using virtual PP)
77
+
78
+ ### Example Output
79
+
80
+ After running the visualization script, you will see a detailed PP schedule visualization similar to:
81
+
82
+ ![PP Schedule Visualization](../../assets/dumped_example.jpg)
83
+
84
+ The visualization shows:
85
+ - Timeline of each pipeline stage
86
+ - Forward and backward pass scheduling
87
+ - Bubble time and idle periods
88
+ - Communication overhead between stages
89
+
90
+ ## Known Issue
91
+ - If the global batch size is very large, it may takes > 1 minutes to see the visualization results.
92
+
93
+ ## Contributing
94
+
95
+ If you encounter issues or have suggestions for improvements, please open an issue or submit a pull request.
examples/megatron-lm/plot.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import argparse
4
+ import re
5
+ from collections import defaultdict
6
+ from src.execution_model import Schedule, ScheduleConfig, Operation
7
+ from src.visualizer import visualize_pipeline_parallelism_dash
8
+
9
+
10
+ def is_valid_event_filename(filename, pp_size, vpp_size):
11
+ """
12
+ Check if filename matches the expected format:
13
+ event_times_PP{pp_size}_VPP{vpp_size}_TPxCPxDP_rank_{rank}_pp_rank_{pp_rank}_rank_{final_rank}.json
14
+
15
+ Returns True if valid, False otherwise.
16
+ """
17
+ # Create regex pattern for the expected format
18
+ pattern = rf"^event_times_PP{pp_size}_VPP{vpp_size}_TPxCPxDP_rank_\d+_pp_rank_\d+_rank_\d+\.json$"
19
+ return bool(re.match(pattern, filename))
20
+
21
+
22
+ def parse_event_filename(filename):
23
+ """
24
+ Parse the event filename and extract rank information.
25
+
26
+ Expected format: event_times_PP{pp_size}_VPP{vpp_size}_TPxCPxDP_rank_{rank}_pp_rank_{pp_rank}_rank_{final_rank}.json
27
+
28
+ Returns: (TPxCPxDP_rank, pp_rank, global_rank) or None if parsing fails
29
+ """
30
+ try:
31
+ # Remove .json extension
32
+ name_without_ext = filename.replace(".json", "")
33
+ parts = name_without_ext.split("_")
34
+
35
+ # Find the TPxCPxDP part and the rank values
36
+ tpxcpxdp_rank = None
37
+ pp_rank = None
38
+ global_rank = None
39
+
40
+ for i, part in enumerate(parts):
41
+ # Look for TPxCPxDP pattern followed by 'rank'
42
+ if part.startswith("TP") and "CP" in part and part.endswith("DP"):
43
+ if i + 2 < len(parts) and parts[i + 1] == "rank":
44
+ tpxcpxdp_rank = int(parts[i + 2])
45
+ # Look for 'pp_rank' pattern
46
+ elif part == "pp" and i + 2 < len(parts) and parts[i + 1] == "rank":
47
+ pp_rank = int(parts[i + 2])
48
+ # Look for the final 'rank' (global rank) - this should be the last rank in the filename
49
+ elif part == "rank" and i + 1 < len(parts) and i == len(parts) - 2:
50
+ global_rank = int(parts[i + 1])
51
+
52
+ if tpxcpxdp_rank is None or pp_rank is None or global_rank is None:
53
+ return None
54
+
55
+ return (tpxcpxdp_rank, pp_rank, global_rank)
56
+
57
+ except (ValueError, IndexError):
58
+ return None
59
+
60
+
61
+ def load_event_times_from_json(data_dir, pp_size, vpp_size):
62
+ """Load event times from JSON files in the specified directory."""
63
+ all_files = [f for f in os.listdir(data_dir) if f.endswith(".json")]
64
+
65
+ # Filter files that match the expected format
66
+ event_files = [
67
+ f for f in all_files if is_valid_event_filename(f, pp_size, vpp_size)
68
+ ]
69
+
70
+ if len(event_files) == 0:
71
+ print(f"Available files in {data_dir}:")
72
+ for f in all_files[:10]: # Show first 10 files for debugging
73
+ print(f" {f}")
74
+ raise ValueError(
75
+ f"No event files found matching pattern event_times_PP{pp_size}_VPP{vpp_size}_*.json"
76
+ )
77
+
78
+ print(f"Found {len(event_files)} matching event files")
79
+ event_times = {}
80
+
81
+ for file_name in event_files:
82
+ parsed_result = parse_event_filename(file_name)
83
+ if parsed_result is None:
84
+ print(f"Warning: Could not parse filename {file_name}")
85
+ continue
86
+
87
+ tpxcpxdp_rank, pp_rank, global_rank = parsed_result
88
+
89
+ if tpxcpxdp_rank == 0:
90
+ try:
91
+ with open(os.path.join(data_dir, file_name), "r") as f:
92
+ event_data = json.load(f)
93
+ event_times[(pp_rank, tpxcpxdp_rank)] = event_data
94
+ print(
95
+ f"Loaded data from {file_name}: global_rank={global_rank}, pp_rank={pp_rank}, tpxcpxdp_rank={tpxcpxdp_rank}"
96
+ )
97
+ except Exception as e:
98
+ print(f"Error loading {file_name}: {e}")
99
+
100
+ return event_times
101
+
102
+
103
+ def create_pp_schedule_from_event_times(event_times, pp_size):
104
+ """Create a Schedule object from event times data."""
105
+ # Determine number of devices/stages from the data
106
+ num_devices = pp_size
107
+
108
+ # Find the maximum batch ID by parsing event names
109
+ max_batch_id = 0
110
+ for events in event_times.values():
111
+ for event_name in events:
112
+ if event_name.startswith(("forward-", "backward-")):
113
+ parts = event_name.split("-")
114
+ if len(parts) >= 2 and parts[1].isdigit():
115
+ batch_id = int(parts[1])
116
+ max_batch_id = max(max_batch_id, batch_id)
117
+
118
+ num_batches = max_batch_id + 1
119
+
120
+ # Create a simple config (actual times will come from event data)
121
+ config = ScheduleConfig(
122
+ num_devices=num_devices,
123
+ num_stages=num_devices, # Assuming 1:1 mapping of devices to stages
124
+ num_batches=num_batches,
125
+ p2p_latency=0, # Will be implicit in the event timing
126
+ op_times={}, # Not needed as we'll use real timing data
127
+ placement_strategy="standard",
128
+ )
129
+
130
+ # Create a schedule
131
+ schedule = Schedule(config)
132
+
133
+ # Populate the schedule with operations based on event times
134
+ for (pp_rank, tpxcpxdp_rank), events in event_times.items():
135
+ # Process forward passes
136
+ for batch_id in range(num_batches):
137
+ forward_start_key = f"forward-{batch_id}-start"
138
+ forward_end_key = f"forward-{batch_id}-end"
139
+
140
+ if forward_start_key in events and forward_end_key in events:
141
+ # Create an operation and set its timing directly
142
+ forward_op = Operation(batch_id, pp_rank, "forward")
143
+ forward_op.execution_time = (
144
+ events[forward_end_key] - events[forward_start_key]
145
+ )
146
+ forward_op.start_time = events[forward_start_key]
147
+ forward_op.end_time = events[forward_end_key]
148
+
149
+ # Add to schedule
150
+ schedule.ops[(batch_id, pp_rank, "forward")] = forward_op
151
+ schedule.device_queues[pp_rank].add_operation(forward_op)
152
+
153
+ # Process backward passes
154
+ for batch_id in range(num_batches):
155
+ backward_start_key = f"backward-{batch_id}-start"
156
+ backward_end_key = f"backward-{batch_id}-end"
157
+
158
+ if backward_start_key in events and backward_end_key in events:
159
+ # Create an operation and set its timing directly
160
+ backward_op = Operation(batch_id, pp_rank, "backward")
161
+ backward_op.execution_time = (
162
+ events[backward_end_key] - events[backward_start_key]
163
+ )
164
+ backward_op.start_time = events[backward_start_key]
165
+ backward_op.end_time = events[backward_end_key]
166
+
167
+ # Add to schedule
168
+ schedule.ops[(batch_id, pp_rank, "backward")] = backward_op
169
+ schedule.device_queues[pp_rank].add_operation(backward_op)
170
+
171
+ return schedule
172
+
173
+
174
+ def create_vpp_schedule_from_event_times(event_times, pp_size, vpp_size):
175
+ """Create a VPP Schedule object from event times data."""
176
+ # Determine number of devices/stages from the data
177
+ # Find the maximum batch ID by parsing event names
178
+ max_batch_id = 0
179
+ for events in event_times.values():
180
+ for event_name in events:
181
+ if event_name.startswith(("forward-", "backward-")):
182
+ parts = event_name.split("-")
183
+ assert len(parts) == 4
184
+ assert parts[0] in ["forward", "backward"]
185
+ assert parts[1].isdigit() and parts[2].isdigit()
186
+ assert parts[3] in ["start", "end"]
187
+ batch_id = int(parts[2]) # backward-0-19-end
188
+ max_batch_id = max(max_batch_id, batch_id)
189
+
190
+ num_batches = max_batch_id + 1
191
+
192
+ # Create a simple config (actual times will come from event data)
193
+ config = ScheduleConfig(
194
+ num_devices=pp_size,
195
+ num_stages=pp_size * vpp_size,
196
+ num_batches=num_batches,
197
+ p2p_latency=0, # Will be implicit in the event timing
198
+ op_times={}, # Not needed as we'll use real timing data
199
+ placement_strategy="interleave",
200
+ )
201
+
202
+ # Create a schedule
203
+ schedule = Schedule(config)
204
+
205
+ # Populate the schedule with operations based on event times
206
+ for (pp_rank, tpxcpxdp_rank), events in event_times.items():
207
+ # Process forward passes
208
+ for model_chunk_id in range(vpp_size):
209
+ for batch_id in range(num_batches):
210
+ forward_start_key = f"forward-{model_chunk_id}-{batch_id}-start"
211
+ forward_end_key = f"forward-{model_chunk_id}-{batch_id}-end"
212
+
213
+ # Create an operation and set its timing directly
214
+ stage_id = pp_size * model_chunk_id + pp_rank
215
+ forward_op = Operation(batch_id, stage_id=stage_id, op_type="forward")
216
+ forward_op.execution_time = (
217
+ events[forward_end_key] - events[forward_start_key]
218
+ )
219
+ forward_op.start_time = events[forward_start_key]
220
+ forward_op.end_time = events[forward_end_key]
221
+
222
+ # Add to schedule
223
+ schedule.ops[(batch_id, stage_id, "forward")] = forward_op
224
+ schedule.device_queues[pp_rank].add_operation(forward_op)
225
+
226
+ # Process backward passes
227
+ for model_chunk_id in range(vpp_size):
228
+ for batch_id in range(num_batches):
229
+ backward_start_key = f"backward-{model_chunk_id}-{batch_id}-start"
230
+ backward_end_key = f"backward-{model_chunk_id}-{batch_id}-end"
231
+
232
+ stage_id = pp_size * model_chunk_id + pp_rank
233
+ if backward_start_key in events and backward_end_key in events:
234
+ # Create an operation and set its timing directly
235
+ backward_op = Operation(
236
+ batch_id, stage_id=stage_id, op_type="backward"
237
+ )
238
+ backward_op.execution_time = (
239
+ events[backward_end_key] - events[backward_start_key]
240
+ )
241
+ backward_op.start_time = events[backward_start_key]
242
+ backward_op.end_time = events[backward_end_key]
243
+
244
+ # Add to schedule
245
+ schedule.ops[(batch_id, stage_id, "backward")] = backward_op
246
+ schedule.device_queues[pp_rank].add_operation(backward_op)
247
+
248
+ return schedule
249
+
250
+
251
+ def main():
252
+ # Parse command-line arguments
253
+ parser = argparse.ArgumentParser(
254
+ description="Visualize pipeline parallelism from event data"
255
+ )
256
+ parser.add_argument(
257
+ "--data-dir",
258
+ type=str,
259
+ required=True,
260
+ help="Directory containing event_times_*.json files",
261
+ )
262
+ parser.add_argument(
263
+ "--pp-size", type=int, required=True, help="Pipeline parallelism size"
264
+ )
265
+ parser.add_argument(
266
+ "--vpp-size", type=int, required=True, help="Virtual pipeline parallelism size"
267
+ )
268
+ parser.add_argument(
269
+ "--port",
270
+ type=int,
271
+ default=8050,
272
+ help="Port for the visualization dashboard (default: 8050)",
273
+ )
274
+ args = parser.parse_args()
275
+
276
+ # Load event times from JSON files
277
+ event_times = load_event_times_from_json(args.data_dir, args.pp_size, args.vpp_size)
278
+
279
+ # Create schedule from event times
280
+ if args.vpp_size == 1:
281
+ schedule = create_pp_schedule_from_event_times(event_times, args.pp_size)
282
+ else:
283
+ schedule = create_vpp_schedule_from_event_times(
284
+ event_times, args.pp_size, args.vpp_size
285
+ )
286
+
287
+ # Calculate and print execution metrics
288
+ total_execution_time = max(
289
+ op.end_time for op in schedule.ops.values() if op.end_time is not None
290
+ )
291
+ print(f"Total execution time: {total_execution_time:.2f} ms")
292
+
293
+ # Calculate bubble time percentage
294
+ device_times = defaultdict(float)
295
+ for device_id, device_queue in enumerate(schedule.device_queues):
296
+ for op in device_queue.ops:
297
+ if op.start_time is not None and op.end_time is not None:
298
+ device_times[device_id] += op.end_time - op.start_time
299
+
300
+ # Print bubble percentage for each device
301
+ for device_id, active_time in device_times.items():
302
+ bubble_percentage = (
303
+ (total_execution_time - active_time) / total_execution_time * 100
304
+ )
305
+ print(f"Device {device_id} bubble: {bubble_percentage:.2f}%")
306
+
307
+ # Visualize the schedule
308
+ print("Launching visualization...")
309
+ visualize_pipeline_parallelism_dash(
310
+ schedule, schedule_type="1F1B-Imported", port=args.port
311
+ )
312
+
313
+
314
+ if __name__ == "__main__":
315
+ main()