File size: 18,256 Bytes
d8bf41d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
from types import SimpleNamespace

import pytest
import torch
from PIL import Image

from examples import example as runner
from examples import example_remote as remote_runner


class FakeImage:
    def __init__(self, size):
        self.size = size
        self.saved = []

    def save(self, path, **kwargs):
        self.saved.append((path, kwargs))


def _args(*extra):
    return runner.apply_window_stride_defaults(runner.parse_args(["--base-model", "model", *extra]))


def test_quantization_parser_defaults_to_none():
    args = _args()

    assert args.quantization == "none"
    assert args.transformer_quantization is None
    assert args.text_encoder_quantization is None
    assert args.vae_quantization is None
    assert args.allow_tf32 is False


def test_allow_tf32_parser_flag():
    args = _args("--allow-tf32")

    assert args.allow_tf32 is True


def test_configure_torch_skips_tf32_when_not_allowed(monkeypatch):
    calls = []
    monkeypatch.setattr(runner.torch.cuda, "is_available", lambda: True)
    monkeypatch.setattr(runner.torch, "set_float32_matmul_precision", calls.append)

    runner.configure_torch(allow_tf32=False)

    assert calls == []


def test_configure_torch_enables_tf32_when_allowed_on_cuda(monkeypatch):
    calls = []
    original_fp32_precision = getattr(runner.torch.backends, "fp32_precision", None)
    monkeypatch.setattr(runner.torch.cuda, "is_available", lambda: True)
    monkeypatch.setattr(runner.torch, "set_float32_matmul_precision", calls.append)

    try:
        runner.configure_torch(allow_tf32=True)
        assert runner.torch.backends.fp32_precision == "tf32"
        assert calls == ["high"]
    finally:
        if original_fp32_precision is not None:
            runner.torch.backends.fp32_precision = original_fp32_precision


def test_configure_torch_skips_tf32_without_cuda(monkeypatch):
    calls = []
    monkeypatch.setattr(runner.torch.cuda, "is_available", lambda: False)
    monkeypatch.setattr(runner.torch, "set_float32_matmul_precision", calls.append)

    runner.configure_torch(allow_tf32=True)

    assert calls == []


def test_save_images_defaults_to_png_for_small_images(tmp_path):
    image = FakeImage((4096, 4096))

    runner.save_images([image], tmp_path / "output", num_images_per_prompt=1)

    assert image.saved == [(tmp_path / "output.png", {})]


def test_save_images_defaults_to_bigtiff_for_large_images(tmp_path):
    image = FakeImage((4097, 4096))

    runner.save_images([image], tmp_path / "output", num_images_per_prompt=1)

    assert image.saved == [(tmp_path / "output.tif", {"format": "TIFF", "big_tiff": True})]


def test_save_images_switches_large_png_output_to_bigtiff(tmp_path):
    image = FakeImage((4097, 4096))

    with pytest.warns(UserWarning, match="saving BigTIFF"):
        runner.save_images([image], tmp_path / "output.png", num_images_per_prompt=1)

    assert image.saved == [(tmp_path / "output.tif", {"format": "TIFF", "big_tiff": True})]


def test_save_images_preserves_numbered_large_bigtiff_stems(tmp_path):
    images = [FakeImage((4097, 4096)), FakeImage((4097, 4096))]

    runner.save_images(images, tmp_path / "output", num_images_per_prompt=2)

    assert images[0].saved == [(tmp_path / "output_0.tif", {"format": "TIFF", "big_tiff": True})]
    assert images[1].saved == [(tmp_path / "output_1.tif", {"format": "TIFF", "big_tiff": True})]


def test_quantization_flag_without_value_uses_float8_weight_only_default():
    args = _args("--quantization")

    assert args.quantization == "float8wo"


def test_quantization_parser_accepts_component_overrides():
    args = _args(
        "--quantization",
        "float8wo",
        "--transformer-quantization",
        "int4wo",
        "--text-encoder-quantization",
        "none",
        "--vae-quantization",
        "float8dyn",
    )

    assert args.quantization == "float8wo"
    assert args.transformer_quantization == "int4wo"
    assert args.text_encoder_quantization == "none"
    assert args.vae_quantization == "float8dyn"


def test_resolve_quantization_mapping_applies_base_strategy_to_all_quantizable_components():
    mapping = runner.resolve_quantization_mapping("float8wo")

    assert mapping == {
        "transformer": "float8wo",
        "text_encoder": "float8wo",
        "vae": "float8wo",
    }


def test_resolve_quantization_mapping_uses_component_overrides():
    mapping = runner.resolve_quantization_mapping(
        "float8wo",
        transformer_quantization="int4wo",
        text_encoder_quantization="none",
        vae_quantization="float8dyn",
    )

    assert mapping == {"transformer": "int4wo", "vae": "float8dyn"}


