|
|
import torch |
|
|
from torch.fx import GraphModule |
|
|
from torch.nn import Module |
|
|
from torch.fx.passes.backends.cudagraphs import partition_cudagraphs |
|
|
from torch.multiprocessing.reductions import StorageWeakRef |
|
|
from torch.utils._pytree import tree_map |
|
|
import torchdynamo |
|
|
from torchdynamo.optimizations.training import AOTAutogradStrategy |
|
|
|
|
|
import operator |
|
|
from collections import defaultdict |
|
|
from typing import Set |
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['aot_autograd_cudagraphs'] |
|
|
|
|
|
def cloner(t): |
|
|
if isinstance(t, torch.Tensor): |
|
|
return t.clone() |
|
|
else: |
|
|
return t |
|
|
|
|
|
|
|
|
class CudaGraphModule(Module): |
|
|
gm: GraphModule |
|
|
mutated_inputs: Set[int] |
|
|
|
|
|
def __init__(self, gm, mutated_inputs): |
|
|
super().__init__() |
|
|
self.gm = gm |
|
|
self.mutated_inputs = mutated_inputs |
|
|
|
|
|
warmed_up = False |
|
|
|
|
|
|
|
|
graph = None |
|
|
static_inputs = None |
|
|
static_outputs = None |
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, *args): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.graph is not None: |
|
|
assert len(args) == len(self.static_inputs) |
|
|
for dst, src in zip(self.static_inputs, args): |
|
|
dst.copy_(src) |
|
|
self.graph.replay() |
|
|
for i in self.mutated_inputs: |
|
|
args[i].copy_(self.static_inputs[i]) |
|
|
return tree_map(cloner, self.static_outputs) |
|
|
|
|
|
elif self.warmed_up: |
|
|
|
|
|
self.static_inputs = [x.clone() for x in args] |
|
|
self.graph = torch.cuda.CUDAGraph() |
|
|
with torch.cuda.graph(self.graph): |
|
|
self.static_outputs = self.gm(*self.static_inputs) |
|
|
|
|
|
|
|
|
self.graph.replay() |
|
|
for i in self.mutated_inputs: |
|
|
args[i].copy_(self.static_inputs[i]) |
|
|
return tree_map(cloner, self.static_outputs) |
|
|
|
|
|
else: |
|
|
|
|
|
stream = torch.cuda.Stream() |
|
|
stream.wait_stream(torch.cuda.current_stream()) |
|
|
with torch.cuda.stream(stream): |
|
|
r = self.gm(*args) |
|
|
torch.cuda.current_stream().wait_stream(stream) |
|
|
self.warmed_up = True |
|
|
return r |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def find_input_mutations(g): |
|
|
FK = 'fake_result' |
|
|
inputs = defaultdict(set) |
|
|
input_idx = 0 |
|
|
mutated_inputs = set() |
|
|
for n in g.nodes: |
|
|
if n.op == 'placeholder': |
|
|
inputs[StorageWeakRef(n.meta[FK].storage())].add(input_idx) |
|
|
input_idx += 1 |
|
|
elif n.op == 'call_function': |
|
|
if n.target is operator.getitem: |
|
|
continue |
|
|
schema = n.target._schema |
|
|
for i, arg in enumerate(schema.arguments): |
|
|
if i < len(n.args): |
|
|
argument = n.args[i] |
|
|
else: |
|
|
if arg.name not in n.kwargs: |
|
|
continue |
|
|
argument = n.kwargs[arg.name] |
|
|
mut_arg = False |
|
|
if arg.alias_info: |
|
|
if arg.alias_info.is_write: |
|
|
mut_arg = True |
|
|
if mut_arg: |
|
|
|
|
|
|
|
|
mutated_inputs |= inputs[StorageWeakRef(argument.meta[FK].storage())] |
|
|
|
|
|
return mutated_inputs |
|
|
|
|
|
|
|
|
|
|
|
def apply_cuda_graphs(gm): |
|
|
for n in gm.graph.nodes: |
|
|
if n.op == 'call_module': |
|
|
assert not n.kwargs |
|
|
submod = gm.get_submodule(n.target) |
|
|
gm.delete_submodule(n.target) |
|
|
mutated_inputs = find_input_mutations(submod.graph) |
|
|
gm.add_submodule(n.target, CudaGraphModule(submod, mutated_inputs)) |
|
|
|
|
|
|
|
|
|
|
|
def cudagraphs(model, inputs): |
|
|
model = partition_cudagraphs(model, inputs) |
|
|
apply_cuda_graphs(model) |
|
|
return model |
|
|
|
|
|
|
|
|
def raw_aot_autograd_cudagraphs(model, inputs): |
|
|
kwargs = { |
|
|
|
|
|
"fw_compiler": cudagraphs, |
|
|
"bw_compiler": cudagraphs, |
|
|
} |
|
|
|
|
|
def _wrapped_bw_compiler(*args, **kwargs): |
|
|
|
|
|
return torchdynamo.disable(bw_compiler(*args, **kwargs)) |
|
|
|
|
|
bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"] |
|
|
kwargs["bw_compiler"] = _wrapped_bw_compiler |
|
|
|
|
|
from functorch.compile import aot_module_simplified |
|
|
|
|
|
return aot_module_simplified(model, **kwargs) |
|
|
|
|
|
|
|
|
class AOTAutogradCudaGraphs(AOTAutogradStrategy): |
|
|
def candidate(self): |
|
|
return raw_aot_autograd_cudagraphs(self.gm, self.example_inputs) |
|
|
|
|
|
|
|
|
aot_autograd_cudagraphs = AOTAutogradCudaGraphs.compile_fn |
|
|
|