| |
| |
|
|
| import copy |
| import os |
| import gc |
| import tempfile |
|
|
| import pytest |
| import numpy as np |
| import torch |
| from omegaconf import DictConfig |
| from torch.distributed.device_mesh import init_device_mesh |
| from torch.distributed.fsdp import MixedPrecisionPolicy |
| from torch.distributed.checkpoint.state_dict import ( |
| StateDictOptions, |
| set_model_state_dict, |
| ) |
|
|
| from fastgen.configs.methods.config_dmd2 import create_config |
| from fastgen.configs.config_utils import override_config_with_opts |
| from fastgen.configs.net import EDM2_IN64_S_Config |
| from fastgen.methods import DMD2Model |
| from fastgen.trainer import Trainer |
| from fastgen.utils import instantiate |
| from fastgen.utils.io_utils import set_env_vars |
| from fastgen.configs.callbacks import ( |
| CTSchedule_CALLBACK, |
| GradClip_CALLBACK, |
| ParamCount_CALLBACK, |
| WANDB_CALLBACK, |
| EMA_CALLBACK, |
| TrainProfiler_CALLBACK, |
| GPUStats_CALLBACK, |
| ForcedWeightNorm_CALLBACK, |
| ) |
| from fastgen.callbacks.callback import CallbackDict |
| from fastgen.utils.test_utils import RunIf, run_distributed_test |
|
|
|
|
| @pytest.fixture |
| def get_model_data(): |
| gc.collect() |
| dmd_config = create_config() |
| dmd_config.log_config.name = "test" |
|
|
| instance = dmd_config.model |
| opts = ["-", "img_resolution=8", "channel_mult=[1]", "channel_mult_noise=1"] |
| instance.net = override_config_with_opts(instance.net, opts) |
| opts_discriminator = ["-", "feature_indices=[0]", "all_res=[8]", "in_channels=128"] |
| instance.discriminator = override_config_with_opts(instance.discriminator, opts_discriminator) |
| instance.use_ema = True |
| instance.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| instance.precision = "float32" if instance.device == torch.device("cpu") else "bfloat16" |
| instance.pretrained_model_path = "" |
| instance.input_shape = [3, 8, 8] |
|
|
| dmd_model = DMD2Model(instance) |
| dmd_model.on_train_begin() |
| dmd_model.init_optimizers() |
|
|
| batch_size = 1 |
| labels = torch.randint(0, 10, (batch_size,)) |
| labels = torch.nn.functional.one_hot(labels, num_classes=10) |
| neg_condition = torch.zeros(batch_size, 10) |
|
|
| |
| data = { |
| "real": torch.randn(batch_size, 3, 8, 8).to(dmd_model.device, dmd_model.precision), |
| "condition": labels.to(dmd_model.device, dmd_model.precision), |
| "neg_condition": neg_condition.to(dmd_model.device, dmd_model.precision), |
| } |
| return dmd_model, data, dmd_config |
|
|
|
|
| def test_ema_callback(get_model_data): |
| """Test EMA callback basic functionality (non-FSDP mode).""" |
| model, data, config = get_model_data |
|
|
| for callback_name, callback_config in EMA_CALLBACK.items(): |
| assert callback_name == "ema" |
| assert model.ema is not None |
|
|
| ema_callback = instantiate(callback_config) |
| ema_callback.config = config |
| |
| ema_callback.on_app_begin() |
| assert ema_callback._is_fsdp is False |
|
|
| assert ema_callback.beta == 0.9999 |
| assert ema_callback.type == "constant" |
| assert ema_callback.gamma == 16.97 |
| assert ema_callback.ema_halflife_kimg == 500 |
| assert ema_callback.ema_rampup_ratio == 0.05 |
|
|
| ema_callback.on_model_init_end(model) |
| assert ema_callback._enabled is True |
|
|
| |
| ema_state = model.ema.state_dict() |
| net_state = model.net.state_dict() |
| assert all(torch.allclose(net_state[k], p_ema) for k, p_ema in ema_state.items()) |
| assert not any(p_ema.requires_grad for p_ema in ema_state.values()) |
|
|
| |
| buffers = [k for k, _ in model.net.named_buffers()] |
| expected_ema_state = {} |
| for k, p_net in net_state.items(): |
| torch.nn.init.normal_(p_net) |
| if k in buffers: |
| expected_ema_state[k] = p_net.detach().clone() |
| else: |
| expected_ema_state[k] = torch.lerp(ema_state[k], p_net.detach(), 1.0 - ema_callback.beta) |
|
|
| |
| ema_callback.on_training_step_end( |
| model, |
| data_batch=None, |
| output_batch=None, |
| loss_dict=None, |
| iteration=1, |
| ) |
|
|
| |
| new_ema_state = model.ema.state_dict() |
| assert all(torch.allclose(expected_ema_state[k], p_ema) for k, p_ema in new_ema_state.items()) |
| assert not any(p_ema.requires_grad for p_ema in new_ema_state.values()) |
|
|
| |
| model.ema = None |
| ema_callback.on_model_init_end(model) |
| assert ema_callback._enabled is False |
| ema_callback.on_training_step_end( |
| model, |
| data_batch=None, |
| output_batch=None, |
| loss_dict=None, |
| iteration=1, |
| ) |
| assert model.ema is None |
|
|
|
|
| def test_ema_initialization_after_build(get_model_data): |
| """Test that EMA is correctly initialized from net state during model build.""" |
| model, data, config = get_model_data |
|
|
| |
| assert model.ema is not None |
| assert model.use_ema == ["ema"] |
|
|
| ema_state = model.ema.state_dict() |
| net_state = model.net.state_dict() |
|
|
| |
| for k in net_state.keys(): |
| assert k in ema_state, f"Key {k} not found in EMA state" |
| assert torch.allclose(net_state[k], ema_state[k]), f"EMA state mismatch for {k}" |
|
|
| |
| assert not any(p.requires_grad for p in model.ema.parameters()) |
| assert model.ema.training is False |
|
|
|
|
| def test_ema_callback_multiple_steps(get_model_data): |
| """Test EMA callback over multiple training steps to verify accumulation.""" |
| model, data, config = get_model_data |
|
|
| ema_callback = instantiate(EMA_CALLBACK["ema"]) |
| ema_callback.config = config |
| ema_callback.on_app_begin() |
|
|
| beta = ema_callback.beta |
|
|
| |
| initial_ema_state = {k: v.clone() for k, v in model.ema.state_dict().items()} |
| buffers = [k for k, _ in model.net.named_buffers()] |
|
|
| |
| for iteration in range(1, 5): |
| |
| for p in model.net.parameters(): |
| torch.nn.init.normal_(p) |
|
|
| |
| net_state = model.net.state_dict() |
| for k in initial_ema_state.keys(): |
| if k in buffers: |
| initial_ema_state[k] = net_state[k].clone() |
| else: |
| initial_ema_state[k].lerp_(net_state[k], 1.0 - beta) |
|
|
| ema_callback.on_training_step_end( |
| model, |
| data_batch=None, |
| output_batch=None, |
| loss_dict=None, |
| iteration=iteration, |
| ) |
|
|
| |
| final_ema_state = model.ema.state_dict() |
| for k, expected in initial_ema_state.items(): |
| assert torch.allclose(expected, final_ema_state[k], atol=1e-6), f"Mismatch at {k}" |
|
|
|
|
| @RunIf(min_gpus=1) |
| def test_ema_callback_fsdp_mode_mocked(get_model_data): |
| """Test EMA callback FSDP mode behavior with mocked FSDP tensors. |
| |
| This test mocks the FSDP behavior by adding a `full_tensor()` method to parameters. |
| In real FSDP, parameters are DTensors with `full_tensor()` that gathers from all ranks. |
| """ |
| model, data, config = get_model_data |
|
|
| |
| |
| original_params = {} |
| for name, param in model.net.named_parameters(): |
| original_params[name] = param.data.clone() |
| |
| param.full_tensor = lambda p=param: p.data.clone() |
|
|
| ema_callback = instantiate(EMA_CALLBACK["ema"]) |
| ema_callback.config = config |
| |
| config.trainer.fsdp = True |
| ema_callback.on_app_begin() |
| assert ema_callback._is_fsdp is True |
|
|
| |
| initial_ema_state = {k: v.clone() for k, v in model.ema.state_dict().items()} |
| buffers = [k for k, _ in model.net.named_buffers()] |
|
|
| |
| for p in model.net.named_parameters(): |
| torch.nn.init.normal_(p[1]) |
|
|
| |
| expected_ema_state = {} |
| net_state = model.net.state_dict() |
| for k in initial_ema_state.keys(): |
| if k in buffers: |
| expected_ema_state[k] = net_state[k].clone() |
| else: |
| expected_ema_state[k] = torch.lerp(initial_ema_state[k], net_state[k], 1.0 - ema_callback.beta) |
|
|
| |
| ema_callback.on_training_step_end( |
| model, |
| data_batch=None, |
| output_batch=None, |
| loss_dict=None, |
| iteration=1, |
| ) |
|
|
| |
| final_ema_state = model.ema.state_dict() |
| for k, expected in expected_ema_state.items(): |
| assert torch.allclose(expected, final_ema_state[k], atol=1e-6), f"Mismatch at {k}" |
|
|
| |
| config.trainer.fsdp = False |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _test_ema_callback_fsdp_distributed_impl(rank: int, world_size: int) -> dict: |
| """Test EMA callback with real FSDP in a distributed setting using EDM model. |
| |
| This test uses the same EDM model architecture as other callback tests to ensure |
| we're testing the actual model code paths. It verifies that: |
| 1. EMA callback correctly gathers full tensors from FSDP-sharded parameters |
| 2. EMA state remains consistent after update |
| 3. Synchronization barriers work correctly |
| |
| Args: |
| rank: Process rank |
| world_size: Total number of processes |
| |
| Returns: |
| dict with test results |
| """ |
| from fastgen.callbacks.ema import EMACallback |
| from fastgen.configs.methods.config_dmd2 import create_config |
| from fastgen.configs.config_utils import override_config_with_opts |
| from fastgen.utils.distributed import synchronize, is_rank0 |
|
|
| device_mesh = init_device_mesh("cuda", (world_size,)) |
| device = torch.cuda.current_device() |
|
|
| |
| |
| dmd_config = create_config() |
| instance = dmd_config.model |
| opts = ["-", "img_resolution=8", "channel_mult=[1]", "channel_mult_noise=1"] |
| instance.net = override_config_with_opts(instance.net, opts) |
| instance.device = torch.device(f"cuda:{rank}") |
| instance.precision = "float32" |
| instance.pretrained_model_path = "" |
|
|
| |
| net = instantiate(instance.net).to(device) |
|
|
| |
| if is_rank0(): |
| broadcast_state_dict = copy.deepcopy(net.state_dict()) |
| else: |
| broadcast_state_dict = None |
|
|
| synchronize() |
|
|
| |
| |
| mp_policy = MixedPrecisionPolicy( |
| param_dtype=torch.float32, |
| reduce_dtype=torch.float32, |
| output_dtype=torch.float32, |
| cast_forward_inputs=True, |
| ) |
| net.fully_shard(mesh=device_mesh, mp_policy=mp_policy) |
|
|
| |
| net.model.to_empty(device=device) |
| if hasattr(net, "reset_parameters"): |
| net.reset_parameters() |
| synchronize() |
|
|
| |
| |
| if broadcast_state_dict is not None: |
| inner_model_prefix = "model." |
| inner_broadcast_state_dict = { |
| k[len(inner_model_prefix) :]: v for k, v in broadcast_state_dict.items() if k.startswith(inner_model_prefix) |
| } |
| else: |
| inner_broadcast_state_dict = None |
|
|
| options = StateDictOptions( |
| full_state_dict=True, |
| broadcast_from_rank0=True, |
| cpu_offload=False, |
| ) |
| set_model_state_dict(net.model, model_state_dict=inner_broadcast_state_dict, options=options) |
| synchronize() |
|
|
| |
| ema_init_state = {} |
| for name, p in net.named_parameters(): |
| if hasattr(p, "full_tensor"): |
| |
| full_p = p.full_tensor().detach().clone() |
| else: |
| full_p = p.detach().clone() |
| ema_init_state[name] = full_p |
|
|
| for name, buf in net.named_buffers(): |
| ema_init_state[name] = buf.detach().clone() |
|
|
| |
| ema = instantiate(instance.net).to(device) |
| ema.eval() |
| for p in ema.parameters(): |
| p.requires_grad = False |
| ema.load_state_dict(ema_init_state) |
| initial_ema_state = {k: v.clone() for k, v in ema.state_dict().items()} |
|
|
| synchronize() |
|
|
| |
| ema_callback = EMACallback( |
| type="constant", |
| beta=0.9, |
| fsdp=True, |
| ) |
| |
| ema_callback._is_fsdp = True |
|
|
| |
| with torch.no_grad(): |
| for p in net.parameters(): |
| |
| p.data.add_(torch.randn_like(p.data) * 0.1) |
|
|
| synchronize() |
|
|
| |
| class MockModel: |
| def __init__(self, net, ema): |
| self.net = net |
| self.ema = ema |
| self.ema_enabled = True |
| self.resume_iter = 0 |
|
|
| mock_model = MockModel(net, ema) |
|
|
| |
| ema_callback.on_training_step_end( |
| mock_model, |
| data_batch=None, |
| output_batch=None, |
| loss_dict=None, |
| iteration=1, |
| ) |
|
|
| synchronize() |
|
|
| |
| assert ema is not None, "EMA should exist" |
|
|
| |
| final_ema_state = ema.state_dict() |
|
|
| |
| params_changed = 0 |
| for k in initial_ema_state: |
| if not torch.allclose(initial_ema_state[k], final_ema_state[k], atol=1e-6): |
| params_changed += 1 |
|
|
| results = { |
| "ema_updated": True, |
| "ema_different_from_initial": params_changed > 0, |
| "params_changed": params_changed, |
| "total_params": len(initial_ema_state), |
| "ema_no_grad": not any(p.requires_grad for p in ema.parameters()), |
| "model_type": "EDM", |
| } |
| return results |
|
|
|
|
| @RunIf(min_gpus=2) |
| def test_ema_callback_fsdp_distributed(): |
| """Test EMA callback with real FSDP distributed training using EDM model. |
| |
| This test requires at least 2 GPUs and uses the actual EDM network architecture |
| (same as other callback tests) to verify that the EMA callback correctly handles |
| FSDP-sharded parameters by: |
| 1. Gathering full tensors from all shards using full_tensor() |
| 2. Performing EMA updates |
| 3. Maintaining proper synchronization across ranks |
| """ |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| result = run_distributed_test( |
| test_fn=_test_ema_callback_fsdp_distributed_impl, |
| world_size=2, |
| timeout=180, |
| setup_fn=set_env_vars, |
| ) |
|
|
| assert result is not None, "Test did not return a result" |
| assert result.get("model_type") == "EDM", "Test should use EDM model" |
| assert result["ema_updated"], "EMA callback should have run without errors" |
| assert result["ema_different_from_initial"], ( |
| f"EMA should have been updated after training step. " |
| f"Only {result.get('params_changed', 0)}/{result.get('total_params', 0)} params changed." |
| ) |
| assert result["ema_no_grad"], "EMA parameters should not require gradients" |
|
|
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
|
|
| |
| |
| |
|
|
|
|
| def test_ema_checkpoint_save_load(get_model_data): |
| """Test that EMA state is correctly saved and loaded from checkpoints.""" |
| model, data, config = get_model_data |
|
|
| |
| ema_callback = instantiate(EMA_CALLBACK["ema"]) |
| ema_callback.config = config |
| ema_callback.on_app_begin() |
| ema_callback.on_model_init_end(model) |
|
|
| |
| for i in range(3): |
| for p in model.net.parameters(): |
| torch.nn.init.normal_(p) |
| ema_callback.on_training_step_end( |
| model, |
| data_batch=None, |
| output_batch=None, |
| loss_dict=None, |
| iteration=i + 1, |
| ) |
|
|
| |
| ema_state_before = {k: v.clone() for k, v in model.ema.state_dict().items()} |
|
|
| |
| with tempfile.TemporaryDirectory() as tmpdir: |
| from fastgen.utils.checkpointer import Checkpointer |
| from omegaconf import OmegaConf |
|
|
| |
| ckpt_config = OmegaConf.create( |
| { |
| "save_dir": tmpdir, |
| "use_s3": False, |
| } |
| ) |
| checkpointer = Checkpointer(ckpt_config) |
|
|
| |
| checkpointer.save( |
| model_dict=model.model_dict, |
| optimizer_dict=None, |
| scheduler_dict=None, |
| grad_scaler=None, |
| callbacks=None, |
| path=os.path.join(tmpdir, "test_ema.pth"), |
| iteration=100, |
| ) |
|
|
| |
| assert os.path.exists(os.path.join(tmpdir, "test_ema.pth")) |
|
|
| |
| for k in model.ema.state_dict(): |
| model.ema.state_dict()[k].zero_() |
|
|
| |
| for k, v in model.ema.state_dict().items(): |
| assert torch.all(v == 0), f"EMA {k} should be zeroed" |
|
|
| |
| loaded_iter = checkpointer.load( |
| model_dict=model.model_dict, |
| optimizer_dict=None, |
| scheduler_dict=None, |
| grad_scaler=None, |
| callbacks=None, |
| path=os.path.join(tmpdir, "test_ema.pth"), |
| ) |
|
|
| assert loaded_iter == 100 |
|
|
| |
| ema_state_after = model.ema.state_dict() |
| for k, v_before in ema_state_before.items(): |
| assert torch.allclose(v_before, ema_state_after[k]), f"EMA state mismatch for {k}" |
|
|
|
|
| def test_ema_callback_beta_types(get_model_data): |
| """Test EMA callback with different beta calculation types.""" |
| model, data, config = get_model_data |
|
|
| |
| ema_callback_power = instantiate(EMA_CALLBACK["ema"]) |
| ema_callback_power.type = "power" |
| ema_callback_power.config = config |
| ema_callback_power.on_app_begin() |
|
|
| |
| iteration = 10 |
| expected_power_beta = (1 - 1 / iteration) ** (ema_callback_power.gamma + 1) |
| actual_power_beta = ema_callback_power._power_function_beta(iteration) |
| assert np.isclose(expected_power_beta, actual_power_beta) |
|
|
| |
| ema_callback_halflife = instantiate(EMA_CALLBACK["ema"]) |
| ema_callback_halflife.type = "halflife" |
| ema_callback_halflife.config = config |
| ema_callback_halflife.on_app_begin() |
|
|
| |
| iteration = 100 |
| halflife_beta = ema_callback_halflife._halflife_beta(iteration) |
| assert 0 < halflife_beta < 1, f"Halflife beta should be between 0 and 1, got {halflife_beta}" |
|
|
|
|
| def test_ct_schedule_callback(get_model_data): |
| model, data, config = get_model_data |
|
|
| for callback_name, callback_config in CTSchedule_CALLBACK.items(): |
| assert callback_name == "ct_schedule" |
| assert config.dataloader_train.batch_size == 256 |
|
|
| ct_schedule_callback = instantiate(callback_config) |
| ct_schedule_callback.config = config |
|
|
| assert ct_schedule_callback.q == 2.0 |
| assert ct_schedule_callback.ratio_limit == 0.999 |
| assert ct_schedule_callback.kimg_per_stage == 12500 |
|
|
| ct_schedule_callback.on_train_begin(model, iteration=0) |
| assert np.isclose(ct_schedule_callback.stage, 0) |
| assert model.ratio == 0.5 |
|
|
| model.resume_iter = 100000 |
| ct_schedule_callback.on_train_begin(model, iteration=0) |
| assert np.isclose(ct_schedule_callback.stage, 2) |
| assert model.ratio == 0.875 |
|
|
| ct_schedule_callback.on_training_step_end( |
| model, |
| data_batch=None, |
| output_batch=None, |
| loss_dict=None, |
| iteration=100000, |
| ) |
|
|
| assert np.isclose(ct_schedule_callback.stage, 4) |
| assert np.isclose(model.ratio, 0.96875) |
|
|
|
|
| def test_grad_clip_callback(get_model_data): |
| model, data, config = get_model_data |
| for callback_name, callback_config in GradClip_CALLBACK.items(): |
| assert callback_name == "grad_clip" |
| callback_config.grad_norm = 10.0 |
|
|
| grad_clip_callback = instantiate(callback_config) |
| grad_clip_callback.config = config |
|
|
| assert grad_clip_callback.grad_norm == 10.0 |
| assert grad_clip_callback.model_key == "net" |
| grad_clip_callback.on_optimizer_step_begin(model) |
|
|
|
|
| @RunIf(min_gpus=1) |
| def test_gpu_stats_callback(get_model_data): |
| model, data, config = get_model_data |
| for callback_name, callback_config in GPUStats_CALLBACK.items(): |
| assert callback_name == "gpu_stats" |
| assert callback_config.every_n == 100 |
|
|
| gpu_stats_callback = instantiate(callback_config) |
| gpu_stats_callback.config = config |
| assert gpu_stats_callback.every_n == 100 |
|
|
| gpu_stats_callback.on_train_begin(model, iteration=0) |
| gpu_stats_callback.on_training_step_end( |
| model, |
| data_batch=None, |
| output_batch=None, |
| loss_dict=None, |
| iteration=0, |
| ) |
|
|
|
|
| def test_param_count_callback(get_model_data): |
| model, data, config = get_model_data |
| for callback_name, callback_config in ParamCount_CALLBACK.items(): |
| assert callback_name == "param_count" |
| param_count_callback = instantiate(callback_config) |
| param_count_callback.config = config |
| param_count_callback.on_train_begin(model) |
|
|
|
|
| def test_train_profiler_callback(get_model_data): |
| model, data, config = get_model_data |
| for callback_name, callback_config in TrainProfiler_CALLBACK.items(): |
| assert callback_name == "train_profiler" |
| assert callback_config.every_n == 100 |
|
|
| train_profiler_callback = instantiate(callback_config) |
| train_profiler_callback.config = config |
| assert train_profiler_callback.last_log_time is None |
| assert train_profiler_callback.every_n == 100 |
|
|
| train_profiler_callback.on_train_begin(model, iteration=0) |
| assert train_profiler_callback.every_n == config.trainer.logging_iter |
| train_profiler_callback.on_training_step_end( |
| model, |
| data_batch=None, |
| output_batch=None, |
| loss_dict=None, |
| iteration=0, |
| ) |
| assert train_profiler_callback.last_log_time is not None |
|
|
|
|
| def test_forced_weight_norm_callback(get_model_data): |
| model, data, config = get_model_data |
| for callback_name, callback_config in ForcedWeightNorm_CALLBACK.items(): |
| assert callback_name == "forced_weight_norm" |
| forced_weight_norm_callback = instantiate(callback_config) |
| forced_weight_norm_callback.config = config |
| forced_weight_norm_callback.on_training_accum_step_begin(model, data) |
|
|
| net_config = EDM2_IN64_S_Config |
| net_config = override_config_with_opts(net_config, ["-", "img_resolution=2", "channel_mult=[1]"]) |
| net = instantiate(net_config) |
| model.net = net |
|
|
| forced_weight_norm_callback.on_training_accum_step_begin(model, data) |
|
|
|
|
| def test_wandb_callback(get_model_data): |
| model, data, config = get_model_data |
| config.log_config.wandb_mode = "disabled" |
| for callback_name, callback_config in WANDB_CALLBACK.items(): |
| assert callback_name == "wandb" |
| wandb_callback = instantiate(callback_config) |
| wandb_callback.config = config |
|
|
| if os.path.isfile(config.log_config.wandb_credential): |
| wandb_callback.on_app_begin() |
| else: |
| with tempfile.NamedTemporaryFile(delete=True) as tmp_file: |
| config.log_config.wandb_credential = tmp_file.name |
| wandb_callback.on_app_begin() |
|
|
| wandb_callback.on_optimizer_step_begin(model) |
|
|
|
|
| def test_callback_list(get_model_data): |
| model, data, config = get_model_data |
| config.trainer.callbacks = DictConfig({**GradClip_CALLBACK, **ParamCount_CALLBACK}) |
| config.trainer.callbacks.update({**ForcedWeightNorm_CALLBACK}) |
|
|
| trainer = Trainer(config) |
| callbacks = CallbackDict(config=config, trainer=trainer) |
| assert len(callbacks._callbacks) == 3 |
| callbacks.on_train_begin(model) |
|
|