def test_build_quantization_config_returns_none_when_disabled():
    assert runner.build_quantization_config("none") is None


def test_build_quantization_config_uses_component_specific_torchao_configs():
    pytest.importorskip("torchao")

    config = runner.build_quantization_config(
        "float8wo",
        transformer_quantization="int4wo",
        text_encoder_quantization="int8wo",
        vae_quantization="float8dyn",
    )

    assert set(config.quant_mapping) == {"transformer", "text_encoder", "vae"}
    assert type(config.quant_mapping["transformer"].quant_type).__name__ == "Int4WeightOnlyConfig"
    assert type(config.quant_mapping["text_encoder"].quant_type).__name__ == "Int8WeightOnlyConfig"
    assert type(config.quant_mapping["vae"].quant_type).__name__ == "Float8DynamicActivationFloat8WeightConfig"


@pytest.mark.parametrize("weighting_type", ["none", "linear", "cosine"])
def test_call_kwargs_forward_all_weighting_types(weighting_type):
    args = _args("--weighting-type", weighting_type)
    generator = object()

    call_kwargs = runner.build_call_kwargs(args, SimpleNamespace(), generator)

    assert call_kwargs["weighting_type"] == weighting_type
    assert "num_inference_steps" not in call_kwargs
    assert call_kwargs["generator"] is generator


def test_call_kwargs_forward_offset_and_panorama_options():
    call_kwargs = runner.build_call_kwargs(
        _args(
            "--window-stride-height-offset",
            "128",
            "--window-stride-width-offset",
            "256",
            "--panorama-width",
            "--panorama-height",
        ),
        SimpleNamespace(),
        object(),
    )

    assert call_kwargs["window_stride_height_offset"] == 128
    assert call_kwargs["window_stride_width_offset"] == 256
    assert call_kwargs["panorama_width"] is True
    assert call_kwargs["panorama_height"] is True


def test_call_kwargs_forward_window_batch_size():
    call_kwargs = runner.build_call_kwargs(
        _args("--batch-size", "6"),
        SimpleNamespace(),
        object(),
    )

    assert call_kwargs["window_batch_size"] == 6


def test_call_kwargs_rejects_invalid_window_batch_size():
    with pytest.raises(ValueError, match="batch-size"):
        runner.build_call_kwargs(
            _args("--batch-size", "0"),
            SimpleNamespace(),
            object(),
        )


def test_csv_prompt_loads_regional_masks_in_csv_order(tmp_path):
    masks_dir = tmp_path / "masks"
    masks_dir.mkdir()
    Image.new("L", (32, 32), 255).save(masks_dir / "00.png")
    Image.new("L", (32, 32), 255).save(masks_dir / "08.png")
    csv_path = tmp_path / "prompts.csv"
    csv_path.write_text('mask,prompt\n08.png,"a forest"\n00.png,"a mountain"\n', encoding="utf-8")

    call_kwargs = runner.build_call_kwargs(
        _args("--prompt", str(csv_path), "--masks", str(masks_dir), "--height", "32", "--width", "32"),
        SimpleNamespace(vae_scale_factor=8),
        object(),
    )

    assert call_kwargs["prompt"] == ["", "a forest", "a mountain"]
    assert torch.equal(call_kwargs["regional_masks"], torch.ones((2, 2, 2)))


def test_csv_prompt_requires_masks_folder(tmp_path):
    csv_path = tmp_path / "prompts.csv"
    csv_path.write_text('mask,prompt\n00.png,"a mountain"\n', encoding="utf-8")

    with pytest.raises(ValueError, match="--masks"):
        runner.build_call_kwargs(
            _args("--prompt", str(csv_path), "--height", "32", "--width", "32"),
            SimpleNamespace(vae_scale_factor=8),
            object(),
        )


def test_regional_mask_loader_warns_for_non_binary_masks(tmp_path):
    masks_dir = tmp_path / "masks"
    masks_dir.mkdir()
    Image.new("L", (32, 32), 128).save(masks_dir / "00.png")
    csv_path = tmp_path / "prompts.csv"
    csv_path.write_text('mask,prompt\n00.png,"a foggy valley"\n', encoding="utf-8")

    with pytest.warns(UserWarning, match="not binary"):
        call_kwargs = runner.build_call_kwargs(
            _args("--prompt", str(csv_path), "--masks", str(masks_dir), "--height", "32", "--width", "32"),
            SimpleNamespace(vae_scale_factor=8),
            object(),
        )

    assert torch.equal(call_kwargs["regional_masks"], torch.full((1, 2, 2), 128 / 255))


