download
raw
9.48 kB
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Dict, List, Tuple
import graphviz
import torch
import torch.fx
from magi_compiler.config import get_compile_config
from magi_compiler.utils import envs, magi_logger
class FX_NODE_OP:
PLACEHOLDER = "placeholder"
OUTPUT = "output"
CALL_MODULE = "call_module"
CALL_FUNCTION = "call_function"
CALL_METHOD = "call_method"
GET_ATTR = "get_attr"
DEFAULT = "default"
FIXED_NODE_STYLES = {
FX_NODE_OP.PLACEHOLDER: {"shape": "rectangle", "style": "filled,bold", "fillcolor": "#E8F4FD", "color": "#2196F3"},
FX_NODE_OP.OUTPUT: {"shape": "rectangle", "style": "filled,bold", "fillcolor": "#FCE4EC", "color": "#E91E63"},
FX_NODE_OP.CALL_MODULE: {"shape": "ellipse", "style": "filled,bold", "fillcolor": "#FFF8E1", "color": "#FFC107"},
FX_NODE_OP.CALL_FUNCTION: {"shape": "ellipse", "style": "filled", "fillcolor": "#E8F5E9", "color": "#4CAF50"},
FX_NODE_OP.CALL_METHOD: {"shape": "ellipse", "style": "filled", "fillcolor": "#F3E5F5", "color": "#9C27B0"},
FX_NODE_OP.GET_ATTR: {"shape": "ellipse", "style": "filled", "fillcolor": "#FFCCBC", "color": "#FF5722"},
FX_NODE_OP.DEFAULT: {"shape": "ellipse", "style": "filled", "fillcolor": "#F5F5F5", "color": "#666666"},
"call_function.linear": {"shape": "ellipse", "style": "filled,bold", "fillcolor": "#E8F5E9", "color": "#44FF00"},
}
def build_node_to_code_map(graph: torch.fx.Graph) -> Dict[torch.fx.Node, str]:
node_to_code = {}
python_code = graph.python_code(root_module="self", verbose=False)
code_lines = python_code.src.strip().split("\n")
lineno_map = python_code._lineno_map
node_index_map = {idx: node for idx, node in enumerate(graph.nodes)}
for line_num, node_idx in lineno_map.items():
if node_idx is None or node_idx not in node_index_map:
continue
node = node_index_map[node_idx]
if 0 <= line_num < len(code_lines):
code_line = code_lines[line_num].strip()
if code_line and not code_line.startswith(("wrap(", "#", "pass")):
node_to_code[node] = code_line
for node in graph.nodes:
if node not in node_to_code:
assert node.op == FX_NODE_OP.PLACEHOLDER, f"Unexpected missing code for {node.op=}, {node.target=}"
node_to_code[node] = ""
return node_to_code
def extract_fx_graph_structure(graph: torch.fx.Graph, simple_desc: bool = False) -> Tuple[List[Dict], List[Dict]]:
import textwrap
def wrap_str_to_multi_lines(text, width=30):
return textwrap.fill(text, width=width, break_long_words=True, replace_whitespace=False)
nodes, edges = [], []
node_to_code = build_node_to_code_map(graph)
for node in graph.nodes:
name_str = str(node.name)
name_str = wrap_str_to_multi_lines(name_str)
if node.op == FX_NODE_OP.GET_ATTR:
# torch._inductor.exc.InductorError: DataDependentOutputException: aten._local_scalar_dense.default
meta_str = f"get_attr: {node.target} (skip fake tensor)"
else:
tensor_meta = node.meta.get("tensor_meta") or node.meta.get("val") or node.meta.get("example_value")
meta_str = tensor_meta_to_str(tensor_meta)
meta_str = wrap_str_to_multi_lines(meta_str)
target_str = target_to_str(node.target)
if hasattr(node, "original_target"):
original_target_str = target_to_str(node.original_target)
target_str += f"\nOriginal: {original_target_str}"
target_str = wrap_str_to_multi_lines(target_str)
node_code = node_to_code.get(node, "empty")
node_code = wrap_str_to_multi_lines(node_code)
style_info = FIXED_NODE_STYLES.get(node.op, FIXED_NODE_STYLES[FX_NODE_OP.DEFAULT])
if node.op == FX_NODE_OP.CALL_FUNCTION and f"call_function.{node.target.__name__}" in FIXED_NODE_STYLES:
style_info = FIXED_NODE_STYLES[f"call_function.{node.target.__name__}"]
node_label = f"Op: {node.op}\nTarget: {target_str}\nName: {name_str}\nMeta: {meta_str}\nCode: {node_code}"
if simple_desc:
node_label = f"Op: {node.op}\nTarget: {target_str}\nName: {name_str}"
nodes.append({"id": node.name, "style": style_info, "node_label": node_label})
def traverse_args(args_kwargs):
if isinstance(args_kwargs, (tuple, list)):
for arg in args_kwargs:
traverse_args(arg)
elif isinstance(args_kwargs, dict):
for val in args_kwargs.values():
traverse_args(val)
elif isinstance(args_kwargs, torch.fx.Node):
d = {"source": args_kwargs.name, "target": node.name}
edges.append(d) if d not in edges else None
traverse_args(node.args)
traverse_args(node.kwargs)
return nodes, edges
def target_to_str(target: Any) -> str:
res = str(target)
if isinstance(target, str):
res = target
elif hasattr(target, "_op"):
res = str(target._op)
elif callable(target):
res = getattr(target, "__name__")
return res
def tensor_meta_to_str(tensor_meta: Any) -> str:
if type(tensor_meta) in [int, float, str, bool]:
return str(tensor_meta)
elif isinstance(tensor_meta, (list, tuple)):
return f"[{', '.join([tensor_meta_to_str(t) for t in tensor_meta])}]"
elif isinstance(tensor_meta, torch.Tensor):
d = {}
d["shape"] = tensor_meta.shape if hasattr(tensor_meta, "shape") else "N/A"
d["size"] = tensor_meta.size() if hasattr(tensor_meta, "size") else "N/A"
d["ndim"] = tensor_meta.ndim if hasattr(tensor_meta, "ndim") else "N/A"
d["numel"] = tensor_meta.numel() if hasattr(tensor_meta, "numel") else "N/A"
d["stride"] = tensor_meta.stride if hasattr(tensor_meta, "stride") else "N/A"
d["stride"] = tensor_meta.stride() if hasattr(tensor_meta, "stride") and callable(tensor_meta.stride) else "N/A"
d["is_contiguous"] = tensor_meta.is_contiguous() if hasattr(tensor_meta, "is_contiguous") else "N/A"
d["dtype"] = str(tensor_meta.dtype) if hasattr(tensor_meta, "dtype") else "N/A"
d["device"] = str(tensor_meta.device) if hasattr(tensor_meta, "device") else "N/A"
return ", ".join([f"{k}: {v}" for k, v in d.items()])
else:
return str(tensor_meta)
def create_fx_graph_dot(nodes: list[Dict], edges: list[Dict]) -> graphviz.Digraph:
dot = graphviz.Digraph(
name="fx_graph",
format="pdf",
graph_attr={"rankdir": "TD", "nodesep": "0.1", "ranksep": "0.1", "overlap": "false", "splines": "spline"},
node_attr={
"fontname": "Helvetica",
"fontsize": "10",
"shape": "rect",
"style": "rounded,filled",
"fixedsize": "false",
"margin": "0.2",
},
)
for node in nodes:
dot.node(node["id"], node["node_label"], **node["style"])
for edge in edges:
dot.edge(edge["source"], edge["target"])
return dot
def get_fx_graph_path(sub_dir: str = "", filename: str = "") -> str:
cache_root_dir = get_compile_config().cache_root_dir
# Unify with magi_depyf output
fx_graph_dir = os.path.join(cache_root_dir, "magi_depyf", "visualizations")
os.makedirs(fx_graph_dir, exist_ok=True)
if sub_dir:
fx_graph_sub_dir = os.path.join(fx_graph_dir, sub_dir)
os.makedirs(fx_graph_sub_dir, exist_ok=True)
if filename:
return os.path.join(fx_graph_sub_dir, filename)
return fx_graph_sub_dir
if filename:
return os.path.join(fx_graph_dir, filename)
return fx_graph_dir
def save_fx_graph_visualization(graph: torch.fx.Graph, sub_dir: str = "", filename: str = "fx_graph"):
"""
Save FX graph visualization as PDF.
Args:
graph: The FX graph or GraphModule to visualize
sub_dir: Optional subdirectory under the fx_graph_views folder
filename: Filename for the output PDF (without extension)
"""
if not envs.MAGI_ENABLE_FX_GRAPH_VIZ:
magi_logger.info("FX graph visualization is disabled. Set MAGI_ENABLE_FX_GRAPH_VIZ=true to enable it.")
return
if isinstance(graph, torch.fx.GraphModule):
graph = graph.graph
assert envs.MAGI_FX_GRAPH_VIZ_NODE_DESC in {
"simple",
"detailed",
}, f"Invalid MAGI_FX_GRAPH_VIZ_NODE_DESC: {envs.MAGI_FX_GRAPH_VIZ_NODE_DESC}"
simple_desc = envs.MAGI_FX_GRAPH_VIZ_NODE_DESC == "simple"
file_path = get_fx_graph_path(sub_dir=sub_dir, filename=filename)
nodes, edges = extract_fx_graph_structure(graph, simple_desc=simple_desc)
dot = create_fx_graph_dot(nodes, edges)
dot.render(filename=file_path, view=False, cleanup=True)
magi_logger.info("FX graph visualization saved to: %s.pdf", file_path)

Xet Storage Details

Size:
9.48 kB
·
Xet hash:
0c07cc65437c69bcf18ccedfa8d1f4ed98fbf89bfbf6e61c4636930d1c6c3dbc

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.