| |
| |
| """ |
| Unit tests for FSDP2 weight synchronization and forward pass correctness. |
| |
| These tests verify: |
| 1. FSDP weight synchronization from rank 0 to all ranks |
| 2. Buffer synchronization across ranks (e.g., RoPE freqs_cos/freqs_sin) |
| 3. Forward pass consistency between FSDP and reference models |
| 4. State dict gathering and comparison |
| 5. Meta device initialization for memory-efficient loading (fsdp_meta_init) |
| 6. End-to-end model_to_fsdp code path with real model instantiation |
| 7. Standard FSDP path without meta init (all ranks load independently) |
| 8. Multi-network models (DMD2 with net, teacher, fake_score) |
| |
| Tests require at least 2 GPUs and are run using torch.multiprocessing.spawn |
| to simulate distributed training. |
| |
| Usage: |
| pytest tests/test_fsdp.py -v |
| |
| # Or run with specific tests: |
| pytest tests/test_fsdp.py::test_fsdp_weight_sync -v |
| pytest tests/test_fsdp.py::test_fsdp_meta_device_init -v |
| pytest tests/test_fsdp.py::test_model_to_fsdp_code_path -v |
| pytest tests/test_fsdp.py::test_model_to_fsdp_no_meta_init -v |
| pytest tests/test_fsdp.py::test_dmd2_model_to_fsdp -v |
| """ |
|
|
| import contextlib |
| import copy |
| from typing import Dict |
|
|
| import torch |
| import torch.distributed as dist |
| from torch.distributed.device_mesh import init_device_mesh |
| from torch.distributed.fsdp import MixedPrecisionPolicy |
| from torch.distributed.checkpoint.state_dict import ( |
| StateDictOptions, |
| get_model_state_dict, |
| set_model_state_dict, |
| ) |
|
|
| from fastgen.utils import instantiate |
| from fastgen.utils.test_utils import RunIf, run_distributed_test |
| from fastgen.utils.basic_utils import clear_gpu_memory |
| from fastgen.utils.io_utils import set_env_vars |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _get_meta_init_context(fsdp_meta_init: bool): |
| """Get context manager for FSDP meta initialization. |
| |
| When fsdp_meta_init is enabled, non-rank-0 processes use meta device |
| for memory-efficient loading. Rank 0 loads weights normally, then |
| FSDP syncs weights to other ranks via sync_module_states. |
| |
| Args: |
| fsdp_meta_init: Whether to use meta initialization. |
| |
| Returns: |
| Context manager (meta device for non-rank-0, nullcontext otherwise) |
| """ |
| rank = dist.get_rank() if dist.is_initialized() else 0 |
| use_meta = fsdp_meta_init and rank != 0 |
| if use_meta: |
| return torch.device("meta") |
| return contextlib.nullcontext() |
|
|
|
|
| def create_edm_network(fsdp_meta_init: bool = False, apply_checkpointing: bool = False): |
| """Create an EDM CIFAR10 network using the internal library config. |
| |
| This is a small, fast network suitable for testing FSDP functionality |
| without the overhead of loading large models like Wan. |
| |
| Args: |
| fsdp_meta_init: If True, non-rank-0 processes use meta device |
| apply_checkpointing: If True, enable gradient checkpointing (unused for EDM) |
| |
| Returns: |
| Instantiated EDM network |
| """ |
| from fastgen.configs.net import EDM_CIFAR10_Config |
|
|
| |
| config = copy.deepcopy(EDM_CIFAR10_Config) |
| |
|
|
| |
| with _get_meta_init_context(fsdp_meta_init): |
| network = instantiate(config) |
|
|
| return network |
|
|
|
|
| def apply_fsdp_to_network(network, device_mesh): |
| """Apply FSDP sharding to a FastGenNetwork. |
| |
| Uses the network's built-in fully_shard method. |
| """ |
| mp_policy = MixedPrecisionPolicy( |
| param_dtype=torch.float32, |
| reduce_dtype=torch.float32, |
| output_dtype=torch.float32, |
| cast_forward_inputs=True, |
| ) |
|
|
| |
| network.fully_shard(mesh=device_mesh, mp_policy=mp_policy) |
|
|
|
|
| def gather_fsdp_state_dict(model) -> Dict[str, torch.Tensor]: |
| """Gather full state dict from FSDP model.""" |
| options = StateDictOptions( |
| full_state_dict=True, |
| cpu_offload=True, |
| ) |
| return get_model_state_dict(model, options=options) |
|
|
|
|
| def generate_dummy_inputs( |
| batch_size: int = 2, |
| img_resolution: int = 32, |
| img_channels: int = 3, |
| label_dim: int = 10, |
| device: torch.device = None, |
| dtype: torch.dtype = torch.float32, |
| ): |
| """Generate dummy input data for EDM CIFAR10 model forward pass.""" |
| if device is None: |
| device = torch.cuda.current_device() |
|
|
| generator = torch.Generator(device="cpu").manual_seed(42) |
|
|
| |
| x = torch.randn(batch_size, img_channels, img_resolution, img_resolution, generator=generator, dtype=dtype).to( |
| device |
| ) |
|
|
| |
| timestep = torch.tensor([0.5] * batch_size, dtype=dtype, device=device) |
|
|
| |
| labels = torch.zeros(batch_size, label_dim, dtype=dtype, device=device) |
| labels[:, 0] = 1.0 |
|
|
| return { |
| "x_t": x, |
| "t": timestep, |
| "condition": labels, |
| } |
|
|
|
|
| def load_reference_network(): |
| """Load a fresh reference network on CPU for comparison. |
| |
| Note: This loads a fresh network with NEW random weights for any |
| randomly-initialized layers. For tests that need to compare against |
| the exact weights that were synced via FSDP, use the original |
| broadcast_state_dict instead. |
| """ |
| from fastgen.configs.net import EDM_CIFAR10_Config |
|
|
| config = copy.deepcopy(EDM_CIFAR10_Config) |
|
|
| |
| network = instantiate(config) |
|
|
| return network |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _test_fsdp_weight_sync_impl( |
| rank: int, |
| world_size: int, |
| apply_checkpointing: bool = False, |
| ) -> Dict: |
| """Test that FSDP weights are correctly synced from rank 0 to all ranks. |
| |
| This test: |
| 1. Rank 0 loads full model, other ranks use meta device |
| 2. Apply FSDP sharding with state dict broadcast |
| 3. Gather state dict back and compare to the ORIGINAL state dict from rank 0 |
| |
| This verifies that ALL weights (including randomly-initialized ones like |
| logvar_linear) are correctly synced via FSDP2's set_model_state_dict. |
| """ |
| device_mesh = init_device_mesh("cuda", (world_size,)) |
|
|
| |
| network = create_edm_network(fsdp_meta_init=True, apply_checkpointing=apply_checkpointing) |
|
|
| |
| |
| |
| |
| if rank == 0: |
| broadcast_state_dict = copy.deepcopy(network.state_dict()) |
| |
| reference_state_dict = {k: v.cpu().clone() for k, v in broadcast_state_dict.items()} |
| else: |
| broadcast_state_dict = None |
| reference_state_dict = None |
|
|
| dist.barrier() |
|
|
| |
| apply_fsdp_to_network(network, device_mesh) |
|
|
| |
| network.to_empty(device=torch.cuda.current_device()) |
| network.reset_parameters() |
| dist.barrier() |
|
|
| |
| |
| options = StateDictOptions( |
| full_state_dict=True, |
| broadcast_from_rank0=True, |
| cpu_offload=False, |
| ) |
| set_model_state_dict(network, model_state_dict=broadcast_state_dict, options=options) |
| dist.barrier() |
|
|
| |
| gathered_state_dict = gather_fsdp_state_dict(network) |
| dist.barrier() |
|
|
| |
| result = {"weights_match": True, "mismatches": [], "num_keys_compared": 0} |
| if rank == 0: |
| |
| |
| reference_keys = set(reference_state_dict.keys()) |
| gathered_keys = set(gathered_state_dict.keys()) |
| common_keys = reference_keys & gathered_keys |
|
|
| mismatches = [] |
| for key in common_keys: |
| |
| reference_tensor = reference_state_dict[key].float() |
| gathered_tensor = gathered_state_dict[key].cpu().float() |
|
|
| if reference_tensor.shape != gathered_tensor.shape: |
| mismatches.append((key, "shape_mismatch", str(reference_tensor.shape), str(gathered_tensor.shape))) |
| continue |
|
|
| max_diff = (reference_tensor - gathered_tensor).abs().max().item() |
| if max_diff > 1e-5: |
| mismatches.append((key, "value_mismatch", max_diff)) |
|
|
| |
| missing_keys = reference_keys - gathered_keys |
| extra_keys = gathered_keys - reference_keys |
| for key in missing_keys: |
| mismatches.append((key, "missing_in_gathered")) |
| for key in extra_keys: |
| mismatches.append((key, "extra_in_gathered")) |
|
|
| result["weights_match"] = len(mismatches) == 0 |
| result["mismatches"] = mismatches |
| result["num_keys_compared"] = len(common_keys) |
| result["missing_keys"] = len(missing_keys) |
| result["extra_keys"] = len(extra_keys) |
|
|
| torch.cuda.empty_cache() |
|
|
| return result |
|
|
|
|
| def _test_fsdp_tensor_sharding_impl( |
| rank: int, |
| world_size: int, |
| ) -> Dict: |
| """Test that FSDP actually shards tensors across ranks. |
| |
| This test verifies that: |
| 1. After FSDP sharding, local tensor shapes differ from full tensor shapes |
| 2. The total number of local elements is less than full elements |
| 3. When gathered, the full shapes are restored |
| |
| This is important to verify that FSDP is actually sharding the model |
| and not just wrapping it without sharding. |
| """ |
| from torch.distributed.tensor import DTensor |
|
|
| device_mesh = init_device_mesh("cuda", (world_size,)) |
|
|
| |
| network = create_edm_network(fsdp_meta_init=True, apply_checkpointing=False) |
|
|
| if rank == 0: |
| broadcast_state_dict = copy.deepcopy(network.state_dict()) |
| |
| original_shapes = {k: v.shape for k, v in broadcast_state_dict.items()} |
| else: |
| broadcast_state_dict = None |
| original_shapes = None |
|
|
| dist.barrier() |
|
|
| |
| apply_fsdp_to_network(network, device_mesh) |
| network.to_empty(device=torch.cuda.current_device()) |
| network.reset_parameters() |
| dist.barrier() |
|
|
| options = StateDictOptions( |
| full_state_dict=True, |
| broadcast_from_rank0=True, |
| cpu_offload=False, |
| ) |
| set_model_state_dict(network, model_state_dict=broadcast_state_dict, options=options) |
| dist.barrier() |
|
|
| |
| dtensor_params = [] |
| local_vs_full_diffs = [] |
| total_local_numel = 0 |
| total_full_numel = 0 |
|
|
| for name, param in network.named_parameters(): |
| if isinstance(param, DTensor): |
| dtensor_params.append(name) |
| local_shape = tuple(param._local_tensor.shape) |
| full_shape = tuple(param.shape) |
| local_numel = param._local_tensor.numel() |
| full_numel = param.numel() |
|
|
| total_local_numel += local_numel |
| total_full_numel += full_numel |
|
|
| |
| if local_shape != full_shape: |
| local_vs_full_diffs.append( |
| { |
| "name": name, |
| "local_shape": local_shape, |
| "full_shape": full_shape, |
| "local_numel": local_numel, |
| "full_numel": full_numel, |
| } |
| ) |
|
|
| |
| gathered_state_dict = gather_fsdp_state_dict(network) |
| dist.barrier() |
|
|
| result = { |
| "rank": rank, |
| "world_size": world_size, |
| "num_dtensor_params": len(dtensor_params), |
| "num_params_with_shape_diff": len(local_vs_full_diffs), |
| "total_local_numel": total_local_numel, |
| "total_full_numel": total_full_numel, |
| "local_vs_full_diffs": local_vs_full_diffs[:5], |
| } |
|
|
| |
| |
| has_sharded_params = len(local_vs_full_diffs) > 0 |
| result["has_sharded_params"] = has_sharded_params |
|
|
| |
| |
| if total_full_numel > 0: |
| sharding_ratio = total_local_numel / total_full_numel |
| expected_ratio = 1.0 / world_size |
| |
| ratio_is_reasonable = sharding_ratio < 0.8 |
| result["sharding_ratio"] = sharding_ratio |
| result["expected_ratio"] = expected_ratio |
| result["ratio_is_reasonable"] = ratio_is_reasonable |
| else: |
| result["ratio_is_reasonable"] = False |
| result["sharding_ratio"] = None |
|
|
| |
| if rank == 0 and original_shapes is not None: |
| shapes_match = True |
| shape_mismatches = [] |
| for key in gathered_state_dict: |
| if key in original_shapes: |
| gathered_shape = tuple(gathered_state_dict[key].shape) |
| orig_shape = tuple(original_shapes[key]) |
| if gathered_shape != orig_shape: |
| shapes_match = False |
| shape_mismatches.append((key, orig_shape, gathered_shape)) |
| result["gathered_shapes_match"] = shapes_match |
| result["shape_mismatches"] = shape_mismatches[:5] |
| else: |
| result["gathered_shapes_match"] = True |
| result["shape_mismatches"] = [] |
|
|
| result["all_passed"] = ( |
| has_sharded_params and result.get("ratio_is_reasonable", False) and result.get("gathered_shapes_match", False) |
| ) |
|
|
| return result |
|
|
|
|
| def _test_fsdp_buffer_sync_impl( |
| rank: int, |
| world_size: int, |
| ) -> Dict: |
| """Test that FSDP buffers are correctly synced across ranks. |
| |
| This is important for non-persistent buffers like RoPE freqs_cos/freqs_sin |
| which are NOT included in state_dict and need the reset_parameters() call. |
| """ |
| device_mesh = init_device_mesh("cuda", (world_size,)) |
|
|
| |
| network = create_edm_network(fsdp_meta_init=True, apply_checkpointing=False) |
|
|
| if rank == 0: |
| broadcast_state_dict = copy.deepcopy(network.state_dict()) |
| else: |
| broadcast_state_dict = None |
|
|
| dist.barrier() |
|
|
| |
| apply_fsdp_to_network(network, device_mesh) |
|
|
| |
| network.to_empty(device=torch.cuda.current_device()) |
| network.reset_parameters() |
| dist.barrier() |
|
|
| options = StateDictOptions( |
| full_state_dict=True, |
| broadcast_from_rank0=True, |
| cpu_offload=False, |
| ) |
| set_model_state_dict(network, model_state_dict=broadcast_state_dict, options=options) |
| dist.barrier() |
|
|
| |
| buffer_max_diff = 0.0 |
| buffer_issues = [] |
|
|
| for name, buffer in network.named_buffers(): |
| |
| rank0_buffer = buffer.clone() |
| dist.broadcast(rank0_buffer, src=0) |
|
|
| |
| local_vs_rank0_diff = (buffer - rank0_buffer).abs().max().item() |
|
|
| if local_vs_rank0_diff > buffer_max_diff: |
| buffer_max_diff = local_vs_rank0_diff |
| if local_vs_rank0_diff > 1e-6: |
| buffer_issues.append((name, local_vs_rank0_diff)) |
|
|
| |
| buffer_max_diff_tensor = torch.tensor([buffer_max_diff], device=torch.cuda.current_device()) |
| dist.all_reduce(buffer_max_diff_tensor, op=dist.ReduceOp.MAX) |
| global_buffer_max_diff = buffer_max_diff_tensor.item() |
|
|
| result = { |
| "buffers_synced": global_buffer_max_diff < 1e-6, |
| "max_diff": global_buffer_max_diff, |
| "issues": buffer_issues if rank == 0 else [], |
| } |
|
|
| return result |
|
|
|
|
| def _test_fsdp_forward_pass_impl( |
| rank: int, |
| world_size: int, |
| apply_checkpointing: bool = False, |
| ) -> Dict: |
| """Test that FSDP forward pass produces consistent results. |
| |
| Tests: |
| 1. FSDP forward is deterministic (same input -> same output) |
| 2. All ranks produce the same output |
| 3. FSDP output matches reference model output |
| """ |
| device_mesh = init_device_mesh("cuda", (world_size,)) |
|
|
| |
| network = create_edm_network(fsdp_meta_init=True, apply_checkpointing=apply_checkpointing) |
|
|
| if rank == 0: |
| broadcast_state_dict = copy.deepcopy(network.state_dict()) |
| else: |
| broadcast_state_dict = None |
|
|
| dist.barrier() |
|
|
| apply_fsdp_to_network(network, device_mesh) |
| network.to_empty(device=torch.cuda.current_device()) |
| network.reset_parameters() |
| dist.barrier() |
|
|
| options = StateDictOptions( |
| full_state_dict=True, |
| broadcast_from_rank0=True, |
| cpu_offload=False, |
| ) |
| set_model_state_dict(network, model_state_dict=broadcast_state_dict, options=options) |
| dist.barrier() |
|
|
| |
| inputs = generate_dummy_inputs( |
| batch_size=2, |
| img_resolution=32, |
| img_channels=3, |
| label_dim=10, |
| device=torch.cuda.current_device(), |
| dtype=torch.float32, |
| ) |
|
|
| network.eval() |
|
|
| |
| with torch.no_grad(): |
| output1 = network( |
| x_t=inputs["x_t"], |
| t=inputs["t"], |
| condition=inputs["condition"], |
| return_features_early=False, |
| feature_indices=set(), |
| ) |
|
|
| output2 = network( |
| x_t=inputs["x_t"], |
| t=inputs["t"], |
| condition=inputs["condition"], |
| return_features_early=False, |
| feature_indices=set(), |
| ) |
|
|
| consistency_diff = (output1.cpu() - output2.cpu()).abs().max().item() |
|
|
| |
| if hasattr(output1, "full_tensor"): |
| output_full = output1.full_tensor() |
| else: |
| output_full = output1 |
|
|
| rank0_output = output_full.clone() |
| dist.broadcast(rank0_output, src=0) |
| rank_consistency_diff = (output_full - rank0_output).abs().max().item() |
|
|
| fsdp_output_cpu = output_full.cpu().float() |
|
|
| dist.barrier() |
|
|
| |
| result = { |
| "deterministic": consistency_diff < 1e-6, |
| "consistency_diff": consistency_diff, |
| "ranks_consistent": rank_consistency_diff < 1e-6, |
| "rank_consistency_diff": rank_consistency_diff, |
| } |
|
|
| if rank == 0: |
| |
| network.to("cpu") |
| torch.cuda.empty_cache() |
|
|
| |
| ref_network = load_reference_network() |
| ref_network = ref_network.to(device=torch.cuda.current_device(), dtype=torch.float32) |
| ref_network.eval() |
|
|
| ref_inputs = {k: v.to(torch.cuda.current_device()) for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| ref_output = ref_network( |
| x_t=ref_inputs["x_t"], |
| t=ref_inputs["t"], |
| condition=ref_inputs["condition"], |
| return_features_early=False, |
| feature_indices=set(), |
| ) |
|
|
| ref_output_cpu = ref_output.cpu().float() |
| del ref_network |
| torch.cuda.empty_cache() |
|
|
| |
| max_diff = (fsdp_output_cpu - ref_output_cpu).abs().max().item() |
| mean_diff = (fsdp_output_cpu - ref_output_cpu).abs().mean().item() |
| ref_norm = ref_output_cpu.abs().mean().item() |
| relative_error = mean_diff / (ref_norm + 1e-8) |
|
|
| result["max_diff"] = max_diff |
| result["mean_diff"] = mean_diff |
| result["relative_error"] = relative_error |
| result["forward_matches"] = relative_error < 1e-3 and max_diff < 1e-2 |
|
|
| return result |
|
|
|
|
| def _test_fsdp_meta_device_init_impl( |
| rank: int, |
| world_size: int, |
| ) -> Dict: |
| """Test that fsdp_meta_init correctly uses meta device for non-rank-0 processes. |
| |
| This test verifies: |
| 1. Rank 0 loads real tensors on CPU/CUDA |
| 2. Non-rank-0 processes have tensors on meta device |
| 3. After FSDP sync, all ranks have real tensors |
| """ |
| |
| network = create_edm_network(fsdp_meta_init=True, apply_checkpointing=False) |
|
|
| |
| meta_tensors_before = [] |
| real_tensors_before = [] |
|
|
| for name, param in network.named_parameters(): |
| if param.device.type == "meta": |
| meta_tensors_before.append(name) |
| else: |
| real_tensors_before.append(name) |
|
|
| |
| if rank == 0: |
| |
| has_correct_device_before = len(meta_tensors_before) == 0 |
| else: |
| |
| has_correct_device_before = len(meta_tensors_before) > 0 and len(real_tensors_before) == 0 |
|
|
| result = { |
| "rank": rank, |
| "meta_tensors_before_fsdp": len(meta_tensors_before), |
| "real_tensors_before_fsdp": len(real_tensors_before), |
| "has_correct_device_before": has_correct_device_before, |
| } |
|
|
| |
| if rank == 0: |
| broadcast_state_dict = copy.deepcopy(network.state_dict()) |
| else: |
| broadcast_state_dict = None |
|
|
| dist.barrier() |
|
|
| device_mesh = init_device_mesh("cuda", (world_size,)) |
| apply_fsdp_to_network(network, device_mesh) |
|
|
| |
| network.to_empty(device=torch.cuda.current_device()) |
| network.reset_parameters() |
| dist.barrier() |
|
|
| |
| options = StateDictOptions( |
| full_state_dict=True, |
| broadcast_from_rank0=True, |
| cpu_offload=False, |
| ) |
| set_model_state_dict(network, model_state_dict=broadcast_state_dict, options=options) |
| dist.barrier() |
|
|
| |
| meta_tensors_after = [] |
| real_tensors_after = [] |
|
|
| for name, param in network.named_parameters(): |
| |
| |
| if hasattr(param, "_local_tensor"): |
| |
| device = param._local_tensor.device |
| else: |
| device = param.device |
|
|
| if device.type == "meta": |
| meta_tensors_after.append(name) |
| else: |
| real_tensors_after.append(name) |
|
|
| |
| has_correct_device_after = len(meta_tensors_after) == 0 |
|
|
| result["meta_tensors_after_fsdp"] = len(meta_tensors_after) |
| result["real_tensors_after_fsdp"] = len(real_tensors_after) |
| result["has_correct_device_after"] = has_correct_device_after |
| result["all_passed"] = has_correct_device_before and has_correct_device_after |
|
|
| return result |
|
|
|
|
| def _test_fsdp_gathered_weights_forward_impl( |
| rank: int, |
| world_size: int, |
| ) -> Dict: |
| """Test that a model loaded with gathered FSDP weights produces correct output. |
| |
| This isolates whether issues come from FSDP forward computation vs weight gathering. |
| """ |
| device_mesh = init_device_mesh("cuda", (world_size,)) |
|
|
| |
| network = create_edm_network(fsdp_meta_init=True, apply_checkpointing=False) |
|
|
| if rank == 0: |
| broadcast_state_dict = copy.deepcopy(network.state_dict()) |
| else: |
| broadcast_state_dict = None |
|
|
| dist.barrier() |
|
|
| apply_fsdp_to_network(network, device_mesh) |
| network.to_empty(device=torch.cuda.current_device()) |
| network.reset_parameters() |
| dist.barrier() |
|
|
| options = StateDictOptions( |
| full_state_dict=True, |
| broadcast_from_rank0=True, |
| cpu_offload=False, |
| ) |
| set_model_state_dict(network, model_state_dict=broadcast_state_dict, options=options) |
| dist.barrier() |
|
|
| |
| gathered_state_dict = gather_fsdp_state_dict(network) |
| dist.barrier() |
|
|
| result = {} |
|
|
| if rank == 0: |
| |
| inputs = generate_dummy_inputs( |
| batch_size=2, |
| img_resolution=32, |
| img_channels=3, |
| label_dim=10, |
| device=torch.cuda.current_device(), |
| dtype=torch.float32, |
| ) |
|
|
| |
| gathered_network = load_reference_network() |
| gathered_network = gathered_network.to(device=torch.cuda.current_device(), dtype=torch.float32) |
|
|
| missing, unexpected = gathered_network.load_state_dict(gathered_state_dict, strict=False) |
|
|
| gathered_network.eval() |
| with torch.no_grad(): |
| gathered_output = gathered_network( |
| x_t=inputs["x_t"], |
| t=inputs["t"], |
| condition=inputs["condition"], |
| return_features_early=False, |
| feature_indices=set(), |
| ) |
|
|
| gathered_output_cpu = gathered_output.cpu().float() |
| del gathered_network |
| torch.cuda.empty_cache() |
|
|
| |
| ref_network = load_reference_network() |
| ref_network = ref_network.to(device=torch.cuda.current_device(), dtype=torch.float32) |
|
|
| ref_network.eval() |
| with torch.no_grad(): |
| ref_output = ref_network( |
| x_t=inputs["x_t"], |
| t=inputs["t"], |
| condition=inputs["condition"], |
| return_features_early=False, |
| feature_indices=set(), |
| ) |
|
|
| ref_output_cpu = ref_output.cpu().float() |
| del ref_network |
| torch.cuda.empty_cache() |
|
|
| |
| max_diff = (gathered_output_cpu - ref_output_cpu).abs().max().item() |
| mean_diff = (gathered_output_cpu - ref_output_cpu).abs().mean().item() |
|
|
| result = { |
| "missing_keys": len(missing), |
| "unexpected_keys": len(unexpected), |
| "max_diff": max_diff, |
| "mean_diff": mean_diff, |
| "matches": max_diff < 1e-4, |
| } |
|
|
| return result |
|
|
|
|
| |
| |
| |
|
|
|
|
| @RunIf(min_gpus=2) |
| def test_fsdp_weight_sync(): |
| """Test that FSDP correctly synchronizes weights from rank 0 to all ranks.""" |
| clear_gpu_memory() |
|
|
| result = run_distributed_test( |
| test_fn=_test_fsdp_weight_sync_impl, |
| world_size=2, |
| setup_fn=set_env_vars, |
| apply_checkpointing=False, |
| ) |
|
|
| assert result is not None, "Test did not return a result" |
| assert result[ |
| "weights_match" |
| ], f"Weight sync failed with {len(result['mismatches'])} mismatches: {result['mismatches'][:5]}" |
|
|
| clear_gpu_memory() |
|
|
|
|
| @RunIf(min_gpus=2) |
| def test_fsdp_weight_sync_with_checkpointing(): |
| """Test FSDP weight sync with activation checkpointing enabled.""" |
| clear_gpu_memory() |
|
|
| result = run_distributed_test( |
| test_fn=_test_fsdp_weight_sync_impl, |
| world_size=2, |
| setup_fn=set_env_vars, |
| apply_checkpointing=True, |
| ) |
|
|
| assert result is not None, "Test did not return a result" |
| assert result["weights_match"], f"Weight sync with checkpointing failed: {result['mismatches'][:5]}" |
|
|
| clear_gpu_memory() |
|
|
|
|
| @RunIf(min_gpus=2) |
| def test_fsdp_tensor_sharding(): |
| """Test that FSDP actually shards tensors across ranks. |
| |
| This verifies that: |
| - Local tensor shapes differ from full tensor shapes for sharded parameters |
| - The sharding ratio is reasonable (local < full) |
| - Gathered shapes match original shapes |
| """ |
| clear_gpu_memory() |
|
|
| result = run_distributed_test( |
| test_fn=_test_fsdp_tensor_sharding_impl, |
| world_size=2, |
| setup_fn=set_env_vars, |
| ) |
|
|
| assert result is not None, "Test did not return a result" |
| assert result.get("has_sharded_params", False), ( |
| f"No parameters were sharded. num_dtensor_params={result.get('num_dtensor_params', 'N/A')}, " |
| f"num_params_with_shape_diff={result.get('num_params_with_shape_diff', 'N/A')}" |
| ) |
| assert result.get("ratio_is_reasonable", False), ( |
| f"Sharding ratio is not reasonable: {result.get('sharding_ratio', 'N/A')} " |
| f"(expected ~{result.get('expected_ratio', 'N/A')})" |
| ) |
| assert result.get( |
| "gathered_shapes_match", False |
| ), f"Gathered shapes don't match original shapes: {result.get('shape_mismatches', 'N/A')}" |
| assert result.get("all_passed", False), ( |
| f"Tensor sharding test failed. Details: " f"local_vs_full_diffs={result.get('local_vs_full_diffs', 'N/A')}" |
| ) |
|
|
| clear_gpu_memory() |
|
|
|
|
| @RunIf(min_gpus=2) |
| def test_fsdp_buffer_sync(): |
| """Test that FSDP correctly synchronizes buffers (e.g., RoPE) across ranks.""" |
| clear_gpu_memory() |
|
|
| result = run_distributed_test( |
| test_fn=_test_fsdp_buffer_sync_impl, |
| world_size=2, |
| setup_fn=set_env_vars, |
| ) |
|
|
| assert result is not None, "Test did not return a result" |
| assert result["buffers_synced"], f"Buffer sync failed with max_diff={result['max_diff']}: {result['issues']}" |
|
|
| clear_gpu_memory() |
|
|
|
|
| @RunIf(min_gpus=2) |
| def test_fsdp_forward_deterministic(): |
| """Test that FSDP forward pass is deterministic.""" |
| clear_gpu_memory() |
|
|
| result = run_distributed_test( |
| test_fn=_test_fsdp_forward_pass_impl, |
| world_size=2, |
| setup_fn=set_env_vars, |
| apply_checkpointing=False, |
| ) |
|
|
| assert result is not None, "Test did not return a result" |
| assert result["deterministic"], f"FSDP forward not deterministic: diff={result['consistency_diff']}" |
|
|
| clear_gpu_memory() |
|
|
|
|
| @RunIf(min_gpus=2) |
| def test_fsdp_forward_rank_consistency(): |
| """Test that all FSDP ranks produce the same output.""" |
| clear_gpu_memory() |
|
|
| result = run_distributed_test( |
| test_fn=_test_fsdp_forward_pass_impl, |
| world_size=2, |
| setup_fn=set_env_vars, |
| apply_checkpointing=False, |
| ) |
|
|
| assert result is not None, "Test did not return a result" |
| assert result["ranks_consistent"], f"Ranks inconsistent: diff={result['rank_consistency_diff']}" |
|
|
| clear_gpu_memory() |
|
|
|
|
| @RunIf(min_gpus=2) |
| def test_fsdp_forward_matches_reference(): |
| """Test that FSDP forward pass matches reference model output.""" |
| clear_gpu_memory() |
|
|
| result = run_distributed_test( |
| test_fn=_test_fsdp_forward_pass_impl, |
| world_size=2, |
| setup_fn=set_env_vars, |
| apply_checkpointing=False, |
| ) |
|
|
| assert result is not None, "Test did not return a result" |
| assert result.get("forward_matches", False), ( |
| f"FSDP forward doesn't match reference: " |
| f"max_diff={result.get('max_diff', 'N/A')}, " |
| f"relative_error={result.get('relative_error', 'N/A')}" |
| ) |
|
|
| clear_gpu_memory() |
|
|
|
|
| @RunIf(min_gpus=2) |
| def test_fsdp_gathered_weights_forward(): |
| """Test that model loaded with gathered FSDP weights produces correct output.""" |
| clear_gpu_memory() |
|
|
| result = run_distributed_test( |
| test_fn=_test_fsdp_gathered_weights_forward_impl, |
| world_size=2, |
| setup_fn=set_env_vars, |
| ) |
|
|
| assert result is not None, "Test did not return a result" |
| assert result.get("matches", False), ( |
| f"Gathered weights model doesn't match reference: " |
| f"max_diff={result.get('max_diff', 'N/A')}, " |
| f"missing_keys={result.get('missing_keys', 'N/A')}" |
| ) |
|
|
| clear_gpu_memory() |
|
|
|
|
| @RunIf(min_gpus=2) |
| def test_fsdp_meta_device_init(): |
| """Test that fsdp_meta_init correctly uses meta device for non-rank-0 processes. |
| |
| This verifies: |
| - Rank 0 loads real tensors (not meta) |
| - Non-rank-0 processes use meta device for memory-efficient loading |
| - After FSDP sync, all tensors are properly materialized |
| """ |
| clear_gpu_memory() |
|
|
| result = run_distributed_test( |
| test_fn=_test_fsdp_meta_device_init_impl, |
| world_size=2, |
| setup_fn=set_env_vars, |
| ) |
|
|
| assert result is not None, "Test did not return a result" |
| assert result.get("has_correct_device_before", False), ( |
| f"Meta device not used correctly before FSDP: " |
| f"rank={result.get('rank', 'N/A')}, " |
| f"meta_tensors={result.get('meta_tensors_before_fsdp', 'N/A')}, " |
| f"real_tensors={result.get('real_tensors_before_fsdp', 'N/A')}" |
| ) |
| assert result.get("has_correct_device_after", False), ( |
| f"Tensors not properly materialized after FSDP: " f"meta_tensors={result.get('meta_tensors_after_fsdp', 'N/A')}" |
| ) |
|
|
| clear_gpu_memory() |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _test_model_to_fsdp_impl( |
| rank: int, |
| world_size: int, |
| sharding_group_size: int = None, |
| ) -> Dict: |
| """Test the complete model_to_fsdp code path with a real model. |
| |
| This test verifies: |
| 1. Model instantiation with fsdp_meta_init=True |
| 2. meta device usage for non-rank-0 processes |
| 3. model_to_fsdp correctly wraps and syncs the model |
| 4. After FSDP, all tensors are properly materialized |
| 5. Forward pass produces consistent results across ranks |
| |
| Args: |
| sharding_group_size: If set, creates a 2D mesh with (replicate, shard) dimensions. |
| """ |
| import copy |
| from fastgen.configs.methods.config_sft import ModelConfig |
| from fastgen.configs.net import EDM_CIFAR10_Config |
| from fastgen.methods import SFTModel |
| from fastgen.utils.distributed.fsdp import model_to_fsdp |
|
|
| |
| config = ModelConfig() |
| config.net = copy.deepcopy(EDM_CIFAR10_Config) |
| config.input_shape = [3, 32, 32] |
| config.precision = "float32" |
| config.device = "cuda" |
| config.fsdp_meta_init = True |
| config.pretrained_model_path = "" |
| config.use_ema = False |
|
|
| result = { |
| "rank": rank, |
| "world_size": world_size, |
| "sharding_group_size": sharding_group_size, |
| } |
|
|
| |
| |
| |
| model = SFTModel(config) |
|
|
| |
| meta_tensors_before = [] |
| real_tensors_before = [] |
|
|
| for name, param in model.net.named_parameters(): |
| if param.device.type == "meta": |
| meta_tensors_before.append(name) |
| else: |
| real_tensors_before.append(name) |
|
|
| if rank == 0: |
| |
| has_correct_device_before = len(meta_tensors_before) == 0 |
| else: |
| |
| has_correct_device_before = len(meta_tensors_before) > 0 and len(real_tensors_before) == 0 |
|
|
| result["meta_tensors_before_fsdp"] = len(meta_tensors_before) |
| result["real_tensors_before_fsdp"] = len(real_tensors_before) |
| result["has_correct_device_before"] = has_correct_device_before |
|
|
| dist.barrier() |
|
|
| |
| |
| model = model_to_fsdp( |
| model, |
| min_num_params=1_000_000, |
| apply_cpu_offload=False, |
| sync_module_states=True, |
| sharding_group_size=sharding_group_size, |
| ) |
|
|
| dist.barrier() |
|
|
| |
| meta_tensors_after = [] |
| real_tensors_after = [] |
|
|
| for name, param in model.net.named_parameters(): |
| |
| if hasattr(param, "_local_tensor"): |
| device = param._local_tensor.device |
| else: |
| device = param.device |
|
|
| if device.type == "meta": |
| meta_tensors_after.append(name) |
| else: |
| real_tensors_after.append(name) |
|
|
| has_correct_device_after = len(meta_tensors_after) == 0 |
| result["meta_tensors_after_fsdp"] = len(meta_tensors_after) |
| result["real_tensors_after_fsdp"] = len(real_tensors_after) |
| result["has_correct_device_after"] = has_correct_device_after |
|
|
| |
| model.net.eval() |
|
|
| |
| generator = torch.Generator(device="cpu").manual_seed(42) |
| dummy_input = torch.randn(2, 3, 32, 32, generator=generator, dtype=torch.float32).cuda() |
| dummy_t = torch.tensor([0.5, 0.5], dtype=torch.float32, device="cuda") |
| dummy_labels = torch.zeros(2, 10, dtype=torch.float32, device="cuda") |
|
|
| with torch.no_grad(): |
| output = model.net(dummy_input, dummy_t, condition=dummy_labels) |
|
|
| |
| if hasattr(output, "full_tensor"): |
| output_full = output.full_tensor() |
| else: |
| output_full = output |
|
|
| rank0_output = output_full.clone() |
| dist.broadcast(rank0_output, src=0) |
| rank_consistency_diff = (output_full - rank0_output).abs().max().item() |
|
|
| result["forward_rank_consistent"] = rank_consistency_diff < 1e-5 |
| result["rank_consistency_diff"] = rank_consistency_diff |
| result["all_passed"] = has_correct_device_before and has_correct_device_after and result["forward_rank_consistent"] |
|
|
| |
| del model |
| torch.cuda.empty_cache() |
|
|
| return result |
|
|
|
|
| @RunIf(min_gpus=2) |
| def test_model_to_fsdp_code_path(): |
| """Test the complete model_to_fsdp code path with a real model. |
| |
| This is an end-to-end test that verifies: |
| - Model builds correctly with fsdp_meta_init=True |
| - Non-rank-0 processes use meta device |
| - model_to_fsdp correctly wraps and syncs weights |
| - Forward pass is consistent across ranks |
| """ |
| clear_gpu_memory() |
|
|
| result = run_distributed_test( |
| test_fn=_test_model_to_fsdp_impl, |
| world_size=2, |
| timeout=300, |
| setup_fn=set_env_vars, |
| ) |
|
|
| assert result is not None, "Test did not return a result" |
| assert result.get("has_correct_device_before", False), ( |
| f"Meta device not used correctly before FSDP: " |
| f"rank={result.get('rank', 'N/A')}, " |
| f"meta_tensors={result.get('meta_tensors_before_fsdp', 'N/A')}, " |
| f"real_tensors={result.get('real_tensors_before_fsdp', 'N/A')}" |
| ) |
| assert result.get("has_correct_device_after", False), ( |
| f"Tensors not properly materialized after model_to_fsdp: " |
| f"meta_tensors={result.get('meta_tensors_after_fsdp', 'N/A')}" |
| ) |
| assert result.get("forward_rank_consistent", False), ( |
| f"Forward pass inconsistent across ranks: " f"diff={result.get('rank_consistency_diff', 'N/A')}" |
| ) |
| assert result.get("all_passed", False), "model_to_fsdp test failed" |
|
|
| clear_gpu_memory() |
|
|
|
|
| @RunIf(min_gpus=4) |
| def test_model_to_fsdp_sharding_group(): |
| """Test model_to_fsdp with a sharding group size (2D mesh). |
| |
| This test verifies FSDP with a 2D device mesh where: |
| - sharding_group_size=2 means weights are sharded within groups of 2 GPUs |
| - The mesh has (replicate, shard) dimensions |
| |
| Requires at least 4 GPUs and world_size must be divisible by 2. |
| """ |
| import pytest |
|
|
| clear_gpu_memory() |
|
|
| |
| num_gpus = torch.cuda.device_count() |
| sharding_group_size = 2 |
| if num_gpus % sharding_group_size != 0: |
| pytest.skip(f"Number of GPUs ({num_gpus}) not divisible by sharding_group_size ({sharding_group_size})") |
|
|
| result = run_distributed_test( |
| test_fn=_test_model_to_fsdp_impl, |
| world_size=4, |
| timeout=300, |
| setup_fn=set_env_vars, |
| sharding_group_size=sharding_group_size, |
| ) |
|
|
| assert result is not None, "Test did not return a result" |
| assert result.get("has_correct_device_before", False), ( |
| f"Meta device not used correctly before FSDP: " |
| f"rank={result.get('rank', 'N/A')}, " |
| f"meta_tensors={result.get('meta_tensors_before_fsdp', 'N/A')}, " |
| f"real_tensors={result.get('real_tensors_before_fsdp', 'N/A')}" |
| ) |
| assert result.get("has_correct_device_after", False), ( |
| f"Tensors not properly materialized after model_to_fsdp: " |
| f"meta_tensors={result.get('meta_tensors_after_fsdp', 'N/A')}" |
| ) |
| assert result.get("forward_rank_consistent", False), ( |
| f"Forward pass inconsistent across ranks: " f"diff={result.get('rank_consistency_diff', 'N/A')}" |
| ) |
| assert result.get("all_passed", False), "model_to_fsdp sharding group test failed" |
|
|
| clear_gpu_memory() |
|
|
|
|
| def _test_model_to_fsdp_no_meta_init_impl( |
| rank: int, |
| world_size: int, |
| ) -> Dict: |
| """Test model_to_fsdp code path WITHOUT meta init. |
| |
| This test verifies that the standard (non-meta-init) FSDP path works: |
| 1. All ranks load real tensors (no meta device) |
| 2. model_to_fsdp correctly wraps the model |
| 3. Forward pass produces consistent results across ranks |
| """ |
| import copy |
| from fastgen.configs.methods.config_sft import ModelConfig |
| from fastgen.configs.net import EDM_CIFAR10_Config |
| from fastgen.methods import SFTModel |
| from fastgen.utils.distributed.fsdp import model_to_fsdp |
|
|
| |
| |
| config = ModelConfig() |
| config.net = copy.deepcopy(EDM_CIFAR10_Config) |
| config.input_shape = [3, 32, 32] |
| config.precision = "float32" |
| config.device = "cuda" |
| config.fsdp_meta_init = False |
| config.pretrained_model_path = "" |
| config.use_ema = False |
|
|
| result = { |
| "rank": rank, |
| "world_size": world_size, |
| } |
|
|
| |
| model = SFTModel(config) |
|
|
| |
| meta_tensors_before = [] |
| real_tensors_before = [] |
|
|
| for name, param in model.net.named_parameters(): |
| if param.device.type == "meta": |
| meta_tensors_before.append(name) |
| else: |
| real_tensors_before.append(name) |
|
|
| |
| has_correct_device_before = len(meta_tensors_before) == 0 |
| result["meta_tensors_before_fsdp"] = len(meta_tensors_before) |
| result["real_tensors_before_fsdp"] = len(real_tensors_before) |
| result["has_correct_device_before"] = has_correct_device_before |
|
|
| dist.barrier() |
|
|
| |
| |
| model = model_to_fsdp( |
| model, |
| min_num_params=1_000_000, |
| apply_cpu_offload=False, |
| sync_module_states=False, |
| ) |
|
|
| dist.barrier() |
|
|
| |
| meta_tensors_after = [] |
| real_tensors_after = [] |
|
|
| for name, param in model.net.named_parameters(): |
| if hasattr(param, "_local_tensor"): |
| device = param._local_tensor.device |
| else: |
| device = param.device |
|
|
| if device.type == "meta": |
| meta_tensors_after.append(name) |
| else: |
| real_tensors_after.append(name) |
|
|
| has_correct_device_after = len(meta_tensors_after) == 0 |
| result["meta_tensors_after_fsdp"] = len(meta_tensors_after) |
| result["real_tensors_after_fsdp"] = len(real_tensors_after) |
| result["has_correct_device_after"] = has_correct_device_after |
|
|
| |
| model.net.eval() |
|
|
| generator = torch.Generator(device="cpu").manual_seed(42) |
| dummy_input = torch.randn(2, 3, 32, 32, generator=generator, dtype=torch.float32).cuda() |
| dummy_t = torch.tensor([0.5, 0.5], dtype=torch.float32, device="cuda") |
| dummy_labels = torch.zeros(2, 10, dtype=torch.float32, device="cuda") |
|
|
| with torch.no_grad(): |
| output = model.net(dummy_input, dummy_t, condition=dummy_labels) |
|
|
| if hasattr(output, "full_tensor"): |
| output_full = output.full_tensor() |
| else: |
| output_full = output |
|
|
| rank0_output = output_full.clone() |
| dist.broadcast(rank0_output, src=0) |
| rank_consistency_diff = (output_full - rank0_output).abs().max().item() |
|
|
| result["forward_rank_consistent"] = rank_consistency_diff < 1e-5 |
| result["rank_consistency_diff"] = rank_consistency_diff |
| result["all_passed"] = has_correct_device_before and has_correct_device_after and result["forward_rank_consistent"] |
|
|
| del model |
| torch.cuda.empty_cache() |
|
|
| return result |
|
|
|
|
| @RunIf(min_gpus=2) |
| def test_model_to_fsdp_no_meta_init(): |
| """Test model_to_fsdp code path WITHOUT meta init. |
| |
| This verifies the standard FSDP path where all ranks load weights independently. |
| """ |
| clear_gpu_memory() |
|
|
| result = run_distributed_test( |
| test_fn=_test_model_to_fsdp_no_meta_init_impl, |
| world_size=2, |
| timeout=300, |
| setup_fn=set_env_vars, |
| ) |
|
|
| assert result is not None, "Test did not return a result" |
| assert result.get("has_correct_device_before", False), ( |
| f"All ranks should have real tensors before FSDP: " |
| f"meta_tensors={result.get('meta_tensors_before_fsdp', 'N/A')}" |
| ) |
| assert result.get("has_correct_device_after", False), ( |
| f"Tensors not on CUDA after model_to_fsdp: " f"meta_tensors={result.get('meta_tensors_after_fsdp', 'N/A')}" |
| ) |
| assert result.get("forward_rank_consistent", False), ( |
| f"Forward pass inconsistent across ranks: " f"diff={result.get('rank_consistency_diff', 'N/A')}" |
| ) |
|
|
| clear_gpu_memory() |
|
|
|
|
| def _test_dmd2_model_to_fsdp_impl( |
| rank: int, |
| world_size: int, |
| ) -> Dict: |
| """Test model_to_fsdp with DMD2Model which has multiple networks. |
| |
| DMD2 has: |
| - net: student network |
| - teacher: frozen teacher network |
| - fake_score: trainable fake score network |
| |
| This tests that FSDP correctly handles multiple networks in fsdp_dict. |
| """ |
| import copy |
| from fastgen.configs.methods.config_dmd2 import ModelConfig |
| from fastgen.configs.net import EDM_CIFAR10_Config |
| from fastgen.methods import DMD2Model |
| from fastgen.utils.distributed.fsdp import model_to_fsdp |
|
|
| |
| config = ModelConfig() |
| config.net = copy.deepcopy(EDM_CIFAR10_Config) |
| config.input_shape = [3, 32, 32] |
| config.precision = "float32" |
| config.device = "cuda" |
| config.fsdp_meta_init = True |
| config.pretrained_model_path = "" |
| config.use_ema = False |
| config.add_teacher_to_fsdp_dict = True |
| |
| config.gan_loss_weight_gen = 0.0 |
|
|
| result = { |
| "rank": rank, |
| "world_size": world_size, |
| } |
|
|
| |
| model = DMD2Model(config) |
|
|
| |
| networks_to_check = ["net", "teacher", "fake_score"] |
| meta_status = {} |
|
|
| for net_name in networks_to_check: |
| if hasattr(model, net_name): |
| network = getattr(model, net_name) |
| meta_count = sum(1 for p in network.parameters() if p.device.type == "meta") |
| real_count = sum(1 for p in network.parameters() if p.device.type != "meta") |
| meta_status[net_name] = { |
| "meta": meta_count, |
| "real": real_count, |
| } |
|
|
| |
| if rank == 0: |
| |
| has_correct_device_before = all(status["meta"] == 0 for status in meta_status.values()) |
| else: |
| |
| has_correct_device_before = all(status["real"] == 0 for status in meta_status.values()) |
|
|
| result["meta_status_before"] = meta_status |
| result["has_correct_device_before"] = has_correct_device_before |
|
|
| dist.barrier() |
|
|
| |
| model = model_to_fsdp( |
| model, |
| min_num_params=1_000_000, |
| apply_cpu_offload=False, |
| sync_module_states=True, |
| ) |
|
|
| dist.barrier() |
|
|
| |
| meta_status_after = {} |
| for net_name in networks_to_check: |
| if hasattr(model, net_name): |
| network = getattr(model, net_name) |
| meta_count = 0 |
| real_count = 0 |
| for p in network.parameters(): |
| if hasattr(p, "_local_tensor"): |
| device = p._local_tensor.device |
| else: |
| device = p.device |
| if device.type == "meta": |
| meta_count += 1 |
| else: |
| real_count += 1 |
| meta_status_after[net_name] = { |
| "meta": meta_count, |
| "real": real_count, |
| } |
|
|
| has_correct_device_after = all(status["meta"] == 0 for status in meta_status_after.values()) |
| result["meta_status_after"] = meta_status_after |
| result["has_correct_device_after"] = has_correct_device_after |
|
|
| |
| model.net.eval() |
| model.fake_score.eval() |
| model.teacher.eval() |
|
|
| generator = torch.Generator(device="cpu").manual_seed(42) |
| dummy_input = torch.randn(2, 3, 32, 32, generator=generator, dtype=torch.float32).cuda() |
| dummy_t = torch.tensor([0.5, 0.5], dtype=torch.float32, device="cuda") |
| dummy_labels = torch.zeros(2, 10, dtype=torch.float32, device="cuda") |
|
|
| with torch.no_grad(): |
| net_output = model.net(dummy_input, dummy_t, condition=dummy_labels) |
| fake_score_output = model.fake_score(dummy_input, dummy_t, condition=dummy_labels) |
| teacher_output = model.teacher(dummy_input, dummy_t, condition=dummy_labels) |
|
|
| |
| forward_results = {} |
| for name, output in [("net", net_output), ("fake_score", fake_score_output), ("teacher", teacher_output)]: |
| if hasattr(output, "full_tensor"): |
| output_full = output.full_tensor() |
| else: |
| output_full = output |
|
|
| rank0_output = output_full.clone() |
| dist.broadcast(rank0_output, src=0) |
| diff = (output_full - rank0_output).abs().max().item() |
| forward_results[name] = { |
| "consistent": diff < 1e-5, |
| "diff": diff, |
| } |
|
|
| result["forward_results"] = forward_results |
| all_forward_consistent = all(r["consistent"] for r in forward_results.values()) |
| result["all_forward_consistent"] = all_forward_consistent |
|
|
| result["all_passed"] = has_correct_device_before and has_correct_device_after and all_forward_consistent |
|
|
| del model |
| torch.cuda.empty_cache() |
|
|
| return result |
|
|
|
|
| @RunIf(min_gpus=2) |
| def test_dmd2_model_to_fsdp(): |
| """Test model_to_fsdp with DMD2Model which has multiple networks. |
| |
| This verifies that FSDP correctly handles models with multiple networks |
| (net, teacher, fake_score) in their fsdp_dict. |
| """ |
| clear_gpu_memory() |
|
|
| result = run_distributed_test( |
| test_fn=_test_dmd2_model_to_fsdp_impl, |
| world_size=2, |
| timeout=300, |
| setup_fn=set_env_vars, |
| ) |
|
|
| assert result is not None, "Test did not return a result" |
| assert result.get("has_correct_device_before", False), ( |
| f"Meta device not used correctly before FSDP: " f"meta_status={result.get('meta_status_before', 'N/A')}" |
| ) |
| assert result.get("has_correct_device_after", False), ( |
| f"Tensors not properly materialized after model_to_fsdp: " |
| f"meta_status={result.get('meta_status_after', 'N/A')}" |
| ) |
| assert result.get("all_forward_consistent", False), ( |
| f"Forward pass inconsistent across ranks: " f"forward_results={result.get('forward_results', 'N/A')}" |
| ) |
| assert result.get("all_passed", False), "DMD2 model_to_fsdp test failed" |
|
|
| clear_gpu_memory() |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _test_fsdp_full_verification_impl( |
| rank: int, |
| world_size: int, |
| ) -> Dict: |
| """Full FSDP verification test combining all checks.""" |
| results = {} |
|
|
| |
| weight_result = _test_fsdp_weight_sync_impl( |
| rank=rank, |
| world_size=world_size, |
| apply_checkpointing=False, |
| ) |
| results["weight_sync"] = weight_result.get("weights_match", False) |
|
|
| dist.barrier() |
| torch.cuda.empty_cache() |
|
|
| |
| buffer_result = _test_fsdp_buffer_sync_impl( |
| rank=rank, |
| world_size=world_size, |
| ) |
| results["buffer_sync"] = buffer_result.get("buffers_synced", False) |
|
|
| dist.barrier() |
| torch.cuda.empty_cache() |
|
|
| |
| forward_result = _test_fsdp_forward_pass_impl( |
| rank=rank, |
| world_size=world_size, |
| apply_checkpointing=False, |
| ) |
| results["forward_deterministic"] = forward_result.get("deterministic", False) |
| results["forward_rank_consistent"] = forward_result.get("ranks_consistent", False) |
| results["forward_matches_reference"] = forward_result.get("forward_matches", False) |
|
|
| results["all_passed"] = all(results.values()) |
|
|
| return results |
|
|
|
|
| @RunIf(min_gpus=2) |
| def test_fsdp_full_verification(): |
| """Full FSDP verification test - combines all individual tests.""" |
| clear_gpu_memory() |
|
|
| result = run_distributed_test( |
| test_fn=_test_fsdp_full_verification_impl, |
| world_size=2, |
| timeout=600, |
| setup_fn=set_env_vars, |
| ) |
|
|
| assert result is not None, "Test did not return a result" |
|
|
| |
| assert result.get("weight_sync", False), "Weight synchronization failed" |
| assert result.get("buffer_sync", False), "Buffer synchronization failed" |
| assert result.get("forward_deterministic", False), "Forward pass not deterministic" |
| assert result.get("forward_rank_consistent", False), "Forward outputs differ across ranks" |
| assert result.get("forward_matches_reference", False), "Forward output doesn't match reference" |
|
|
| assert result["all_passed"], "Not all verification checks passed" |
|
|
| clear_gpu_memory() |
|
|