def test_regional_mask_loader_rejects_wrong_size(tmp_path):
    masks_dir = tmp_path / "masks"
    masks_dir.mkdir()
    Image.new("L", (16, 32), 255).save(masks_dir / "00.png")
    csv_path = tmp_path / "prompts.csv"
    csv_path.write_text('mask,prompt\n00.png,"a mountain"\n', encoding="utf-8")

    with pytest.raises(ValueError, match="must have size"):
        runner.build_call_kwargs(
            _args("--prompt", str(csv_path), "--masks", str(masks_dir), "--height", "32", "--width", "32"),
            SimpleNamespace(vae_scale_factor=8),
            object(),
        )


@pytest.mark.parametrize(
    ("config", "expected_class_name"),
    [
        ({"_class_name": "Flux2Pipeline"}, "Flux2MultiDiffusionAutoBlocks"),
        ({"_class_name": "Flux2KleinPipeline", "is_distilled": True}, "Flux2KleinMultiDiffusionAutoBlocks"),
        ({"_class_name": "Flux2KleinPipeline", "is_distilled": False}, "Flux2KleinBaseMultiDiffusionAutoBlocks"),
        ({"_class_name": "Flux2KleinPipeline"}, "Flux2KleinBaseMultiDiffusionAutoBlocks"),
    ],
)
def test_multidiffusion_blocks_are_selected_from_loaded_pipeline_config(config, expected_class_name):
    assert runner.build_multidiffusion_blocks(config).__class__.__name__ == expected_class_name


def test_explicit_num_inference_steps_are_forwarded():
    call_kwargs = runner.build_call_kwargs(
        _args("--num-inference-steps", "7"),
        SimpleNamespace(),
        object(),
        {"_class_name": "Flux2KleinPipeline", "is_distilled": True},
    )

    assert call_kwargs["num_inference_steps"] == 7


def test_distilled_klein_default_num_inference_steps_is_forwarded():
    call_kwargs = runner.build_call_kwargs(
        _args(),
        SimpleNamespace(),
        object(),
        {"_class_name": "Flux2KleinPipeline", "is_distilled": True},
    )

    assert call_kwargs["num_inference_steps"] == 4


@pytest.mark.parametrize(
    "config",
    [
        {"_class_name": "Flux2Pipeline"},
        {"_class_name": "Flux2KleinPipeline", "is_distilled": False},
        {"_class_name": "Flux2KleinModularPipeline", "is_distilled": False},
        {"_class_name": "Flux2KleinPipeline"},
        {},
    ],
)
def test_non_distilled_models_keep_pipeline_num_inference_steps_default(config):
    call_kwargs = runner.build_call_kwargs(
        _args(),
        SimpleNamespace(),
        object(),
        config,
    )

    assert "num_inference_steps" not in call_kwargs


def test_enable_tiling_and_slicing_call_vae_methods():
    class FakeVae:
        def __init__(self):
            self.tiling_enabled = False
            self.slicing_enabled = False

        def enable_tiling(self):
            self.tiling_enabled = True

        def enable_slicing(self):
            self.slicing_enabled = True

    pipe = SimpleNamespace(vae=FakeVae())

    runner.apply_vae_memory_options(pipe, enable_tiling=True, enable_slicing=True)

    assert pipe.vae.tiling_enabled
    assert pipe.vae.slicing_enabled


def test_init_selects_blocks_from_model_config_and_keeps_model_guidance_default(monkeypatch):
    calls = []

    class FakeBlocks:
        def init_pipeline(self, base_model, components_manager):
            calls.append((base_model, components_manager))
            return FakePipe()

    class FakePipe:
        def __init__(self):
            self.guider = object()
            self.updated_components = None

        def load_components(self, **kwargs):
            self.load_kwargs = kwargs

        def update_components(self, **components):
            self.updated_components = components

    monkeypatch.setattr(
        runner, "build_multidiffusion_blocks", lambda model_config: calls.append(model_config) or FakeBlocks()
    )

    pipe = runner.init_modular_pipeline(
        base_model="black-forest-labs/FLUX.2-klein-base-4B",
        model_config={"_class_name": "Flux2KleinPipeline"},
        guidance_scale=None,
        dtype="float32",
        device="cpu",
        local_files_only=True,
        compile=False,
    )

    assert calls[0] == {"_class_name": "Flux2KleinPipeline"}
    assert calls[1][0] == "black-forest-labs/FLUX.2-klein-base-4B"
    assert pipe.updated_components is None


