File size: 3,504 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
from src.Utilities.Latent import Flux2 as Flux2LatentFormat


def _make_tiny_flux2(dtype=torch.float32):
    from src.NeuralNetwork.flux2.model import Flux2, Flux2Params

    params = Flux2Params(
        context_in_dim=16,
        vec_in_dim=16,
        hidden_size=64,
        num_heads=1,
        axes_dim=(16, 16, 16, 16),
        depth=1,
        depth_single_blocks=1,
    )
    return Flux2(params=params, dtype=dtype)


def test_patchify_from_vae_even_dims():
    fmt = Flux2LatentFormat()
    latent = torch.randn(1, 32, 64, 64)
    out = fmt.patchify_from_vae(latent)
    assert out.shape == (1, 128, 32, 32)


def test_patchify_from_vae_odd_dims_pads():
    fmt = Flux2LatentFormat()
    # Odd height/width should be padded internally so the operation succeeds
    latent = torch.randn(1, 32, 67, 65)
    out = fmt.patchify_from_vae(latent)
    # padded_h = 68 -> 68//2 = 34 ; padded_w = 66 -> 66//2 = 33
    assert out.shape == (1, 128, 34, 33)


def test_patchify_unpatchify_roundtrip_preserves_original_region():
    fmt = Flux2LatentFormat()
    latent = torch.randn(1, 32, 67, 65)
    patched = fmt.patchify_from_vae(latent)
    unpatched = fmt.unpatchify_for_vae(patched)
    # Flux2.forward crops back to original size; ensure round-trip preserves original region
    cropped = unpatched[:, :, :latent.shape[2], :latent.shape[3]]
    assert cropped.shape == latent.shape
    assert torch.allclose(cropped, latent)


def test_flux2_apply_model_accepts_vae_latent_with_odd_spatial_dims():
    """Simulate the exact call-site used by calc_cond_batch/_run_model_per_chunk
    to ensure Flux2.apply_model (and Latent.patchify_from_vae) accept odd H/W.
    """
    model = _make_tiny_flux2(dtype=torch.float32)
    # VAE-format latent with odd height (will require internal padding)
    vae_latent = torch.randn(1, 32, 67, 64)
    # timestep vector (single value is supported)
    t = torch.tensor([0.5])
    # dummy text conditioning matching model.params.context_in_dim
    txt = torch.zeros(1, 1, model.params.context_in_dim)

    out = model.apply_model(vae_latent, t, c_crossattn=txt)
    # output should match input spatial dims after model processing
    assert out.shape[0] == vae_latent.shape[0]
    assert out.shape[2] == vae_latent.shape[2]
    assert out.shape[3] == vae_latent.shape[3]


def test_flux2_apply_model_pads_or_crops_to_transformer_options():
    """When explicit transformer_options img_h/img_w are provided, the model
    must pad/crop the incoming latent so positional ids (RoPE) align.
    """
    model = _make_tiny_flux2(dtype=torch.float32)

    # Case A: VAE latent smaller than transformer_options -> should pad
    vae_latent_small = torch.randn(1, 32, 96, 38)
    t = torch.tensor([0.5])
    txt = torch.zeros(1, 1, model.params.context_in_dim)
    out_small = model.apply_model(vae_latent_small, t, c_crossattn=txt, transformer_options={"img_h": 512, "img_w": 512})
    # model should return same spatial shape as input VAE latent (crop back behavior)
    assert out_small.shape[2] == vae_latent_small.shape[2]
    assert out_small.shape[3] == vae_latent_small.shape[3]

    # Case B: VAE latent larger than transformer_options -> should crop safely
    vae_latent_large = torch.randn(1, 32, 160, 160)
    out_large = model.apply_model(vae_latent_large, t, c_crossattn=txt, transformer_options={"img_h": 256, "img_w": 256})
    assert out_large.shape[2] == vae_latent_large.shape[2]
    assert out_large.shape[3] == vae_latent_large.shape[3]