| from __future__ import annotations |
| import torch |
|
|
| import comfy.utils |
| from comfy.patcher_extension import WrappersMP |
| from typing import TYPE_CHECKING, Callable, Optional |
| if TYPE_CHECKING: |
| from comfy.model_patcher import ModelPatcher |
| from comfy.patcher_extension import WrapperExecutor |
|
|
|
|
| COMPILE_KEY = "torch.compile" |
| TORCH_COMPILE_KWARGS = "torch_compile_kwargs" |
|
|
|
|
| def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable]) -> Callable: |
| ''' |
| Create a wrapper that will refer to the compiled_diffusion_model. |
| ''' |
| def apply_torch_compile_wrapper(executor: WrapperExecutor, *args, **kwargs): |
| try: |
| orig_modules = {} |
| for key, value in compiled_module_dict.items(): |
| orig_modules[key] = comfy.utils.get_attr(executor.class_obj, key) |
| comfy.utils.set_attr(executor.class_obj, key, value) |
| return executor(*args, **kwargs) |
| finally: |
| for key, value in orig_modules.items(): |
| comfy.utils.set_attr(executor.class_obj, key, value) |
| return apply_torch_compile_wrapper |
|
|
|
|
| def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Optional[dict[str,str]]=None, |
| mode: Optional[str]=None, fullgraph=False, dynamic: Optional[bool]=None, |
| keys: list[str]=["diffusion_model"], *args, **kwargs): |
| ''' |
| Perform torch.compile that will be applied at sample time for either the whole model or specific params of the BaseModel instance. |
| |
| When keys is None, it will default to using ["diffusion_model"], compiling the whole diffusion_model. |
| When a list of keys is provided, it will perform torch.compile on only the selected modules. |
| ''' |
| |
| model.remove_wrappers_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY) |
| |
| if not keys: |
| keys = ["diffusion_model"] |
| |
| compile_kwargs = { |
| "backend": backend, |
| "options": options, |
| "mode": mode, |
| "fullgraph": fullgraph, |
| "dynamic": dynamic, |
| } |
| |
| compiled_modules = {} |
| for key in keys: |
| compiled_modules[key] = torch.compile( |
| model=model.get_model_object(key), |
| **compile_kwargs, |
| ) |
| |
| wrapper_func = apply_torch_compile_factory( |
| compiled_module_dict=compiled_modules, |
| ) |
| |
| model.add_wrapper_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY, wrapper_func) |
| |
| model.model_options[TORCH_COMPILE_KWARGS] = compile_kwargs |
|
|