def test_init_applies_explicit_guidance_scale(monkeypatch):
    class FakeGuiderSpec:
        @staticmethod
        def create(guidance_scale):
            return {"guidance_scale": guidance_scale}

    class FakePipe:
        def __init__(self):
            self.guider = object()
            self.updated_components = None

        def load_components(self, **kwargs):
            self.load_kwargs = kwargs

        def get_component_spec(self, name):
            assert name == "guider"
            return FakeGuiderSpec()

        def update_components(self, **components):
            self.updated_components = components

    monkeypatch.setattr(
        runner,
        "build_multidiffusion_blocks",
        lambda model_config: SimpleNamespace(init_pipeline=lambda base_model, components_manager: FakePipe()),
    )

    pipe = runner.init_modular_pipeline(
        base_model="black-forest-labs/FLUX.2-klein-base-4B",
        model_config={"_class_name": "Flux2KleinPipeline"},
        guidance_scale=2.5,
        dtype="float32",
        device="cpu",
        local_files_only=True,
        compile=False,
    )

    assert pipe.updated_components == {"guider": {"guidance_scale": 2.5}}


def test_remote_parser_requires_hub_repo_id():
    args = remote_runner.parse_args(["--repo-id", "user/multidiff-modular"])

    assert args.repo_id == "user/multidiff-modular"


def test_remote_init_uses_modular_from_pretrained_with_remote_code(monkeypatch):
    calls = {}

    class FakePipe:
        def __init__(self):
            self.guider = None

        def load_components(self, **kwargs):
            calls["load_components"] = kwargs

    def fake_from_pretrained(repo_id, **kwargs):
        calls["repo_id"] = repo_id
        calls["from_pretrained"] = kwargs
        return FakePipe()

    monkeypatch.setattr(remote_runner.ModularPipeline, "from_pretrained", fake_from_pretrained)

    pipe = remote_runner.init_remote_modular_pipeline(
        repo_id="user/multidiff-modular",
        guidance_scale=None,
        dtype="float32",
        device="cpu",
        local_files_only=True,
        compile=False,
    )

    assert pipe.guider is None
    assert calls["repo_id"] == "user/multidiff-modular"
    assert calls["from_pretrained"]["trust_remote_code"] is True
    assert calls["from_pretrained"]["local_files_only"] is True
    assert calls["load_components"]["torch_dtype"] is runner.DTYPE_MAP["float32"]
    assert calls["load_components"]["local_files_only"] is True


def test_image_conditioning_loads_and_forwards_image(monkeypatch):
    loaded_image = object()
    seen_paths = []

    def fake_load_image(path):
        seen_paths.append(path)
        return loaded_image

    monkeypatch.setattr(runner, "load_image", fake_load_image)

    call_kwargs = runner.build_call_kwargs(
        _args("--image-conditioning", "conditioning.png"),
        SimpleNamespace(config={}),
        object(),
    )

    assert seen_paths == ["conditioning.png"]
    assert call_kwargs["image"] is loaded_image


def test_image_img2img_prepares_latents_and_forwards_strength(monkeypatch):
    loaded_image = object()
    raw_latents = object()
    generator = object()
    vae = object()

    class FakeImageProcessor:
        def __init__(self):
            self.calls = []

        def preprocess(self, image, *, height, width, resize_mode):
            self.calls.append((image, height, width, resize_mode))
            return "preprocessed-image"

    class FakeEncoderPipe:
        def __init__(self):
            self.updated_components = None
            self.calls = []

        def update_components(self, **components):
            self.updated_components = components

        def __call__(self, *, condition_images, generator):
            self.calls.append((condition_images, generator))
            return SimpleNamespace(image_latents=[raw_latents])

    encoder_pipe = FakeEncoderPipe()
    encoder_block = SimpleNamespace(init_pipeline=lambda: encoder_pipe)
    img_conditioning_block = SimpleNamespace(sub_blocks={"encode": encoder_block})
    vae_encoder_block = SimpleNamespace(sub_blocks={"img_conditioning": img_conditioning_block})
    image_processor = FakeImageProcessor()
    pipe = SimpleNamespace(
        config={},
        vae=vae,
        vae_scale_factor=8,
        blocks=SimpleNamespace(sub_blocks={"vae_encoder": vae_encoder_block}),
        image_processor=image_processor,
    )

    monkeypatch.setattr(runner, "load_image", lambda path: loaded_image)
    monkeypatch.setattr(
        runner.Flux2PrepareImageLatentsStep,
        "_pack_latents",
        staticmethod(lambda latents: ("packed", latents)),
    )

    call_kwargs = runner.build_call_kwargs(
        _args("--image-img2img", "init.png", "--height", "65", "--width", "81", "--strength", "0.42"),
        pipe,
        generator,
    )

    assert image_processor.calls == [(loaded_image, 64, 80, "default")]
    assert encoder_pipe.updated_components == {"vae": vae}
    assert encoder_pipe.calls == [(["preprocessed-image"], generator)]
    assert call_kwargs["image_img2img"] == ("packed", raw_latents)
    assert call_kwargs["strength"] == 0.42