File size: 40,458 Bytes
9dd3461
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
# Owner(s): ["oncall: distributed"]

import functools
import itertools
import sys
from abc import ABC, abstractmethod
from contextlib import suppress
from copy import deepcopy
from enum import Enum, auto
from math import inf
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from unittest import mock

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp import CPUOffload
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
    BackwardPrefetch,
    MixedPrecision,
    ShardingStrategy,
    TrainingState_,
)
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.fsdp.wrap import (
    always_wrap_policy,
    transformer_auto_wrap_policy,
    wrap,
)
from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.testing._internal.common_distributed import (
    TEST_SKIPS,
    MultiProcessTestCase,
)
from torch.testing._internal.common_utils import FILE_SCHEMA, get_cycles_per_ms


class FSDPInitMode(Enum):
    # No FSDP wrapping
    NO_FSDP = auto()
    # FSDP recursive wrapping
    RECURSIVE = auto()
    # TODO: FSDP non-recursive wrapping
    # NONRECURSIVE = auto()


class CUDAInitMode(Enum):
    # Move model to CUDA before passing to the FSDP constructor
    CUDA_BEFORE = auto()
    # Move model to CUDA after passing to the FSDP constructor
    CUDA_AFTER = auto()
    # Keep on CPU
    CUDA_NEVER = auto()


class FSDPTestModel(nn.Module, ABC):
    """This defines the interface expected from all models used commonly for
    FSDP unit tests."""
    @abstractmethod
    def get_input(self, device) -> Tuple[torch.Tensor, ...]:
        """Returns an input for the model as as tuple."""
        ...

    @abstractmethod
    def get_loss(self, input, output) -> torch.Tensor:
        """Returns the loss given the input and output."""
        ...

    @abstractmethod
    def run_backward(self, loss) -> None:
        """Runs the backward pass (e.g. including ``loss.backward()``)."""
        ...

    @staticmethod
    @abstractmethod
    def init(
        group: dist.ProcessGroup,
        fsdp_init_mode: FSDPInitMode,
        *init_args: Any,
        cuda_init_mode: CUDAInitMode,
        fsdp_kwargs: Optional[Dict[str, Any]] = None,
        deterministic: bool = False,
        **init_kwargs: Any,
    ) -> nn.Module:
        """Initializes an instance of this model."""
        ...



def _assert_module_states(
    model: nn.Module,
    process_group: dist.ProcessGroup,
    assert_fn: Callable,
):
    """
    All-gathers module states across ranks and calls ``assert_fn`` on each pair
    of corresponding states from rank 0 and a nonzero rank. For example, if
    ``assert_fn`` is ``self.assertEqual()``, then this checks that all module
    states are equal across ranks.
    """
    # Include names for debugging convenience
    named_module_states = [
        (param_name, param.detach().cpu())
        for param_name, param in model.named_parameters()
    ]
    named_module_states += [
        (buffer_name, buffer.detach().cpu())
        for buffer_name, buffer in model.named_buffers()
    ]
    world_size = dist.get_world_size(process_group)
    olist = [None for _ in range(world_size)]
    dist.all_gather_object(olist, named_module_states, group=process_group)
    rank0_states = olist[0]
    for state in olist[1:]:
        for (_, p1), (_, p2) in zip(rank0_states, state):
            assert_fn(p1, p2)

def _zero_model(
    model: nn.Module,
    zero_buffers: bool = False,
):
    """Zeros the parameters and optionally buffers of ``model`` in place."""
    with FSDP.summon_full_params(model):
        for param in model.parameters():
            with torch.no_grad():
                param.zero_()
        if zero_buffers:
            for buffer in model.buffers():
                with torch.no_grad():
                    buffer.zero_()

def _get_state_dict(model, cpu_offload=False, half=False):
    if not cpu_offload:
        model = model.cuda()
    if half:
        model.half()

    return model.state_dict()

def subtest_name(test_name_mapping, *args):
    return '_'.join(
        [test_name_mapping[str(s)] if s is not None else "none" for s in args]
    )

def get_full_params(model: nn.Module, recurse: bool = True):
    """
    Returns the full unsharded parameters of ``model``. Any FSDP-managed
    parameters offloaded to CPU are moved to GPU in the returned list.

    Args:
        recurse (bool): If ``False``, only unshards the parameters immediate to
            ``model``; if ``True``, recurses through the module hierarchy
            rooted at ``model``.
    """
    with FSDP.summon_full_params(model, recurse=recurse):
        return deepcopy(list(model.parameters()))

def _maybe_cuda(model: nn.Module, move_to_cuda: bool):
    return model.cuda() if move_to_cuda else model

