Victarry commited on
Commit
2bb73ed
·
2 Parent(s): f300aa9 fa7e466

Merge branch 'main' into hf_space

Browse files

* main:
Make P2P in warmup/cooldown stage to sync comm.
Add support to set custom stage time.
Add visualization script and README for Pipeline Parallelism in Megatron-LM
Update link.

app.py CHANGED
@@ -292,6 +292,35 @@ timing_params_card = dbc.Card([
292
  ])
293
  ], style=card_style)
294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  # Updated app layout with improved structure
296
  app.layout = html.Div([
297
  header,
@@ -346,6 +375,7 @@ app.layout = html.Div([
346
  basic_params_card,
347
  scheduling_params_card,
348
  timing_params_card,
 
349
 
350
  # Generate button with better styling
351
  dbc.Button([
@@ -398,15 +428,11 @@ app.layout = html.Div([
398
  html.A([
399
  html.I(className="bi bi-github me-2"),
400
  "View on GitHub"
401
- ], href="#", className="small text-muted d-block mb-2"),
402
- html.A([
403
- html.I(className="bi bi-book me-2"),
404
- "Documentation"
405
- ], href="#", className="small text-muted d-block mb-2"),
406
  html.A([
407
  html.I(className="bi bi-question-circle me-2"),
408
  "Report an Issue"
409
- ], href="#", className="small text-muted d-block")
410
  ])
411
  ], md=4)
412
  ]),
@@ -525,6 +551,75 @@ def toggle_advanced_options(n_clicks, is_open):
525
  return not is_open
526
  return is_open
527
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528
  # --- Client-side Callback for Strategy Card Selection ---
529
  app.clientside_callback(
530
  """
@@ -580,12 +675,14 @@ app.clientside_callback(
580
  State('op_time_overlapped_fwd_bwd', 'value'),
581
  State('microbatch_group_size_per_vp_stage', 'value'),
582
  State('selected-strategies-store', 'data'),
 
 
583
  prevent_initial_call=True
584
  )
585
  def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
586
  op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
587
  op_time_overlapped_fwd_bwd, microbatch_group_size_per_vp_stage,
588
- selected_strategies):
589
 
590
  strategy_display_order = ["1f1b", "1f1b_interleave", "1f1b_overlap", "1f1b_interleave_overlap", "dualpipe", "zb1p"]
591
 
@@ -673,14 +770,40 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
673
  if adjustment_msg not in automatic_adjustments:
674
  automatic_adjustments.append(adjustment_msg)
675
 
676
- op_times = { "forward": float(op_time_forward) * time_scale_factor }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
677
 
 
678
  if split_backward:
679
  op_times["backward_D"] = float(op_time_backward_d) * time_scale_factor
680
  op_times["backward_W"] = float(op_time_backward_w) * time_scale_factor
681
  op_times["backward"] = (float(op_time_backward_d) + float(op_time_backward_w)) * time_scale_factor
682
  else:
683
- op_times["backward"] = float(op_time_backward) * time_scale_factor
 
 
 
 
 
 
 
 
 
 
684
 
685
  if op_time_overlapped_fwd_bwd is not None:
686
  try:
 
292
  ])
293
  ], style=card_style)
294
 
295
+ # Per-stage timing configuration card
296
+ per_stage_timing_card = dbc.Card([
297
+ dbc.CardBody([
298
+ html.H5([
299
+ html.I(className="bi bi-list-ol section-icon"),
300
+ "Per-Stage Timing Configuration"
301
+ ], className="section-title"),
302
+
303
+ dbc.Button([
304
+ html.I(className="bi bi-sliders2 me-2"),
305
+ "Customize Per-Stage Timing"
306
+ ],
307
+ id="per-stage-timing-toggle",
308
+ color="light",
309
+ className="mb-3 w-100",
310
+ size="sm"
311
+ ),
312
+
313
+ dbc.Collapse([
314
+ dbc.Alert([
315
+ html.I(className="bi bi-info-circle-fill me-2"),
316
+ "Override global timing values for individual stages. Leave empty to use global values."
317
+ ], color="info", className="mb-3"),
318
+
319
+ html.Div(id='per-stage-inputs-container', children=[])
320
+ ], id="per-stage-timing-collapse", is_open=False)
321
+ ])
322
+ ], style=card_style)
323
+
324
  # Updated app layout with improved structure
325
  app.layout = html.Div([
326
  header,
 
375
  basic_params_card,
376
  scheduling_params_card,
377
  timing_params_card,
378
+ per_stage_timing_card,
379
 
380
  # Generate button with better styling
381
  dbc.Button([
 
428
  html.A([
429
  html.I(className="bi bi-github me-2"),
430
  "View on GitHub"
431
+ ], href="https://github.com/Victarry/PP-Schedule-Visualization", className="small text-muted d-block mb-2"),
 
 
 
 
432
  html.A([
433
  html.I(className="bi bi-question-circle me-2"),
434
  "Report an Issue"
435
+ ], href="https://github.com/Victarry/PP-Schedule-Visualization/issues", className="small text-muted d-block")
436
  ])
437
  ], md=4)
438
  ]),
 
551
  return not is_open
552
  return is_open
553
 
554
+ # --- Callback to toggle Per-Stage Timing Collapse ---
555
+ @app.callback(
556
+ Output("per-stage-timing-collapse", "is_open"),
557
+ Input("per-stage-timing-toggle", "n_clicks"),
558
+ State("per-stage-timing-collapse", "is_open"),
559
+ prevent_initial_call=True,
560
+ )
561
+ def toggle_per_stage_timing(n_clicks, is_open):
562
+ if n_clicks:
563
+ return not is_open
564
+ return is_open
565
+
566
+ # --- Callback to dynamically generate per-stage timing inputs ---
567
+ @app.callback(
568
+ Output("per-stage-inputs-container", "children"),
569
+ Input("num_stages", "value"),
570
+ )
571
+ def generate_per_stage_inputs(num_stages):
572
+ if num_stages is None or num_stages < 1:
573
+ return []
574
+
575
+ # Limit to reasonable number of stages for UI
576
+ num_stages = min(int(num_stages), 32)
577
+
578
+ stage_inputs = []
579
+ for stage_id in range(num_stages):
580
+ stage_inputs.append(
581
+ dbc.Row([
582
+ dbc.Col([
583
+ html.Strong(f"Stage {stage_id}", className="text-muted")
584
+ ], width=2, className="d-flex align-items-center"),
585
+ dbc.Col([
586
+ dbc.InputGroup([
587
+ dbc.InputGroupText("F", style={"minWidth": "30px"}),
588
+ dbc.Input(
589
+ id={"type": "stage-forward", "index": stage_id},
590
+ type="number",
591
+ placeholder="1.0",
592
+ min=0.01,
593
+ step=0.01,
594
+ size="sm"
595
+ ),
596
+ ], size="sm")
597
+ ], width=5),
598
+ dbc.Col([
599
+ dbc.InputGroup([
600
+ dbc.InputGroupText("B", style={"minWidth": "30px"}),
601
+ dbc.Input(
602
+ id={"type": "stage-backward", "index": stage_id},
603
+ type="number",
604
+ placeholder="1.0",
605
+ min=0.01,
606
+ step=0.01,
607
+ size="sm"
608
+ ),
609
+ ], size="sm")
610
+ ], width=5),
611
+ ], className="mb-2 g-2")
612
+ )
613
+
614
+ # Add header row
615
+ header = dbc.Row([
616
+ dbc.Col([html.Small("Stage", className="text-muted fw-bold")], width=2),
617
+ dbc.Col([html.Small("Forward Time", className="text-muted fw-bold")], width=5),
618
+ dbc.Col([html.Small("Backward Time", className="text-muted fw-bold")], width=5),
619
+ ], className="mb-2")
620
+
621
+ return [header] + stage_inputs
622
+
623
  # --- Client-side Callback for Strategy Card Selection ---
624
  app.clientside_callback(
625
  """
 
675
  State('op_time_overlapped_fwd_bwd', 'value'),
676
  State('microbatch_group_size_per_vp_stage', 'value'),
677
  State('selected-strategies-store', 'data'),
678
+ State({'type': 'stage-forward', 'index': ALL}, 'value'),
679
+ State({'type': 'stage-backward', 'index': ALL}, 'value'),
680
  prevent_initial_call=True
681
  )
682
  def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
683
  op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
684
  op_time_overlapped_fwd_bwd, microbatch_group_size_per_vp_stage,
685
+ selected_strategies, stage_forward_values, stage_backward_values):
686
 
687
  strategy_display_order = ["1f1b", "1f1b_interleave", "1f1b_overlap", "1f1b_interleave_overlap", "dualpipe", "zb1p"]
688
 
 
770
  if adjustment_msg not in automatic_adjustments:
771
  automatic_adjustments.append(adjustment_msg)
772
 
773
+ # Check if per-stage timing values are provided
774
+ has_per_stage_forward = stage_forward_values and any(v is not None for v in stage_forward_values)
775
+ has_per_stage_backward = stage_backward_values and any(v is not None for v in stage_backward_values)
776
+
777
+ # Build forward timing - either per-stage dict or global value
778
+ if has_per_stage_forward:
779
+ forward_times = {}
780
+ for stage_id in range(current_num_stages):
781
+ if stage_id < len(stage_forward_values) and stage_forward_values[stage_id] is not None:
782
+ forward_times[stage_id] = float(stage_forward_values[stage_id]) * time_scale_factor
783
+ else:
784
+ # Use global value as fallback (default 1.0 if not specified)
785
+ forward_times[stage_id] = float(op_time_forward if op_time_forward else 1.0) * time_scale_factor
786
+ op_times = {"forward": forward_times}
787
+ else:
788
+ op_times = {"forward": float(op_time_forward) * time_scale_factor}
789
 
790
+ # Build backward timing
791
  if split_backward:
792
  op_times["backward_D"] = float(op_time_backward_d) * time_scale_factor
793
  op_times["backward_W"] = float(op_time_backward_w) * time_scale_factor
794
  op_times["backward"] = (float(op_time_backward_d) + float(op_time_backward_w)) * time_scale_factor
795
  else:
796
+ if has_per_stage_backward:
797
+ backward_times = {}
798
+ for stage_id in range(current_num_stages):
799
+ if stage_id < len(stage_backward_values) and stage_backward_values[stage_id] is not None:
800
+ backward_times[stage_id] = float(stage_backward_values[stage_id]) * time_scale_factor
801
+ else:
802
+ # Use global value as fallback (default 1.0 if not specified)
803
+ backward_times[stage_id] = float(op_time_backward if op_time_backward else 1.0) * time_scale_factor
804
+ op_times["backward"] = backward_times
805
+ else:
806
+ op_times["backward"] = float(op_time_backward) * time_scale_factor
807
 
808
  if op_time_overlapped_fwd_bwd is not None:
809
  try:
assets/dumped_example.jpg ADDED

Git LFS Details

  • SHA256: a12aecf1e52d57ef7f13c2b6bfd0a7320dd8ed389c8722a86fd999953b301a3b
  • Pointer size: 131 Bytes
  • Size of remote file: 796 kB
examples/megatron-lm/README.md ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pipeline Parallelism Visualization for Megatron-LM
2
+
3
+ This tool provides visualization capabilities for Pipeline Parallelism (PP) scheduling in Megatron-LM training, helping you analyze load balancing issues and debug abnormal PP bubble problems that are difficult to inspect directly from Nsight Systems profiling.
4
+
5
+ ## Overview
6
+
7
+ The visualization tool offers intuitive visual representation of PP scheduling, making it easier to:
8
+ - Identify load balancing issues across pipeline stages
9
+ - Debug PP bubble problems
10
+ - Analyze pipeline efficiency and bottlenecks
11
+ - Optimize pipeline parallelism configurations
12
+
13
+ ## Prerequisites
14
+
15
+ - Megatron-LM with PP timer support
16
+ - Python environment with required dependencies
17
+ - UV package manager (recommended)
18
+
19
+ ## Usage
20
+
21
+ ### Step 1: Enable PP Timer in Megatron-LM
22
+
23
+ First, you need to apply the PP timer patch to your Megatron-LM installation:
24
+
25
+ 1. Cherry-pick the commit from the modified Megatron-LM repository:
26
+ ```bash
27
+ # Navigate to your Megatron-LM directory
28
+ cd /path/to/Megatron-LM
29
+
30
+ # Cherry-pick the PP timer commit
31
+ git remote add victarry https://github.com/Victarry/PP-Schedule-Visualization.git
32
+ git fetch victarry
33
+ git cherry-pick ad3bc3a22adc79827dc1b35619ad6813078e621b
34
+ ```
35
+
36
+ **Note**: The commit can be viewed at: https://github.com/Victarry/Megatron-LM/commit/ad3bc3a22adc79827dc1b35619ad6813078e621b
37
+
38
+ ### Step 2: Configure Environment Variables
39
+
40
+ Set the following environment variables before running your training script:
41
+
42
+ ```bash
43
+ # Enable PP timer functionality
44
+ export ENABLE_PP_TIMER=1
45
+
46
+ # Specify which iteration to dump (e.g., iteration 1)
47
+ export ENABLE_PP_TIMER_ITER=1
48
+
49
+ # Set directory to save the dumped timer results
50
+ export PP_TIMER_LOG_DIR=/path/to/save/timer/logs
51
+
52
+ # Run your training script
53
+ bash your_training_script.sh
54
+ ```
55
+
56
+ ### Step 3: Generate Visualization
57
+
58
+ Once you have collected the timer data, use the visualization script:
59
+
60
+ ```bash
61
+ # Navigate to the PP-Schedule-Visualization directory
62
+ cd /path/to/PP-Schedule-Visualization
63
+
64
+ # Set your configuration parameters
65
+ PP_SIZE=4 # Number of pipeline parallel stages
66
+ VPP_SIZE=1 # Virtual pipeline parallel size (usually 1)
67
+ DATA_DIR=/path/to/timer/logs # Directory containing the dumped timer data
68
+
69
+ # Run the visualization script
70
+ uv run examples/megatron-lm/plot.py --data-dir $DATA_DIR --pp-size $PP_SIZE --vpp-size $VPP_SIZE
71
+ ```
72
+
73
+ **Parameters:**
74
+ - `--data-dir`: Path to the directory containing PP timer log files
75
+ - `--pp-size`: Number of pipeline parallel stages in your training setup
76
+ - `--vpp-size`: Virtual pipeline parallel size (typically 1 unless using virtual PP)
77
+
78
+ ### Example Output
79
+
80
+ After running the visualization script, you will see a detailed PP schedule visualization similar to:
81
+
82
+ ![PP Schedule Visualization](../../assets/dumped_example.jpg)
83
+
84
+ The visualization shows:
85
+ - Timeline of each pipeline stage
86
+ - Forward and backward pass scheduling
87
+ - Bubble time and idle periods
88
+ - Communication overhead between stages
89
+
90
+ ## Known Issue
91
+ - If the global batch size is very large, it may takes > 1 minutes to see the visualization results.
92
+
93
+ ## Contributing
94
+
95
+ If you encounter issues or have suggestions for improvements, please open an issue or submit a pull request.
examples/megatron-lm/plot.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import argparse
4
+ import re
5
+ from collections import defaultdict
6
+ from src.execution_model import Schedule, ScheduleConfig, Operation
7
+ from src.visualizer import visualize_pipeline_parallelism_dash
8
+
9
+
10
+ def is_valid_event_filename(filename, pp_size, vpp_size):
11
+ """
12
+ Check if filename matches the expected format:
13
+ event_times_PP{pp_size}_VPP{vpp_size}_TPxCPxDP_rank_{rank}_pp_rank_{pp_rank}_rank_{final_rank}.json
14
+
15
+ Returns True if valid, False otherwise.
16
+ """
17
+ # Create regex pattern for the expected format
18
+ pattern = rf"^event_times_PP{pp_size}_VPP{vpp_size}_TPxCPxDP_rank_\d+_pp_rank_\d+_rank_\d+\.json$"
19
+ return bool(re.match(pattern, filename))
20
+
21
+
22
+ def parse_event_filename(filename):
23
+ """
24
+ Parse the event filename and extract rank information.
25
+
26
+ Expected format: event_times_PP{pp_size}_VPP{vpp_size}_TPxCPxDP_rank_{rank}_pp_rank_{pp_rank}_rank_{final_rank}.json
27
+
28
+ Returns: (TPxCPxDP_rank, pp_rank, global_rank) or None if parsing fails
29
+ """
30
+ try:
31
+ # Remove .json extension
32
+ name_without_ext = filename.replace(".json", "")
33
+ parts = name_without_ext.split("_")
34
+
35
+ # Find the TPxCPxDP part and the rank values
36
+ tpxcpxdp_rank = None
37
+ pp_rank = None
38
+ global_rank = None
39
+
40
+ for i, part in enumerate(parts):
41
+ # Look for TPxCPxDP pattern followed by 'rank'
42
+ if part.startswith("TP") and "CP" in part and part.endswith("DP"):
43
+ if i + 2 < len(parts) and parts[i + 1] == "rank":
44
+ tpxcpxdp_rank = int(parts[i + 2])
45
+ # Look for 'pp_rank' pattern
46
+ elif part == "pp" and i + 2 < len(parts) and parts[i + 1] == "rank":
47
+ pp_rank = int(parts[i + 2])
48
+ # Look for the final 'rank' (global rank) - this should be the last rank in the filename
49
+ elif part == "rank" and i + 1 < len(parts) and i == len(parts) - 2:
50
+ global_rank = int(parts[i + 1])
51
+
52
+ if tpxcpxdp_rank is None or pp_rank is None or global_rank is None:
53
+ return None
54
+
55
+ return (tpxcpxdp_rank, pp_rank, global_rank)
56
+
57
+ except (ValueError, IndexError):
58
+ return None
59
+
60
+
61
+ def load_event_times_from_json(data_dir, pp_size, vpp_size):
62
+ """Load event times from JSON files in the specified directory."""
63
+ all_files = [f for f in os.listdir(data_dir) if f.endswith(".json")]
64
+
65
+ # Filter files that match the expected format
66
+ event_files = [
67
+ f for f in all_files if is_valid_event_filename(f, pp_size, vpp_size)
68
+ ]
69
+
70
+ if len(event_files) == 0:
71
+ print(f"Available files in {data_dir}:")
72
+ for f in all_files[:10]: # Show first 10 files for debugging
73
+ print(f" {f}")
74
+ raise ValueError(
75
+ f"No event files found matching pattern event_times_PP{pp_size}_VPP{vpp_size}_*.json"
76
+ )
77
+
78
+ print(f"Found {len(event_files)} matching event files")
79
+ event_times = {}
80
+
81
+ for file_name in event_files:
82
+ parsed_result = parse_event_filename(file_name)
83
+ if parsed_result is None:
84
+ print(f"Warning: Could not parse filename {file_name}")
85
+ continue
86
+
87
+ tpxcpxdp_rank, pp_rank, global_rank = parsed_result
88
+
89
+ if tpxcpxdp_rank == 0:
90
+ try:
91
+ with open(os.path.join(data_dir, file_name), "r") as f:
92
+ event_data = json.load(f)
93
+ event_times[(pp_rank, tpxcpxdp_rank)] = event_data
94
+ print(
95
+ f"Loaded data from {file_name}: global_rank={global_rank}, pp_rank={pp_rank}, tpxcpxdp_rank={tpxcpxdp_rank}"
96
+ )
97
+ except Exception as e:
98
+ print(f"Error loading {file_name}: {e}")
99
+
100
+ return event_times
101
+
102
+
103
+ def create_pp_schedule_from_event_times(event_times, pp_size):
104
+ """Create a Schedule object from event times data."""
105
+ # Determine number of devices/stages from the data
106
+ num_devices = pp_size
107
+
108
+ # Find the maximum batch ID by parsing event names
109
+ max_batch_id = 0
110
+ for events in event_times.values():
111
+ for event_name in events:
112
+ if event_name.startswith(("forward-", "backward-")):
113
+ parts = event_name.split("-")
114
+ if len(parts) >= 2 and parts[1].isdigit():
115
+ batch_id = int(parts[1])
116
+ max_batch_id = max(max_batch_id, batch_id)
117
+
118
+ num_batches = max_batch_id + 1
119
+
120
+ # Create a simple config (actual times will come from event data)
121
+ config = ScheduleConfig(
122
+ num_devices=num_devices,
123
+ num_stages=num_devices, # Assuming 1:1 mapping of devices to stages
124
+ num_batches=num_batches,
125
+ p2p_latency=0, # Will be implicit in the event timing
126
+ op_times={}, # Not needed as we'll use real timing data
127
+ placement_strategy="standard",
128
+ )
129
+
130
+ # Create a schedule
131
+ schedule = Schedule(config)
132
+
133
+ # Populate the schedule with operations based on event times
134
+ for (pp_rank, tpxcpxdp_rank), events in event_times.items():
135
+ # Process forward passes
136
+ for batch_id in range(num_batches):
137
+ forward_start_key = f"forward-{batch_id}-start"
138
+ forward_end_key = f"forward-{batch_id}-end"
139
+
140
+ if forward_start_key in events and forward_end_key in events:
141
+ # Create an operation and set its timing directly
142
+ forward_op = Operation(batch_id, pp_rank, "forward")
143
+ forward_op.execution_time = (
144
+ events[forward_end_key] - events[forward_start_key]
145
+ )
146
+ forward_op.start_time = events[forward_start_key]
147
+ forward_op.end_time = events[forward_end_key]
148
+
149
+ # Add to schedule
150
+ schedule.ops[(batch_id, pp_rank, "forward")] = forward_op
151
+ schedule.device_queues[pp_rank].add_operation(forward_op)
152
+
153
+ # Process backward passes
154
+ for batch_id in range(num_batches):
155
+ backward_start_key = f"backward-{batch_id}-start"
156
+ backward_end_key = f"backward-{batch_id}-end"
157
+
158
+ if backward_start_key in events and backward_end_key in events:
159
+ # Create an operation and set its timing directly
160
+ backward_op = Operation(batch_id, pp_rank, "backward")
161
+ backward_op.execution_time = (
162
+ events[backward_end_key] - events[backward_start_key]
163
+ )
164
+ backward_op.start_time = events[backward_start_key]
165
+ backward_op.end_time = events[backward_end_key]
166
+
167
+ # Add to schedule
168
+ schedule.ops[(batch_id, pp_rank, "backward")] = backward_op
169
+ schedule.device_queues[pp_rank].add_operation(backward_op)
170
+
171
+ return schedule
172
+
173
+
174
+ def create_vpp_schedule_from_event_times(event_times, pp_size, vpp_size):
175
+ """Create a VPP Schedule object from event times data."""
176
+ # Determine number of devices/stages from the data
177
+ # Find the maximum batch ID by parsing event names
178
+ max_batch_id = 0
179
+ for events in event_times.values():
180
+ for event_name in events:
181
+ if event_name.startswith(("forward-", "backward-")):
182
+ parts = event_name.split("-")
183
+ assert len(parts) == 4
184
+ assert parts[0] in ["forward", "backward"]
185
+ assert parts[1].isdigit() and parts[2].isdigit()
186
+ assert parts[3] in ["start", "end"]
187
+ batch_id = int(parts[2]) # backward-0-19-end
188
+ max_batch_id = max(max_batch_id, batch_id)
189
+
190
+ num_batches = max_batch_id + 1
191
+
192
+ # Create a simple config (actual times will come from event data)
193
+ config = ScheduleConfig(
194
+ num_devices=pp_size,
195
+ num_stages=pp_size * vpp_size,
196
+ num_batches=num_batches,
197
+ p2p_latency=0, # Will be implicit in the event timing
198
+ op_times={}, # Not needed as we'll use real timing data
199
+ placement_strategy="interleave",
200
+ )
201
+
202
+ # Create a schedule
203
+ schedule = Schedule(config)
204
+
205
+ # Populate the schedule with operations based on event times
206
+ for (pp_rank, tpxcpxdp_rank), events in event_times.items():
207
+ # Process forward passes
208
+ for model_chunk_id in range(vpp_size):
209
+ for batch_id in range(num_batches):
210
+ forward_start_key = f"forward-{model_chunk_id}-{batch_id}-start"
211
+ forward_end_key = f"forward-{model_chunk_id}-{batch_id}-end"
212
+
213
+ # Create an operation and set its timing directly
214
+ stage_id = pp_size * model_chunk_id + pp_rank
215
+ forward_op = Operation(batch_id, stage_id=stage_id, op_type="forward")
216
+ forward_op.execution_time = (
217
+ events[forward_end_key] - events[forward_start_key]
218
+ )
219
+ forward_op.start_time = events[forward_start_key]
220
+ forward_op.end_time = events[forward_end_key]
221
+
222
+ # Add to schedule
223
+ schedule.ops[(batch_id, stage_id, "forward")] = forward_op
224
+ schedule.device_queues[pp_rank].add_operation(forward_op)
225
+
226
+ # Process backward passes
227
+ for model_chunk_id in range(vpp_size):
228
+ for batch_id in range(num_batches):
229
+ backward_start_key = f"backward-{model_chunk_id}-{batch_id}-start"
230
+ backward_end_key = f"backward-{model_chunk_id}-{batch_id}-end"
231
+
232
+ stage_id = pp_size * model_chunk_id + pp_rank
233
+ if backward_start_key in events and backward_end_key in events:
234
+ # Create an operation and set its timing directly
235
+ backward_op = Operation(
236
+ batch_id, stage_id=stage_id, op_type="backward"
237
+ )
238
+ backward_op.execution_time = (
239
+ events[backward_end_key] - events[backward_start_key]
240
+ )
241
+ backward_op.start_time = events[backward_start_key]
242
+ backward_op.end_time = events[backward_end_key]
243
+
244
+ # Add to schedule
245
+ schedule.ops[(batch_id, stage_id, "backward")] = backward_op
246
+ schedule.device_queues[pp_rank].add_operation(backward_op)
247
+
248
+ return schedule
249
+
250
+
251
+ def main():
252
+ # Parse command-line arguments
253
+ parser = argparse.ArgumentParser(
254
+ description="Visualize pipeline parallelism from event data"
255
+ )
256
+ parser.add_argument(
257
+ "--data-dir",
258
+ type=str,
259
+ required=True,
260
+ help="Directory containing event_times_*.json files",
261
+ )
262
+ parser.add_argument(
263
+ "--pp-size", type=int, required=True, help="Pipeline parallelism size"
264
+ )
265
+ parser.add_argument(
266
+ "--vpp-size", type=int, required=True, help="Virtual pipeline parallelism size"
267
+ )
268
+ parser.add_argument(
269
+ "--port",
270
+ type=int,
271
+ default=8050,
272
+ help="Port for the visualization dashboard (default: 8050)",
273
+ )
274
+ args = parser.parse_args()
275
+
276
+ # Load event times from JSON files
277
+ event_times = load_event_times_from_json(args.data_dir, args.pp_size, args.vpp_size)
278
+
279
+ # Create schedule from event times
280
+ if args.vpp_size == 1:
281
+ schedule = create_pp_schedule_from_event_times(event_times, args.pp_size)
282
+ else:
283
+ schedule = create_vpp_schedule_from_event_times(
284
+ event_times, args.pp_size, args.vpp_size
285
+ )
286
+
287
+ # Calculate and print execution metrics
288
+ total_execution_time = max(
289
+ op.end_time for op in schedule.ops.values() if op.end_time is not None
290
+ )
291
+ print(f"Total execution time: {total_execution_time:.2f} ms")
292
+
293
+ # Calculate bubble time percentage
294
+ device_times = defaultdict(float)
295
+ for device_id, device_queue in enumerate(schedule.device_queues):
296
+ for op in device_queue.ops:
297
+ if op.start_time is not None and op.end_time is not None:
298
+ device_times[device_id] += op.end_time - op.start_time
299
+
300
+ # Print bubble percentage for each device
301
+ for device_id, active_time in device_times.items():
302
+ bubble_percentage = (
303
+ (total_execution_time - active_time) / total_execution_time * 100
304
+ )
305
+ print(f"Device {device_id} bubble: {bubble_percentage:.2f}%")
306
+
307
+ # Visualize the schedule
308
+ print("Launching visualization...")
309
+ visualize_pipeline_parallelism_dash(
310
+ schedule, schedule_type="1F1B-Imported", port=args.port
311
+ )
312
+
313
+
314
+ if __name__ == "__main__":
315
+ main()
src/execution_model.py CHANGED
@@ -243,6 +243,39 @@ class Schedule:
243
  return None
244
  return self.ops[(batch_id, stage_id, op_type)]
245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  def get_dependencies(self, op: Operation, include_device_dependency=True):
247
  deps = []
248
  if isinstance(op, OverlappedOperation):
@@ -327,7 +360,34 @@ class Schedule:
327
  if include_device_dependency:
328
  device_index = self.device_queues[op.device_id].ops.index(op)
329
  if device_index > 0:
330
- deps.append((self.device_queues[op.device_id].ops[device_index - 1], 0.0))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  return deps
332
 
333
  def show(self):
 
243
  return None
244
  return self.ops[(batch_id, stage_id, op_type)]
245
 
246
+ def get_p2p_receiver_op(self, sender_op: Operation) -> Optional[Operation]:
247
+ """
248
+ Get the operation that receives P2P data from sender_op.
249
+
250
+ For forward ops: sender on stage N sends to receiver on stage N+1
251
+ For backward ops: sender on stage N sends to receiver on stage N-1
252
+
253
+ Returns None if there is no P2P receiver (first/last stage).
254
+ """
255
+ if isinstance(sender_op, OverlappedOperation):
256
+ # For overlapped ops, return None (P2P is overlapped with computation)
257
+ return None
258
+
259
+ if sender_op.op_type == "forward":
260
+ # Forward sends to next stage
261
+ next_stage = sender_op.stage_id + 1
262
+ if next_stage >= self.config.num_stages:
263
+ return None # Last stage, no P2P
264
+ return self.get_op(sender_op.batch_id, next_stage, "forward", allow_none=True)
265
+
266
+ elif sender_op.op_type in ("backward", "backward_D"):
267
+ # Backward sends to previous stage
268
+ prev_stage = sender_op.stage_id - 1
269
+ if prev_stage < 0:
270
+ return None # First stage, no P2P
271
+ # Try backward_D first, then backward
272
+ receiver = self.get_op(sender_op.batch_id, prev_stage, "backward_D", allow_none=True)
273
+ if receiver is None:
274
+ receiver = self.get_op(sender_op.batch_id, prev_stage, "backward", allow_none=True)
275
+ return receiver
276
+
277
+ return None
278
+
279
  def get_dependencies(self, op: Operation, include_device_dependency=True):
280
  deps = []
281
  if isinstance(op, OverlappedOperation):
 
360
  if include_device_dependency:
361
  device_index = self.device_queues[op.device_id].ops.index(op)
362
  if device_index > 0:
363
+ prev_op = self.device_queues[op.device_id].ops[device_index - 1]
364
+
365
+ # Check if sync P2P should apply
366
+ # Sync P2P means sender waits for P2P transfer to complete before next op
367
+ # This adds p2p_latency to the device dependency gap
368
+ sync_p2p_gap = 0.0
369
+ if self.config.p2p_latency > 0:
370
+ is_prev_overlapped = isinstance(prev_op, OverlappedOperation)
371
+ is_current_overlapped = isinstance(op, OverlappedOperation)
372
+
373
+ # Only add sync P2P gap when:
374
+ # 1. Both ops are not overlapped (not in overlap schedule's steady state)
375
+ # 2. Both ops have the same base type (both forward or both backward)
376
+ # 3. Both ops are on the same stage (ensures we're in a pure warmup/cooldown
377
+ # sequence for a specific stage, avoiding cycles in interleaved schedules)
378
+ if not is_prev_overlapped and not is_current_overlapped:
379
+ prev_base_type = "backward" if prev_op.op_type.startswith("backward") else prev_op.op_type
380
+ curr_base_type = "backward" if op.op_type.startswith("backward") else op.op_type
381
+
382
+ if prev_base_type == curr_base_type and prev_op.stage_id == op.stage_id:
383
+ receiver_op = self.get_p2p_receiver_op(prev_op)
384
+ if receiver_op is not None and not isinstance(receiver_op, OverlappedOperation):
385
+ # Sync P2P: sender waits for P2P transfer to complete
386
+ # Current op starts after prev_op.end_time + p2p_latency
387
+ # (not after receiver completes, just after transfer completes)
388
+ sync_p2p_gap = self.config.p2p_latency
389
+
390
+ deps.append((prev_op, sync_p2p_gap))
391
  return deps
392
 
393
  def show(self):