Spaces:
Running on Zero
Running on Zero
| import torch | |
| from src.Device import Device | |
| from src.Model.ModelPatcher import ModelPatcher | |
| from src.cond import cond_util | |
| from src.sample.CFG import CFGGuider | |
| class DummyDiffusionModel(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.linear = torch.nn.Linear(4, 4) | |
| def memory_required(self, input_shape=None): | |
| return 1 | |
| def test_model_function_wrappers_compose_in_application_order(): | |
| patcher = ModelPatcher( | |
| DummyDiffusionModel(), | |
| load_device=torch.device("cpu"), | |
| offload_device=torch.device("cpu"), | |
| ) | |
| call_order = [] | |
| def wrapper_one(model_function, params): | |
| call_order.append("wrapper_one_before") | |
| out = model_function(params["input"], params["timestep"], **params["c"]) | |
| call_order.append("wrapper_one_after") | |
| return out + 1 | |
| def wrapper_two(model_function, params): | |
| call_order.append("wrapper_two_before") | |
| out = model_function(params["input"], params["timestep"], **params["c"]) | |
| call_order.append("wrapper_two_after") | |
| return out * 2 | |
| patcher.set_model_unet_function_wrapper(wrapper_one) | |
| patcher.set_model_unet_function_wrapper(wrapper_two) | |
| wrapped = patcher.model_options["model_function_wrapper"] | |
| def base_model_function(input_x, timestep, **c_kwargs): | |
| call_order.append("base") | |
| return input_x + c_kwargs["bias"] | |
| result = wrapped( | |
| base_model_function, | |
| { | |
| "input": torch.tensor([1.0]), | |
| "timestep": torch.tensor([0.0]), | |
| "c": {"bias": torch.tensor([3.0])}, | |
| }, | |
| ) | |
| assert torch.equal(result, torch.tensor([10.0])) | |
| assert call_order == [ | |
| "wrapper_two_before", | |
| "wrapper_one_before", | |
| "base", | |
| "wrapper_one_after", | |
| "wrapper_two_after", | |
| ] | |
| def test_sageattention_enabled_allows_compute_12_when_available(monkeypatch): | |
| monkeypatch.setattr(Device, "cpu_state", Device.CPUState.GPU) | |
| monkeypatch.setattr(Device, "directml_enabled", False) | |
| monkeypatch.setattr(Device, "SAGEATTENTION_IS_AVAILABLE", True) | |
| monkeypatch.setattr(Device, "SPARGEATTN_IS_AVAILABLE", True) | |
| monkeypatch.setattr(Device, "is_intel_xpu", lambda: False) | |
| monkeypatch.setattr(Device, "is_rocm", lambda: False) | |
| monkeypatch.setattr(torch.cuda, "is_available", lambda: True) | |
| monkeypatch.setattr(torch.cuda, "get_device_capability", lambda *args, **kwargs: (12, 0)) | |
| assert Device.sageattention_enabled() is True | |
| assert Device.spargeattn_enabled() is False | |
| def test_cfg_guider_reads_model_options_from_wrapped_model(): | |
| patcher = ModelPatcher( | |
| DummyDiffusionModel(), | |
| load_device=torch.device("cpu"), | |
| offload_device=torch.device("cpu"), | |
| ) | |
| patcher.model_options["sentinel"] = "wrapped" | |
| class Wrapper: | |
| def __init__(self, model): | |
| self.model = model | |
| self.load_device = torch.device("cpu") | |
| guider = CFGGuider(Wrapper(patcher)) | |
| assert guider.model_options["sentinel"] == "wrapped" | |
| def test_prepare_sampling_accepts_wrapper_objects(monkeypatch): | |
| patcher = ModelPatcher( | |
| DummyDiffusionModel(), | |
| load_device=torch.device("cpu"), | |
| offload_device=torch.device("cpu"), | |
| ) | |
| conds = {"positive": [], "negative": []} | |
| load_calls = [] | |
| monkeypatch.setattr( | |
| Device, | |
| "load_models_gpu", | |
| lambda models, memory_required, minimum_memory_required, force_full_load=False: load_calls.append( | |
| { | |
| "models": models, | |
| "memory_required": memory_required, | |
| "minimum_memory_required": minimum_memory_required, | |
| "force_full_load": force_full_load, | |
| } | |
| ), | |
| ) | |
| class Wrapper: | |
| def __init__(self, model): | |
| self.model = model | |
| self.load_device = torch.device("cpu") | |
| real_model, returned_conds, loaded_models = cond_util.prepare_sampling( | |
| Wrapper(patcher), | |
| noise_shape=(1, 4, 8, 8), | |
| conds=conds, | |
| ) | |
| assert real_model is patcher.model | |
| assert returned_conds is conds | |
| assert loaded_models == [] | |
| assert load_calls[0]["models"][0] is patcher | |
| def test_prepare_sampling_keeps_direct_patcher_instead_of_unwrapping_to_raw_module(monkeypatch): | |
| patcher = ModelPatcher( | |
| DummyDiffusionModel(), | |
| load_device=torch.device("cpu"), | |
| offload_device=torch.device("cpu"), | |
| ) | |
| conds = {"positive": [], "negative": []} | |
| load_calls = [] | |
| monkeypatch.setattr( | |
| Device, | |
| "load_models_gpu", | |
| lambda models, memory_required, minimum_memory_required, force_full_load=False: load_calls.append(models), | |
| ) | |
| real_model, returned_conds, loaded_models = cond_util.prepare_sampling( | |
| patcher, | |
| noise_shape=(1, 4, 8, 8), | |
| conds=conds, | |
| ) | |
| assert real_model is patcher.model | |
| assert returned_conds is conds | |
| assert loaded_models == [] | |
| assert load_calls[0][0] is patcher | |
| def test_cfg_guider_sample_uses_wrapped_model_load_device(monkeypatch): | |
| patcher = ModelPatcher( | |
| DummyDiffusionModel(), | |
| load_device=torch.device("cpu"), | |
| offload_device=torch.device("cpu"), | |
| ) | |
| class Wrapper: | |
| def __init__(self, model): | |
| self.model = model | |
| guider = CFGGuider(Wrapper(patcher)) | |
| guider.original_conds = {"positive": [], "negative": []} | |
| monkeypatch.setattr( | |
| cond_util, | |
| "prepare_sampling", | |
| lambda model, noise_shape, conds: (patcher, conds, []), | |
| ) | |
| captured = {} | |
| def fake_inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, pipeline=False): | |
| captured["device"] = device | |
| return noise | |
| guider.inner_sample = fake_inner_sample | |
| class DummyCache: | |
| def prevent_model_cleanup(self, conds, loaded_models): | |
| return None | |
| monkeypatch.setattr("src.Device.ModelCache.get_model_cache", lambda: DummyCache()) | |
| output = guider.sample( | |
| noise=torch.zeros((1, 4, 8, 8), dtype=torch.float32), | |
| latent_image=torch.zeros((1, 4, 8, 8), dtype=torch.float32), | |
| sampler=None, | |
| sigmas=torch.zeros((1,), dtype=torch.float32), | |
| ) | |
| assert torch.equal(output, torch.zeros((1, 4, 8, 8), dtype=torch.float32)) | |
| assert captured["device"] == torch.device("cpu") | |