def _maybe_wrap_fsdp(model: nn.Module, wrap_fsdp: bool, *args, **kwargs):
    return (
        model if not wrap_fsdp
        else FSDP(model, *args, **kwargs)
    )

class DummyProcessGroup:
    def __init__(self, rank: int, size: int):
        self._rank = rank
        self._size = size

    def rank(self) -> int:
        return self._rank

    def size(self) -> int:
        return self._size

    def allreduce(self, *args, **kwargs):
        dist_wait = mock.Mock()

        def get_future():
            future = torch.futures.Future()
            future.set_result(1)
            return future

        dist_wait.get_future = get_future
        return dist_wait

class DeterministicModel(torch.nn.Module):
    def __init__(self, wrap_fsdp, cpu_offload=CPUOffload(offload_params=False)):
        super().__init__()
        # keep everything deterministic for model initialization
        torch.manual_seed(0)
        self.inner: Union[torch.nn.Linear, FSDP] = \
            torch.nn.Linear(2, 2).cuda()
        if wrap_fsdp:
            self.inner = FSDP(self.inner, cpu_offload=cpu_offload)
        self.outer = torch.nn.Linear(2, 2).cuda()

    def forward(self, x):
        y = self.inner(x)
        return self.outer(y)

class TransformerWithSharedParams(FSDPTestModel):
    def __init__(
        self,
        group: dist.ProcessGroup,
        cuda_init_mode: CUDAInitMode,
        add_bn: bool,
        deterministic: bool,
    ):
        super().__init__()
        self.rank = group.rank()
        self.world_size = group.size()
        if deterministic:
            torch.manual_seed(0)
        d_vocab = 23
        d_model = 16

        self.embed_tokens = nn.Embedding(d_vocab, d_model)
        self.transformer = nn.Transformer(
            d_model=d_model,
            num_encoder_layers=2,
            num_decoder_layers=2,
            dim_feedforward=8,
            dropout=0.1,
        )
        self.output_proj = nn.Linear(d_model, d_vocab)

        # share the embedding and output projection weights
        self.output_proj.weight = self.embed_tokens.weight
        self.register_buffer(
            "vocab_bias", self.embed_tokens.weight.new_ones((d_model,))
        )
        self.register_buffer(
            "long_buffer",
            torch.zeros_like(self.vocab_bias, dtype=torch.long),
        )  # type: ignore[arg-type]

        self.bs = 2
        self.bn = torch.nn.BatchNorm1d(self.bs) if add_bn else torch.nn.Identity()
        if cuda_init_mode == CUDAInitMode.CUDA_BEFORE:
            self = self.cuda()
        if deterministic:
            self.eval()

    def get_input(self, device):
        torch.manual_seed(1 + self.rank)  # keep everything deterministic
        src = torch.arange(12, device=device).view(6, self.bs)  # T x B
        tgt = torch.arange(self.bs * 4, device=device).view(4, self.bs)  # T x B
        return (src, tgt)

    def forward(self, src_ids, tgt_ids):
        src = self.embed_tokens(src_ids)
        src = src + self.vocab_bias + self.long_buffer.type_as(src)  # type: ignore[operator]
        tgt = self.embed_tokens(tgt_ids)
        tgt = self.bn(tgt)
        x = self.transformer(src, tgt)
        return self.output_proj(x)

    def get_loss(self, input, output):
        _, tgt = input
        return nn.functional.cross_entropy(
            output.view(-1, output.size(-1)), tgt.view(-1), reduction="sum"
        )

    def run_backward(self, loss):
        loss.backward()

    @staticmethod
    def init(
        group: dist.ProcessGroup,
        fsdp_init_mode: FSDPInitMode,
        cuda_init_mode: CUDAInitMode,
        fsdp_kwargs: Optional[Dict[str, Any]] = None,
        deterministic: bool = False,
        add_bn: bool = True,
    ) -> Union[nn.Module, FSDP]:
        """
        Initializes a :class:`TransformerWithSharedParams` instance.

        Args:
            fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap
                any modules with FSDP. If ``RECURSIVE``, then wraps with
                top-level FSDP. By default, the top-level FSDP uses the
                ``transformer_auto_wrap_policy()`` for encoder and decoder
                layers, but a different auto wrap policy may be specified via
                ``fsdp_kwargs``.
            cuda_init_mode (CUDAInitMode): Determines model movement to CUDA.
            fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments
                forwarded to the FSDP constructor.
            deterministic (bool): Whether to make the model deterministic
                across constructions.
            add_bn (bool): Whether to include batch norm in the model.
        """
        if fsdp_kwargs is None:
            fsdp_kwargs = {}
        if fsdp_init_mode == FSDPInitMode.NO_FSDP:
            return TransformerWithSharedParams(group, cuda_init_mode, add_bn, deterministic)
        elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
            # Default to the `transformer_auto_wrap_policy()`
            if "auto_wrap_policy" not in fsdp_kwargs:
                auto_wrap_policy = functools.partial(
                    transformer_auto_wrap_policy,
                    transformer_layer_cls={
                        TransformerEncoderLayer,
                        TransformerDecoderLayer,
                    },
                )
            else:
                auto_wrap_policy = fsdp_kwargs.pop("auto_wrap_policy")
            fsdp_model = FSDP(
                TransformerWithSharedParams(group, cuda_init_mode, add_bn, deterministic),
                group,
                auto_wrap_policy=auto_wrap_policy,
                **fsdp_kwargs,
            )
            if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
                fsdp_model = fsdp_model.cuda()
            return fsdp_model
        raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}")

    def get_ignored_modules(self):
        return [self.transformer]


