fastgen-offline / FastGen /tests /test_sftmodel.py
taohu's picture
Upload folder using huggingface_hub
0839907 verified
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import gc
import torch
import pytest
from fastgen.methods import SFTModel
from fastgen.configs.methods.config_sft import ModelConfig
from fastgen.configs.config_utils import override_config_with_opts
@pytest.fixture
def get_model_data():
gc.collect() # https://github.com/pytest-dev/pytest/discussions/10387
instance = ModelConfig()
opts = ["-", "img_resolution=8", "channel_mult=[1]", "channel_mult_noise=1", "r_timestep=False"]
instance.net = override_config_with_opts(instance.net, opts)
instance.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
instance.precision = "float32" if instance.device == torch.device("cpu") else "bfloat16"
instance.pretrained_model_path = "" # disable ckpt loading
instance.input_shape = [3, 8, 8]
# SFT-specific configurations
instance.cond_dropout_prob = 0.1
instance.cond_keys_no_dropout = []
instance.guidance_scale = None
model = SFTModel(instance)
model.on_train_begin()
model.init_optimizers()
batch_size = 1
# Use one-hot encoded labels as EDM network expects them
labels = torch.randint(0, 10, (batch_size,))
labels = torch.nn.functional.one_hot(labels, num_classes=10).float()
neg_condition = torch.zeros(batch_size, 10)
# Create mock data
data = {
"real": torch.randn(batch_size, 3, 8, 8).to(model.device, model.precision),
"condition": labels.to(model.device, model.precision),
"neg_condition": neg_condition.to(model.device, model.precision),
}
return model, data
def test_single_train_step(get_model_data):
"""Test the single training step of SFT model."""
model, data = get_model_data
# Run the training step
loss_map, outputs = model.single_train_step(data, 0)
# Test loss_map structure
assert isinstance(loss_map, dict)
assert "total_loss" in loss_map
assert "dsm_loss" in loss_map
assert isinstance(loss_map["total_loss"], torch.Tensor)
assert isinstance(loss_map["dsm_loss"], torch.Tensor)
assert loss_map["total_loss"].detach().item() >= 0.0
assert loss_map["dsm_loss"].detach().item() >= 0.0
# For SFT, total_loss should equal dsm_loss
assert torch.allclose(loss_map["total_loss"], loss_map["dsm_loss"])
# Test outputs structure
assert isinstance(outputs, dict)
assert "gen_rand" in outputs
assert "input_rand" in outputs
assert isinstance(outputs["input_rand"], torch.Tensor)
assert outputs["input_rand"].shape == data["real"].shape
# gen_rand should be a callable (partial function)
assert callable(outputs["gen_rand"])
def test_mix_condition_tensor(get_model_data):
"""Test the _mix_condition method with tensor inputs."""
model, data = get_model_data
condition = data["condition"]
neg_condition = data["neg_condition"]
# Test with cond_dropout_prob = None (no dropout)
model.config.cond_dropout_prob = None
mixed_condition = model._mix_condition(condition, neg_condition)
assert torch.allclose(mixed_condition, condition)
# Test with cond_dropout_prob = 1.0 (full dropout)
model.config.cond_dropout_prob = 1.0
mixed_condition = model._mix_condition(condition, neg_condition)
assert torch.allclose(mixed_condition, neg_condition)
# Test with intermediate cond_dropout_prob
model.config.cond_dropout_prob = 0.5
mixed_condition = model._mix_condition(condition, neg_condition)
assert mixed_condition.shape == condition.shape
def test_mix_condition_dict(get_model_data):
"""Test the _mix_condition method with dictionary inputs."""
model, data = get_model_data
# Create dictionary conditions
dict_condition = {"text_embeds": data["condition"], "other_info": torch.ones_like(data["condition"])}
dict_neg_condition = {"text_embeds": data["neg_condition"], "other_info": torch.zeros_like(data["neg_condition"])}
# Test with cond_dropout_prob = None (no dropout)
model.config.cond_dropout_prob = None
mixed_condition = model._mix_condition(dict_condition, dict_neg_condition)
assert torch.allclose(mixed_condition["text_embeds"], dict_condition["text_embeds"])
assert torch.allclose(mixed_condition["other_info"], dict_condition["other_info"])
# Test with cond_dropout_prob = 1.0 (full dropout)
model.config.cond_dropout_prob = 1.0
mixed_condition = model._mix_condition(dict_condition, dict_neg_condition)
assert torch.allclose(mixed_condition["text_embeds"], dict_neg_condition["text_embeds"])
assert torch.allclose(mixed_condition["other_info"], dict_neg_condition["other_info"])
def test_mix_condition_with_no_dropout_keys(get_model_data):
"""Test the _mix_condition method with keys that should not be dropped."""
model, data = get_model_data
# Set cond_keys_no_dropout
model.config.cond_keys_no_dropout = {"other_info"}
# Create dictionary conditions
dict_condition = {"text_embeds": data["condition"], "other_info": torch.ones_like(data["condition"])}
dict_neg_condition = {"text_embeds": data["neg_condition"], "other_info": torch.zeros_like(data["neg_condition"])}
# Test with cond_dropout_prob = 1.0 (full dropout)
model.config.cond_dropout_prob = 1.0
mixed_condition = model._mix_condition(dict_condition, dict_neg_condition)
# text_embeds should be dropped (replaced with neg_condition)
assert torch.allclose(mixed_condition["text_embeds"], dict_neg_condition["text_embeds"])
# other_info should NOT be dropped (kept as original condition)
assert torch.allclose(mixed_condition["other_info"], dict_condition["other_info"])
def test_generator_fn(get_model_data):
"""Test the static generator_fn method."""
model, data = get_model_data
# Mock a network with sample method
class MockNet(torch.nn.Module):
def sample(self, noise, condition=None, neg_condition=None, guidance_scale=None, **kwargs):
return torch.randn_like(noise)
mock_net = MockNet()
noise = torch.randn_like(data["real"])
# Test generator function
result = SFTModel.generator_fn(
net=mock_net,
noise=noise,
condition=data["condition"],
neg_condition=data["neg_condition"],
guidance_scale=None,
)
assert isinstance(result, torch.Tensor)
assert result.shape == noise.shape
def test_optimizers(get_model_data):
"""Test optimizer functionality."""
model, data = get_model_data
# Test that optimizer operations run without errors
model.optimizers_zero_grad(0)
loss_map, _ = model.single_train_step(data, 0)
# Test that loss requires gradients and can be backpropagated
assert loss_map["total_loss"].requires_grad
model.grad_scaler.scale(loss_map["total_loss"]).backward()
# Test that gradients exist after backward pass
has_gradients = False
for param in model.net.parameters():
if param.requires_grad and param.grad is not None:
has_gradients = True
break
assert has_gradients, "No gradients found after backward pass"
# Test that optimizer step runs without errors
model.optimizers_schedulers_step(0)
def test_loss_computation(get_model_data):
"""Test that loss computation produces reasonable values."""
model, data = get_model_data
loss_map, outputs = model.single_train_step(data, 0)
# Test loss properties
total_loss = loss_map["total_loss"]
dsm_loss = loss_map["dsm_loss"]
assert total_loss.requires_grad
assert dsm_loss.requires_grad
assert not torch.isnan(total_loss)
assert not torch.isnan(dsm_loss)
assert not torch.isinf(total_loss)
assert not torch.isinf(dsm_loss)