| |
| |
|
|
| import os |
| import torch |
| import pytest |
|
|
| from fastgen.utils import instantiate |
|
|
| |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True,max_split_size_mb:512") |
|
|
|
|
| from fastgen.configs.net import ( |
| EDM_CIFAR10_Config, |
| EDM_ImageNet64_Config, |
| EDM2_IN64_S_Config, |
| DiT_IN256_XL_Config, |
| SD15Config, |
| FluxConfig, |
| CogVideoXConfig, |
| Wan_1_3B_Config, |
| CausalWan_1_3B_Config, |
| VACE_Wan_1_3B_Config, |
| Wan21_I2V_14B_480P_Config, |
| Wan22_I2V_5B_Config, |
| CausalWan22_I2V_5B_Config, |
| CausalWan21_I2V_14B_480P_Config, |
| CausalWan21_I2V_14B_720P_Config, |
| ) |
| from fastgen.configs.discriminator import ( |
| Discriminator_Wan_1_3B_Config, |
| Discriminator_EDM_CIFAR10_Config, |
| Discriminator_EDM_ImageNet64_Config, |
| ) |
| from fastgen.configs.config_utils import override_config_with_opts |
| from fastgen.utils.basic_utils import clear_gpu_memory |
| from fastgen.utils.test_utils import RunIf |
| from fastgen.utils.io_utils import set_env_vars |
| from unittest.mock import patch, MagicMock |
|
|
|
|
| def _validate_basic_scheduler_properties(scheduler, device): |
| """Test basic scheduler properties and structure.""" |
| |
| assert hasattr(scheduler, "max_sigma"), "Scheduler should have max_sigma attribute" |
| assert hasattr(scheduler, "min_t"), "Scheduler should have min_t attribute" |
| assert hasattr(scheduler, "max_t"), "Scheduler should have max_t attribute" |
| assert hasattr(scheduler, "alpha"), "Scheduler should have alpha method" |
| assert hasattr(scheduler, "sigma"), "Scheduler should have sigma method" |
| assert hasattr(scheduler, "sample_t"), "Scheduler should have sample_t method" |
|
|
| |
| assert scheduler.min_t < scheduler.max_t, f"min_t ({scheduler.min_t}) should be < max_t ({scheduler.max_t})" |
|
|
| if scheduler.min_t is not None: |
| assert scheduler.min_t >= 0, f"min_t ({scheduler.min_t}) should be non-negative" |
|
|
| if scheduler.max_t is not None: |
| |
| assert scheduler.max_t <= 10000, f"max_t ({scheduler.max_t}) should be reasonable (<=10000)" |
|
|
|
|
| def _validate_time_sampling_and_functions(scheduler, device): |
| """Test time sampling and alpha/sigma function behavior.""" |
| |
| batch_size = 2 |
| t = scheduler.sample_t(batch_size, time_dist_type="uniform") |
| assert t.shape == (batch_size,), f"Expected shape ({batch_size},), got {t.shape}" |
|
|
| |
| assert scheduler.is_t_valid(t), f"Sampled times {t} should be valid" |
|
|
| |
| alpha_t = scheduler.alpha(t) |
| sigma_t = scheduler.sigma(t) |
|
|
| assert alpha_t.shape == t.shape, f"alpha(t) shape {alpha_t.shape} should match t shape {t.shape}" |
| assert sigma_t.shape == t.shape, f"sigma(t) shape {sigma_t.shape} should match t shape {t.shape}" |
| assert torch.all(alpha_t >= 0), f"alpha(t) should be non-negative, got {alpha_t}" |
| assert torch.all(sigma_t >= 0), f"sigma(t) should be non-negative, got {sigma_t}" |
|
|
| return t, alpha_t, sigma_t |
|
|
|
|
| def _validate_boundary_values(scheduler, device): |
| """Test alpha and sigma at boundary time values.""" |
|
|
| |
| min_t_tensor = torch.tensor([scheduler.min_t]).to(device) |
| max_t_tensor = torch.tensor([scheduler.max_t]).to(device) |
|
|
| alpha_min = scheduler.alpha(min_t_tensor) |
| alpha_max = scheduler.alpha(max_t_tensor) |
| sigma_min = scheduler.sigma(min_t_tensor) |
| sigma_max = scheduler.sigma(max_t_tensor) |
|
|
| |
| max_sigma_tensor = torch.tensor(scheduler.max_sigma, device=device, dtype=sigma_max.dtype) |
|
|
| |
| rtol = 1e-2 if sigma_max.dtype == torch.bfloat16 else 1e-3 |
| atol = 1e-2 if sigma_max.dtype == torch.bfloat16 else 1e-3 |
|
|
| assert torch.allclose( |
| sigma_max, max_sigma_tensor, rtol=rtol, atol=atol |
| ), f"max_sigma ({scheduler.max_sigma}) should match sigma(max_t) ({sigma_max.item()}) within tolerance" |
|
|
| return alpha_min, alpha_max, sigma_min, sigma_max |
|
|
|
|
| def _validate_edm_schedule(scheduler, t, alpha_t, sigma_t, sigma_min, sigma_max): |
| """Validate EDM-specific schedule properties.""" |
| |
| assert torch.allclose(alpha_t, torch.ones_like(alpha_t), rtol=1e-4), f"EDM: alpha(t) should be 1, got {alpha_t}" |
| assert torch.allclose(sigma_t, t, rtol=1e-4), f"EDM: sigma(t) should equal t, got σ(t)={sigma_t}, t={t}" |
|
|
| |
| assert sigma_min <= sigma_max, ( |
| f"EDM: sigma should increase with t, got " |
| f"σ({scheduler.min_t})={sigma_min.item():.4f} " |
| f"> σ({scheduler.max_t})={sigma_max.item():.4f}" |
| ) |
|
|
| |
| assert torch.allclose( |
| torch.tensor(scheduler.max_sigma), torch.tensor(scheduler.max_t), rtol=1e-3 |
| ), f"EDM: max_sigma ({scheduler.max_sigma}) should equal max_t ({scheduler.max_t})" |
|
|
| if scheduler.max_t is not None: |
| assert scheduler.max_t <= 100.0, f"EDM: max_t should be reasonable (≤100), got {scheduler.max_t}" |
|
|
|
|
| def _validate_rectified_flow_schedule(scheduler, device, t, alpha_t, sigma_t): |
| """Validate Rectified Flow-specific mathematical properties.""" |
| |
|
|
| |
| assert torch.allclose( |
| alpha_t + sigma_t, torch.ones_like(alpha_t), rtol=1e-4, atol=1e-5 |
| ), f"RF: α(t) + σ(t) should equal 1, got α+σ = {(alpha_t + sigma_t).tolist()}" |
|
|
| |
| assert torch.allclose( |
| alpha_t, 1.0 - t, rtol=1e-4, atol=1e-5 |
| ), f"RF: α(t) should equal 1-t, got α(t)={alpha_t.tolist()}, expected={1.0 - t}" |
| assert torch.allclose( |
| sigma_t, t, rtol=1e-4, atol=1e-5 |
| ), f"RF: σ(t) should equal t, got σ(t)={sigma_t.tolist()}, expected={t.tolist()}" |
|
|
| |
| max_t_tensor = torch.tensor(scheduler.max_t, dtype=torch.float32) |
| max_sigma_tensor = torch.tensor(scheduler.max_sigma, dtype=torch.float32) |
| assert torch.allclose( |
| max_sigma_tensor, max_t_tensor, rtol=1e-3, atol=1e-4 |
| ), f"RF: max_sigma should equal max_t, got max_σ={scheduler.max_sigma}, max_t={scheduler.max_t}" |
|
|
| |
| if scheduler.min_t <= 0.01: |
| near_zero_t = torch.tensor([scheduler.min_t]).to(device) |
| alpha_zero = scheduler.alpha(near_zero_t) |
| sigma_zero = scheduler.sigma(near_zero_t) |
| expected_alpha_zero = 1.0 - scheduler.min_t |
| assert torch.allclose(alpha_zero, torch.tensor([expected_alpha_zero]).to(device), rtol=1e-2), ( |
| f"RF: α(t≈0) should ≈ 1-min_t, got " |
| f"α({scheduler.min_t})={alpha_zero.item():.4f}, " |
| f"expected {expected_alpha_zero:.4f}" |
| ) |
| assert torch.allclose(sigma_zero, near_zero_t, rtol=1e-2), ( |
| f"RF: σ(t≈0) should ≈ min_t, " |
| f"got σ({scheduler.min_t})={sigma_zero.item():.4f}, " |
| f"expected {scheduler.min_t}" |
| ) |
|
|
| |
| variance_sum = alpha_t**2 + sigma_t**2 |
| expected_variance = 1 - 2 * t + 2 * t**2 |
| assert torch.allclose(variance_sum, expected_variance, rtol=1e-3, atol=1e-4), ( |
| f"RF: α²(t) + σ²(t) should equal 1-2t+2t², got " |
| f"{variance_sum.tolist()}, expected {expected_variance.tolist()}" |
| ) |
|
|
|
|
| def _validate_ddpm_based_schedule(scheduler, alpha_t, sigma_t, alpha_min, alpha_max, sigma_min, sigma_max): |
| """Validate DDPM-based schedule properties (shared by SD, CogVideoX, Alphas).""" |
| |
| alpha_squared_plus_sigma_squared = alpha_t**2 + sigma_t**2 |
| expected_ones = torch.ones_like(alpha_squared_plus_sigma_squared) |
| assert torch.allclose( |
| alpha_squared_plus_sigma_squared, expected_ones, rtol=1e-2, atol=1e-3 |
| ), f"DDPM: α²(t) + σ²(t) should equal 1, got {alpha_squared_plus_sigma_squared}" |
|
|
| |
| assert alpha_min >= alpha_max, ( |
| f"DDPM: alpha should decrease with t, " |
| f"got α({scheduler.min_t})={alpha_min.item():.4f} < " |
| f"α({scheduler.max_t})={alpha_max.item():.4f}" |
| ) |
| assert sigma_min <= sigma_max, ( |
| f"DDPM: sigma should increase with t, " |
| f"got σ({scheduler.min_t})={sigma_min.item():.4f} " |
| f"> σ({scheduler.max_t})={sigma_max.item():.4f}" |
| ) |
|
|
|
|
| def validate_noise_scheduler_properties(teacher, device, expected_schedule_type=None): |
| """ |
| Comprehensive and consistent noise scheduler testing helper. |
| |
| Args: |
| teacher: The instantiated network model |
| device: Device to run tests on |
| expected_schedule_type: Expected schedule type string (e.g., "edm", "rf", "sd") |
| """ |
| |
| assert hasattr(teacher, "noise_scheduler"), "Model should have noise_scheduler attribute" |
|
|
| scheduler = teacher.noise_scheduler |
| assert teacher.schedule_type == expected_schedule_type |
| assert scheduler.max_t is not None and scheduler.min_t is not None |
|
|
| |
| _validate_basic_scheduler_properties(scheduler, device) |
|
|
| |
| t, alpha_t, sigma_t = _validate_time_sampling_and_functions(scheduler, device) |
|
|
| |
| alpha_min, alpha_max, sigma_min, sigma_max = _validate_boundary_values(scheduler, device) |
|
|
| |
| if expected_schedule_type in ["edm"]: |
| _validate_edm_schedule(scheduler, t, alpha_t, sigma_t, sigma_min, sigma_max) |
|
|
| elif expected_schedule_type in ["rf", "rectified_flow"]: |
| _validate_rectified_flow_schedule(scheduler, device, t, alpha_t, sigma_t) |
|
|
| elif expected_schedule_type in ["sd", "sdxl"]: |
| _validate_ddpm_based_schedule(scheduler, alpha_t, sigma_t, alpha_min, alpha_max, sigma_min, sigma_max) |
|
|
| elif expected_schedule_type in ["cogvideox"]: |
| _validate_ddpm_based_schedule(scheduler, alpha_t, sigma_t, alpha_min, alpha_max, sigma_min, sigma_max) |
|
|
| elif expected_schedule_type in ["alphas"]: |
| _validate_ddpm_based_schedule(scheduler, alpha_t, sigma_t, alpha_min, alpha_max, sigma_min, sigma_max) |
|
|
| else: |
| raise ValueError(f"Unrecognized schedule type: {expected_schedule_type}") |
|
|
|
|
| def test_network_edm_cifar10(): |
| teacher_config = EDM_CIFAR10_Config |
|
|
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
|
| |
| teacher_config = override_config_with_opts( |
| teacher_config, |
| ["-", "img_resolution=2", "model_channels=32", "channel_mult=[1]", "channel_mult_noise=1", "r_timestep=False"], |
| ) |
|
|
| teacher = instantiate(teacher_config) |
| teacher = teacher.to(device=device, dtype=dtype) |
|
|
| |
| validate_noise_scheduler_properties(teacher, device, expected_schedule_type="edm") |
|
|
| batch_size = 1 |
| x = torch.randn(batch_size, 3, 2, 2, device=device, dtype=dtype) |
| |
| t = teacher.noise_scheduler.sample_t(batch_size, time_dist_type="polynomial").to(device=device, dtype=dtype) |
| labels = torch.randint(0, 10, (batch_size,), device=device) |
| |
| labels = torch.nn.functional.one_hot(labels, num_classes=10).to(dtype=dtype) |
| output = teacher(x, t, labels) |
|
|
| assert output.shape == x.shape |
| assert output.device == x.device |
|
|
| |
| output = teacher(x, t, labels, return_features_early=True, feature_indices=set()) |
| assert isinstance(output, list) |
|
|
|
|
| def test_network_edm_imagenet64(): |
| teacher_config = EDM_ImageNet64_Config |
|
|
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
|
| |
| teacher_config = override_config_with_opts( |
| teacher_config, |
| ["-", "img_resolution=2", "model_channels=32", "channel_mult=[1]", "num_blocks=1", "r_timestep=False"], |
| ) |
|
|
| teacher = instantiate(teacher_config) |
| teacher = teacher.to(device=device, dtype=dtype) |
|
|
| |
| validate_noise_scheduler_properties(teacher, device, expected_schedule_type="edm") |
|
|
| batch_size = 1 |
| x = torch.randn(batch_size, 3, 2, 2, device=device, dtype=dtype) |
| |
| t = teacher.noise_scheduler.sample_t(batch_size, time_dist_type="lognormal").to(device=device, dtype=dtype) |
| labels = torch.randint(0, 1000, (batch_size,), device=device) |
| |
| labels = torch.nn.functional.one_hot(labels, num_classes=1000).to(dtype=dtype) |
|
|
| output = teacher(x, t, labels) |
|
|
| assert output.shape == x.shape |
| assert output.device == x.device |
|
|
| |
| output = teacher(x, t, labels, return_features_early=True, feature_indices=set()) |
| assert isinstance(output, list) |
|
|
|
|
| def test_network_edm2_in64(): |
| teacher_config = EDM2_IN64_S_Config |
|
|
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
|
| |
| teacher_config = override_config_with_opts( |
| teacher_config, ["-", "img_resolution=2", "model_channels=32", "channel_mult=[1]", "num_blocks=1"] |
| ) |
|
|
| teacher = instantiate(teacher_config) |
| teacher = teacher.to(device=device, dtype=dtype) |
|
|
| |
| validate_noise_scheduler_properties(teacher, device, expected_schedule_type="edm") |
|
|
| batch_size = 1 |
| x = torch.randn(batch_size, 3, 2, 2, device=device, dtype=dtype) |
| |
| t = teacher.noise_scheduler.sample_t(batch_size, time_dist_type="uniform").to(device=device, dtype=dtype) |
| labels = torch.randint(0, 1000, (batch_size,), device=device) |
| |
| labels = torch.nn.functional.one_hot(labels, num_classes=1000).to(dtype=dtype) |
| output = teacher(x, t, labels) |
|
|
| assert output.shape == x.shape |
| assert output.device == x.device |
|
|
| |
| output_empty = teacher(x, t, labels, feature_indices=set()) |
| assert output_empty.shape == torch.Size([batch_size, 3, 2, 2]) |
|
|
| |
| |
| features = teacher(x, t, labels, return_features_early=True, feature_indices=set()) |
| assert isinstance(features, list) |
| assert len(features) == 0 |
|
|
|
|
| def test_network_dit_in256_xl(): |
| """ |
| Lightweight test that mocks the VAE to avoid downloading models. |
| """ |
| teacher_config = DiT_IN256_XL_Config |
|
|
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| |
| dtype = torch.float16 |
|
|
| |
| teacher_config = override_config_with_opts( |
| teacher_config, ["-", "input_size=2", "hidden_size=32", "depth=1", "num_heads=1"] |
| ) |
|
|
| |
| with patch("diffusers.AutoencoderKL.from_pretrained") as mock_vae_from_pretrained: |
| mock_vae = MagicMock() |
| mock_vae.decode.return_value = torch.randn(1, 3, 16, 16, device=device, dtype=dtype) |
| mock_vae_from_pretrained.return_value = mock_vae |
|
|
| teacher = instantiate(teacher_config) |
| teacher.vae = mock_vae |
|
|
| |
| teacher = teacher.to(device=device, dtype=dtype) |
|
|
| batch_size = 1 |
| |
| x = torch.randn(batch_size, 4, 2, 2, device=device, dtype=dtype) |
| |
| t = teacher.noise_scheduler.sample_t(batch_size, time_dist_type="uniform") |
| t = t.to(device=device, dtype=dtype) |
|
|
| |
| validate_noise_scheduler_properties(teacher, device, expected_schedule_type="rf") |
|
|
| labels = torch.randint(0, 1000, (batch_size,), device=device) |
| labels = torch.nn.functional.one_hot(labels, num_classes=1000).to(device=device, dtype=dtype) |
|
|
| |
| output = teacher(x, t, labels) |
| assert output.shape == x.shape |
| assert output.device == x.device |
|
|
| |
| output, logvar = teacher(x, t, labels, return_logvar=True) |
| assert output.shape == x.shape |
| assert logvar.shape == torch.Size([batch_size, 1]) |
|
|
|
|
| @RunIf(min_gpus=1) |
| @pytest.mark.large_model |
| def test_network_sd15(): |
| teacher_config = SD15Config |
|
|
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
|
| |
| |
| teacher = instantiate(teacher_config) |
| teacher.init_preprocessors() |
| |
| teacher = teacher.to(device=device) |
| teacher.vae.to(device=device, dtype=dtype) |
| teacher.text_encoder.to(device=device, dtype=dtype) |
|
|
| |
| validate_noise_scheduler_properties(teacher, device, expected_schedule_type="sd") |
|
|
| batch_size = 1 |
| x = torch.randn(batch_size, 4, 8, 8, device=device, dtype=dtype) |
| t = teacher.noise_scheduler.sample_t(batch_size, time_dist_type="uniform").to(device) |
|
|
| captions = ["a caption"] |
| condition = teacher.text_encoder.encode(captions) |
|
|
| |
| assert isinstance(condition, tuple) and len(condition) == 2 |
| embeddings, attention_mask = condition |
| embeddings = embeddings.to(device=device, dtype=dtype) |
| attention_mask = attention_mask.to(device=device, dtype=dtype) |
| condition = (embeddings, attention_mask) |
|
|
| |
| with torch.autocast(device_type=device.type, dtype=dtype): |
| output = teacher(x, t, condition=condition) |
|
|
| assert output.shape == x.shape |
| assert output.device == x.device |
|
|
| teacher = teacher.to(device=device, dtype=dtype) |
| |
| output = teacher(x, t, condition=condition) |
|
|
| output = teacher(x, t, condition=condition, return_features_early=True, feature_indices=set()) |
|
|
| assert isinstance(output, list) |
|
|
|
|
| @RunIf(min_gpus=1) |
| @pytest.mark.large_model |
| def test_network_flux(): |
| """Test Flux network for text-to-image generation.""" |
| |
| clear_gpu_memory() |
|
|
| teacher_config = FluxConfig |
|
|
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
|
| |
| if torch.cuda.is_available(): |
| total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) |
| if total_memory < 40: |
| pytest.skip(f"Test skipped: Flux model requires ~40GB GPU memory, but only {total_memory:.1f}GB available") |
|
|
| |
| try: |
| teacher = instantiate(teacher_config) |
| except OSError as e: |
| if "not a valid model identifier" in str(e) or "token" in str(e): |
| pytest.skip(f"Test skipped: Flux model not accessible (requires HuggingFace authentication): {e}") |
| raise |
| teacher.init_preprocessors() |
| |
| teacher = teacher.to(device=device) |
| teacher.vae.to(device=device, dtype=dtype) |
| teacher.text_encoder.to(device=device, dtype=dtype) |
|
|
| |
| validate_noise_scheduler_properties(teacher, device, expected_schedule_type="rf") |
|
|
| batch_size = 1 |
| |
| x = torch.randn(batch_size, 16, 8, 8, device=device, dtype=dtype) |
| t = teacher.noise_scheduler.sample_t(batch_size, time_dist_type="uniform").to(device=device, dtype=dtype) |
|
|
| captions = ["a caption"] |
| condition = teacher.text_encoder.encode(captions) |
|
|
| guidance_scale = 3.5 |
| guidance_tensor = torch.full((batch_size,), guidance_scale, device=x.device, dtype=x.dtype) |
|
|
| |
| assert isinstance(condition, tuple) and len(condition) == 2 |
| pooled_prompt_embeds, prompt_embeds = condition |
| pooled_prompt_embeds = pooled_prompt_embeds.to(device=device, dtype=dtype) |
| prompt_embeds = prompt_embeds.to(device=device, dtype=dtype) |
| condition = (pooled_prompt_embeds, prompt_embeds) |
|
|
| |
| with torch.autocast(device_type=device.type, dtype=dtype): |
| output = teacher(x, t, condition, guidance=guidance_tensor) |
|
|
| assert output.shape == x.shape |
| assert output.device == x.device |
|
|
| teacher = teacher.to(device=device, dtype=dtype) |
|
|
| |
| output, logvar = teacher(x, t, condition, guidance=guidance_tensor, return_logvar=True) |
| assert output.shape == x.shape |
| assert logvar.shape == torch.Size([batch_size, 1]) |
|
|
| |
| output = teacher(x, t, condition, guidance=guidance_tensor, return_features_early=True, feature_indices=set()) |
| assert isinstance(output, list) |
| assert len(output) == 0 |
|
|
| |
| output = teacher(x, t, condition, guidance=guidance_tensor, return_features_early=False, feature_indices={0}) |
| assert isinstance(output, list) and len(output) == 2 |
| assert output[0].shape == x.shape |
| assert isinstance(output[1], list) and len(output[1]) == 1 |
|
|
| |
| features = teacher(x, t, condition, guidance=guidance_tensor, return_features_early=True, feature_indices={0}) |
| assert isinstance(features, list) |
| assert len(features) == 1 |
|
|
| |
| clear_gpu_memory() |
|
|
|
|
| @RunIf(min_gpus=1) |
| @pytest.mark.large_model |
| def test_network_cogvideox(): |
| |
| clear_gpu_memory() |
|
|
| set_env_vars() |
| teacher_config = CogVideoXConfig |
|
|
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
|
| |
| teacher = instantiate(teacher_config) |
|
|
| teacher.init_preprocessors() |
| |
| teacher = teacher.to(device=device) |
| teacher.vae.to(device=device, dtype=dtype) |
| teacher.text_encoder.to(device=device, dtype=dtype) |
|
|
| |
| validate_noise_scheduler_properties(teacher, device, expected_schedule_type="cogvideox") |
|
|
| batch_size = 1 |
| C, T, H, W = 16, 2, 4, 4 |
|
|
| |
| x = torch.randn(batch_size, C, T, H, W, device=device, dtype=dtype) |
|
|
| |
| t = teacher.noise_scheduler.sample_t(batch_size, time_dist_type="uniform").to(device=device, dtype=dtype) |
|
|
| captions = ["a caption"] |
| condition = teacher.text_encoder.encode(captions) |
| |
| if isinstance(condition, tuple): |
| condition = condition[0] |
| condition = condition.to(device=device, dtype=dtype) |
|
|
| |
| with torch.autocast(device_type=device.type, dtype=dtype): |
| output = teacher(x, t, condition=condition) |
|
|
| assert output.shape == x.shape |
| assert output.device == x.device |
|
|
| teacher = teacher.to(device=device, dtype=dtype) |
|
|
| |
| output = teacher(x, t, condition=condition, return_features_early=False, feature_indices={0}) |
| assert output[0].shape == x.shape |
| assert isinstance(output[1], list) and len(output[1]) == 1 |
| for feature in output[1]: |
| assert feature.shape == (batch_size, 480, T, H, W) |
|
|
| |
| output = teacher(x, t, condition=condition, return_features_early=True, feature_indices={0}) |
| assert isinstance(output, list) |
| for feature in output: |
| |
| expected_channels = 480 |
| assert feature.shape == (batch_size, expected_channels, T, H, W) |
|
|
| |
| clear_gpu_memory() |
|
|
|
|
| @RunIf(min_gpus=1) |
| @pytest.mark.large_model |
| def test_network_wan(): |
| |
| clear_gpu_memory() |
|
|
| set_env_vars() |
| teacher_config = Wan_1_3B_Config |
|
|
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
|
| |
| teacher = instantiate(teacher_config) |
|
|
| teacher.init_preprocessors() |
| |
| teacher = teacher.to(device=device) |
| teacher.vae.to(device=device, dtype=dtype) |
| teacher.text_encoder.to(device=device, dtype=dtype) |
|
|
| |
| validate_noise_scheduler_properties(teacher, device, expected_schedule_type="rf") |
|
|
| batch_size = 1 |
| T, H, W = 2, 4, 4 |
|
|
| x = torch.randn(batch_size, 16, T, H, W, device=device, dtype=dtype) |
| |
| t = teacher.noise_scheduler.sample_t(batch_size, time_dist_type="uniform").to(device=device, dtype=dtype) |
|
|
| condition = teacher.text_encoder.encode(["a caption"]) |
| |
| if isinstance(condition, tuple): |
| condition = condition[0] |
| condition = condition.to(device=device, dtype=dtype) |
|
|
| |
| with torch.autocast(device_type=device.type, dtype=dtype): |
| output = teacher(x, t, condition=condition) |
|
|
| assert output.shape == x.shape |
| assert output.device == x.device |
|
|
| teacher = teacher.to(device=device, dtype=dtype) |
|
|
| |
| output = teacher(x, t, condition=condition) |
|
|
| output = teacher(x, t, condition=condition, return_features_early=True, feature_indices={0}) |
|
|
| assert isinstance(output, list) |
|
|
| for feature in output: |
| |
| expected_channels = 384 |
| assert feature.shape == (batch_size, expected_channels, T, H, W) |
|
|
| |
| teacher_config.schedule_type = "rf" |
| teacher = instantiate(teacher_config) |
| teacher.init_preprocessors() |
| teacher = teacher.to(device=device, dtype=dtype) |
|
|
| |
| validate_noise_scheduler_properties(teacher, device, expected_schedule_type="rf") |
|
|
| |
| clear_gpu_memory() |
|
|
|
|
| @RunIf(min_gpus=1) |
| @pytest.mark.large_model |
| def test_network_vace_wan(): |
| |
| clear_gpu_memory() |
|
|
| |
| if torch.cuda.is_available(): |
| free_memory = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / (1024**3) |
| if free_memory < 20: |
| pytest.skip(f"Test skipped: requires ~20GB free GPU memory, but only {free_memory:.1f}GB available") |
|
|
| set_env_vars() |
| teacher_config = VACE_Wan_1_3B_Config |
|
|
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
|
| teacher = instantiate(teacher_config) |
|
|
| teacher.init_preprocessors() |
| |
| teacher = teacher.to(device=device) |
| teacher.vae.to(device=device, dtype=dtype) |
| teacher.text_encoder.to(device=device, dtype=dtype) |
|
|
| |
| validate_noise_scheduler_properties(teacher, device, expected_schedule_type="rf") |
|
|
| batch_size = 1 |
| |
| x = torch.randn(batch_size, 16, 2, 4, 4, device=device, dtype=dtype) |
| |
| t = teacher.noise_scheduler.sample_t(batch_size, time_dist_type="logitnormal").to(device=device, dtype=dtype) |
|
|
| |
| captions = ["a caption"] |
| text_embeds = teacher.text_encoder.encode(captions).to(device=device, dtype=dtype) |
|
|
| |
| |
| context_video = torch.randn(batch_size, 3, 2, 16, 16, device=device, dtype=dtype) |
| context_video = torch.clamp(context_video, -1, 1) |
|
|
| |
| vid_context = teacher.prepare_vid_conditioning(context_video) |
|
|
| |
| condition = {"text_embeds": text_embeds, "vid_context": vid_context} |
|
|
| |
| with torch.autocast(device_type=device.type, dtype=dtype): |
| output = teacher(x, t, condition=condition) |
|
|
| assert output.shape == x.shape |
| assert output.device == x.device |
|
|
| teacher = teacher.to(device=device, dtype=dtype) |
|
|
| |
| output = teacher(x, t, condition=condition) |
|
|
| output = teacher(x, t, condition=condition, return_features_early=True, feature_indices={0}) |
|
|
| assert isinstance(output, list) |
|
|
| expected_channels = 384 |
| for feature in output: |
| assert feature.shape == (batch_size, expected_channels, 2, 4, 4) |
|
|
| |
| clear_gpu_memory() |
|
|
|
|
| @RunIf(min_gpus=1) |
| def test_network_discriminator_wan(): |
| """ |
| Lightweight unit test for Discriminator_Wan implementation. |
| |
| Tests core functionality with minimal memory usage. |
| """ |
| set_env_vars() |
|
|
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
|
| discriminator_config = Discriminator_Wan_1_3B_Config |
|
|
| |
| batch_size = 1 |
| inner_dim = 16 |
| T, H, W = 8, 8, 8 |
|
|
| |
| dummy_features = [torch.randn(batch_size, inner_dim, T, H, W, device=device, dtype=dtype)] |
|
|
| |
| efficient_architectures = [ |
| "conv3d_down_mlp_efficient", |
| "multiscale_down_mlp_efficient", |
| ] |
|
|
| |
| for arch_name in efficient_architectures: |
| |
| |
| discriminator_config = override_config_with_opts( |
| discriminator_config, |
| ["-", f"++disc_type={arch_name}", f"++inner_dim={inner_dim}", "++num_blocks=2", "++feature_indices=[0]"], |
| ) |
|
|
| |
| discriminator = instantiate(discriminator_config) |
|
|
| |
| discriminator = discriminator.to(device=device, dtype=dtype) |
|
|
| |
| with torch.no_grad(): |
| output = discriminator(dummy_features) |
|
|
| |
| expected_shape = torch.Size([batch_size, 1]) |
| assert output.shape == expected_shape, f"Expected {expected_shape}, got {output.shape}" |
|
|
| |
| assert torch.isfinite(output).all(), "Output contains NaN or Inf values" |
|
|
| |
| total_params = sum(p.numel() for p in discriminator.parameters() if p.requires_grad) |
| assert total_params > 10, f"Too few parameters: {total_params}" |
| assert total_params < 50_000_000, f"Too many parameters: {total_params}" |
|
|
| |
| discriminator_config = override_config_with_opts( |
| discriminator_config, |
| [ |
| "-", |
| "++disc_type=factorized_down_mlp_efficient", |
| f"++inner_dim={inner_dim}", |
| "++num_blocks=2", |
| "++feature_indices=[0,1]", |
| ], |
| ) |
|
|
| multi_head_discriminator = instantiate(discriminator_config) |
| multi_head_discriminator = multi_head_discriminator.to(device=device, dtype=dtype) |
|
|
| |
| multi_head_features = [torch.randn(batch_size, inner_dim, T, H, W, device=device, dtype=dtype) for _ in range(2)] |
|
|
| with torch.no_grad(): |
| output = multi_head_discriminator(multi_head_features) |
|
|
| |
| expected_shape = torch.Size([batch_size, 2]) |
| assert output.shape == expected_shape, f"Expected {expected_shape}, got {output.shape}" |
| assert torch.isfinite(output).all(), "Multi-head output contains NaN or Inf values" |
|
|
|
|
| def test_network_discriminator_edm(): |
| """ |
| Lightweight unit test for Discriminator_EDM implementation. |
| |
| Tests core functionality with minimal memory usage for both CIFAR10 and ImageNet64 configurations. |
| """ |
| set_env_vars() |
|
|
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
|
| |
| test_configs = [ |
| { |
| "name": "CIFAR10", |
| "config": Discriminator_EDM_CIFAR10_Config, |
| "resolutions": [32, 16, 8], |
| "in_channels": 256, |
| "feature_indices": {0, 1, 2}, |
| }, |
| { |
| "name": "ImageNet64", |
| "config": Discriminator_EDM_ImageNet64_Config, |
| "resolutions": [64, 32, 16, 8], |
| "in_channels": 768, |
| "feature_indices": None, |
| }, |
| ] |
|
|
| batch_size = 1 |
|
|
| for test_case in test_configs: |
| print(f"Testing EDM Discriminator {test_case['name']} configuration...") |
|
|
| config = test_case["config"] |
| resolutions = test_case["resolutions"] |
| in_channels = test_case["in_channels"] |
| feature_indices = test_case["feature_indices"] |
|
|
| |
| test_in_channels = min(in_channels, 128) |
|
|
| |
| config = override_config_with_opts( |
| config, |
| ["-", f"in_channels={test_in_channels}", f"all_res={resolutions}"], |
| ) |
|
|
| |
| discriminator = instantiate(config) |
| discriminator = discriminator.to(device=device, dtype=dtype) |
|
|
| |
| if feature_indices is None: |
| |
| test_feature_indices = [len(resolutions) - 1] |
| else: |
| |
| test_feature_indices = sorted([i for i in feature_indices if i < len(resolutions)]) |
|
|
| |
| |
| dummy_features = [] |
| for idx in test_feature_indices: |
| res = resolutions[idx] |
| |
| feature = torch.randn(batch_size, test_in_channels, res, res, device=device, dtype=dtype) |
| dummy_features.append(feature) |
|
|
| |
| with torch.no_grad(): |
| output = discriminator(dummy_features) |
|
|
| |
| expected_num_heads = len(test_feature_indices) |
| expected_shape = torch.Size([batch_size, expected_num_heads]) |
| assert output.shape == expected_shape, f"Expected {expected_shape}, got {output.shape}" |
|
|
| |
| assert torch.isfinite(output).all(), f"Output contains NaN or Inf values for {test_case['name']}" |
|
|
| |
| assert output.abs().max() < 100, f"Output values seem too large for {test_case['name']}: {output}" |
|
|
|
|
| @RunIf(min_gpus=1) |
| @pytest.mark.large_model |
| def test_network_causal_wan(): |
| """ |
| Test CausalWan network, specifically the sample method. |
| """ |
| |
| clear_gpu_memory() |
|
|
| |
| if torch.cuda.is_available(): |
| free_memory = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / (1024**3) |
| if free_memory < 20: |
| pytest.skip(f"Test skipped: requires ~20GB free GPU memory, but only {free_memory:.1f}GB available") |
|
|
| set_env_vars() |
| teacher_config = CausalWan_1_3B_Config |
|
|
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
|
| |
| teacher = instantiate(teacher_config) |
| teacher.init_preprocessors() |
| |
| teacher = teacher.to(device=device) |
| teacher.vae.to(device=device, dtype=dtype) |
| teacher.text_encoder.to(device=device, dtype=dtype) |
|
|
| |
| validate_noise_scheduler_properties(teacher, device, expected_schedule_type="rf") |
|
|
| batch_size = 1 |
| C, T, H, W = 16, 3, 4, 4 |
|
|
| |
| x = torch.randn(batch_size, C, T, H, W, device=device, dtype=dtype) |
|
|
| |
| t_1d = teacher.noise_scheduler.sample_t(batch_size, time_dist_type="uniform").to(device=device, dtype=dtype) |
| t = t_1d.unsqueeze(1).expand(batch_size, T) |
|
|
| captions = ["a test caption for causal WAN"] |
| condition = teacher.text_encoder.encode(captions) |
| if isinstance(condition, tuple): |
| condition = condition[0] |
| condition = condition.to(device=device, dtype=dtype) |
|
|
| |
| neg_condition = teacher.text_encoder.encode([""]) |
| if isinstance(neg_condition, tuple): |
| neg_condition = neg_condition[0] |
| neg_condition = neg_condition.to(device=device, dtype=dtype) |
|
|
| |
| with torch.autocast(device_type=device.type, dtype=dtype): |
| output = teacher(x, t, condition=condition) |
|
|
| assert output.shape == x.shape, f"Expected shape {x.shape}, got {output.shape}" |
| assert output.device == x.device, f"Expected device {x.device}, got {output.device}" |
|
|
| teacher = teacher.to(device=device, dtype=dtype) |
| |
| output = teacher(x, t, condition=condition) |
|
|
| |
| output_with_kv = teacher(x, t, condition=condition, store_kv=True) |
| assert output_with_kv.shape == x.shape, f"Expected shape {x.shape}, got {output_with_kv.shape}" |
|
|
| |
| original_noise = torch.randn(batch_size, C, T, H, W, device=device, dtype=dtype) |
| with torch.no_grad(): |
| ar_output = teacher.sample( |
| noise=original_noise, |
| condition=condition, |
| neg_condition=neg_condition, |
| ) |
|
|
| assert ar_output.shape == original_noise.shape, f"Expected shape {original_noise.shape}, got {ar_output.shape}" |
| assert ar_output.device == original_noise.device, f"Expected device {original_noise.device}, got {ar_output.device}" |
| assert ar_output.dtype == original_noise.dtype, f"Expected dtype {original_noise.dtype}, got {ar_output.dtype}" |
|
|
| |
| t_inhom, idx = teacher.noise_scheduler.sample_t_inhom(batch_size, T, teacher.chunk_size, sample_steps=4) |
| t_inhom = t_inhom.to(device=device, dtype=dtype) |
| t_inhom_reshaped = t_inhom[:, None, :, None, None] |
|
|
| eps_inhom = torch.randn_like(x) |
| noisy = teacher.noise_scheduler.forward_process(x, eps_inhom, t_inhom_reshaped) |
| assert noisy.shape == x.shape and noisy.device == x.device and noisy.dtype == x.dtype |
|
|
| |
| output_inhom = teacher(x, t_inhom, condition=condition) |
| assert output_inhom.shape == x.shape |
|
|
| |
| output_features = teacher(x, t, condition=condition, return_features_early=True, feature_indices={0}) |
| assert isinstance(output_features, list), "Feature extraction should return a list" |
|
|
| |
| assert hasattr(teacher, "chunk_size"), "CausalWan should have chunk_size attribute" |
| assert teacher.chunk_size == 3, f"Expected chunk_size=3, got {teacher.chunk_size}" |
|
|
| |
| single_frame_latents = torch.randn(batch_size, C, 1, H, W, device=device, dtype=dtype) |
| with torch.no_grad(): |
| single_frame_output = teacher.sample( |
| noise=single_frame_latents, |
| condition=condition, |
| neg_condition=neg_condition, |
| ) |
| assert single_frame_output.shape == single_frame_latents.shape |
|
|
| |
| odd_frames_latents = torch.randn(batch_size, C, 5, H, W, device=device, dtype=dtype) |
| with torch.no_grad(): |
| odd_output = teacher.sample( |
| noise=odd_frames_latents, |
| condition=condition, |
| neg_condition=neg_condition, |
| ) |
| assert odd_output.shape == odd_frames_latents.shape |
|
|
| |
| clear_gpu_memory() |
|
|
|
|
| @RunIf(min_gpus=1) |
| @pytest.mark.large_model |
| def test_network_wan22_5b_i2v(): |
| |
| clear_gpu_memory() |
|
|
| set_env_vars() |
| teacher_config = Wan22_I2V_5B_Config |
|
|
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
|
| |
| teacher = instantiate(teacher_config) |
|
|
| |
| teacher.transformer.blocks = teacher.transformer.blocks[:1] |
|
|
| teacher.init_preprocessors() |
| |
| teacher = teacher.to(device=device) |
| teacher.vae.to(device=device, dtype=dtype) |
| teacher.text_encoder.to(device=device, dtype=dtype) |
|
|
| |
| validate_noise_scheduler_properties(teacher, device, expected_schedule_type="rf") |
|
|
| batch_size = 1 |
| num_frames, height, width = 5, 32, 32 |
| ( |
| T, |
| H, |
| W, |
| ) = (num_frames + 3) // 4, height // 16, width // 16 |
|
|
| x = torch.randn(batch_size, 48, T, H, W, device=device, dtype=dtype) |
|
|
| |
| t = teacher.noise_scheduler.sample_t(batch_size, time_dist_type="uniform").to(device=device, dtype=dtype) |
|
|
| |
| text_embeds = teacher.text_encoder.encode(["a caption"]) |
|
|
| |
| image = torch.zeros(batch_size, 3, height, width, device=device, dtype=dtype) |
| image = image.unsqueeze(2) |
| first_frame_cond = image |
| first_frame_cond = first_frame_cond.to(device=device, dtype=dtype) |
| first_frame_cond = teacher.vae.encode(first_frame_cond) |
|
|
| |
| if isinstance(text_embeds, tuple): |
| text_embeds = text_embeds[0] |
| text_embeds = text_embeds.to(device=device, dtype=dtype) |
| condition = dict( |
| text_embeds=text_embeds, |
| first_frame_cond=first_frame_cond, |
| ) |
|
|
| |
| with torch.autocast(device_type=device.type, dtype=dtype): |
| output = teacher(x, t, condition=condition) |
|
|
| assert output.shape == x.shape |
| assert output.device == x.device |
|
|
| teacher = teacher.to(device=device, dtype=dtype) |
|
|
| |
| output = teacher(x, t, condition=condition) |
|
|
| output = teacher( |
| x, |
| t, |
| condition=condition, |
| return_features_early=True, |
| feature_indices={0}, |
| ) |
|
|
| assert isinstance(output, list) |
|
|
| for feature in output: |
| |
| expected_channels = 768 |
| assert feature.shape == (batch_size, expected_channels, T, H, W) |
|
|
| |
| teacher_config.schedule_type = "rf" |
| teacher = instantiate(teacher_config) |
| teacher.init_preprocessors() |
| teacher = teacher.to(device=device, dtype=dtype) |
|
|
| |
| validate_noise_scheduler_properties(teacher, device, expected_schedule_type="rf") |
|
|
| |
| del teacher |
|
|
| |
| clear_gpu_memory() |
|
|
|
|
| @RunIf(min_gpus=1) |
| @pytest.mark.large_model |
| def test_network_wan21_14b_i2v(): |
| |
| clear_gpu_memory() |
|
|
| set_env_vars() |
| teacher_config = Wan21_I2V_14B_480P_Config |
|
|
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
|
| |
| if torch.cuda.is_available(): |
| total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) |
| if total_memory < 79: |
| pytest.skip(f"Test skipped: 14B model requires ~80GB GPU memory, but only {total_memory:.1f}GB available") |
|
|
| |
| teacher = instantiate(teacher_config) |
|
|
| |
| teacher.transformer.blocks = teacher.transformer.blocks[:1] |
|
|
| teacher.init_preprocessors() |
| |
| teacher = teacher.to(device=device) |
| teacher.vae.to(device=device, dtype=dtype) |
| teacher.text_encoder.to(device=device, dtype=dtype) |
| teacher.image_encoder.to(device=device, dtype=dtype) |
|
|
| |
| validate_noise_scheduler_properties(teacher, device, expected_schedule_type="rf") |
|
|
| batch_size = 1 |
| num_frames, height, width = 5, 32, 32 |
| ( |
| T, |
| H, |
| W, |
| ) = (num_frames + 3) // 4, height // 8, width // 8 |
|
|
| x = torch.randn(batch_size, 16, T, H, W, device=device, dtype=dtype) |
|
|
| |
| t = teacher.noise_scheduler.sample_t(batch_size, time_dist_type="uniform").to(device=device, dtype=dtype) |
|
|
| |
| text_embeds = teacher.text_encoder.encode(["a caption"]) |
|
|
| |
| image = torch.zeros(batch_size, 3, height, width, device=device, dtype=dtype) |
| encoder_hidden_states_image = teacher.image_encoder.encode(image) |
|
|
| |
| image = image.unsqueeze(2) |
| first_frame_cond = image |
| |
| first_frame_cond = torch.cat( |
| [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 |
| ) |
| first_frame_cond = first_frame_cond.to(device=device, dtype=dtype) |
| first_frame_cond = teacher.vae.encode(first_frame_cond) |
|
|
| |
| if isinstance(text_embeds, tuple): |
| text_embeds = text_embeds[0] |
| text_embeds = text_embeds.to(device=device, dtype=dtype) |
| condition = dict( |
| text_embeds=text_embeds, |
| first_frame_cond=first_frame_cond, |
| encoder_hidden_states_image=encoder_hidden_states_image, |
| ) |
|
|
| |
| with torch.autocast(device_type=device.type, dtype=dtype): |
| output = teacher(x, t, condition=condition) |
|
|
| assert output.shape == x.shape |
| assert output.device == x.device |
|
|
| teacher = teacher.to(device=device, dtype=dtype) |
|
|
| |
| output = teacher(x, t, condition) |
|
|
| output = teacher( |
| x, |
| t, |
| condition=condition, |
| return_features_early=True, |
| feature_indices={0}, |
| ) |
|
|
| assert isinstance(output, list) |
|
|
| for feature in output: |
| |
| expected_channels = 1280 |
| assert feature.shape == (batch_size, expected_channels, T, H, W) |
|
|
| |
| teacher_config.schedule_type = "rf" |
| teacher = instantiate(teacher_config) |
| teacher.init_preprocessors() |
| teacher = teacher.to(device=device, dtype=dtype) |
|
|
| |
| validate_noise_scheduler_properties(teacher, device, expected_schedule_type="rf") |
|
|
| |
| del teacher |
|
|
| |
| clear_gpu_memory() |
|
|
|
|
| @RunIf(min_gpus=1) |
| @pytest.mark.large_model |
| def test_network_causal_wan22_5b_i2v(): |
| """ |
| Test CausalWanI2V network with Wan 2.2 TI2V 5B model. |
| Tests forward pass, sample, and causal-specific features. |
| """ |
| |
| clear_gpu_memory() |
|
|
| set_env_vars() |
| teacher_config = CausalWan22_I2V_5B_Config |
|
|
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
|
| teacher = instantiate(teacher_config) |
|
|
| |
| teacher.transformer.blocks = teacher.transformer.blocks[:1] |
|
|
| teacher.init_preprocessors() |
| |
| teacher = teacher.to(device=device) |
| teacher.vae.to(device=device, dtype=dtype) |
| teacher.text_encoder.to(device=device, dtype=dtype) |
|
|
| |
| validate_noise_scheduler_properties(teacher, device, expected_schedule_type="rf") |
|
|
| batch_size = 1 |
| num_frames, height, width = 5, 32, 32 |
| T, H, W = (num_frames + 3) // 4, height // 16, width // 16 |
|
|
| x = torch.randn(batch_size, 48, T, H, W, device=device, dtype=dtype) |
|
|
| |
| t_1d = teacher.noise_scheduler.sample_t(batch_size, time_dist_type="uniform").to(device=device, dtype=dtype) |
| t = t_1d.unsqueeze(1).expand(batch_size, T) |
|
|
| |
| text_embeds = teacher.text_encoder.encode(["a caption"]) |
|
|
| |
| image = torch.zeros(batch_size, 3, height, width, device=device, dtype=dtype) |
| image = image.unsqueeze(2) |
| first_frame_cond = image.to(device=device, dtype=dtype) |
| first_frame_cond = teacher.vae.encode(first_frame_cond) |
|
|
| if isinstance(text_embeds, tuple): |
| text_embeds = text_embeds[0] |
| text_embeds = text_embeds.to(device=device, dtype=dtype) |
| condition = dict( |
| text_embeds=text_embeds, |
| first_frame_cond=first_frame_cond, |
| ) |
|
|
| |
| with torch.autocast(device_type=device.type, dtype=dtype): |
| output = teacher(x, t, condition) |
|
|
| assert output.shape == x.shape, f"Expected shape {x.shape}, got {output.shape}" |
| assert output.device == x.device, f"Expected device {x.device}, got {output.device}" |
|
|
| teacher = teacher.to(device=device, dtype=dtype) |
|
|
| |
| output_with_kv = teacher(x, t, condition, store_kv=True) |
| assert output_with_kv.shape == x.shape, f"Expected shape {x.shape}, got {output_with_kv.shape}" |
|
|
| |
| teacher.clear_caches() |
|
|
| |
| neg_text_embeds = teacher.text_encoder.encode([""]) |
| if isinstance(neg_text_embeds, tuple): |
| neg_text_embeds = neg_text_embeds[0] |
| neg_text_embeds = neg_text_embeds.to(device=device, dtype=dtype) |
| neg_condition = dict( |
| text_embeds=neg_text_embeds, |
| first_frame_cond=first_frame_cond, |
| ) |
|
|
| |
| original_noise = torch.randn(batch_size, 48, T, H, W, device=device, dtype=dtype) |
| with torch.no_grad(): |
| ar_output = teacher.sample( |
| noise=original_noise, |
| condition=condition, |
| neg_condition=neg_condition, |
| sample_steps=2, |
| ) |
|
|
| assert ar_output.shape == original_noise.shape, f"Expected shape {original_noise.shape}, got {ar_output.shape}" |
| assert ar_output.device == original_noise.device |
| assert ar_output.dtype == original_noise.dtype |
|
|
| |
| output_features = teacher(x, t, condition, return_features_early=True, feature_indices={0}) |
| assert isinstance(output_features, list), "Feature extraction should return a list" |
|
|
| for feature in output_features: |
| expected_channels = 768 |
| assert feature.shape == (batch_size, expected_channels, T, H, W) |
|
|
| |
| assert hasattr(teacher, "chunk_size"), "CausalWanI2V should have chunk_size attribute" |
| assert teacher.chunk_size == 3, f"Expected chunk_size=3, got {teacher.chunk_size}" |
|
|
| |
| del teacher |
|
|
| |
| clear_gpu_memory() |
|
|
|
|
| @RunIf(min_gpus=1) |
| @pytest.mark.large_model |
| def test_network_causal_wan21_14b_480p_i2v(): |
| """ |
| Test CausalWanI2V network with Wan 2.1 I2V 14B 480P model. |
| Tests forward pass, sample, and causal-specific features. |
| """ |
| |
| clear_gpu_memory() |
|
|
| set_env_vars() |
| teacher_config = CausalWan21_I2V_14B_480P_Config |
|
|
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
|
| |
| if torch.cuda.is_available(): |
| total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) |
| if total_memory < 79: |
| pytest.skip(f"Test skipped: 14B model requires ~80GB GPU memory, but only {total_memory:.1f}GB available") |
|
|
| teacher = instantiate(teacher_config) |
|
|
| |
| teacher.transformer.blocks = teacher.transformer.blocks[:1] |
|
|
| teacher.init_preprocessors() |
| |
| teacher = teacher.to(device=device) |
| teacher.vae.to(device=device, dtype=dtype) |
| teacher.text_encoder.to(device=device, dtype=dtype) |
| teacher.image_encoder.to(device=device, dtype=dtype) |
|
|
| |
| validate_noise_scheduler_properties(teacher, device, expected_schedule_type="rf") |
|
|
| batch_size = 1 |
| num_frames, height, width = 5, 32, 32 |
| T, H, W = (num_frames + 3) // 4, height // 8, width // 8 |
|
|
| x = torch.randn(batch_size, 16, T, H, W, device=device, dtype=dtype) |
|
|
| |
| t_1d = teacher.noise_scheduler.sample_t(batch_size, time_dist_type="uniform").to(device=device, dtype=dtype) |
| t = t_1d.unsqueeze(1).expand(batch_size, T) |
|
|
| |
| text_embeds = teacher.text_encoder.encode(["a caption"]) |
|
|
| |
| image = torch.zeros(batch_size, 3, height, width, device=device, dtype=dtype) |
| encoder_hidden_states_image = teacher.image_encoder.encode(image) |
|
|
| |
| image = image.unsqueeze(2) |
| |
| first_frame_cond = torch.cat( |
| [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 |
| ) |
| first_frame_cond = first_frame_cond.to(device=device, dtype=dtype) |
| first_frame_cond = teacher.vae.encode(first_frame_cond) |
|
|
| if isinstance(text_embeds, tuple): |
| text_embeds = text_embeds[0] |
| text_embeds = text_embeds.to(device=device, dtype=dtype) |
| condition = dict( |
| text_embeds=text_embeds, |
| first_frame_cond=first_frame_cond, |
| encoder_hidden_states_image=encoder_hidden_states_image, |
| ) |
|
|
| |
| with torch.autocast(device_type=device.type, dtype=dtype): |
| output = teacher(x, t, condition) |
|
|
| assert output.shape == x.shape, f"Expected shape {x.shape}, got {output.shape}" |
| assert output.device == x.device, f"Expected device {x.device}, got {output.device}" |
|
|
| teacher = teacher.to(device=device, dtype=dtype) |
| output = teacher(x, t, condition) |
|
|
| |
| output_with_kv = teacher(x, t, condition, store_kv=True) |
| assert output_with_kv.shape == x.shape, f"Expected shape {x.shape}, got {output_with_kv.shape}" |
|
|
| |
| teacher.clear_caches() |
|
|
| |
| neg_text_embeds = teacher.text_encoder.encode([""]) |
| if isinstance(neg_text_embeds, tuple): |
| neg_text_embeds = neg_text_embeds[0] |
| neg_text_embeds = neg_text_embeds.to(device=device, dtype=dtype) |
| neg_condition = dict( |
| text_embeds=neg_text_embeds, |
| first_frame_cond=first_frame_cond, |
| encoder_hidden_states_image=encoder_hidden_states_image, |
| ) |
|
|
| |
| original_noise = torch.randn(batch_size, 16, T, H, W, device=device, dtype=dtype) |
| with torch.no_grad(): |
| ar_output = teacher.sample( |
| noise=original_noise, |
| condition=condition, |
| neg_condition=neg_condition, |
| sample_steps=2, |
| ) |
|
|
| assert ar_output.shape == original_noise.shape, f"Expected shape {original_noise.shape}, got {ar_output.shape}" |
| assert ar_output.device == original_noise.device |
| assert ar_output.dtype == original_noise.dtype |
|
|
| |
| output_features = teacher(x, t, condition, return_features_early=True, feature_indices={0}) |
| assert isinstance(output_features, list), "Feature extraction should return a list" |
|
|
| for feature in output_features: |
| expected_channels = 1280 |
| assert feature.shape == (batch_size, expected_channels, T, H, W) |
|
|
| |
| assert hasattr(teacher, "chunk_size"), "CausalWanI2V should have chunk_size attribute" |
| assert teacher.chunk_size == 3, f"Expected chunk_size=3, got {teacher.chunk_size}" |
|
|
| |
| del teacher |
|
|
| |
| clear_gpu_memory() |
|
|
|
|
| @RunIf(min_gpus=1) |
| @pytest.mark.large_model |
| def test_network_causal_wan21_14b_720p_i2v(): |
| """ |
| Test CausalWanI2V network with Wan 2.1 I2V 14B 720P model. |
| Tests forward pass, sample, and causal-specific features. |
| """ |
| |
| clear_gpu_memory() |
|
|
| set_env_vars() |
| teacher_config = CausalWan21_I2V_14B_720P_Config |
|
|
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
|
| |
| if torch.cuda.is_available(): |
| total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) |
| if total_memory < 79: |
| pytest.skip(f"Test skipped: 14B model requires ~80GB GPU memory, but only {total_memory:.1f}GB available") |
|
|
| teacher = instantiate(teacher_config) |
|
|
| |
| teacher.transformer.blocks = teacher.transformer.blocks[:1] |
|
|
| teacher.init_preprocessors() |
| |
| teacher = teacher.to(device=device) |
| teacher.vae.to(device=device, dtype=dtype) |
| teacher.text_encoder.to(device=device, dtype=dtype) |
| teacher.image_encoder.to(device=device, dtype=dtype) |
|
|
| |
| validate_noise_scheduler_properties(teacher, device, expected_schedule_type="rf") |
|
|
| batch_size = 1 |
| num_frames, height, width = 5, 32, 32 |
| T, H, W = (num_frames + 3) // 4, height // 8, width // 8 |
|
|
| x = torch.randn(batch_size, 16, T, H, W, device=device, dtype=dtype) |
|
|
| |
| t_1d = teacher.noise_scheduler.sample_t(batch_size, time_dist_type="uniform").to(device=device, dtype=dtype) |
| t = t_1d.unsqueeze(1).expand(batch_size, T) |
|
|
| |
| text_embeds = teacher.text_encoder.encode(["a caption"]) |
|
|
| |
| image = torch.zeros(batch_size, 3, height, width, device=device, dtype=dtype) |
| encoder_hidden_states_image = teacher.image_encoder.encode(image) |
|
|
| |
| image = image.unsqueeze(2) |
| |
| first_frame_cond = torch.cat( |
| [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 |
| ) |
| first_frame_cond = first_frame_cond.to(device=device, dtype=dtype) |
| first_frame_cond = teacher.vae.encode(first_frame_cond) |
|
|
| if isinstance(text_embeds, tuple): |
| text_embeds = text_embeds[0] |
| text_embeds = text_embeds.to(device=device, dtype=dtype) |
| condition = dict( |
| text_embeds=text_embeds, |
| first_frame_cond=first_frame_cond, |
| encoder_hidden_states_image=encoder_hidden_states_image, |
| ) |
|
|
| |
| with torch.autocast(device_type=device.type, dtype=dtype): |
| output = teacher(x, t, condition) |
|
|
| assert output.shape == x.shape, f"Expected shape {x.shape}, got {output.shape}" |
| assert output.device == x.device, f"Expected device {x.device}, got {output.device}" |
|
|
| teacher = teacher.to(device=device, dtype=dtype) |
| output = teacher(x, t, condition) |
|
|
| |
| output_with_kv = teacher(x, t, condition, store_kv=True) |
| assert output_with_kv.shape == x.shape, f"Expected shape {x.shape}, got {output_with_kv.shape}" |
|
|
| |
| teacher.clear_caches() |
|
|
| |
| neg_text_embeds = teacher.text_encoder.encode([""]) |
| if isinstance(neg_text_embeds, tuple): |
| neg_text_embeds = neg_text_embeds[0] |
| neg_text_embeds = neg_text_embeds.to(device=device, dtype=dtype) |
| neg_condition = dict( |
| text_embeds=neg_text_embeds, |
| first_frame_cond=first_frame_cond, |
| encoder_hidden_states_image=encoder_hidden_states_image, |
| ) |
|
|
| |
| original_noise = torch.randn(batch_size, 16, T, H, W, device=device, dtype=dtype) |
| with torch.no_grad(): |
| ar_output = teacher.sample( |
| noise=original_noise, |
| condition=condition, |
| neg_condition=neg_condition, |
| sample_steps=2, |
| ) |
|
|
| assert ar_output.shape == original_noise.shape, f"Expected shape {original_noise.shape}, got {ar_output.shape}" |
| assert ar_output.device == original_noise.device |
| assert ar_output.dtype == original_noise.dtype |
|
|
| |
| output_features = teacher(x, t, condition, return_features_early=True, feature_indices={0}) |
| assert isinstance(output_features, list), "Feature extraction should return a list" |
|
|
| for feature in output_features: |
| expected_channels = 1280 |
| assert feature.shape == (batch_size, expected_channels, T, H, W) |
|
|
| |
| assert hasattr(teacher, "chunk_size"), "CausalWanI2V should have chunk_size attribute" |
| assert teacher.chunk_size == 3, f"Expected chunk_size=3, got {teacher.chunk_size}" |
|
|
| |
| del teacher |
|
|
| |
| clear_gpu_memory() |
|
|