|
|
|
|
|
import re
|
|
|
from typing import Callable, Optional, Union
|
|
|
|
|
|
import torch.fx
|
|
|
from torch.fx.node import map_arg
|
|
|
from torch.fx.passes.split_module import split_module
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
"FoldedGraphModule",
|
|
|
"get_unique_attr_name_in_module",
|
|
|
"split_const_subgraphs",
|
|
|
]
|
|
|
|
|
|
|
|
|
class FoldedGraphModule(torch.fx.GraphModule):
|
|
|
"""
|
|
|
FoldedGraphModule is a GraphModule which also contains another
|
|
|
`const_subgraph_module` representing a subgraph which has all const attr
|
|
|
inputs and which can be run once before running the main standard
|
|
|
`graph`. The `const_output_names` are the ordered list names of attrs which
|
|
|
represent what each respective output from the const_subgraph should be set
|
|
|
on which attrs.
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
root: torch.nn.Module,
|
|
|
graph: torch.fx.Graph,
|
|
|
const_subgraph: Optional[torch.fx.Graph] = None,
|
|
|
fx_const_folded_attrs_name: Optional[str] = None,
|
|
|
device_for_folded_attrs: str = "cuda",
|
|
|
):
|
|
|
super().__init__(root, graph)
|
|
|
self.const_subgraph_module = (
|
|
|
None
|
|
|
if const_subgraph is None
|
|
|
else torch.fx.GraphModule(root, const_subgraph)
|
|
|
)
|
|
|
self.has_folding_been_run = False
|
|
|
self.fx_const_folded_attrs_name = fx_const_folded_attrs_name
|
|
|
self.device_for_folded_attrs = device_for_folded_attrs
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
|
if not self.has_folding_been_run:
|
|
|
self.run_folding()
|
|
|
return super().__call__(*args)
|
|
|
|
|
|
def run_folding(self):
|
|
|
|
|
|
|
|
|
if (
|
|
|
self.const_subgraph_module is None
|
|
|
or self.fx_const_folded_attrs_name is None
|
|
|
):
|
|
|
return
|
|
|
|
|
|
assert not self.has_folding_been_run
|
|
|
self.has_folding_been_run = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
folded_attrs = self.const_subgraph_module()
|
|
|
|
|
|
def _create_param(i):
|
|
|
return torch.nn.Parameter(
|
|
|
i.detach().clone()
|
|
|
if not isinstance(i, int)
|
|
|
else torch.Tensor([i]).to(device=self.device_for_folded_attrs),
|
|
|
requires_grad=i.requires_grad if isinstance(i, torch.Tensor) else False,
|
|
|
)
|
|
|
|
|
|
params = (
|
|
|
torch.nn.ParameterList([_create_param(i) for i in folded_attrs])
|
|
|
if isinstance(folded_attrs, tuple)
|
|
|
else _create_param(folded_attrs)
|
|
|
)
|
|
|
setattr(self, self.fx_const_folded_attrs_name, params)
|
|
|
|
|
|
|
|
|
def _inline_module(gm: torch.fx.GraphModule, inline_mod_name: str):
|
|
|
"""
|
|
|
Given `gm` and some graph module which is called with target name `inline_mod_name`,
|
|
|
this helper will inline all of the nodes from that called graph module into `gm`.
|
|
|
"""
|
|
|
|
|
|
inline_mod = dict(gm.named_modules())[inline_mod_name]
|
|
|
assert isinstance(inline_mod, torch.fx.GraphModule)
|
|
|
call_mod_node_to_replace = None
|
|
|
for node in gm.graph.nodes:
|
|
|
if node.op == "call_module" and node.target == inline_mod_name:
|
|
|
call_mod_node_to_replace = node
|
|
|
break
|
|
|
assert call_mod_node_to_replace is not None
|
|
|
|
|
|
|
|
|
|
|
|
call_mod_args = call_mod_node_to_replace.args
|
|
|
call_mod_kwargs = call_mod_node_to_replace.kwargs
|
|
|
|
|
|
replacement_mapping: dict[torch.fx.Node, torch.fx.Node] = {}
|
|
|
ph_count = 0
|
|
|
|
|
|
def replacement_fn(node):
|
|
|
new_node = replacement_mapping[node]
|
|
|
new_node.meta = node.meta.copy()
|
|
|
return new_node
|
|
|
|
|
|
for inline_node in inline_mod.graph.nodes:
|
|
|
if inline_node.op == "placeholder":
|
|
|
replacement_mapping[inline_node] = (
|
|
|
call_mod_kwargs[inline_node.name]
|
|
|
if inline_node.name in call_mod_kwargs
|
|
|
else call_mod_args[ph_count]
|
|
|
)
|
|
|
|
|
|
ph_count += 1
|
|
|
continue
|
|
|
|
|
|
if inline_node.op == "output":
|
|
|
outputs = inline_node.args[0]
|
|
|
output_replacements = map_arg(outputs, replacement_fn)
|
|
|
call_mod_node_to_replace.replace_all_uses_with(output_replacements)
|
|
|
continue
|
|
|
|
|
|
with gm.graph.inserting_before(call_mod_node_to_replace):
|
|
|
new_node = gm.graph.node_copy(inline_node, replacement_fn)
|
|
|
replacement_mapping[inline_node] = new_node
|
|
|
|
|
|
gm.graph.eliminate_dead_code()
|
|
|
|
|
|
|
|
|
def get_unique_attr_name_in_module(mod_traced: torch.fx.GraphModule, name: str) -> str:
|
|
|
"""
|
|
|
Make sure the name is unique (in a module) and can represents an attr.
|
|
|
"""
|
|
|
|
|
|
name = re.sub("[^0-9a-zA-Z_]+", "_", name)
|
|
|
if name[0].isdigit():
|
|
|
name = f"_{name}"
|
|
|
|
|
|
while hasattr(mod_traced, name):
|
|
|
match = re.match(r"(.*)_(\d+)$", name)
|
|
|
if match is None:
|
|
|
name = name + "_1"
|
|
|
else:
|
|
|
base, num = match.group(1, 2)
|
|
|
name = f"{base}_{int(num) + 1}"
|
|
|
|
|
|
return name
|
|
|
|
|
|
|
|
|
def split_const_subgraphs(
|
|
|
module: Union[torch.nn.Module, torch.fx.GraphModule],
|
|
|
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
|
|
|
device_for_folded_attrs: str = "cpu",
|
|
|
) -> FoldedGraphModule:
|
|
|
"""
|
|
|
Looks through `module` for any nodes that have all constant attribute inputs
|
|
|
and separates them out into their own constant subgraph, and returns a
|
|
|
FoldedGraphModule which runs that constant subgraph on the first run to set
|
|
|
attributes on the module prior to running the non-constant portion of the
|
|
|
graph.
|
|
|
"""
|
|
|
if not isinstance(module, torch.fx.GraphModule):
|
|
|
mod_traced = torch.fx.symbolic_trace(module)
|
|
|
else:
|
|
|
mod_traced = module
|
|
|
|
|
|
|
|
|
|
|
|
const_nodes: set[torch.fx.Node] = set()
|
|
|
found_const_folding = False
|
|
|
for node in mod_traced.graph.nodes:
|
|
|
|
|
|
|
|
|
if node.op in {"placeholder", "output"}:
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
if node.op != "get_attr" and not set(node.all_input_nodes).issubset(
|
|
|
const_nodes
|
|
|
):
|
|
|
continue
|
|
|
|
|
|
|
|
|
if skip_folding_node_fn and skip_folding_node_fn(node):
|
|
|
continue
|
|
|
|
|
|
|
|
|
if node.is_impure():
|
|
|
continue
|
|
|
|
|
|
|
|
|
const_nodes.add(node)
|
|
|
if node.op != "get_attr":
|
|
|
found_const_folding = True
|
|
|
|
|
|
|
|
|
if not found_const_folding:
|
|
|
return FoldedGraphModule(mod_traced, mod_traced.graph)
|
|
|
|
|
|
|
|
|
|
|
|
def mod_partition(node: torch.fx.Node):
|
|
|
return 0 if node in const_nodes else 1
|
|
|
|
|
|
split = split_module(mod_traced, module, mod_partition)
|
|
|
|
|
|
const_mod_name, non_const_mod_name = "submod_0", "submod_1"
|
|
|
|
|
|
const_gm, non_const_gm = split.submod_0, getattr(split, non_const_mod_name, None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for node in non_const_gm.graph.nodes if non_const_gm else []:
|
|
|
if node.op == "call_module":
|
|
|
setattr(split, node.target, getattr(non_const_gm, node.target))
|
|
|
for node in const_gm.graph.nodes:
|
|
|
if node.op == "call_module":
|
|
|
setattr(split, node.target, getattr(const_gm, node.target))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
call_const_gm_args = None
|
|
|
for node in split.graph.nodes:
|
|
|
if node.op == "call_module":
|
|
|
if node.target == const_mod_name:
|
|
|
call_const_gm_args = node.args
|
|
|
break
|
|
|
assert call_const_gm_args is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
root_const_gm = torch.fx.GraphModule(split, const_gm.graph)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ph_idx = 0
|
|
|
for node in root_const_gm.graph.nodes:
|
|
|
if node.op == "output":
|
|
|
multiple_outputs = isinstance(node.args[0], tuple)
|
|
|
continue
|
|
|
if node.op != "placeholder":
|
|
|
continue
|
|
|
assert ph_idx < len(call_const_gm_args)
|
|
|
in_node = call_const_gm_args[ph_idx]
|
|
|
ph_idx += 1
|
|
|
assert in_node.op == "get_attr"
|
|
|
with root_const_gm.graph.inserting_before(node):
|
|
|
new_node = root_const_gm.graph.get_attr(in_node.target)
|
|
|
new_node.meta = node.meta.copy()
|
|
|
node.replace_all_uses_with(new_node)
|
|
|
root_const_gm.graph.erase_node(node)
|
|
|
assert "multiple_outputs" in locals()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fx_const_folded_attrs_name = get_unique_attr_name_in_module(
|
|
|
mod_traced, "_FX_CONST_FOLDED_ATTRS"
|
|
|
)
|
|
|
setattr(
|
|
|
split,
|
|
|
fx_const_folded_attrs_name,
|
|
|
torch.nn.ParameterList() if multiple_outputs else torch.nn.Parameter(),
|
|
|
)
|
|
|
for node in split.graph.nodes:
|
|
|
if node.op == "call_module" and node.target == const_mod_name:
|
|
|
with node.graph.inserting_before(node):
|
|
|
folded_attrs = node.graph.get_attr(fx_const_folded_attrs_name)
|
|
|
folded_attrs.meta = node.meta.copy()
|
|
|
node.replace_all_uses_with(folded_attrs)
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if hasattr(split, non_const_mod_name):
|
|
|
_inline_module(split, non_const_mod_name)
|
|
|
|
|
|
split.graph.eliminate_dead_code()
|
|
|
|
|
|
return FoldedGraphModule(
|
|
|
split,
|
|
|
split.graph,
|
|
|
root_const_gm.graph,
|
|
|
fx_const_folded_attrs_name,
|
|
|
device_for_folded_attrs,
|
|
|
)
|
|
|
|