Spaces:
Running on Zero
Running on Zero
File size: 10,008 Bytes
b701455 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 | import contextlib
import functools
import importlib.util
import logging
import os
import traceback
from dataclasses import dataclass
import torch
try:
from sfast.compilers.diffusion_pipeline_compiler import CompilationConfig
from sfast.compilers.diffusion_pipeline_compiler import (
_enable_xformers,
_modify_model,
)
from sfast.cuda.graphs import make_dynamic_graphed_callable
from sfast.jit import utils as jit_utils
from sfast.jit.trace_helper import trace_with_kwargs
except Exception: # pragma: no cover - sfast optional dependency
CompilationConfig = None
_enable_xformers = None
_modify_model = None
make_dynamic_graphed_callable = None
jit_utils = None
trace_with_kwargs = None
def hash_arg(arg):
# micro optimization: bool obj is an instance of int
if isinstance(arg, (str, int, float, bytes)):
return arg
if isinstance(arg, (tuple, list)):
return tuple(map(hash_arg, arg))
if isinstance(arg, dict):
return tuple(
sorted(
((hash_arg(k), hash_arg(v)) for k, v in arg.items()), key=lambda x: x[0]
)
)
return type(arg)
class ModuleFactory:
def get_converted_kwargs(self):
return self.converted_kwargs
class BaseModelApplyModelModule(torch.nn.Module):
def __init__(self, func, module):
super().__init__()
self.func = func
self.module = module
def forward(
self,
input_x,
timestep,
c_concat=None,
c_crossattn=None,
y=None,
control=None,
transformer_options={},
):
kwargs = {"y": y}
new_transformer_options = {}
return self.func(
input_x,
timestep,
c_concat=c_concat,
c_crossattn=c_crossattn,
control=control,
transformer_options=new_transformer_options,
**kwargs,
)
class BaseModelApplyModelModuleFactory(ModuleFactory):
kwargs_name = (
"input_x",
"timestep",
"c_concat",
"c_crossattn",
"y",
"control",
)
def __init__(self, callable, kwargs) -> None:
self.callable = callable
self.unet_config = callable.__self__.model_config.unet_config
self.kwargs = kwargs
self.converted_kwargs = self.gen_converted_kwargs()
def gen_converted_kwargs(self):
converted_kwargs = {}
for arg_name, arg in self.kwargs.items():
if arg_name in self.kwargs_name:
converted_kwargs[arg_name] = arg
return converted_kwargs
def gen_cache_key(self):
key_kwargs = {}
for k, v in self.converted_kwargs.items():
key_kwargs[k] = v
patch_module_cache_key = {}
return (
self.callable.__class__.__qualname__,
hash_arg(self.unet_config),
hash_arg(key_kwargs),
hash_arg(patch_module_cache_key),
)
@contextlib.contextmanager
def converted_module_context(self):
module = BaseModelApplyModelModule(self.callable, self.callable.__self__)
yield (module, self.converted_kwargs)
logger = logging.getLogger()
@dataclass
class TracedModuleCacheItem:
module: object
patch_id: int
device: str
class LazyTraceModule:
traced_modules = {}
cuda_graph_modules = {}
def __init__(self, config=None, patch_id=None, **kwargs_) -> None:
self.config = config
self.patch_id = patch_id
self.kwargs_ = kwargs_
self.modify_model = functools.partial(
_modify_model,
enable_cnn_optimization=config.enable_cnn_optimization,
prefer_lowp_gemm=config.prefer_lowp_gemm,
enable_triton=config.enable_triton,
enable_triton_reshape=config.enable_triton,
memory_format=config.memory_format,
)
def ts_compiler(
self,
m,
):
with torch.jit.optimized_execution(True):
if self.config.enable_jit_freeze:
# raw freeze causes Tensor reference leak
# because the constant Tensors in the GraphFunction of
# the compilation unit are never freed.
m.eval()
m = jit_utils.better_freeze(m)
self.modify_model(m)
if self.config.enable_cuda_graph:
m = make_dynamic_graphed_callable(m)
return m
def __call__(self, model_function, /, **kwargs):
module_factory = BaseModelApplyModelModuleFactory(model_function, kwargs)
kwargs = module_factory.get_converted_kwargs()
key = module_factory.gen_cache_key()
traced_module = self.cuda_graph_modules.get(key)
if traced_module is None:
with module_factory.converted_module_context() as (m_model, m_kwargs):
logger.info(
f'Tracing {getattr(m_model, "__name__", m_model.__class__.__name__)}'
)
tensor_debug = os.getenv("STABLE_FAST_DEBUG_TENSOR", "0") == "1"
if tensor_debug:
original_tensor = torch.tensor
def _debug_tensor(*args, **kwargs):
stack = "".join(traceback.format_stack(limit=8))
logger.warning(
"Stable Fast trace hit torch.tensor with args=%s kwargs=%s\n%s",
args,
kwargs,
stack,
)
return original_tensor(*args, **kwargs)
torch.tensor = _debug_tensor
try:
traced_m, call_helper = trace_with_kwargs(
m_model, None, m_kwargs, **self.kwargs_
)
finally:
if tensor_debug:
torch.tensor = original_tensor
traced_m = self.ts_compiler(traced_m)
traced_module = call_helper(traced_m)
self.cuda_graph_modules[key] = traced_module
return traced_module(**kwargs)
def build_lazy_trace_module(config, device, patch_id):
config.enable_cuda_graph = config.enable_cuda_graph and device.type == "cuda"
if config.enable_xformers:
_enable_xformers(None)
return LazyTraceModule(
config=config,
patch_id=patch_id,
check_trace=True,
strict=True,
)
def gen_stable_fast_config(enable_cuda_graph=False):
"""
Generate a StableFast compilation config, enabling or disabling CUDA graph
compilation based on the `enable_cuda_graph` parameter.
If the optional sfast package is not available, a fallback config is returned.
"""
if CompilationConfig is None:
logger.warning("StableFast: optional 'sfast' dependency not available; using fallback no-op config")
class _FallbackConfig:
def __init__(self):
self.enable_xformers = False
self.enable_cuda_graph = False
self.enable_jit_freeze = False
self.enable_cnn_optimization = False
self.prefer_lowp_gemm = False
self.enable_triton = False
self.memory_format = None
return _FallbackConfig()
config = CompilationConfig.Default()
if importlib.util.find_spec("xformers") is not None:
config.enable_xformers = True
else:
print("xformers not installed, skip")
# CUDA Graph is suggested for small batch sizes and can improve performance.
# When enabled, the model is specialized for a fixed image size after capture.
config.enable_cuda_graph = enable_cuda_graph
return config
class StableFastPatch:
def __init__(self, model, config):
self.model = model
self.config = config
self.stable_fast_model = None
def __call__(self, model_function, params):
input_x = params.get("input")
timestep_ = params.get("timestep")
c = params.get("c")
if self.stable_fast_model is None:
self.stable_fast_model = build_lazy_trace_module(
self.config,
input_x.device,
id(self),
)
return self.stable_fast_model(
model_function, input_x=input_x, timestep=timestep_, **c
)
def to(self, device):
if isinstance(device, torch.device):
if self.config.enable_cuda_graph or self.config.enable_jit_freeze:
if device.type == "cpu":
del self.stable_fast_model
self.stable_fast_model = None
print(
"\33[93mWarning: Your graphics card doesn't have enough video memory to keep the model. If you experience a noticeable delay every time you start sampling, please consider disable enable_cuda_graph.\33[0m"
)
return self
class ApplyStableFastUnet:
def apply_stable_fast(self, model, enable_cuda_graph):
# If the compilation components are not present, skip applying
# the StableFast transformation and return the original model
# in a single-element tuple so callers can index [0] as usual.
if CompilationConfig is None or trace_with_kwargs is None or _modify_model is None:
logger.warning(
"StableFast.apply_stable_fast: sfast optional components missing; skipping stable-fast patch and returning original model"
)
return (model,)
config = gen_stable_fast_config(enable_cuda_graph)
if config.memory_format is not None:
model.model.to(memory_format=config.memory_format)
patch = StableFastPatch(model, config)
model_stable_fast = model.clone()
model_stable_fast.set_model_unet_function_wrapper(patch)
return (model_stable_fast,)
|