Spaces:
Running on Zero
Running on Zero
| import torch | |
| import pytest | |
| from src.cond import cond as cond_module | |
| from src.cond.cond import CONDRegular, calc_cond_batch | |
| class DummyModel: | |
| def memory_required(self, input_shape): | |
| # return small value so batching heuristics don't interfere | |
| return 1 | |
| def apply_model(self, *args, **kwargs): | |
| # Should not be called if fallback is triggered | |
| inp = args[0] if args else kwargs.get("input") | |
| return inp | |
| class RecordingDummyModel(DummyModel): | |
| def __init__(self): | |
| self.batch_sizes = [] | |
| def apply_model(self, *args, **kwargs): | |
| inp = args[0] if args else kwargs.get("input") | |
| self.batch_sizes.append(int(inp.shape[0])) | |
| return inp | |
| def test_calc_cond_batch_fallback_on_transformer_options_mismatch(monkeypatch): | |
| called = {"flag": False} | |
| def spy_run_model_per_chunk(model, x_in, timestep, input_x_list, c_list, batch_sizes, batch_indices_list, cond_or_uncond, model_options): | |
| called["flag"] = True | |
| # return zeroed outputs with the same shapes | |
| return [torch.zeros_like(x) for x in input_x_list] | |
| monkeypatch.setattr(cond_module, "_run_model_per_chunk", spy_run_model_per_chunk) | |
| # Create input_x shaped as 4x4 token grid | |
| x_in = torch.zeros((1, 128, 4, 4)) | |
| # Build a minimal condition (actual transformer_options come from our fake_get_area_and_mult) | |
| cond_dict = {"model_conds": {"c_crossattn": CONDRegular(torch.zeros((1, 1, 1, 1)))}} | |
| # Pass conds as lists (positive and None for uncond) | |
| conds = [[cond_dict], None] | |
| # Run calc_cond_batch and ensure fallback was taken (spy was called) | |
| model_opts = {"transformer_options": {"img_h": 200, "img_w": 200}} | |
| out = calc_cond_batch(DummyModel(), conds, x_in, timestep=0, model_options=model_opts) | |
| assert called["flag"], "Expected _run_model_per_chunk to be called due to transformer_options mismatch" | |
| # out should be a list of two tensors matching input shape | |
| assert isinstance(out, list) and len(out) == 2 | |
| assert out[0].shape == x_in.shape | |
| assert out[1].shape == x_in.shape | |
| def test_calc_cond_batch_honors_batched_cfg_toggle(): | |
| x_in = torch.zeros((1, 4, 8, 8)) | |
| cond_dict = {"model_conds": {"c_crossattn": CONDRegular(torch.zeros((1, 1, 1, 1)))}} | |
| conds = [[cond_dict], [cond_dict]] | |
| batched_model = RecordingDummyModel() | |
| calc_cond_batch( | |
| batched_model, | |
| conds, | |
| x_in, | |
| timestep=0, | |
| model_options={"batched_cfg": True}, | |
| ) | |
| unbatched_model = RecordingDummyModel() | |
| calc_cond_batch( | |
| unbatched_model, | |
| conds, | |
| x_in, | |
| timestep=0, | |
| model_options={"batched_cfg": False}, | |
| ) | |
| assert batched_model.batch_sizes == [2] | |
| assert unbatched_model.batch_sizes == [1, 1] | |