LightDiffusion-Next / tests /unit /test_calc_cond_batch_fallback.py
Aatricks's picture
Deploy ZeroGPU Gradio Space snapshot
b701455
import torch
import pytest
from src.cond import cond as cond_module
from src.cond.cond import CONDRegular, calc_cond_batch
class DummyModel:
def memory_required(self, input_shape):
# return small value so batching heuristics don't interfere
return 1
def apply_model(self, *args, **kwargs):
# Should not be called if fallback is triggered
inp = args[0] if args else kwargs.get("input")
return inp
class RecordingDummyModel(DummyModel):
def __init__(self):
self.batch_sizes = []
def apply_model(self, *args, **kwargs):
inp = args[0] if args else kwargs.get("input")
self.batch_sizes.append(int(inp.shape[0]))
return inp
def test_calc_cond_batch_fallback_on_transformer_options_mismatch(monkeypatch):
called = {"flag": False}
def spy_run_model_per_chunk(model, x_in, timestep, input_x_list, c_list, batch_sizes, batch_indices_list, cond_or_uncond, model_options):
called["flag"] = True
# return zeroed outputs with the same shapes
return [torch.zeros_like(x) for x in input_x_list]
monkeypatch.setattr(cond_module, "_run_model_per_chunk", spy_run_model_per_chunk)
# Create input_x shaped as 4x4 token grid
x_in = torch.zeros((1, 128, 4, 4))
# Build a minimal condition (actual transformer_options come from our fake_get_area_and_mult)
cond_dict = {"model_conds": {"c_crossattn": CONDRegular(torch.zeros((1, 1, 1, 1)))}}
# Pass conds as lists (positive and None for uncond)
conds = [[cond_dict], None]
# Run calc_cond_batch and ensure fallback was taken (spy was called)
model_opts = {"transformer_options": {"img_h": 200, "img_w": 200}}
out = calc_cond_batch(DummyModel(), conds, x_in, timestep=0, model_options=model_opts)
assert called["flag"], "Expected _run_model_per_chunk to be called due to transformer_options mismatch"
# out should be a list of two tensors matching input shape
assert isinstance(out, list) and len(out) == 2
assert out[0].shape == x_in.shape
assert out[1].shape == x_in.shape
def test_calc_cond_batch_honors_batched_cfg_toggle():
x_in = torch.zeros((1, 4, 8, 8))
cond_dict = {"model_conds": {"c_crossattn": CONDRegular(torch.zeros((1, 1, 1, 1)))}}
conds = [[cond_dict], [cond_dict]]
batched_model = RecordingDummyModel()
calc_cond_batch(
batched_model,
conds,
x_in,
timestep=0,
model_options={"batched_cfg": True},
)
unbatched_model = RecordingDummyModel()
calc_cond_batch(
unbatched_model,
conds,
x_in,
timestep=0,
model_options={"batched_cfg": False},
)
assert batched_model.batch_sizes == [2]
assert unbatched_model.batch_sizes == [1, 1]