# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 import copy import os import gc import tempfile import pytest import numpy as np import torch from omegaconf import DictConfig from torch.distributed.device_mesh import init_device_mesh from torch.distributed.fsdp import MixedPrecisionPolicy from torch.distributed.checkpoint.state_dict import ( StateDictOptions, set_model_state_dict, ) from fastgen.configs.methods.config_dmd2 import create_config from fastgen.configs.config_utils import override_config_with_opts from fastgen.configs.net import EDM2_IN64_S_Config from fastgen.methods import DMD2Model from fastgen.trainer import Trainer from fastgen.utils import instantiate from fastgen.utils.io_utils import set_env_vars from fastgen.configs.callbacks import ( CTSchedule_CALLBACK, GradClip_CALLBACK, ParamCount_CALLBACK, WANDB_CALLBACK, EMA_CALLBACK, TrainProfiler_CALLBACK, GPUStats_CALLBACK, ForcedWeightNorm_CALLBACK, ) from fastgen.callbacks.callback import CallbackDict from fastgen.utils.test_utils import RunIf, run_distributed_test @pytest.fixture def get_model_data(): gc.collect() # https://github.com/pytest-dev/pytest/discussions/10387 dmd_config = create_config() dmd_config.log_config.name = "test" instance = dmd_config.model opts = ["-", "img_resolution=8", "channel_mult=[1]", "channel_mult_noise=1"] instance.net = override_config_with_opts(instance.net, opts) opts_discriminator = ["-", "feature_indices=[0]", "all_res=[8]", "in_channels=128"] instance.discriminator = override_config_with_opts(instance.discriminator, opts_discriminator) instance.use_ema = True instance.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") instance.precision = "float32" if instance.device == torch.device("cpu") else "bfloat16" instance.pretrained_model_path = "" # disable ckpt loading instance.input_shape = [3, 8, 8] dmd_model = DMD2Model(instance) dmd_model.on_train_begin() dmd_model.init_optimizers() batch_size = 1 labels = torch.randint(0, 10, (batch_size,)) labels = torch.nn.functional.one_hot(labels, num_classes=10) neg_condition = torch.zeros(batch_size, 10) # Create mock data data = { "real": torch.randn(batch_size, 3, 8, 8).to(dmd_model.device, dmd_model.precision), "condition": labels.to(dmd_model.device, dmd_model.precision), "neg_condition": neg_condition.to(dmd_model.device, dmd_model.precision), } return dmd_model, data, dmd_config def test_ema_callback(get_model_data): """Test EMA callback basic functionality (non-FSDP mode).""" model, data, config = get_model_data for callback_name, callback_config in EMA_CALLBACK.items(): assert callback_name == "ema" assert model.ema is not None ema_callback = instantiate(callback_config) ema_callback.config = config # Call on_app_begin to initialize _is_fsdp flag (should be False for non-FSDP) ema_callback.on_app_begin() assert ema_callback._is_fsdp is False assert ema_callback.beta == 0.9999 assert ema_callback.type == "constant" assert ema_callback.gamma == 16.97 assert ema_callback.ema_halflife_kimg == 500 assert ema_callback.ema_rampup_ratio == 0.05 ema_callback.on_model_init_end(model) assert ema_callback._enabled is True # EMA should be initialized from net during model.build_model() ema_state = model.ema.state_dict() net_state = model.net.state_dict() assert all(torch.allclose(net_state[k], p_ema) for k, p_ema in ema_state.items()) assert not any(p_ema.requires_grad for p_ema in ema_state.values()) # Modify network parameters and compute expected EMA update buffers = [k for k, _ in model.net.named_buffers()] expected_ema_state = {} for k, p_net in net_state.items(): torch.nn.init.normal_(p_net) if k in buffers: expected_ema_state[k] = p_net.detach().clone() else: expected_ema_state[k] = torch.lerp(ema_state[k], p_net.detach(), 1.0 - ema_callback.beta) # Run EMA update step ema_callback.on_training_step_end( model, data_batch=None, output_batch=None, loss_dict=None, iteration=1, ) # Verify EMA was updated correctly new_ema_state = model.ema.state_dict() assert all(torch.allclose(expected_ema_state[k], p_ema) for k, p_ema in new_ema_state.items()) assert not any(p_ema.requires_grad for p_ema in new_ema_state.values()) # Test that EMA update is skipped when ema is None model.ema = None ema_callback.on_model_init_end(model) assert ema_callback._enabled is False ema_callback.on_training_step_end( model, data_batch=None, output_batch=None, loss_dict=None, iteration=1, ) assert model.ema is None def test_ema_initialization_after_build(get_model_data): """Test that EMA is correctly initialized from net state during model build.""" model, data, config = get_model_data # Verify EMA exists and matches net state assert model.ema is not None assert model.use_ema == ["ema"] ema_state = model.ema.state_dict() net_state = model.net.state_dict() # All EMA parameters should match net parameters exactly after initialization for k in net_state.keys(): assert k in ema_state, f"Key {k} not found in EMA state" assert torch.allclose(net_state[k], ema_state[k]), f"EMA state mismatch for {k}" # EMA should not require gradients assert not any(p.requires_grad for p in model.ema.parameters()) assert model.ema.training is False # EMA should be in eval mode def test_ema_callback_multiple_steps(get_model_data): """Test EMA callback over multiple training steps to verify accumulation.""" model, data, config = get_model_data ema_callback = instantiate(EMA_CALLBACK["ema"]) ema_callback.config = config ema_callback.on_app_begin() beta = ema_callback.beta # Store initial EMA state initial_ema_state = {k: v.clone() for k, v in model.ema.state_dict().items()} buffers = [k for k, _ in model.net.named_buffers()] # Run multiple EMA update steps for iteration in range(1, 5): # Modify network parameters for p in model.net.parameters(): torch.nn.init.normal_(p) # Update expected EMA net_state = model.net.state_dict() for k in initial_ema_state.keys(): if k in buffers: initial_ema_state[k] = net_state[k].clone() else: initial_ema_state[k].lerp_(net_state[k], 1.0 - beta) ema_callback.on_training_step_end( model, data_batch=None, output_batch=None, loss_dict=None, iteration=iteration, ) # Verify final EMA state final_ema_state = model.ema.state_dict() for k, expected in initial_ema_state.items(): assert torch.allclose(expected, final_ema_state[k], atol=1e-6), f"Mismatch at {k}" @RunIf(min_gpus=1) def test_ema_callback_fsdp_mode_mocked(get_model_data): """Test EMA callback FSDP mode behavior with mocked FSDP tensors. This test mocks the FSDP behavior by adding a `full_tensor()` method to parameters. In real FSDP, parameters are DTensors with `full_tensor()` that gathers from all ranks. """ model, data, config = get_model_data # Mock FSDP by adding full_tensor method to parameters # In real FSDP, this gathers the full tensor from all shards original_params = {} for name, param in model.net.named_parameters(): original_params[name] = param.data.clone() # Add a mock full_tensor method that returns the parameter itself param.full_tensor = lambda p=param: p.data.clone() ema_callback = instantiate(EMA_CALLBACK["ema"]) ema_callback.config = config # Simulate FSDP mode config.trainer.fsdp = True ema_callback.on_app_begin() assert ema_callback._is_fsdp is True # Get initial EMA state initial_ema_state = {k: v.clone() for k, v in model.ema.state_dict().items()} buffers = [k for k, _ in model.net.named_buffers()] # Modify network parameters for p in model.net.named_parameters(): torch.nn.init.normal_(p[1]) # Compute expected EMA update expected_ema_state = {} net_state = model.net.state_dict() for k in initial_ema_state.keys(): if k in buffers: expected_ema_state[k] = net_state[k].clone() else: expected_ema_state[k] = torch.lerp(initial_ema_state[k], net_state[k], 1.0 - ema_callback.beta) # Run EMA update (should use full_tensor() for FSDP) ema_callback.on_training_step_end( model, data_batch=None, output_batch=None, loss_dict=None, iteration=1, ) # Verify EMA was updated correctly final_ema_state = model.ema.state_dict() for k, expected in expected_ema_state.items(): assert torch.allclose(expected, final_ema_state[k], atol=1e-6), f"Mismatch at {k}" # Reset config config.trainer.fsdp = False # ============================================================================= # True FSDP EMA Test Implementation # ============================================================================= def _test_ema_callback_fsdp_distributed_impl(rank: int, world_size: int) -> dict: """Test EMA callback with real FSDP in a distributed setting using EDM model. This test uses the same EDM model architecture as other callback tests to ensure we're testing the actual model code paths. It verifies that: 1. EMA callback correctly gathers full tensors from FSDP-sharded parameters 2. EMA state remains consistent after update 3. Synchronization barriers work correctly Args: rank: Process rank world_size: Total number of processes Returns: dict with test results """ from fastgen.callbacks.ema import EMACallback from fastgen.configs.methods.config_dmd2 import create_config from fastgen.configs.config_utils import override_config_with_opts from fastgen.utils.distributed import synchronize, is_rank0 device_mesh = init_device_mesh("cuda", (world_size,)) device = torch.cuda.current_device() # Create EDM network using the same configuration as other tests # Use small resolution and simple architecture for fast testing dmd_config = create_config() instance = dmd_config.model opts = ["-", "img_resolution=8", "channel_mult=[1]", "channel_mult_noise=1"] instance.net = override_config_with_opts(instance.net, opts) instance.device = torch.device(f"cuda:{rank}") instance.precision = "float32" instance.pretrained_model_path = "" # disable ckpt loading # Instantiate the network (EDM architecture) net = instantiate(instance.net).to(device) # Get state dict before FSDP sharding for broadcast if is_rank0(): broadcast_state_dict = copy.deepcopy(net.state_dict()) else: broadcast_state_dict = None synchronize() # Apply FSDP sharding using the network's fully_shard method # This follows the same pattern as test_fsdp.py mp_policy = MixedPrecisionPolicy( param_dtype=torch.float32, reduce_dtype=torch.float32, output_dtype=torch.float32, cast_forward_inputs=True, ) net.fully_shard(mesh=device_mesh, mp_policy=mp_policy) # Materialize meta tensors and reset parameters (following test_fsdp.py pattern) net.model.to_empty(device=device) if hasattr(net, "reset_parameters"): net.reset_parameters() synchronize() # Broadcast state dict from rank 0 (following test_fsdp.py pattern) # Extract only the inner model's state dict since that's what's sharded if broadcast_state_dict is not None: inner_model_prefix = "model." inner_broadcast_state_dict = { k[len(inner_model_prefix) :]: v for k, v in broadcast_state_dict.items() if k.startswith(inner_model_prefix) } else: inner_broadcast_state_dict = None options = StateDictOptions( full_state_dict=True, broadcast_from_rank0=True, cpu_offload=False, ) set_model_state_dict(net.model, model_state_dict=inner_broadcast_state_dict, options=options) synchronize() # Create EMA model (matching production behavior) ema_init_state = {} for name, p in net.named_parameters(): if hasattr(p, "full_tensor"): # All ranks must participate in full_tensor() gather full_p = p.full_tensor().detach().clone() else: full_p = p.detach().clone() ema_init_state[name] = full_p for name, buf in net.named_buffers(): ema_init_state[name] = buf.detach().clone() # Create a fresh EDM network for EMA ema = instantiate(instance.net).to(device) ema.eval() for p in ema.parameters(): p.requires_grad = False ema.load_state_dict(ema_init_state) initial_ema_state = {k: v.clone() for k, v in ema.state_dict().items()} synchronize() # Create EMA callback and configure for FSDP mode ema_callback = EMACallback( type="constant", beta=0.9, # Use larger learning rate for visible updates fsdp=True, ) # Configure for FSDP mode ema_callback._is_fsdp = True # Modify network parameters (simulate training step) with torch.no_grad(): for p in net.parameters(): # The modification happens via the sharded parameter p.data.add_(torch.randn_like(p.data) * 0.1) synchronize() # Create a mock model object with net and ema attributes class MockModel: def __init__(self, net, ema): self.net = net self.ema = ema self.ema_enabled = True self.resume_iter = 0 mock_model = MockModel(net, ema) # Run EMA update ema_callback.on_training_step_end( mock_model, data_batch=None, output_batch=None, loss_dict=None, iteration=1, ) synchronize() # Verify results assert ema is not None, "EMA should exist" # Check that EMA was updated (should be different from initial) final_ema_state = ema.state_dict() # Check if EMA was updated (at least one param should differ) params_changed = 0 for k in initial_ema_state: if not torch.allclose(initial_ema_state[k], final_ema_state[k], atol=1e-6): params_changed += 1 results = { "ema_updated": True, "ema_different_from_initial": params_changed > 0, "params_changed": params_changed, "total_params": len(initial_ema_state), "ema_no_grad": not any(p.requires_grad for p in ema.parameters()), "model_type": "EDM", } return results @RunIf(min_gpus=2) def test_ema_callback_fsdp_distributed(): """Test EMA callback with real FSDP distributed training using EDM model. This test requires at least 2 GPUs and uses the actual EDM network architecture (same as other callback tests) to verify that the EMA callback correctly handles FSDP-sharded parameters by: 1. Gathering full tensors from all shards using full_tensor() 2. Performing EMA updates 3. Maintaining proper synchronization across ranks """ gc.collect() torch.cuda.empty_cache() result = run_distributed_test( test_fn=_test_ema_callback_fsdp_distributed_impl, world_size=2, timeout=180, # Slightly longer for model instantiation setup_fn=set_env_vars, ) assert result is not None, "Test did not return a result" assert result.get("model_type") == "EDM", "Test should use EDM model" assert result["ema_updated"], "EMA callback should have run without errors" assert result["ema_different_from_initial"], ( f"EMA should have been updated after training step. " f"Only {result.get('params_changed', 0)}/{result.get('total_params', 0)} params changed." ) assert result["ema_no_grad"], "EMA parameters should not require gradients" gc.collect() torch.cuda.empty_cache() # ============================================================================= # Non-Distributed Tests (continue below) # ============================================================================= def test_ema_checkpoint_save_load(get_model_data): """Test that EMA state is correctly saved and loaded from checkpoints.""" model, data, config = get_model_data # Initialize EMA callback and run a few updates ema_callback = instantiate(EMA_CALLBACK["ema"]) ema_callback.config = config ema_callback.on_app_begin() ema_callback.on_model_init_end(model) # Modify network and update EMA a few times for i in range(3): for p in model.net.parameters(): torch.nn.init.normal_(p) ema_callback.on_training_step_end( model, data_batch=None, output_batch=None, loss_dict=None, iteration=i + 1, ) # Store EMA state before saving ema_state_before = {k: v.clone() for k, v in model.ema.state_dict().items()} # Create a temporary directory for checkpoint with tempfile.TemporaryDirectory() as tmpdir: from fastgen.utils.checkpointer import Checkpointer from omegaconf import OmegaConf # Create checkpointer config ckpt_config = OmegaConf.create( { "save_dir": tmpdir, "use_s3": False, } ) checkpointer = Checkpointer(ckpt_config) # Save checkpoint checkpointer.save( model_dict=model.model_dict, optimizer_dict=None, scheduler_dict=None, grad_scaler=None, callbacks=None, path=os.path.join(tmpdir, "test_ema.pth"), iteration=100, ) # Verify checkpoint file exists assert os.path.exists(os.path.join(tmpdir, "test_ema.pth")) # Reset EMA state to verify loading works for k in model.ema.state_dict(): model.ema.state_dict()[k].zero_() # Verify EMA is zeroed for k, v in model.ema.state_dict().items(): assert torch.all(v == 0), f"EMA {k} should be zeroed" # Load checkpoint loaded_iter = checkpointer.load( model_dict=model.model_dict, optimizer_dict=None, scheduler_dict=None, grad_scaler=None, callbacks=None, path=os.path.join(tmpdir, "test_ema.pth"), ) assert loaded_iter == 100 # Verify EMA state was restored ema_state_after = model.ema.state_dict() for k, v_before in ema_state_before.items(): assert torch.allclose(v_before, ema_state_after[k]), f"EMA state mismatch for {k}" def test_ema_callback_beta_types(get_model_data): """Test EMA callback with different beta calculation types.""" model, data, config = get_model_data # Test power function beta ema_callback_power = instantiate(EMA_CALLBACK["ema"]) ema_callback_power.type = "power" ema_callback_power.config = config ema_callback_power.on_app_begin() # Power function should return beta = (1 - 1/iteration)^(gamma + 1) iteration = 10 expected_power_beta = (1 - 1 / iteration) ** (ema_callback_power.gamma + 1) actual_power_beta = ema_callback_power._power_function_beta(iteration) assert np.isclose(expected_power_beta, actual_power_beta) # Test halflife beta ema_callback_halflife = instantiate(EMA_CALLBACK["ema"]) ema_callback_halflife.type = "halflife" ema_callback_halflife.config = config ema_callback_halflife.on_app_begin() # Halflife beta should use the formula 0.5^(batch_size / ema_halflife_nimg) iteration = 100 halflife_beta = ema_callback_halflife._halflife_beta(iteration) assert 0 < halflife_beta < 1, f"Halflife beta should be between 0 and 1, got {halflife_beta}" def test_ct_schedule_callback(get_model_data): model, data, config = get_model_data for callback_name, callback_config in CTSchedule_CALLBACK.items(): assert callback_name == "ct_schedule" assert config.dataloader_train.batch_size == 256 ct_schedule_callback = instantiate(callback_config) ct_schedule_callback.config = config assert ct_schedule_callback.q == 2.0 assert ct_schedule_callback.ratio_limit == 0.999 assert ct_schedule_callback.kimg_per_stage == 12500 ct_schedule_callback.on_train_begin(model, iteration=0) assert np.isclose(ct_schedule_callback.stage, 0) assert model.ratio == 0.5 model.resume_iter = 100000 ct_schedule_callback.on_train_begin(model, iteration=0) assert np.isclose(ct_schedule_callback.stage, 2) assert model.ratio == 0.875 ct_schedule_callback.on_training_step_end( model, data_batch=None, output_batch=None, loss_dict=None, iteration=100000, ) assert np.isclose(ct_schedule_callback.stage, 4) assert np.isclose(model.ratio, 0.96875) def test_grad_clip_callback(get_model_data): model, data, config = get_model_data for callback_name, callback_config in GradClip_CALLBACK.items(): assert callback_name == "grad_clip" callback_config.grad_norm = 10.0 grad_clip_callback = instantiate(callback_config) grad_clip_callback.config = config assert grad_clip_callback.grad_norm == 10.0 assert grad_clip_callback.model_key == "net" grad_clip_callback.on_optimizer_step_begin(model) @RunIf(min_gpus=1) def test_gpu_stats_callback(get_model_data): model, data, config = get_model_data for callback_name, callback_config in GPUStats_CALLBACK.items(): assert callback_name == "gpu_stats" assert callback_config.every_n == 100 gpu_stats_callback = instantiate(callback_config) gpu_stats_callback.config = config assert gpu_stats_callback.every_n == 100 gpu_stats_callback.on_train_begin(model, iteration=0) gpu_stats_callback.on_training_step_end( model, data_batch=None, output_batch=None, loss_dict=None, iteration=0, ) def test_param_count_callback(get_model_data): model, data, config = get_model_data for callback_name, callback_config in ParamCount_CALLBACK.items(): assert callback_name == "param_count" param_count_callback = instantiate(callback_config) param_count_callback.config = config param_count_callback.on_train_begin(model) def test_train_profiler_callback(get_model_data): model, data, config = get_model_data for callback_name, callback_config in TrainProfiler_CALLBACK.items(): assert callback_name == "train_profiler" assert callback_config.every_n == 100 train_profiler_callback = instantiate(callback_config) train_profiler_callback.config = config assert train_profiler_callback.last_log_time is None assert train_profiler_callback.every_n == 100 train_profiler_callback.on_train_begin(model, iteration=0) assert train_profiler_callback.every_n == config.trainer.logging_iter train_profiler_callback.on_training_step_end( model, data_batch=None, output_batch=None, loss_dict=None, iteration=0, ) assert train_profiler_callback.last_log_time is not None def test_forced_weight_norm_callback(get_model_data): model, data, config = get_model_data for callback_name, callback_config in ForcedWeightNorm_CALLBACK.items(): assert callback_name == "forced_weight_norm" forced_weight_norm_callback = instantiate(callback_config) forced_weight_norm_callback.config = config forced_weight_norm_callback.on_training_accum_step_begin(model, data) net_config = EDM2_IN64_S_Config net_config = override_config_with_opts(net_config, ["-", "img_resolution=2", "channel_mult=[1]"]) net = instantiate(net_config) model.net = net forced_weight_norm_callback.on_training_accum_step_begin(model, data) def test_wandb_callback(get_model_data): model, data, config = get_model_data config.log_config.wandb_mode = "disabled" for callback_name, callback_config in WANDB_CALLBACK.items(): assert callback_name == "wandb" wandb_callback = instantiate(callback_config) wandb_callback.config = config if os.path.isfile(config.log_config.wandb_credential): wandb_callback.on_app_begin() else: with tempfile.NamedTemporaryFile(delete=True) as tmp_file: config.log_config.wandb_credential = tmp_file.name wandb_callback.on_app_begin() wandb_callback.on_optimizer_step_begin(model) def test_callback_list(get_model_data): model, data, config = get_model_data config.trainer.callbacks = DictConfig({**GradClip_CALLBACK, **ParamCount_CALLBACK}) config.trainer.callbacks.update({**ForcedWeightNorm_CALLBACK}) trainer = Trainer(config) callbacks = CallbackDict(config=config, trainer=trainer) assert len(callbacks._callbacks) == 3 callbacks.on_train_begin(model)