Buckets:
rahul7star/MagiGPU-storage / MagiCompiler /magi_compiler /utils /visualize /joint_graph_visualizer.py
| # Copyright (c) 2025 SandAI. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import textwrap | |
| from typing import Dict, List, Optional, Set, Tuple | |
| import graphviz | |
| import torch | |
| import torch.fx as fx | |
| from magi_compiler.config import get_compile_config | |
| from magi_compiler.utils import envs, magi_logger | |
| class NodeCategory: | |
| FWD = "fwd" | |
| BWD = "bwd" | |
| SAVE_TENSOR = "save_tensor" | |
| INPUT = "input" | |
| FWD_OUTPUT = "fwd_output" | |
| BWD_OUTPUT = "bwd_output" | |
| TANGENT = "tangent" | |
| NODE_CATEGORY_STYLES = { | |
| NodeCategory.INPUT: {"shape": "box", "style": "filled,bold", "fillcolor": "#E8F4FD", "color": "#2196F3"}, | |
| NodeCategory.FWD: {"shape": "ellipse", "style": "filled", "fillcolor": "#E8F5E9", "color": "#4CAF50"}, | |
| NodeCategory.BWD: {"shape": "ellipse", "style": "filled", "fillcolor": "#FFF3E0", "color": "#FF9800"}, | |
| NodeCategory.SAVE_TENSOR: {"shape": "hexagon", "style": "filled,bold", "fillcolor": "#FFEB3B", "color": "#D32F2F"}, | |
| NodeCategory.TANGENT: {"shape": "box", "style": "filled,bold", "fillcolor": "#B3E5FC", "color": "#0288D1"}, | |
| NodeCategory.FWD_OUTPUT: {"shape": "box", "style": "filled,bold", "fillcolor": "#C8E6C9", "color": "#388E3C"}, | |
| NodeCategory.BWD_OUTPUT: {"shape": "box", "style": "filled,bold", "fillcolor": "#FFCCBC", "color": "#E64A19"}, | |
| } | |
| def get_graph_node_names(graph: fx.Graph) -> Set[str]: | |
| """Extract all node names from a graph.""" | |
| return {node.name for node in graph.nodes} | |
| def is_tangent_node(node: fx.Node) -> bool: | |
| """Check if a node is a tangent node (gradient input for backward).""" | |
| return node.name.startswith("tangent") or "tangent" in node.name.lower() | |
| def categorize_joint_nodes( | |
| joint_graph: fx.Graph, fwd_graph: fx.Graph, bwd_graph: fx.Graph, save_tensor_nodes: Optional[List[fx.Node]] = None | |
| ) -> Tuple[Dict[str, str], Set[str]]: | |
| """ | |
| Categorize nodes in joint graph with priority: save_tensor > fwd > bwd. | |
| Returns: | |
| - node_categories: dict mapping node name to category | |
| - input_save_tensors: set of input node names that are also save tensors | |
| """ | |
| fwd_names = get_graph_node_names(fwd_graph) | |
| bwd_names = get_graph_node_names(bwd_graph) | |
| save_tensor_names = {node.name for node in save_tensor_nodes} if save_tensor_nodes else set() | |
| node_categories = {} | |
| input_save_tensors = set() | |
| for node in joint_graph.nodes: | |
| if node.op == "placeholder": | |
| if is_tangent_node(node): | |
| node_categories[node.name] = NodeCategory.TANGENT | |
| else: | |
| node_categories[node.name] = NodeCategory.INPUT | |
| # Track if this input is also a save tensor | |
| if node.name in save_tensor_names: | |
| input_save_tensors.add(node.name) | |
| elif node.op == "output": | |
| node_categories[node.name] = NodeCategory.BWD_OUTPUT | |
| elif node.name in save_tensor_names: | |
| node_categories[node.name] = NodeCategory.SAVE_TENSOR | |
| elif node.name in fwd_names: | |
| node_categories[node.name] = NodeCategory.FWD | |
| elif node.name in bwd_names: | |
| node_categories[node.name] = NodeCategory.BWD | |
| else: | |
| node_categories[node.name] = NodeCategory.FWD | |
| return node_categories, input_save_tensors | |
| def extract_joint_graph_structure( | |
| graph: fx.Graph, node_categories: Dict[str, str], input_save_tensors: Optional[Set[str]] = None | |
| ) -> Tuple[List[Dict], List[Dict]]: | |
| """Extract nodes and edges from joint graph with category-based styling.""" | |
| def wrap_str(text, width=40): | |
| return textwrap.fill(text, width=width, break_long_words=True, replace_whitespace=False) | |
| nodes, edges = [], [] | |
| input_save_tensors = input_save_tensors or set() | |
| for node in graph.nodes: | |
| name_str = wrap_str(str(node.name)) | |
| if callable(node.target): | |
| target_str = getattr(node.target, "__name__", str(node.target)) | |
| elif hasattr(node.target, "_op"): | |
| target_str = str(node.target._op) | |
| else: | |
| target_str = str(node.target) | |
| target_str = wrap_str(target_str) | |
| category = node_categories.get(node.name, NodeCategory.FWD) | |
| style_info = NODE_CATEGORY_STYLES.get(category, NODE_CATEGORY_STYLES[NodeCategory.FWD]) | |
| # Add annotation if input node is also a save tensor | |
| if node.name in input_save_tensors: | |
| node_label = f"{name_str}\n[{target_str}]\n(SaveTensor)" | |
| else: | |
| node_label = f"{name_str}\n[{target_str}]" | |
| nodes.append({"id": node.name, "style": style_info, "node_label": node_label, "category": category}) | |
| def traverse_args(args_kwargs): | |
| if isinstance(args_kwargs, (tuple, list)): | |
| for arg in args_kwargs: | |
| traverse_args(arg) | |
| elif isinstance(args_kwargs, dict): | |
| for val in args_kwargs.values(): | |
| traverse_args(val) | |
| elif isinstance(args_kwargs, fx.Node): | |
| d = {"source": args_kwargs.name, "target": node.name} | |
| if d not in edges: | |
| edges.append(d) | |
| traverse_args(node.args) | |
| traverse_args(node.kwargs) | |
| return nodes, edges | |
| def create_joint_graph_dot(nodes: List[Dict], edges: List[Dict]) -> graphviz.Digraph: | |
| """ | |
| Create a graphviz Digraph for joint graph visualization. | |
| Layout (using rankdir=BT, bottom-to-top): | |
| - Top: BWD output (gradients) | |
| - Middle-Left: FWD cluster (inputs, fwd ops, save_tensors) | |
| - Middle-Right: BWD cluster (bwd ops) | |
| - Bottom: Tangent inputs + FWD output | |
| """ | |
| dot = graphviz.Digraph( | |
| name="joint_graph", | |
| format="pdf", | |
| graph_attr={ | |
| "rankdir": "BT", | |
| "nodesep": "0.4", | |
| "ranksep": "0.6", | |
| "overlap": "false", | |
| "splines": "true", | |
| "newrank": "true", | |
| "label": "Joint Graph Visualization\\n" | |
| "Blue: Input | Cyan: Tangent | Green: FWD | Yellow: SaveTensor | Orange: BWD | Top: BWD Output", | |
| "labelloc": "t", | |
| "fontsize": "12", | |
| }, | |
| node_attr={"fontname": "Helvetica", "fontsize": "9", "fixedsize": "false", "margin": "0.12"}, | |
| ) | |
| input_nodes = [] | |
| fwd_nodes = [] | |
| save_tensor_nodes = [] | |
| tangent_nodes = [] | |
| bwd_nodes = [] | |
| bwd_output_nodes = [] | |
| for node in nodes: | |
| cat = node["category"] | |
| if cat == NodeCategory.INPUT: | |
| input_nodes.append(node) | |
| elif cat == NodeCategory.FWD: | |
| fwd_nodes.append(node) | |
| elif cat == NodeCategory.SAVE_TENSOR: | |
| save_tensor_nodes.append(node) | |
| elif cat == NodeCategory.TANGENT: | |
| tangent_nodes.append(node) | |
| elif cat == NodeCategory.BWD: | |
| bwd_nodes.append(node) | |
| elif cat == NodeCategory.BWD_OUTPUT: | |
| bwd_output_nodes.append(node) | |
| else: | |
| fwd_nodes.append(node) | |
| with dot.subgraph(name="cluster_fwd") as fwd_cluster: | |
| fwd_cluster.attr(label="Forward Pass", style="rounded,dashed", color="#4CAF50", bgcolor="#F1F8E9", penwidth="2") | |
| for node in input_nodes: | |
| fwd_cluster.node(node["id"], node["node_label"], **node["style"]) | |
| for node in fwd_nodes: | |
| fwd_cluster.node(node["id"], node["node_label"], **node["style"]) | |
| for node in save_tensor_nodes: | |
| fwd_cluster.node(node["id"], node["node_label"], **node["style"]) | |
| with dot.subgraph(name="cluster_bwd") as bwd_cluster: | |
| bwd_cluster.attr(label="Backward Pass", style="rounded,dashed", color="#FF9800", bgcolor="#FFF8E1", penwidth="2") | |
| for node in bwd_nodes: | |
| bwd_cluster.node(node["id"], node["node_label"], **node["style"]) | |
| for node in tangent_nodes: | |
| dot.node(node["id"], node["node_label"], **node["style"]) | |
| for node in bwd_output_nodes: | |
| dot.node(node["id"], node["node_label"], **node["style"]) | |
| with dot.subgraph() as s: | |
| s.attr(rank="min") | |
| for node in tangent_nodes: | |
| s.node(node["id"]) | |
| for node in save_tensor_nodes: | |
| s.node(node["id"]) | |
| with dot.subgraph() as s: | |
| s.attr(rank="max") | |
| for node in bwd_output_nodes: | |
| s.node(node["id"]) | |
| if input_nodes and fwd_nodes: | |
| input_nodes[0]["id"] | |
| first_fwd = fwd_nodes[0]["id"] if fwd_nodes else None | |
| if first_fwd: | |
| pass | |
| if bwd_nodes and tangent_nodes: | |
| first_bwd = bwd_nodes[0]["id"] | |
| first_tangent = tangent_nodes[0]["id"] | |
| dot.edge(first_tangent, first_bwd, style="invis", constraint="true") | |
| for edge in edges: | |
| dot.edge(edge["source"], edge["target"]) | |
| return dot | |
| def get_joint_graph_path(sub_dir: str = "", filename: str = "") -> str: | |
| """Get the path for saving joint graph visualization.""" | |
| cache_root_dir = get_compile_config().cache_root_dir | |
| rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 | |
| joint_graph_dir = os.path.join(cache_root_dir, "joint_graph_views", f"rank_{rank}") | |
| os.makedirs(joint_graph_dir, exist_ok=True) | |
| if sub_dir: | |
| joint_graph_sub_dir = os.path.join(joint_graph_dir, sub_dir) | |
| os.makedirs(joint_graph_sub_dir, exist_ok=True) | |
| if filename: | |
| return os.path.join(joint_graph_sub_dir, filename) | |
| return joint_graph_sub_dir | |
| if filename: | |
| return os.path.join(joint_graph_dir, filename) | |
| return joint_graph_dir | |
| def joint_graph_vis( | |
| joint_module: fx.GraphModule, | |
| fwd_module: fx.GraphModule, | |
| bwd_module: fx.GraphModule, | |
| save_tensor_nodes: Optional[List[fx.Node]] = None, | |
| file_path: str = None, | |
| ): | |
| """ | |
| Visualize joint graph with coloring priority: save_tensor > fwd > bwd. | |
| Layout (bottom-to-top flow): | |
| - Top: BWD output (gradient outputs) | |
| - Middle-Left: FWD cluster (inputs + fwd ops + save_tensors) | |
| - Middle-Right: BWD cluster (bwd ops) | |
| - Bottom: Tangent inputs + FWD outputs (save_tensors) | |
| Node colors: | |
| - Blue box: Input nodes (fwd placeholders) | |
| - Cyan box: Tangent nodes (bwd gradient inputs) | |
| - Green ellipse: FWD nodes | |
| - Yellow hexagon: SaveTensor nodes (in FWD cluster) | |
| - Orange ellipse: BWD nodes | |
| - Orange box: BWD output (gradient outputs, at top) | |
| Args: | |
| joint_module: The joint graph module containing both fwd and bwd | |
| fwd_module: The forward graph module | |
| bwd_module: The backward graph module | |
| save_tensor_nodes: List of nodes that are saved tensors for backward | |
| file_path: Optional path to save the visualization. If None, uses default path. | |
| """ | |
| if not envs.MAGI_ENABLE_FX_GRAPH_VIZ: | |
| magi_logger.info("Joint graph visualization is disabled. Set MAGI_ENABLE_FX_GRAPH_VIZ=true to enable it.") | |
| return | |
| joint_graph = joint_module.graph if isinstance(joint_module, fx.GraphModule) else joint_module | |
| fwd_graph = fwd_module.graph if isinstance(fwd_module, fx.GraphModule) else fwd_module | |
| bwd_graph = bwd_module.graph if isinstance(bwd_module, fx.GraphModule) else bwd_module | |
| node_categories, input_save_tensors = categorize_joint_nodes(joint_graph, fwd_graph, bwd_graph, save_tensor_nodes) | |
| nodes, edges = extract_joint_graph_structure(joint_graph, node_categories, input_save_tensors) | |
| dot = create_joint_graph_dot(nodes, edges) | |
| if file_path is None: | |
| file_path = get_joint_graph_path(filename="joint_graph") | |
| dot.render(filename=file_path, view=False, cleanup=True) | |
| magi_logger.info("Joint graph visualization saved to: %s.pdf", file_path) | |
Xet Storage Details
- Size:
- 12.2 kB
- Xet hash:
- dd32db8223ef60ebf03731ce6e05f4e5ef7ce5008d3c8bc3555847f400642b48
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.