import torch from src.Utilities.Latent import Flux2 as Flux2LatentFormat def _make_tiny_flux2(dtype=torch.float32): from src.NeuralNetwork.flux2.model import Flux2, Flux2Params params = Flux2Params( context_in_dim=16, vec_in_dim=16, hidden_size=64, num_heads=1, axes_dim=(16, 16, 16, 16), depth=1, depth_single_blocks=1, ) return Flux2(params=params, dtype=dtype) def test_patchify_from_vae_even_dims(): fmt = Flux2LatentFormat() latent = torch.randn(1, 32, 64, 64) out = fmt.patchify_from_vae(latent) assert out.shape == (1, 128, 32, 32) def test_patchify_from_vae_odd_dims_pads(): fmt = Flux2LatentFormat() # Odd height/width should be padded internally so the operation succeeds latent = torch.randn(1, 32, 67, 65) out = fmt.patchify_from_vae(latent) # padded_h = 68 -> 68//2 = 34 ; padded_w = 66 -> 66//2 = 33 assert out.shape == (1, 128, 34, 33) def test_patchify_unpatchify_roundtrip_preserves_original_region(): fmt = Flux2LatentFormat() latent = torch.randn(1, 32, 67, 65) patched = fmt.patchify_from_vae(latent) unpatched = fmt.unpatchify_for_vae(patched) # Flux2.forward crops back to original size; ensure round-trip preserves original region cropped = unpatched[:, :, :latent.shape[2], :latent.shape[3]] assert cropped.shape == latent.shape assert torch.allclose(cropped, latent) def test_flux2_apply_model_accepts_vae_latent_with_odd_spatial_dims(): """Simulate the exact call-site used by calc_cond_batch/_run_model_per_chunk to ensure Flux2.apply_model (and Latent.patchify_from_vae) accept odd H/W. """ model = _make_tiny_flux2(dtype=torch.float32) # VAE-format latent with odd height (will require internal padding) vae_latent = torch.randn(1, 32, 67, 64) # timestep vector (single value is supported) t = torch.tensor([0.5]) # dummy text conditioning matching model.params.context_in_dim txt = torch.zeros(1, 1, model.params.context_in_dim) out = model.apply_model(vae_latent, t, c_crossattn=txt) # output should match input spatial dims after model processing assert out.shape[0] == vae_latent.shape[0] assert out.shape[2] == vae_latent.shape[2] assert out.shape[3] == vae_latent.shape[3] def test_flux2_apply_model_pads_or_crops_to_transformer_options(): """When explicit transformer_options img_h/img_w are provided, the model must pad/crop the incoming latent so positional ids (RoPE) align. """ model = _make_tiny_flux2(dtype=torch.float32) # Case A: VAE latent smaller than transformer_options -> should pad vae_latent_small = torch.randn(1, 32, 96, 38) t = torch.tensor([0.5]) txt = torch.zeros(1, 1, model.params.context_in_dim) out_small = model.apply_model(vae_latent_small, t, c_crossattn=txt, transformer_options={"img_h": 512, "img_w": 512}) # model should return same spatial shape as input VAE latent (crop back behavior) assert out_small.shape[2] == vae_latent_small.shape[2] assert out_small.shape[3] == vae_latent_small.shape[3] # Case B: VAE latent larger than transformer_options -> should crop safely vae_latent_large = torch.randn(1, 32, 160, 160) out_large = model.apply_model(vae_latent_large, t, c_crossattn=txt, transformer_options={"img_h": 256, "img_w": 256}) assert out_large.shape[2] == vae_latent_large.shape[2] assert out_large.shape[3] == vae_latent_large.shape[3]