fastgen-offline / FastGen /tests /test_dmd2model.py
taohu's picture
Upload folder using huggingface_hub
0839907 verified
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import gc
import torch
import pytest
from fastgen.methods import DMD2Model
from fastgen.configs.methods.config_dmd2 import ModelConfig
from fastgen.configs.config_utils import override_config_with_opts
from fastgen.utils.test_utils import check_grad_zero
@pytest.fixture
def get_model_data():
gc.collect() # https://github.com/pytest-dev/pytest/discussions/10387
instance = ModelConfig()
opts = ["-", "img_resolution=8", "channel_mult=[1]", "channel_mult_noise=1"]
instance.net = override_config_with_opts(instance.net, opts)
opts_discriminator = ["-", "feature_indices=[0]", "all_res=[8]", "in_channels=128"]
instance.discriminator = override_config_with_opts(instance.discriminator, opts_discriminator)
instance.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
instance.precision = "float32" if instance.device == torch.device("cpu") else "bfloat16"
instance.pretrained_model_path = "" # disable ckpt loading
instance.student_update_freq = 2
instance.input_shape = [3, 8, 8]
model = DMD2Model(instance)
model.on_train_begin()
model.init_optimizers()
batch_size = 1
labels = torch.randint(0, 10, (batch_size,))
labels = torch.nn.functional.one_hot(labels, num_classes=10)
neg_condition = torch.zeros(batch_size, 10)
# Create mock data
data = {
"real": torch.randn(batch_size, 3, 8, 8).to(model.device, model.precision),
"condition": labels.to(model.device, model.precision),
"neg_condition": neg_condition.to(model.device, model.precision),
}
return model, data
def test_single_train_step_student_update(get_model_data):
model, data = get_model_data
# Run the training step
loss_map, outputs = model.single_train_step(data, 0)
# Assertions
assert "total_loss" in loss_map
assert "vsd_loss" in loss_map
assert "gan_loss_gen" in loss_map
assert "gen_rand" in outputs
assert "input_rand" in outputs
assert isinstance(outputs["gen_rand"], torch.Tensor)
assert isinstance(outputs["input_rand"], torch.Tensor)
def test_single_train_step_fake_score_update(get_model_data):
model, data = get_model_data
model.precision = torch.float32
model.on_train_begin()
for k, v in data.items():
if isinstance(v, torch.Tensor):
data[k] = v.to(model.precision)
# Run the training step
loss_map, outputs = model.single_train_step(data, 1)
# Assertions
assert "fake_score_loss" in loss_map
assert "gan_loss_disc" in loss_map
assert "gen_rand" in outputs
assert "input_rand" in outputs
assert isinstance(outputs["gen_rand"], torch.Tensor)
assert isinstance(outputs["input_rand"], torch.Tensor)
assert loss_map["total_loss"].dtype == torch.float32
def test_optimizers(get_model_data):
model, data = get_model_data
# Test for both student and fake_score optimizer
for iteration in range(2):
model.optimizers_zero_grad(iteration)
loss_map, _ = model.single_train_step(data, iteration)
model.grad_scaler.scale(loss_map["total_loss"]).backward()
model.optimizers_schedulers_step(iteration)
# Test for both student and fake_score optimizer zero grad, after at least one backward pass
model.optimizers_zero_grad(2)
check_grad_zero(model.net)
model.optimizers_zero_grad(3)
check_grad_zero(model.discriminator)
check_grad_zero(model.fake_score)