File size: 2,790 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import pytest

from src.cond import cond as cond_module
from src.cond.cond import CONDRegular, calc_cond_batch


class DummyModel:
    def memory_required(self, input_shape):
        # return small value so batching heuristics don't interfere
        return 1

    def apply_model(self, *args, **kwargs):
        # Should not be called if fallback is triggered
        inp = args[0] if args else kwargs.get("input")
        return inp


class RecordingDummyModel(DummyModel):
    def __init__(self):
        self.batch_sizes = []

    def apply_model(self, *args, **kwargs):
        inp = args[0] if args else kwargs.get("input")
        self.batch_sizes.append(int(inp.shape[0]))
        return inp


def test_calc_cond_batch_fallback_on_transformer_options_mismatch(monkeypatch):
    called = {"flag": False}

    def spy_run_model_per_chunk(model, x_in, timestep, input_x_list, c_list, batch_sizes, batch_indices_list, cond_or_uncond, model_options):
        called["flag"] = True
        # return zeroed outputs with the same shapes
        return [torch.zeros_like(x) for x in input_x_list]

    monkeypatch.setattr(cond_module, "_run_model_per_chunk", spy_run_model_per_chunk)


    # Create input_x shaped as 4x4 token grid
    x_in = torch.zeros((1, 128, 4, 4))

    # Build a minimal condition (actual transformer_options come from our fake_get_area_and_mult)
    cond_dict = {"model_conds": {"c_crossattn": CONDRegular(torch.zeros((1, 1, 1, 1)))}}

    # Pass conds as lists (positive and None for uncond)
    conds = [[cond_dict], None]

    # Run calc_cond_batch and ensure fallback was taken (spy was called)
    model_opts = {"transformer_options": {"img_h": 200, "img_w": 200}}
    out = calc_cond_batch(DummyModel(), conds, x_in, timestep=0, model_options=model_opts)

    assert called["flag"], "Expected _run_model_per_chunk to be called due to transformer_options mismatch"
    # out should be a list of two tensors matching input shape
    assert isinstance(out, list) and len(out) == 2
    assert out[0].shape == x_in.shape
    assert out[1].shape == x_in.shape


def test_calc_cond_batch_honors_batched_cfg_toggle():
    x_in = torch.zeros((1, 4, 8, 8))
    cond_dict = {"model_conds": {"c_crossattn": CONDRegular(torch.zeros((1, 1, 1, 1)))}}
    conds = [[cond_dict], [cond_dict]]

    batched_model = RecordingDummyModel()
    calc_cond_batch(
        batched_model,
        conds,
        x_in,
        timestep=0,
        model_options={"batched_cfg": True},
    )

    unbatched_model = RecordingDummyModel()
    calc_cond_batch(
        unbatched_model,
        conds,
        x_in,
        timestep=0,
        model_options={"batched_cfg": False},
    )

    assert batched_model.batch_sizes == [2]
    assert unbatched_model.batch_sizes == [1, 1]