Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| import os | |
| import time | |
| import functools | |
| import argparse | |
| import logging | |
| import warnings | |
| from dataclasses import dataclass | |
| logging.getLogger("DeepSpeed").disabled = True | |
| warnings.filterwarnings(action="ignore", category=FutureWarning) | |
| warnings.filterwarnings(action="ignore", category=DeprecationWarning) | |
| import torch | |
| import diffusers | |
| n_warmup = 5 | |
| n_traces = 10 | |
| n_runs = 100 | |
| args = {} | |
| pipe = None | |
| log = logging.getLogger("sd") | |
| def setup_logging(): | |
| from rich.theme import Theme | |
| from rich.logging import RichHandler | |
| from rich.console import Console | |
| from rich.traceback import install | |
| log.setLevel(logging.DEBUG) | |
| console = Console(log_time=True, log_time_format='%H:%M:%S-%f', theme=Theme({ "traceback.border": "black", "traceback.border.syntax_error": "black", "inspect.value.border": "black" })) | |
| logging.basicConfig(level=logging.ERROR, format='%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s', handlers=[logging.NullHandler()]) # redirect default logger to null | |
| rh = RichHandler(show_time=True, omit_repeated_times=False, show_level=True, show_path=False, markup=False, rich_tracebacks=True, log_time_format='%H:%M:%S-%f', level=logging.DEBUG, console=console) | |
| rh.setLevel(logging.DEBUG) | |
| log.addHandler(rh) | |
| logging.getLogger("diffusers").setLevel(logging.ERROR) | |
| logging.getLogger("torch").setLevel(logging.ERROR) | |
| warnings.filterwarnings(action="ignore", category=torch.jit.TracerWarning) | |
| install(console=console, extra_lines=1, max_frames=10, width=console.width, word_wrap=False, indent_guides=False, suppress=[]) | |
| def generate_inputs(): | |
| if args.type == 'sd15': | |
| sample = torch.randn(2, 4, 64, 64).half().cuda() | |
| timestep = torch.rand(1).half().cuda() * 999 | |
| encoder_hidden_states = torch.randn(2, 77, 768).half().cuda() | |
| return sample, timestep, encoder_hidden_states | |
| if args.type == 'sdxl': | |
| sample = torch.randn(2, 4, 64, 64).half().cuda() | |
| timestep = torch.rand(1).half().cuda() * 999 | |
| encoder_hidden_states = torch.randn(2, 77, 768).half().cuda() | |
| text_embeds = torch.randn(1, 77, 2048).half().cuda() | |
| return sample, timestep, encoder_hidden_states, text_embeds | |
| def load_model(): | |
| log.info(f'versions: torch={torch.__version__} diffusers={diffusers.__version__}') | |
| diffusers_load_config = { | |
| "low_cpu_mem_usage": True, | |
| "torch_dtype": torch.float16, | |
| "safety_checker": None, | |
| "requires_safety_checker": False, | |
| "load_safety_checker": False, | |
| "load_connected_pipeline": True, | |
| "use_safetensors": True, | |
| } | |
| pipeline = diffusers.StableDiffusionPipeline if args.type == 'sd15' else diffusers.StableDiffusionXLPipeline | |
| global pipe # pylint: disable=global-statement | |
| t0 = time.time() | |
| pipe = pipeline.from_single_file(args.model, **diffusers_load_config).to('cuda') | |
| size = os.path.getsize(args.model) | |
| log.info(f'load: model={args.model} type={args.type} time={time.time() - t0:.3f}s size={size / 1024 / 1024:.3f}mb') | |
| def load_trace(fn: str): | |
| class UNet2DConditionOutput: | |
| sample: torch.FloatTensor | |
| class TracedUNet(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.in_channels = pipe.unet.in_channels | |
| self.device = pipe.unet.device | |
| def forward(self, latent_model_input, t, encoder_hidden_states): | |
| sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0] | |
| return UNet2DConditionOutput(sample=sample) | |
| t0 = time.time() | |
| unet_traced = torch.jit.load(fn) | |
| pipe.unet = TracedUNet() | |
| size = os.path.getsize(fn) | |
| log.info(f'load: optimized={fn} time={time.time() - t0:.3f}s size={size / 1024 / 1024:.3f}mb') | |
| def trace_model(): | |
| log.info(f'tracing model: {args.model}') | |
| torch.set_grad_enabled(False) | |
| unet = pipe.unet | |
| unet.eval() | |
| # unet.to(memory_format=torch.channels_last) # use channels_last memory format | |
| unet.forward = functools.partial(unet.forward, return_dict=False) # set return_dict=False as default | |
| # warmup | |
| t0 = time.time() | |
| for _ in range(n_warmup): | |
| with torch.inference_mode(): | |
| inputs = generate_inputs() | |
| _output = unet(*inputs) | |
| log.info(f'warmup: time={time.time() - t0:.3f}s passes={n_warmup}') | |
| # trace | |
| t0 = time.time() | |
| unet_traced = torch.jit.trace(unet, inputs, check_trace=True) | |
| unet_traced.eval() | |
| log.info(f'trace: time={time.time() - t0:.3f}s') | |
| # optimize graph | |
| t0 = time.time() | |
| for _ in range(n_traces): | |
| with torch.inference_mode(): | |
| inputs = generate_inputs() | |
| _output = unet_traced(*inputs) | |
| log.info(f'optimize: time={time.time() - t0:.3f}s passes={n_traces}') | |
| # save the model | |
| if args.save: | |
| t0 = time.time() | |
| basename, _ext = os.path.splitext(args.model) | |
| fn = f"{basename}.pt" | |
| unet_traced.save(fn) | |
| size = os.path.getsize(fn) | |
| log.info(f'save: optimized={fn} time={time.time() - t0:.3f}s size={size / 1024 / 1024:.3f}mb') | |
| return fn | |
| pipe.unet = unet_traced | |
| return None | |
| def benchmark_model(msg: str): | |
| with torch.inference_mode(): | |
| inputs = generate_inputs() | |
| torch.cuda.synchronize() | |
| for n in range(n_runs): | |
| if n > n_runs / 10: | |
| t0 = time.time() | |
| _output = pipe.unet(*inputs) | |
| torch.cuda.synchronize() | |
| t1 = time.time() | |
| log.info(f"benchmark unet: {t1 - t0:.3f}s passes={n_runs} type={msg}") | |
| return t1 - t0 | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description = 'SD.Next') | |
| parser.add_argument('--model', type=str, default='', required=True, help='model path') | |
| parser.add_argument('--type', type=str, default='sd15', choices=['sd15', 'sdxl'], required=False, help='model type, default: %(default)s') | |
| parser.add_argument('--benchmark', default = False, action='store_true', help = "run benchmarks, default: %(default)s") | |
| parser.add_argument('--trace', default = True, action='store_true', help = "run jit tracing, default: %(default)s") | |
| parser.add_argument('--save', default = False, action='store_true', help = "save optimized unet, default: %(default)s") | |
| args = parser.parse_args() | |
| setup_logging() | |
| log.info('sdnext model jit tracing') | |
| if not os.path.isfile(args.model): | |
| log.error(f"invalid model path: {args.model}") | |
| exit(1) | |
| load_model() | |
| if args.benchmark: | |
| time0 = benchmark_model('original') | |
| unet_saved = trace_model() | |
| if unet_saved is not None: | |
| load_trace(unet_saved) | |
| if args.benchmark: | |
| time1 = benchmark_model('traced') | |
| log.info(f'benchmark speedup: {100 * (time0 - time1) / time0:.3f}%') | |