File size: 7,111 Bytes
f4cade0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
# Copyright (c) Meta Platforms, Inc. and affiliates
"""
This visualizer requires matplotlib to be installed.
Example usage:
ops = get_schedule_ops("InterleavedZeroBubble", 4, 8)
visualize_schedule(ops, "test.png")
"""
from typing import Optional, Union
from unittest import mock
from torch.distributed.pipelining.schedules import (
_Action,
_ComputationType,
_PipelineSchedule,
get_schedule_class,
PipelineScheduleMulti,
PipelineScheduleSingle,
)
from torch.distributed.pipelining.stage import PipelineStage
def get_schedule_ops(
schedule: Union[str, _PipelineSchedule],
pp_degree: int,
num_microbatches: int,
num_stages_per_rank: Optional[int] = None,
) -> list[list[Optional[_Action]]]:
"""
Get all actions for a given schedule, pp_degree, and num_microbatches. The actions are returned in a list of lists
where each inner list represents a rank and each element in the inner list represents an action.
The schedule can be specified as a string which is passed into get_schedule_class() or a _PipelineSchedule instance.
"""
if isinstance(schedule, str):
schedule_class = get_schedule_class(schedule)
elif type(schedule) == _PipelineSchedule:
schedule_class = schedule
else:
raise ValueError(f"Invalid schedule: {schedule}")
# Create a mock of the PipelineStage class
mock_pipeline_stage = mock.create_autospec(PipelineStage, instance=True)
# Set the return values for group_rank and group_size methods
mock_pipeline_stage.group_rank = 0
mock_pipeline_stage.group_size = pp_degree
mock_pipeline_stage.submod = None
# Check num_stages_per_rank is valid
if issubclass(schedule_class, PipelineScheduleSingle):
if num_stages_per_rank is None:
num_stages_per_rank = 1
assert num_stages_per_rank == 1
stages = mock_pipeline_stage
stages.num_stages = num_stages_per_rank * pp_degree
elif issubclass(schedule_class, PipelineScheduleMulti):
if num_stages_per_rank is None:
num_stages_per_rank = 2
assert num_stages_per_rank >= 2
stages = [mock_pipeline_stage for _ in range(num_stages_per_rank)]
for stage in stages:
stage.num_stages = num_stages_per_rank * pp_degree
else:
raise ValueError(f"Invalid schedule: {schedule_class}")
# Instantiate the schedule class
schedule_instance = schedule_class(stages, num_microbatches)
# Convert to List[List[_Action]]
all_actions = []
for rank in range(pp_degree):
all_actions.append(schedule_instance.pipeline_order[rank])
# Return the pipeline order
return all_actions
class _ComputationTypeColor:
def __init__(
self,
color: str,
text: str = "",
width: int = 1,
):
self.color = color
self.width = width
self.text = text
# Update the mapping to use _ComputationTypeColor instances
action_type_to_color_mapping = {
_ComputationType.FORWARD: _ComputationTypeColor("blue", "Forward"),
_ComputationType.BACKWARD_INPUT: _ComputationTypeColor("teal", "Backward Input"),
_ComputationType.BACKWARD_WEIGHT: _ComputationTypeColor("green", "Backward Weight"),
_ComputationType.FULL_BACKWARD: _ComputationTypeColor("orange", "Full Backward", 2),
}
def visualize_schedule(
schedule: list[list[Optional[_Action]]], filename: Optional[str] = None
) -> None:
"""
Visualize the schedule using matplotlib.
The schedule is a list of lists where each inner list represents a rank and each element in the inner list represents an action.
The actions are represented as rectangles with different colors based on their computation type.
The filename is optional and if provided, the plot will be saved to that file.
"""
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
plt.rcParams["font.family"] = (
"DejaVu Sans" # or any other font available on your system
)
num_ranks = len(schedule)
max_actions = max(len(rank) for rank in schedule)
# Increase the figure size to provide more space for the legend
fig, ax = plt.subplots(figsize=(max_actions + 2, num_ranks + 2))
max_draw_position = -1
# Calculate dynamic font size based on figure size
font_size = min(max_actions, num_ranks) + 4
used_computation = set()
for rank_idx, actions in enumerate(schedule):
draw_position = 0 # Initialize drawing position for each rank
for action in actions:
if action is not None:
comp_type_color = action_type_to_color_mapping.get(
action.computation_type, _ComputationTypeColor("black")
)
used_computation.add(action.computation_type)
color = comp_type_color.color
width = comp_type_color.width
# Draw the rectangle to represent the action duration
rect = Rectangle(
(draw_position, num_ranks - rank_idx - 1),
width,
1,
facecolor=color,
edgecolor="black",
)
ax.add_patch(rect)
# Draw the text centered within the rectangle
ax.text(
draw_position + width / 2,
num_ranks - rank_idx - 1 + 0.5,
str(action),
ha="center",
va="center",
fontsize=font_size,
color="white",
)
# Increment the drawing position by the width of the current action
draw_position += width
else:
draw_position += 1 # Move to the next
max_draw_position = max(max_draw_position, draw_position)
ax.set_xlim(-0.5, max_draw_position + 1)
ax.set_ylim(-0.5, num_ranks + 0.5) # Add extra space at the top
# Set y-ticks to be in the middle of each rank's row
ax.set_yticks([num_ranks - rank_idx - 0.5 for rank_idx in range(num_ranks)])
ax.set_yticklabels([f"Rank {i}" for i in range(num_ranks)], fontsize=font_size)
ax.set_xticklabels([])
# Remove grid lines and ticks
ax.grid(False)
# Add legend with larger font size
legend_elements = [
Rectangle(
(0, 0),
1,
1,
facecolor=action_type_to_color_mapping[comp_type].color,
edgecolor="black",
label=action_type_to_color_mapping[comp_type].text,
)
for comp_type in used_computation
]
ax.legend(handles=legend_elements, loc="upper right", fontsize=font_size)
# Save to file if filename is provided, otherwise display the plot
if filename:
plt.savefig(filename, bbox_inches="tight")
else:
plt.show()
|