Spaces:
Running on Zero
Running on Zero
File size: 3,504 Bytes
b701455 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | import torch
from src.Utilities.Latent import Flux2 as Flux2LatentFormat
def _make_tiny_flux2(dtype=torch.float32):
from src.NeuralNetwork.flux2.model import Flux2, Flux2Params
params = Flux2Params(
context_in_dim=16,
vec_in_dim=16,
hidden_size=64,
num_heads=1,
axes_dim=(16, 16, 16, 16),
depth=1,
depth_single_blocks=1,
)
return Flux2(params=params, dtype=dtype)
def test_patchify_from_vae_even_dims():
fmt = Flux2LatentFormat()
latent = torch.randn(1, 32, 64, 64)
out = fmt.patchify_from_vae(latent)
assert out.shape == (1, 128, 32, 32)
def test_patchify_from_vae_odd_dims_pads():
fmt = Flux2LatentFormat()
# Odd height/width should be padded internally so the operation succeeds
latent = torch.randn(1, 32, 67, 65)
out = fmt.patchify_from_vae(latent)
# padded_h = 68 -> 68//2 = 34 ; padded_w = 66 -> 66//2 = 33
assert out.shape == (1, 128, 34, 33)
def test_patchify_unpatchify_roundtrip_preserves_original_region():
fmt = Flux2LatentFormat()
latent = torch.randn(1, 32, 67, 65)
patched = fmt.patchify_from_vae(latent)
unpatched = fmt.unpatchify_for_vae(patched)
# Flux2.forward crops back to original size; ensure round-trip preserves original region
cropped = unpatched[:, :, :latent.shape[2], :latent.shape[3]]
assert cropped.shape == latent.shape
assert torch.allclose(cropped, latent)
def test_flux2_apply_model_accepts_vae_latent_with_odd_spatial_dims():
"""Simulate the exact call-site used by calc_cond_batch/_run_model_per_chunk
to ensure Flux2.apply_model (and Latent.patchify_from_vae) accept odd H/W.
"""
model = _make_tiny_flux2(dtype=torch.float32)
# VAE-format latent with odd height (will require internal padding)
vae_latent = torch.randn(1, 32, 67, 64)
# timestep vector (single value is supported)
t = torch.tensor([0.5])
# dummy text conditioning matching model.params.context_in_dim
txt = torch.zeros(1, 1, model.params.context_in_dim)
out = model.apply_model(vae_latent, t, c_crossattn=txt)
# output should match input spatial dims after model processing
assert out.shape[0] == vae_latent.shape[0]
assert out.shape[2] == vae_latent.shape[2]
assert out.shape[3] == vae_latent.shape[3]
def test_flux2_apply_model_pads_or_crops_to_transformer_options():
"""When explicit transformer_options img_h/img_w are provided, the model
must pad/crop the incoming latent so positional ids (RoPE) align.
"""
model = _make_tiny_flux2(dtype=torch.float32)
# Case A: VAE latent smaller than transformer_options -> should pad
vae_latent_small = torch.randn(1, 32, 96, 38)
t = torch.tensor([0.5])
txt = torch.zeros(1, 1, model.params.context_in_dim)
out_small = model.apply_model(vae_latent_small, t, c_crossattn=txt, transformer_options={"img_h": 512, "img_w": 512})
# model should return same spatial shape as input VAE latent (crop back behavior)
assert out_small.shape[2] == vae_latent_small.shape[2]
assert out_small.shape[3] == vae_latent_small.shape[3]
# Case B: VAE latent larger than transformer_options -> should crop safely
vae_latent_large = torch.randn(1, 32, 160, 160)
out_large = model.apply_model(vae_latent_large, t, c_crossattn=txt, transformer_options={"img_h": 256, "img_w": 256})
assert out_large.shape[2] == vae_latent_large.shape[2]
assert out_large.shape[3] == vae_latent_large.shape[3]
|