| |
| |
|
|
| from typing import Callable |
| import copy |
| import gc |
| import pytest |
| import torch |
|
|
| from fastgen.configs.methods.config_dmd2 import ModelConfig |
| from fastgen.configs.net import CausalWan_1_3B_Config, Wan_1_3B_Config |
| from fastgen.configs.discriminator import Discriminator_Wan_1_3B_Config |
| from fastgen.methods import CausVidModel |
| from fastgen.utils.test_utils import check_grad_zero |
| from fastgen.utils.test_utils import RunIf |
| from fastgen.utils.io_utils import set_env_vars |
|
|
|
|
| @pytest.fixture |
| def get_model_data(): |
| gc.collect() |
| set_env_vars() |
|
|
| config = ModelConfig() |
| config.net = copy.deepcopy(CausalWan_1_3B_Config) |
| config.net.chunk_size = 2 |
| config.teacher = Wan_1_3B_Config |
|
|
| |
| config.discriminator = Discriminator_Wan_1_3B_Config |
| config.discriminator.num_blocks = 1 |
|
|
| |
| config.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| config.precision = "float32" if config.device == torch.device("cpu") else "bfloat16" |
| config.pretrained_model_path = "" |
| config.student_update_freq = 2 |
| config.student_sample_steps = 4 |
|
|
| |
| config.sample_t_cfg.max_t = 0.999 |
| config.sample_t_cfg.min_t = 0.001 |
|
|
| |
| config.input_shape = [16, 2, 4, 4] |
| config.enable_preprocessors = True |
| config.gan_loss_weight_gen = 0.001 |
|
|
| model = CausVidModel(config) |
| |
| model.net.transformer.blocks = model.net.transformer.blocks[:1] |
| model.teacher.transformer.blocks = model.teacher.transformer.blocks[:1] |
|
|
| |
| model.net.init_preprocessors() |
| model.on_train_begin() |
| model.init_optimizers() |
|
|
| batch_size = 1 |
| channels, n_frames, height, width = config.input_shape |
|
|
| |
| data = { |
| "real": torch.randn(batch_size, channels, n_frames, height, width, device=model.device, dtype=model.precision), |
| "condition": torch.randn(batch_size, 512, 4096, device=model.device, dtype=model.precision), |
| "neg_condition": torch.zeros(batch_size, 512, 4096, device=model.device, dtype=model.precision), |
| } |
|
|
| return model, data |
|
|
|
|
| @RunIf(min_gpus=1) |
| def test_causal_timestep_sampling(get_model_data): |
| model, data = get_model_data |
|
|
| |
| batch_size = data["real"].shape[0] |
| num_frames = data["real"].shape[2] |
|
|
| |
| chunk_size = getattr(model.net, "chunk_size", 2) |
|
|
| |
| t_inhom, _ = model.teacher.noise_scheduler.sample_t_inhom( |
| n=batch_size, seq_len=num_frames, chunk_size=chunk_size, sample_steps=4 |
| ) |
|
|
| |
| assert t_inhom.shape == (batch_size, num_frames) |
|
|
| |
| if num_frames > chunk_size: |
| |
| first_chunk_t = t_inhom[:, :chunk_size] |
| |
| second_chunk_t = ( |
| t_inhom[:, chunk_size : 2 * chunk_size] if num_frames >= 2 * chunk_size else t_inhom[:, chunk_size:] |
| ) |
|
|
| |
| assert torch.allclose(first_chunk_t[:, 0:1], first_chunk_t, atol=1e-6) |
| if second_chunk_t.shape[1] > 1: |
| assert torch.allclose(second_chunk_t[:, 0:1], second_chunk_t, atol=1e-6) |
| else: |
| |
| assert torch.allclose(t_inhom[:, 0:1], t_inhom, atol=1e-6) |
|
|
|
|
| @RunIf(min_gpus=1) |
| def test_single_train_step_student_update(get_model_data): |
| model, data = get_model_data |
| |
| loss_map, outputs = model.single_train_step(data, 0) |
|
|
| |
| assert "total_loss" in loss_map |
| assert "vsd_loss" in loss_map |
| assert "gan_loss_gen" in loss_map |
| assert "gen_rand" in outputs |
| assert "input_rand" in outputs |
| assert isinstance(outputs["gen_rand"], Callable) |
| assert isinstance(outputs["input_rand"], torch.Tensor) |
|
|
| gen_tensor = outputs["gen_rand"]() |
| assert gen_tensor.shape == data["real"].shape |
| assert outputs["input_rand"].shape == data["real"].shape |
|
|
|
|
| @RunIf(min_gpus=1) |
| def test_single_train_step_fake_score_update(get_model_data): |
| model, data = get_model_data |
|
|
| |
| loss_map, outputs = model.single_train_step(data, 1) |
|
|
| |
| assert "fake_score_loss" in loss_map |
| assert "gan_loss_disc" in loss_map |
| assert "gen_rand" in outputs |
| assert "input_rand" in outputs |
| assert isinstance(outputs["gen_rand"], Callable) |
| assert isinstance(outputs["input_rand"], torch.Tensor) |
|
|
| |
| assert torch.isfinite(loss_map["total_loss"]) |
| assert torch.isfinite(loss_map["fake_score_loss"]) |
|
|
|
|
| @RunIf(min_gpus=1) |
| def test_optimizers(get_model_data): |
| model, data = get_model_data |
|
|
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| |
| for iteration in range(2): |
| model.optimizers_zero_grad(iteration) |
| loss_map, _ = model.single_train_step(data, iteration) |
| model.grad_scaler.scale(loss_map["total_loss"]).backward() |
| model.optimizers_schedulers_step(iteration) |
|
|
| |
| model.optimizers_zero_grad(2) |
| check_grad_zero(model.net) |
| model.optimizers_zero_grad(3) |
| check_grad_zero(model.discriminator) |
| check_grad_zero(model.fake_score) |
|
|