class NestedWrappedModule(FSDPTestModel):
    def __init__(
        self,
        group: dist.ProcessGroup,
        wrap_fsdp: bool,
        cuda_init_mode: CUDAInitMode,
        deterministic: bool,
        **fsdp_kwargs,
    ):
        super().__init__()
        self.rank = group.rank()
        self.world_size = group.size()
        move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE

        def _maybe_wrap(layer):
            if wrap_fsdp:
                return FSDP(layer, group, **fsdp_kwargs)
            return layer

        if deterministic:
            torch.manual_seed(0)
        self.module = nn.Sequential(
            _maybe_cuda(nn.Linear(8, 4), move_to_cuda),
            _maybe_wrap(
                nn.Sequential(
                    _maybe_wrap(_maybe_cuda(nn.Linear(4, 16), move_to_cuda)),
                    _maybe_cuda(nn.Linear(16, 16), move_to_cuda),
                ),
            ),
            _maybe_wrap(_maybe_cuda(nn.Linear(16, 4), move_to_cuda)),
            _maybe_cuda(nn.Linear(4, 8), move_to_cuda),
        )

    def get_input(self, device):
        torch.manual_seed(1 + self.rank)  # keep everything deterministic
        return (torch.rand(4, 8, device=device),)

    def forward(self, x):
        return self.module(x)

    def get_loss(self, input, output):
        loss = output.sum()
        return loss

    def run_backward(self, loss):
        loss.backward()

    @staticmethod
    def init(
        group: dist.ProcessGroup,
        fsdp_init_mode: FSDPInitMode,
        cuda_init_mode: CUDAInitMode,
        fsdp_kwargs: Optional[Dict[str, Any]] = None,
        deterministic: bool = False,
    ) -> nn.Module:
        """
        Initializes a :class:`NestedWrappedModule` instance.

        Args:
            fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap
                any modules with FSDP. If ``RECURSIVE``, then wraps some nested
                modules with FSDP but not the top-level module. The model may
                later be wrapped with a top-level FSDP external to this method
                if desired.
            cuda_init_mode (CUDAInitMode): Determines model movement to CUDA.
            fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments
                forwarded to the FSDP constructor.
            deterministic (bool): Whether to make the model deterministic
                across constructions.
        """
        if fsdp_kwargs is None:
            fsdp_kwargs = {}
        if fsdp_init_mode == FSDPInitMode.NO_FSDP:
            return NestedWrappedModule(
                group,
                wrap_fsdp=False,
                cuda_init_mode=cuda_init_mode,
                deterministic=deterministic,
            )
        elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
            # Does not wrap with top-level FSDP
            fsdp_model = NestedWrappedModule(
                group,
                wrap_fsdp=True,
                cuda_init_mode=cuda_init_mode,
                deterministic=deterministic,
                **fsdp_kwargs,
            )
            if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
                fsdp_model = fsdp_model.cuda()
            return fsdp_model
        raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}")


class AlwaysWrapNestedWrappedModule(NestedWrappedModule):
    @staticmethod
    def init(
        group: dist.ProcessGroup,
        fsdp_init_mode: FSDPInitMode,
        cuda_init_mode: CUDAInitMode,
        fsdp_kwargs: Optional[Dict[str, Any]] = None,
        deterministic: bool = False,
    ):
        """
        Initializes a :class:`NestedWrappedModule` instance, but unlike
        :meth:`NestedWrappedModule.init`, for the ``RECURSIVE`` init mode, this
        wraps with top-level FSDP and the ``always_wrap_policy()`` auto wrap
        policy.
        """
        super_ = super(AlwaysWrapNestedWrappedModule, AlwaysWrapNestedWrappedModule)
        model = super_.init(
            group=group,
            fsdp_init_mode=FSDPInitMode.NO_FSDP,
            cuda_init_mode=cuda_init_mode,
            fsdp_kwargs=fsdp_kwargs,
            deterministic=deterministic,
        )
        if fsdp_init_mode == FSDPInitMode.NO_FSDP:
            return model
        elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
            fsdp_model = FSDP(model, auto_wrap_policy=always_wrap_policy, **fsdp_kwargs)
            if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
                fsdp_model = fsdp_model.cuda()
            return fsdp_model


