Victarry's picture
Add visualization script and README for Pipeline Parallelism in Megatron-LM
592da35
import json
import os
import argparse
import re
from collections import defaultdict
from src.execution_model import Schedule, ScheduleConfig, Operation
from src.visualizer import visualize_pipeline_parallelism_dash
def is_valid_event_filename(filename, pp_size, vpp_size):
"""
Check if filename matches the expected format:
event_times_PP{pp_size}_VPP{vpp_size}_TPxCPxDP_rank_{rank}_pp_rank_{pp_rank}_rank_{final_rank}.json
Returns True if valid, False otherwise.
"""
# Create regex pattern for the expected format
pattern = rf"^event_times_PP{pp_size}_VPP{vpp_size}_TPxCPxDP_rank_\d+_pp_rank_\d+_rank_\d+\.json$"
return bool(re.match(pattern, filename))
def parse_event_filename(filename):
"""
Parse the event filename and extract rank information.
Expected format: event_times_PP{pp_size}_VPP{vpp_size}_TPxCPxDP_rank_{rank}_pp_rank_{pp_rank}_rank_{final_rank}.json
Returns: (TPxCPxDP_rank, pp_rank, global_rank) or None if parsing fails
"""
try:
# Remove .json extension
name_without_ext = filename.replace(".json", "")
parts = name_without_ext.split("_")
# Find the TPxCPxDP part and the rank values
tpxcpxdp_rank = None
pp_rank = None
global_rank = None
for i, part in enumerate(parts):
# Look for TPxCPxDP pattern followed by 'rank'
if part.startswith("TP") and "CP" in part and part.endswith("DP"):
if i + 2 < len(parts) and parts[i + 1] == "rank":
tpxcpxdp_rank = int(parts[i + 2])
# Look for 'pp_rank' pattern
elif part == "pp" and i + 2 < len(parts) and parts[i + 1] == "rank":
pp_rank = int(parts[i + 2])
# Look for the final 'rank' (global rank) - this should be the last rank in the filename
elif part == "rank" and i + 1 < len(parts) and i == len(parts) - 2:
global_rank = int(parts[i + 1])
if tpxcpxdp_rank is None or pp_rank is None or global_rank is None:
return None
return (tpxcpxdp_rank, pp_rank, global_rank)
except (ValueError, IndexError):
return None
def load_event_times_from_json(data_dir, pp_size, vpp_size):
"""Load event times from JSON files in the specified directory."""
all_files = [f for f in os.listdir(data_dir) if f.endswith(".json")]
# Filter files that match the expected format
event_files = [
f for f in all_files if is_valid_event_filename(f, pp_size, vpp_size)
]
if len(event_files) == 0:
print(f"Available files in {data_dir}:")
for f in all_files[:10]: # Show first 10 files for debugging
print(f" {f}")
raise ValueError(
f"No event files found matching pattern event_times_PP{pp_size}_VPP{vpp_size}_*.json"
)
print(f"Found {len(event_files)} matching event files")
event_times = {}
for file_name in event_files:
parsed_result = parse_event_filename(file_name)
if parsed_result is None:
print(f"Warning: Could not parse filename {file_name}")
continue
tpxcpxdp_rank, pp_rank, global_rank = parsed_result
if tpxcpxdp_rank == 0:
try:
with open(os.path.join(data_dir, file_name), "r") as f:
event_data = json.load(f)
event_times[(pp_rank, tpxcpxdp_rank)] = event_data
print(
f"Loaded data from {file_name}: global_rank={global_rank}, pp_rank={pp_rank}, tpxcpxdp_rank={tpxcpxdp_rank}"
)
except Exception as e:
print(f"Error loading {file_name}: {e}")
return event_times
def create_pp_schedule_from_event_times(event_times, pp_size):
"""Create a Schedule object from event times data."""
# Determine number of devices/stages from the data
num_devices = pp_size
# Find the maximum batch ID by parsing event names
max_batch_id = 0
for events in event_times.values():
for event_name in events:
if event_name.startswith(("forward-", "backward-")):
parts = event_name.split("-")
if len(parts) >= 2 and parts[1].isdigit():
batch_id = int(parts[1])
max_batch_id = max(max_batch_id, batch_id)
num_batches = max_batch_id + 1
# Create a simple config (actual times will come from event data)
config = ScheduleConfig(
num_devices=num_devices,
num_stages=num_devices, # Assuming 1:1 mapping of devices to stages
num_batches=num_batches,
p2p_latency=0, # Will be implicit in the event timing
op_times={}, # Not needed as we'll use real timing data
placement_strategy="standard",
)
# Create a schedule
schedule = Schedule(config)
# Populate the schedule with operations based on event times
for (pp_rank, tpxcpxdp_rank), events in event_times.items():
# Process forward passes
for batch_id in range(num_batches):
forward_start_key = f"forward-{batch_id}-start"
forward_end_key = f"forward-{batch_id}-end"
if forward_start_key in events and forward_end_key in events:
# Create an operation and set its timing directly
forward_op = Operation(batch_id, pp_rank, "forward")
forward_op.execution_time = (
events[forward_end_key] - events[forward_start_key]
)
forward_op.start_time = events[forward_start_key]
forward_op.end_time = events[forward_end_key]
# Add to schedule
schedule.ops[(batch_id, pp_rank, "forward")] = forward_op
schedule.device_queues[pp_rank].add_operation(forward_op)
# Process backward passes
for batch_id in range(num_batches):
backward_start_key = f"backward-{batch_id}-start"
backward_end_key = f"backward-{batch_id}-end"
if backward_start_key in events and backward_end_key in events:
# Create an operation and set its timing directly
backward_op = Operation(batch_id, pp_rank, "backward")
backward_op.execution_time = (
events[backward_end_key] - events[backward_start_key]
)
backward_op.start_time = events[backward_start_key]
backward_op.end_time = events[backward_end_key]
# Add to schedule
schedule.ops[(batch_id, pp_rank, "backward")] = backward_op
schedule.device_queues[pp_rank].add_operation(backward_op)
return schedule
def create_vpp_schedule_from_event_times(event_times, pp_size, vpp_size):
"""Create a VPP Schedule object from event times data."""
# Determine number of devices/stages from the data
# Find the maximum batch ID by parsing event names
max_batch_id = 0
for events in event_times.values():
for event_name in events:
if event_name.startswith(("forward-", "backward-")):
parts = event_name.split("-")
assert len(parts) == 4
assert parts[0] in ["forward", "backward"]
assert parts[1].isdigit() and parts[2].isdigit()
assert parts[3] in ["start", "end"]
batch_id = int(parts[2]) # backward-0-19-end
max_batch_id = max(max_batch_id, batch_id)
num_batches = max_batch_id + 1
# Create a simple config (actual times will come from event data)
config = ScheduleConfig(
num_devices=pp_size,
num_stages=pp_size * vpp_size,
num_batches=num_batches,
p2p_latency=0, # Will be implicit in the event timing
op_times={}, # Not needed as we'll use real timing data
placement_strategy="interleave",
)
# Create a schedule
schedule = Schedule(config)
# Populate the schedule with operations based on event times
for (pp_rank, tpxcpxdp_rank), events in event_times.items():
# Process forward passes
for model_chunk_id in range(vpp_size):
for batch_id in range(num_batches):
forward_start_key = f"forward-{model_chunk_id}-{batch_id}-start"
forward_end_key = f"forward-{model_chunk_id}-{batch_id}-end"
# Create an operation and set its timing directly
stage_id = pp_size * model_chunk_id + pp_rank
forward_op = Operation(batch_id, stage_id=stage_id, op_type="forward")
forward_op.execution_time = (
events[forward_end_key] - events[forward_start_key]
)
forward_op.start_time = events[forward_start_key]
forward_op.end_time = events[forward_end_key]
# Add to schedule
schedule.ops[(batch_id, stage_id, "forward")] = forward_op
schedule.device_queues[pp_rank].add_operation(forward_op)
# Process backward passes
for model_chunk_id in range(vpp_size):
for batch_id in range(num_batches):
backward_start_key = f"backward-{model_chunk_id}-{batch_id}-start"
backward_end_key = f"backward-{model_chunk_id}-{batch_id}-end"
stage_id = pp_size * model_chunk_id + pp_rank
if backward_start_key in events and backward_end_key in events:
# Create an operation and set its timing directly
backward_op = Operation(
batch_id, stage_id=stage_id, op_type="backward"
)
backward_op.execution_time = (
events[backward_end_key] - events[backward_start_key]
)
backward_op.start_time = events[backward_start_key]
backward_op.end_time = events[backward_end_key]
# Add to schedule
schedule.ops[(batch_id, stage_id, "backward")] = backward_op
schedule.device_queues[pp_rank].add_operation(backward_op)
return schedule
def main():
# Parse command-line arguments
parser = argparse.ArgumentParser(
description="Visualize pipeline parallelism from event data"
)
parser.add_argument(
"--data-dir",
type=str,
required=True,
help="Directory containing event_times_*.json files",
)
parser.add_argument(
"--pp-size", type=int, required=True, help="Pipeline parallelism size"
)
parser.add_argument(
"--vpp-size", type=int, required=True, help="Virtual pipeline parallelism size"
)
parser.add_argument(
"--port",
type=int,
default=8050,
help="Port for the visualization dashboard (default: 8050)",
)
args = parser.parse_args()
# Load event times from JSON files
event_times = load_event_times_from_json(args.data_dir, args.pp_size, args.vpp_size)
# Create schedule from event times
if args.vpp_size == 1:
schedule = create_pp_schedule_from_event_times(event_times, args.pp_size)
else:
schedule = create_vpp_schedule_from_event_times(
event_times, args.pp_size, args.vpp_size
)
# Calculate and print execution metrics
total_execution_time = max(
op.end_time for op in schedule.ops.values() if op.end_time is not None
)
print(f"Total execution time: {total_execution_time:.2f} ms")
# Calculate bubble time percentage
device_times = defaultdict(float)
for device_id, device_queue in enumerate(schedule.device_queues):
for op in device_queue.ops:
if op.start_time is not None and op.end_time is not None:
device_times[device_id] += op.end_time - op.start_time
# Print bubble percentage for each device
for device_id, active_time in device_times.items():
bubble_percentage = (
(total_execution_time - active_time) / total_execution_time * 100
)
print(f"Device {device_id} bubble: {bubble_percentage:.2f}%")
# Visualize the schedule
print("Launching visualization...")
visualize_pipeline_parallelism_dash(
schedule, schedule_type="1F1B-Imported", port=args.port
)
if __name__ == "__main__":
main()