File size: 1,646 Bytes
c1a41d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import torch
import glog
import time
def get_graph_wrapper(cls):
class GraphWrapper(cls):
def __init__(self, config):
super(GraphWrapper, self).__init__(config)
self.built_graph = False
def forward(self, *args, **kwargs):
start = time.time()
if not self.built_graph:
self.static_args = args
self.static_kwargs = kwargs
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
super(GraphWrapper, self).forward(*self.static_args, **self.static_kwargs)
torch.cuda.current_stream().wait_stream(s)
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
self.static_output = super(GraphWrapper, self).forward(*self.static_args, **self.static_kwargs)
self.built_graph = True
glog.info("Built CUDA graph of model.")
# these two loops take < 1e-4 seconds for llama2
for i in range(len(args)):
if isinstance(args[i], torch.Tensor):
self.static_args[i].copy_(args[i])
for kw in kwargs:
if isinstance(kwargs[kw], torch.Tensor):
self.static_kwargs[kw].copy_(kwargs[kw])
self.graph.replay()
return self.static_output
def reset(self):
if self.built_graph:
del self.static_args, self.static_kwargs
self.built_graph = False
return GraphWrapper
|