class ModuleWithDelay(FSDPTestModel):
    """This class wraps a :class:`FSDPTestModel` to optionally add a delay
    after computing the loss and/or before the gradient reduction."""
    def __init__(
        self,
        module: nn.Module,
        delay_after_loss_ms: int,
        delay_before_reduction_ms: int,
    ):
        super().__init__()
        self.delay_after_loss_ms = delay_after_loss_ms
        self.delay_before_reduction_ms = delay_before_reduction_ms
        self.module = module

    def get_input(self, device):
        return self.module.get_input(device)

    def forward(self, x):
        return self.module(x)

    def get_loss(self, input, output):
        loss = self.module.get_loss(input, output)
        if self.delay_after_loss_ms > 0:
            torch.cuda._sleep(int(self.delay_after_loss_ms * get_cycles_per_ms()))
        return loss

    def run_backward(self, loss):
        orig_reduce_scatter = torch.distributed._reduce_scatter_base

        def _delayed_reduce_scatter(*args, **kwargs):
            if self.delay_before_reduction_ms > 0:
                torch.cuda._sleep(
                    int(self.delay_before_reduction_ms * get_cycles_per_ms())
                )
            return orig_reduce_scatter(*args, **kwargs)

        with mock.patch(
            "torch.distributed._reduce_scatter_base", _delayed_reduce_scatter
        ):
            self.module.run_backward(loss)

    @staticmethod
    def init(
        module_class: Type[FSDPTestModel],
        *model_args: Any,
        delay_after_loss_ms: int,
        delay_before_reduction_ms: int,
        **model_kwargs: Any,
    ):
        """
        Args:
            module_class (Type[FSDPTestModel]): Wrapped module class to which
                to add delays.
            model_args: Positional arguments forwarded to the ``module_class``
                ``init()``.
            delay_after_loss_ms (int): Delay after computing the loss/before
                the optimizer step (in ms).
            delay_before_reduction_ms (int): Delay before reduce-scattering
                gradients (in ms).
            model_kwargs: Keyword arguments forwarded to the ``module_class``
                ``init()``.
        """
        return ModuleWithDelay(
            module_class.init(*model_args, **model_kwargs),
            delay_after_loss_ms,
            delay_before_reduction_ms,
        )

class NestedWrappedModuleWithDelay(ModuleWithDelay):
    @staticmethod
    def init(
        group: dist.ProcessGroup,
        fsdp_init_mode: FSDPInitMode,
        cuda_init_mode: CUDAInitMode = CUDAInitMode.CUDA_AFTER,
        fsdp_kwargs: Optional[Dict[str, Any]] = None,
        deterministic: bool = False,
        delay_after_loss_ms: int = 0,
        delay_before_reduction_ms: int = 0,
    ):
        return super(NestedWrappedModuleWithDelay, NestedWrappedModuleWithDelay).init(
            NestedWrappedModule,
            group=group,
            fsdp_init_mode=fsdp_init_mode,
            cuda_init_mode=cuda_init_mode,
            fsdp_kwargs=fsdp_kwargs,
            deterministic=deterministic,
            delay_after_loss_ms=delay_after_loss_ms,
            delay_before_reduction_ms=delay_before_reduction_ms,
        )


