| | import functools |
| |
|
| | import torch |
| | from diffusers.models.attention import BasicTransformerBlock |
| | from diffusers.utils.import_utils import is_xformers_available |
| |
|
| | from .lora import LoraInjectedLinear |
| |
|
| | if is_xformers_available(): |
| | import xformers |
| | import xformers.ops |
| | else: |
| | xformers = None |
| |
|
| |
|
| | @functools.cache |
| | def test_xformers_backwards(size): |
| | @torch.enable_grad() |
| | def _grad(size): |
| | q = torch.randn((1, 4, size), device="cuda") |
| | k = torch.randn((1, 4, size), device="cuda") |
| | v = torch.randn((1, 4, size), device="cuda") |
| |
|
| | q = q.detach().requires_grad_() |
| | k = k.detach().requires_grad_() |
| | v = v.detach().requires_grad_() |
| |
|
| | out = xformers.ops.memory_efficient_attention(q, k, v) |
| | loss = out.sum(2).mean(0).sum() |
| |
|
| | return torch.autograd.grad(loss, v) |
| |
|
| | try: |
| | _grad(size) |
| | print(size, "pass") |
| | return True |
| | except Exception as e: |
| | print(size, "fail") |
| | return False |
| |
|
| |
|
| | def set_use_memory_efficient_attention_xformers( |
| | module: torch.nn.Module, valid: bool |
| | ) -> None: |
| | def fn_test_dim_head(module: torch.nn.Module): |
| | if isinstance(module, BasicTransformerBlock): |
| | |
| | source = module.attn1.to_v |
| | if isinstance(source, LoraInjectedLinear): |
| | source = source.linear |
| |
|
| | dim_head = source.out_features // module.attn1.heads |
| |
|
| | result = test_xformers_backwards(dim_head) |
| |
|
| | |
| | if not result: |
| | module.set_use_memory_efficient_attention_xformers(False) |
| |
|
| | for child in module.children(): |
| | fn_test_dim_head(child) |
| |
|
| | if not is_xformers_available() and valid: |
| | print("XFormers is not available. Skipping.") |
| | return |
| |
|
| | module.set_use_memory_efficient_attention_xformers(valid) |
| |
|
| | if valid: |
| | fn_test_dim_head(module) |
| |
|