class DummyDDP(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module

    def forward(self, *args, **kwargs):
        return self.module(*args, **kwargs)


class MixtureOfExperts(NestedWrappedModule):
    def __init__(
        self,
        group: dist.ProcessGroup,
        wrap_fsdp: bool,
        cuda_init_mode: CUDAInitMode,
        delay_before_free_ms: int,
        deterministic: bool,
        **fsdp_kwargs,
    ):
        super().__init__(
            group=group,
            wrap_fsdp=wrap_fsdp,
            cuda_init_mode=cuda_init_mode,
            deterministic=deterministic,
        )
        self.group = group
        self.delay_before_free_ms = delay_before_free_ms
        self.wrap_fsdp = wrap_fsdp
        self.move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE
        if deterministic:
            # Give each rank different expert parameters
            torch.manual_seed(42 + self.rank)
        d_expert = 23
        d_shared = 12
        d_input = 8
        expert = _maybe_cuda(nn.Linear(d_expert, d_shared), self.move_to_cuda)

        self.num_expert_params = sum([p.numel() for p in expert.parameters()])
        for p in expert.parameters():
            p.expert = True  # type: ignore[attr-defined]

        if deterministic:
            # Keep all other parameters the same across ranks
            torch.manual_seed(0)

        shared = _maybe_cuda(nn.Linear(d_shared, d_expert), self.move_to_cuda)

        if wrap_fsdp:
            # we create a process group of size 1 for the expert params
            expert_group = torch.distributed.new_group(
                [group.rank()]
            )  # world size 1 means no shard
            expert = FSDP(expert, expert_group, **fsdp_kwargs)  # type: ignore[assignment]
            shared = FSDP(shared, group, **fsdp_kwargs)  # type: ignore[assignment]

        self.module = nn.Sequential(
            _maybe_cuda(nn.Linear(d_input, d_shared), self.move_to_cuda),
            shared,
            expert,
            _maybe_cuda(nn.Linear(d_shared, d_input), self.move_to_cuda)
        )

    def forward(self, x):
        if self.delay_before_free_ms > 0:
            expert = self.module[2]
            if isinstance(expert, FSDP):
                orig_reshard = self.module[2]._reshard

                def _free_full_params_with_delay(*args):
                    torch.cuda._sleep(
                        int(self.delay_before_free_ms * get_cycles_per_ms())
                    )
                    return orig_reshard(*args)

                assert hasattr(
                    expert, "_reshard"
                ), "expert FSDP module should have a `_reshard()` method"
                with mock.patch.object(
                    expert, "_reshard", _free_full_params_with_delay
                ):
                    return self.module(x)

        return self.module(x)

    def run_backward(self, loss):
        loss.backward()
        # Manually reduce gradients if not wrapped in FullyShardedDataParallel
        if not self.wrap_fsdp:
            with torch.no_grad():
                for p in self.parameters():
                    if hasattr(p, "expert"):
                        continue  # these params don't need grad reduction
                    p.grad.div_(self.world_size)
                    torch.distributed.all_reduce(p.grad, group=self.group)

    @staticmethod
    def init(
        group: dist.ProcessGroup,
        fsdp_init_mode: FSDPInitMode,
        cuda_init_mode: CUDAInitMode,
        fsdp_kwargs: Optional[Dict[str, Any]] = None,
        deterministic: bool = False,
        delay_before_free_ms: int = 0,
    ):
        """
        Initializes a :class:`MixtureOfExperts` instance.

        Args:
            fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap
                any modules with FSDP. If ``RECURSIVE``, then wraps some nested
                modules with FSDP, including the expert and shared layers, but
                not the top-level module. The model may later be wrapped with a
                top-level FSDP external to this method if desired.
            cuda_init_mode (CUDAInitMode): Determines model movement to CUDA.
            fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments
                forwarded to the FSDP constructor.
            deterministic (bool): Whether to make the model deterministic
                across constructions.
            delay_before_free_ms (int): Delay before resharding expert
                parameters in the forward pass (in ms).
        """
        if fsdp_kwargs is None:
            fsdp_kwargs = {}
        if fsdp_init_mode == FSDPInitMode.NO_FSDP:
            return MixtureOfExperts(
                group,
                wrap_fsdp=False,
                cuda_init_mode=cuda_init_mode,
                delay_before_free_ms=delay_before_free_ms,
                deterministic=deterministic,
            )
        elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
            # Does not wrap with top-level FSDP
            fsdp_model = MixtureOfExperts(
                group,
                wrap_fsdp=True,
                cuda_init_mode=cuda_init_mode,
                delay_before_free_ms=delay_before_free_ms,
                deterministic=deterministic,
                **fsdp_kwargs,
            )
            if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
                fsdp_model = fsdp_model.cuda()
            return fsdp_model
        raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}")


class FSDPTest(MultiProcessTestCase):
    def setUp(self):
        super(FSDPTest, self).setUp()
        self._spawn_processes()

    @property
    def world_size(self):
        return torch.cuda.device_count() if torch.cuda.is_available() else 4

    @property
    def process_group(self):
        return dist.distributed_c10d._get_default_group()

    @property
    def init_method(self):
        return "{}{file_name}".format(FILE_SCHEMA, file_name=self.file_name)

    def _check_cpu_offload(self, fsdp_model, cpu_offload):
        self.assertEqual(cpu_offload, fsdp_model.cpu_offload)

    def _check_backward_prefetch(self, fsdp_model, backward_prefetch):
        self.assertEqual(backward_prefetch, fsdp_model.backward_prefetch)

    def _check_forward_prefetch(self, fsdp_model, forward_prefetch):
        self.assertEqual(forward_prefetch, fsdp_model.forward_prefetch)

    def run_subtests(
        self,
        subtest_config: Dict[str, List[Any]],
        test_fn: Callable,
        *test_args,
        **test_kwargs: Any,
    ):
        """
        Runs a test function given by ``test_fn`` as a subtest according to the
        configurations specified by ``subtest_config``. This amortizes the
        costly setup overhead (including process spawn and initializing the
        process group) over the subtests.

        Args:
            subtest_config (Dict[str, List[Any]]): A mapping from subtest
                keyword argument name to a list of its possible values.
            test_fn (Callable): A callable that runs the actual test.
            test_args: Positional arguments to pass to ``test_fn``.
            test_kwargs: Keyword arguments to pass to ``test_fn``.
        """
        # Convert the config mapping to a list to have a fixed order
        subtest_config_items: List[Tuple[str, List[Any]]] = list(subtest_config.items())
        subtest_config_keys: List[str] = [item[0] for item in subtest_config_items]
        subtest_config_values: List[List[Any]] = [item[1] for item in subtest_config_items]
        for values in itertools.product(*subtest_config_values):
            # Map keyword to chosen value
            subtest_kwargs = {
                kwarg: value for kwarg, value in zip(subtest_config_keys, values)
            }
            with self.subTest(**subtest_kwargs):
                test_fn(*test_args, **test_kwargs, **subtest_kwargs)
            dist.barrier()

    @classmethod
    def _run(cls, rank, test_name, file_name, pipe):
        self = cls(test_name)
        self.rank = rank
        self.file_name = file_name

        print(f"dist init r={self.rank}, world={self.world_size}")

        # Specify gloo backend to make 'init_process_group()' succeed,
        # Actual tests will be skipped if there is no enough GPUs.
        backend = "nccl" if torch.cuda.is_available() else "gloo"

        try:
            dist.init_process_group(
                init_method=self.init_method,
                backend=backend,
                world_size=int(self.world_size),
                rank=self.rank,
            )
        except RuntimeError as e:
            if "recompile" in e.args[0]:
                sys.exit(TEST_SKIPS["backend_unavailable"].exit_code)

            raise

        if torch.cuda.is_available() and torch.cuda.device_count():
            torch.cuda.set_device(self.rank % torch.cuda.device_count())

        # Execute barrier prior to running test to ensure that every process
        # has finished initialization and that the following test
        # immediately exiting due to a skip doesn't cause flakiness.
        dist.barrier()

        self.run_test(test_name, pipe)

        dist.barrier()

        dist.destroy_process_group()
        sys.exit(0)

    def _train_for_several_steps(
        self,
        model: nn.Module,
        num_steps: int,
        autocast: bool,
        lr: float = 0.01,
        fsdp_cpu_offload: Optional[CPUOffload] = None,
        norm_type: Optional[Union[float, int]] = None,
        save_model: bool = False,
        mixed_precision: Optional[MixedPrecision] = None,
        enable_sharded_grad_scaler: bool = False,
        use_pure_fp16: bool = False,
    ):
        cpu_offload_params = fsdp_cpu_offload and fsdp_cpu_offload.offload_params

        model_device = next(model.parameters()).device
        sharded_grad_scaler = ShardedGradScaler(enabled=enable_sharded_grad_scaler)
        # use SGD with momentum instead of Adam, since Adam is scale invariant
        # and this makes it bad for tests
        optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
        for _ in range(num_steps):
            optim.zero_grad()
            with torch.cuda.amp.autocast(enabled=autocast):
                # Inputs always cuda regardless of cpu offloading, or model.device
                input = model.module.get_input(torch.device("cuda"))
                if use_pure_fp16 or (mixed_precision and not isinstance(model, FSDP)):
                    if isinstance(input, torch.Tensor):
                        input = input.half()
                    else:
                        input = tuple(x.half() for x in input)
                output = model(*input)
                # Post-forward, if CPU offloading model param should be on CPU.
                if cpu_offload_params and isinstance(model, FSDP):
                    for p in model.parameters():
                        # Params should always be on CPU
                        self.assertEqual(p.device, torch.device("cpu"))

                loss = model.module.get_loss(input, output).to(model_device)
            loss = sharded_grad_scaler.scale(loss)

            if not mixed_precision and not use_pure_fp16:
                assert (
                    loss.dtype == torch.float32
                ), "loss data type should be float32, as the original \
                    parameter data type is float32."
            else:
                if use_pure_fp16:
                    self.assertEqual(loss.dtype, torch.float16)
                # FSDP loss is fp16, DDP AMP loss is fp32
                elif isinstance(model, FSDP):
                    self.assertEqual(loss.dtype, mixed_precision.param_dtype)
                else:
                    self.assertEqual(loss.dtype, torch.float32)
            model.module.run_backward(loss)
            if norm_type is not None:
                max_norm = 0.3
                if isinstance(model, FSDP):
                    model.clip_grad_norm_(max_norm, norm_type)
                    total_norm_after_clip = _collect_total_grad_norm_fsdp(
                        model, norm_type, self.rank
                    )
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm, norm_type)
                    total_norm_after_clip = _collect_total_grad_norm_local(
                        model, norm_type
                    )
                self.assertTrue(total_norm_after_clip <= max_norm)
            # Post-backward, if CPU offloading model params should be on CPU.
            if cpu_offload_params and isinstance(model, FSDP):
                for p in model.parameters():
                    # Params should always be on CPU
                    self.assertEqual(p.device, torch.device("cpu"))
            # Unscale the gradients and step
            sharded_grad_scaler.step(optim)
            # Update the scale factor
            sharded_grad_scaler.update()
            # if save_model, simulate save + load.
            if save_model:
                state_dict = {k: v.clone() for k, v in model.state_dict().items()}
                # Zero params, if save/load state_dict did not work properly, this
                # would break the parity test with DDP.
                _zero_model(model)
                model.load_state_dict(state_dict)

        if isinstance(model, FSDP):
            model._assert_state(TrainingState_.IDLE)
        return loss.detach()

    def _test_fsdp_parity(
        self,
        model_class: Type[FSDPTestModel],
        fsdp_init_mode: FSDPInitMode,
        cuda_init_mode: CUDAInitMode,
        ref_init_fn: Optional[Callable] = None,
        num_iters: int = 2,
        save_model: bool = True,
        cpu_offload: CPUOffload = CPUOffload(),
        backward_prefetch: Optional[BackwardPrefetch] = None,
        sharding_strategy: Optional[ShardingStrategy] = None,
        mixed_precision: Optional[MixedPrecision] = None,
        forward_prefetch: bool = False,
        enable_sharded_grad_scaler: bool = False,
        use_pure_fp16: bool = False,
        norm_type: Optional[Union[float, int]] = None,
        init_kwargs: Optional[Dict[str, Any]] = None,
        **fsdp_kwargs,
    ):
        """
        Tests FSDP training against a reference, which defaults to DDP but
        may be customized with ``ref_init_fn``.

        Args:
            model_class (Type[FSDPTestModel]): A model class that inherits from
                ``FSDPTestModel``, which defines the expected interface.
            fsdp_init_mode (FSDPInitMode): The mode to initialize the
                FSDP-wrapped model. This should not be ``NO_FSDP``.
            ref_init_fn (Optional[Callable]): A callable to invoke that wraps a
                non-wrapped model to construct the reference model, where this
                wrapper should provide data parallel semantics. If ``None``,
                then the callable defaults to the DDP constructor.
        """
        assert fsdp_init_mode != FSDPInitMode.NO_FSDP, "Expects an FSDP init mode that wraps with FSDP"
        if init_kwargs is None:
            init_kwargs = {}
        lr = 1e-2
        rank = self.process_group.rank()
        # Establish reference behavior with DDP
        model = model_class.init(
            self.process_group,
            FSDPInitMode.NO_FSDP,
            CUDAInitMode.CUDA_BEFORE,
            deterministic=True,
            **init_kwargs,
        )
        if ref_init_fn is None:
            ref_model = DDP(model, device_ids=[rank], output_device=rank)
        else:
            ref_model = ref_init_fn(model)
        if use_pure_fp16:
            ref_model = ref_model.half()
        ref_loss = self._train_for_several_steps(
            ref_model,
            num_iters,
            autocast=mixed_precision is not None,
            lr=lr,
            fsdp_cpu_offload=cpu_offload,
            mixed_precision=mixed_precision,
            norm_type=norm_type,
            enable_sharded_grad_scaler=enable_sharded_grad_scaler,
            use_pure_fp16=use_pure_fp16,
        )
        ddp_params = list(ref_model.parameters())
        # Check against FSDP behavior
        fsdp_kwargs.update(
            {
                "cpu_offload": cpu_offload,
                "backward_prefetch": backward_prefetch,
                "sharding_strategy": sharding_strategy,
                "mixed_precision": mixed_precision,
                "forward_prefetch": forward_prefetch,
            }
        )
        try:
            fsdp_model = model_class.init(
                self.process_group,
                fsdp_init_mode,
                cuda_init_mode,
                fsdp_kwargs,
                deterministic=True,
                **init_kwargs,
            )
        except Exception as e:
            raise ValueError(f"Initializing {model_class} raised error {str(e)}")
        if not isinstance(fsdp_model, FSDP):
            # Enforce that we wrap with top-level FSDP since we are comparing
            # assuming a data parallel reference and some test models may not
            # do so in their `init()` method
            fsdp_model = FSDP(fsdp_model, self.process_group, **fsdp_kwargs)
        if use_pure_fp16:
            # Change the model parameter dtype after FSDP initialization
            fsdp_model = fsdp_model.half()
        if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
            fsdp_model = fsdp_model.cuda()
        offload_params = cpu_offload is not None and cpu_offload.offload_params
        # Offloading parameters with `CUDA_AFTER` should raise an error during
        # lazy initialization due to the parameter devices not being CPU;
        # otherwise, all parameter devices should be CPU
        expects_device_error = offload_params and cuda_init_mode == CUDAInitMode.CUDA_AFTER
        expects_cpu_device = offload_params and cuda_init_mode != CUDAInitMode.CUDA_AFTER
        if expects_cpu_device:
            cpu_device = torch.device("cpu")
            for param in fsdp_model.parameters():
                self.assertEqual(param.device, cpu_device)
        context = (
            self.assertRaisesRegex(AssertionError, "Expected param to be on CPU")
            if expects_device_error else suppress()
        )
        with context:
            fsdp_loss = self._train_for_several_steps(
                fsdp_model,
                num_iters,
                autocast=False,
                lr=lr,
                fsdp_cpu_offload=cpu_offload,
                save_model=save_model,
                mixed_precision=mixed_precision,
                norm_type=norm_type,
                enable_sharded_grad_scaler=enable_sharded_grad_scaler,
                use_pure_fp16=use_pure_fp16,
            )
        # No need to check for parameter and loss parity if expecting an error
        if expects_device_error:
            return
        # Check parameter devices are CPU if offloading to CPU before calling
        # `get_full_params()`, which will cast the parameters to FP32
        if offload_params:
            for param in fsdp_model.parameters():
                self.assertEqual(param.device, cpu_device)
            fsdp_loss = fsdp_loss.cuda()
        fsdp_unsharded_params = get_full_params(fsdp_model)
        torch.testing.assert_allclose(ref_loss, fsdp_loss)
        # Do not check for parameter parity if using mixed precision since (1)
        # the DDP parameters are in FP16 (from `half()`) while the FSDP
        # parameters are in FP32 (from `summon_full_params()`) and (2) DDP runs
        # the optimizer in FP16 while FSDP runs it in FP32
        if mixed_precision is not None:
            self.assertEqual(
                ddp_params,
                fsdp_unsharded_params,
                exact_device=True,
                msg="FSDP did not match DDP",
            )


class SkipModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = nn.Linear(10, 10, bias=False)

    def forward(self, x):
        return self.lin(x)


class NestedLinear(nn.Module):
    def __init__(self, fsdp_wrap):
        super().__init__()
        if fsdp_wrap:
            self.nested_linear = wrap(nn.Linear(10, 10, bias=False).cuda())
        else:
            self.nested_linear = nn.Linear(10, 10, bias=False).cuda()

    def forward(self, x):
        return self.nested_linear(x)


class SkipModel(nn.Module):
    def __init__(self, double_nest):
        super().__init__()
        self.linear = nn.Linear(10, 10, bias=False).cuda()
        self.linear_skip = SkipModule().cuda()
        self.nested_linear = wrap(NestedLinear(fsdp_wrap=double_nest))

    def forward(self, x):
        x = self.linear(x)
        x = self.linear_skip(x)
        x = self.nested_linear(x)
        return x


def _collect_total_grad_norm_fsdp(model, norm_type, rank):
    total_norm = _collect_total_grad_norm_local(model, norm_type)
    op = torch.distributed.ReduceOp.SUM
    if norm_type == inf:
        op = torch.distributed.ReduceOp.MAX
        norm_type = 1.0
    return_norm = torch.tensor(total_norm ** norm_type, device=rank)
    dist.all_reduce(return_norm, op=op)
    return return_norm ** (1.0 / norm_type)


def _collect_total_grad_norm_local(model, norm_type):
    if norm_type == inf:
        return max(p.grad.abs().max() for p in model.parameters())
    else:
        total_norm = 0.0
        for p in model.parameters():
            local_norm = torch.linalg.vector_norm(p.grad, norm_type, dtype=torch.float32)
            total_norm += local_norm ** norm_type
        return total_norm ** (1.0 / norm_type)