File size: 38,038 Bytes
f4cade0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Optional, Union

import torch
import torch.nn as nn
from torch.distributed.tensor import (
    DeviceMesh,
    distribute_module,
    distribute_tensor,
    DTensor,
    Replicate,
    Shard,
)
from torch.distributed.tensor.placement_types import Placement


__all__ = [
    "ParallelStyle",
    "RowwiseParallel",
    "SequenceParallel",
    "ColwiseParallel",
    "PrepareModuleInput",
    "PrepareModuleInputOutput",
    "PrepareModuleOutput",
]


class ParallelStyle(ABC):
    """

    The parallel style contract defines how the module or submodule should be parallelized.



    It only defines the ``apply`` method for ``parallelize_module`` to use, this allows maximum

    flexibility for different kind of style implementations.

    """

    src_data_rank: Optional[int] = 0

    @abstractmethod
    def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: ...


class ColwiseParallel(ParallelStyle):
    """

    Partition a compatible nn.Module in a column-wise fashion. Currently supports nn.Linear and nn.Embedding.

    Users can compose it together with RowwiseParallel to achieve the sharding of more complicated modules.

    (i.e. MLP, Attention)



    Keyword Args:

        input_layouts (Placement, optional):

            The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to

            become a DTensor. If not specified, we assume the input tensor to be replicated.

        output_layouts (Placement, optional):

            The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module

            with the user desired layout. If not specified, the output tensor is sharded on the last dimension.

        use_local_output (bool, optional):

            Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True.

    Returns:

        A :class:`ParallelStyle` object that represents Colwise sharding of the nn.Module.



    Example::

        >>> # xdoctest: +SKIP(failing)

        >>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel

        >>> from torch.distributed.device_mesh import init_device_mesh

        >>> ...

        >>> m = Model(...)  # m is a nn.Module that contains a "w1" nn.Linear submodule

        >>> tp_mesh = init_device_mesh("cuda", (8,))

        >>>

        >>> # By default, the input of the "w1" Linear will be converted to Replicated DTensor

        >>> # and the output of "w1" will return :class:`torch.Tensor` that shards on the last dim.

        >>>

        >>> sharded_mod = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel()})

        >>> ...



    .. note:: By default ``ColwiseParallel`` output is sharded on the last dimension if the ``output_layouts`` not

        specified, if there're operators that require specific tensor shape (i.e. before the paired ``RowwiseParallel``),

        keep in mind that if the output is sharded the operator might need to be adjusted to the sharded size.

    """

    def __init__(

        self,

        *,

        input_layouts: Optional[Placement] = None,

        output_layouts: Optional[Placement] = None,

        use_local_output: bool = True,

    ):
        super().__init__()
        self.input_layouts = (input_layouts or Replicate(),)
        self.output_layouts = (output_layouts or Shard(-1),)
        # colwise linear runtime sharding (desired sharding):
        # 1. requires replicate input
        # 2. shard output on last dim
        self.desired_input_layouts = (Replicate(),)
        self.use_local_output = use_local_output

    @staticmethod
    def _prepare_input_fn(

        input_layouts, desired_input_layouts, mod, inputs, device_mesh

    ):
        # TODO: figure out dynamo support for instance method and switch this to instance method

        # annotate module input placements/sharding with input_layouts
        input_tensor = inputs[0]
        if not isinstance(input_tensor, DTensor):
            input_tensor = DTensor.from_local(
                input_tensor, device_mesh, input_layouts, run_check=False
            )

        # transform the input layouts to the desired layouts of ColwiseParallel
        if input_layouts != desired_input_layouts:
            input_tensor = input_tensor.redistribute(
                placements=desired_input_layouts, async_op=True
            )
        return input_tensor

    def _partition_linear_fn(self, name, module, device_mesh):
        # colwise shard weight/bias to Shard(0), weight be Shard(0)
        # means Colwise as Linear is input * weight^T + bias, where
        # weight would become Shard(1)
        for name, param in module.named_parameters():
            dist_param = nn.Parameter(
                distribute_tensor(
                    param, device_mesh, [Shard(0)], src_data_rank=self.src_data_rank
                )
            )
            module.register_parameter(name, dist_param)

    def _partition_embedding_fn(self, name, module, device_mesh):
        # colwise shard embedding.weight is straight forward as Shard(1)
        for name, param in module.named_parameters():
            dist_param = nn.Parameter(
                distribute_tensor(
                    param, device_mesh, [Shard(1)], src_data_rank=self.src_data_rank
                )
            )
            module.register_parameter(name, dist_param)

    @staticmethod
    def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
        # outputs is a shard on last dimension DTensor, i.e. Shard(-1)
        if outputs.placements != output_layouts:
            outputs = outputs.redistribute(placements=output_layouts, async_op=True)
        # back to local tensor
        return outputs.to_local() if use_local_output else outputs

    def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
        if isinstance(module, nn.Linear):
            partition_fn = self._partition_linear_fn
        elif isinstance(module, nn.Embedding):
            partition_fn = self._partition_embedding_fn
        else:
            raise NotImplementedError(
                "ColwiseParallel currently only support nn.Linear and nn.Embedding!"
            )

        return distribute_module(
            module,
            device_mesh,
            partition_fn,
            partial(
                self._prepare_input_fn, self.input_layouts, self.desired_input_layouts
            ),
            partial(
                self._prepare_output_fn, self.output_layouts, self.use_local_output
            ),
        )

    def __repr__(self) -> str:
        tmpstr = self.__class__.__name__ + "("
        tmpstr += f"input_layouts={self.input_layouts}, "
        tmpstr += f"output_layouts={self.output_layouts}, "
        tmpstr += f"use_local_output={self.use_local_output}"
        tmpstr += ")"
        return tmpstr


class RowwiseParallel(ParallelStyle):
    """

    Partition a compatible nn.Module in a row-wise fashion. Currently supports nn.Linear and nn.Embedding.

    Users can compose it with ColwiseParallel to achieve the sharding of more complicated modules.

    (i.e. MLP, Attention)



    Keyword Args:

        input_layouts (Placement, optional):

            The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to

            become a DTensor. If not specified, we assume the input tensor to be sharded on the last dimension.

        output_layouts (Placement, optional):

            The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module

            with the user desired layout. If not specified, the output tensor is replicated.

        use_local_output (bool, optional):

            Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True.

    Returns:

        A :class:`ParallelStyle` object that represents Rowwise sharding of the nn.Module.



    Example::

        >>> # xdoctest: +SKIP(failing)

        >>> from torch.distributed.tensor.parallel import parallelize_module, RowwiseParallel

        >>> from torch.distributed.device_mesh import init_device_mesh

        >>> ...

        >>> m = Model(...)  # m is a nn.Module that contains a "w2" nn.Linear submodule

        >>> tp_mesh = init_device_mesh("cuda", (8,))

        >>>

        >>> # By default, the input of the "w2" Linear will be converted to DTensor that shards on the last dim

        >>> # and the output of "w2" will return a replicated :class:`torch.Tensor`.

        >>>

        >>> sharded_mod = parallelize_module(m, tp_mesh, {"w2": RowwiseParallel()}),

        >>> ...

    """

    def __init__(

        self,

        *,

        input_layouts: Optional[Placement] = None,

        output_layouts: Optional[Placement] = None,

        use_local_output: bool = True,

    ):
        super().__init__()
        self.input_layouts = (input_layouts or Shard(-1),)
        self.output_layouts = (output_layouts or Replicate(),)
        self.use_local_output = use_local_output

    @staticmethod
    def _prepare_input_fn(

        input_layouts, desired_input_layouts, mod, inputs, device_mesh

    ):
        input_tensor = inputs[0]
        if not isinstance(input_tensor, DTensor):
            input_tensor = DTensor.from_local(
                input_tensor, device_mesh, input_layouts, run_check=False
            )

        if input_layouts != desired_input_layouts:
            input_tensor = input_tensor.redistribute(
                placements=desired_input_layouts, async_op=True
            )
        return input_tensor

    def _partition_linear_fn(self, name, module, device_mesh):
        # Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
        # means Rowwise as nn.Linear is input * weight^T + bias, where
        # weight would become Shard(0)
        module.register_parameter(
            "weight",
            nn.Parameter(
                distribute_tensor(
                    module.weight,
                    device_mesh,
                    [Shard(1)],
                    src_data_rank=self.src_data_rank,
                )
            ),
        )
        if getattr(module, "bias", None) is not None:
            # The Linear module has bias
            module.register_parameter(
                "bias",
                nn.Parameter(
                    distribute_tensor(
                        module.bias,
                        device_mesh,
                        [Replicate()],
                        src_data_rank=self.src_data_rank,
                    )
                ),
            )

    def _partition_embedding_fn(self, name, module, device_mesh):
        # rowwise shard embedding.weight is Shard(0)
        for name, param in module.named_parameters():
            dist_param = nn.Parameter(
                distribute_tensor(
                    param, device_mesh, [Shard(0)], src_data_rank=self.src_data_rank
                )
            )
            module.register_parameter(name, dist_param)

    @staticmethod
    def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
        # Rowwise sharding produces partial output, depending on output layouts:
        # 1. to replicate -> allreduce
        # 2. to shard -> reduce_scatter
        if outputs.placements != output_layouts:
            outputs = outputs.redistribute(placements=output_layouts, async_op=True)
        # back to local tensor if use_local_output is True
        return outputs.to_local() if use_local_output else outputs

    def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
        if isinstance(module, nn.Linear):
            partition_fn = self._partition_linear_fn
            # rowwise linear runtime sharding requires input tensor shard on last dim
            self.desired_input_layouts: tuple[Placement, ...] = (Shard(-1),)
        elif isinstance(module, nn.Embedding):
            partition_fn = self._partition_embedding_fn
            # rowwise embedding runtime sharding requires input tensor replicated
            self.desired_input_layouts = (Replicate(),)
        else:
            raise NotImplementedError(
                "RowwiseParallel currently only support nn.Linear and nn.Embedding!"
            )

        return distribute_module(
            module,
            device_mesh,
            partition_fn,
            partial(
                self._prepare_input_fn, self.input_layouts, self.desired_input_layouts
            ),
            partial(
                self._prepare_output_fn, self.output_layouts, self.use_local_output
            ),
        )

    def __repr__(self) -> str:
        tmpstr = self.__class__.__name__ + "("
        tmpstr += f"input_layouts={self.input_layouts}, "
        tmpstr += f"output_layouts={self.output_layouts}, "
        tmpstr += f"use_local_output={self.use_local_output}"
        tmpstr += ")"
        return tmpstr


class SequenceParallel(ParallelStyle):
    """

    SequenceParallel replicates a compatible ``nn.Module`` parameters and runs the sharded computation with

    input sharded on the sequence dimension. This currently supports ``nn.LayerNorm``, ``nn.Dropout``, and the

    `RMSNorm python implementation <https://github.com/facebookresearch/llama/blob/main/llama/model.py#L34>`__



    This style implements the operation that is described in the paper

    `Reducing Activation Recomputation in Large Transformer Models <https://arxiv.org/abs/2205.05198>`__



    If the input passed in to this ``nn.Module`` is a :class:`torch.Tensor`, it assumes that the input is already sharded

    on the sequence dimension and converts the input to a :class:`DTensor` sharded on the sequence dimension. If the input

    passed in to this ``nn.Module`` is already a :class:`DTensor` but is not sharded on the sequence dimension, it would

    redistribute the input to be sharded on the sequence dimension.



    The output of the ``nn.Module`` will be sharded on the sequence dimension.



    Keyword Args:

        sequence_dim (int, optional):

            The sequence dimension of the input tensor for the ``nn.Module``, this is used to annotate the input tensor to

            become a DTensor that is sharded on the sequence dimension, default: 1.

        use_local_output (bool, optional):

            Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: False.

    Returns:

        A :class:`ParallelStyle` object that represents Sequence Parallel of the ``nn.Module``.



    Example::

        >>> # xdoctest: +SKIP(failing)

        >>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel

        >>> from torch.distributed.device_mesh import init_device_mesh

        >>> ...

        >>> m = Model(...)  # m is a nn.Module that contains a "norm" nn.LayerNorm submodule

        >>> tp_mesh = init_device_mesh("cuda", (8,))

        >>>

        >>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim

        >>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`.

        >>>

        >>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}),

        >>> ...



    .. note:: SequenceParallel style assumes ones initialization if there are weights in the nn.Module (i.e.

        ``nn.LayerNorm`` or ``RMSNorm``, and they by default have ones initialization). If you have custom

        inits for the weights on those modules, you need to broadcast the weights before/after parallelizing

        to ensure that they are replicated.

    """

    def __init__(self, *, sequence_dim: int = 1, use_local_output: bool = False):
        super().__init__()
        self.sequence_sharding = (Shard(sequence_dim),)
        self.use_local_output = use_local_output

    def _replicate_module_fn(

        self, name: str, module: nn.Module, device_mesh: DeviceMesh

    ):
        for p_name, param in module.named_parameters():
            # simple replication with fixed ones_ init from LayerNorm/RMSNorm, which allow
            # us to simply just use from_local
            replicated_param = torch.nn.Parameter(
                DTensor.from_local(param, device_mesh, [Replicate()], run_check=False)
            )
            module.register_parameter(p_name, replicated_param)

    @staticmethod
    def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh):
        input_tensor = inputs[0]
        if isinstance(input_tensor, DTensor):
            # if the passed in input DTensor is not sharded on the sequence dim, we need to redistribute it
            if input_tensor.placements != sequence_sharding:
                input_tensor = input_tensor.redistribute(
                    placements=sequence_sharding, async_op=True
                )
            return input_tensor
        elif isinstance(input_tensor, torch.Tensor):
            # assume the input passed in already sharded on the sequence dim and create the DTensor
            return DTensor.from_local(
                input_tensor, device_mesh, sequence_sharding, run_check=False
            )
        else:
            raise ValueError(
                f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}"
            )

    @staticmethod
    def _prepare_output_fn(use_local_output, mod, outputs, device_mesh):
        return outputs.to_local() if use_local_output else outputs

    def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
        return distribute_module(
            module,
            device_mesh,
            self._replicate_module_fn,
            partial(self._prepare_input_fn, self.sequence_sharding),
            partial(self._prepare_output_fn, self.use_local_output),
        )

    def __repr__(self) -> str:
        tmpstr = self.__class__.__name__ + "("
        if len(self.sequence_sharding) == 1:
            tmpstr += f"sequence_dim={self.sequence_sharding[0].dim}, "
        tmpstr += f"use_local_output={self.use_local_output}"
        tmpstr += ")"
        return tmpstr


class PrepareModuleInput(ParallelStyle):
    """

    Configure the nn.Module's inputs to convert the input tensors of the nn.Module to DTensors at runtime according to

    ``input_layouts``, and perform layout redistribution according to the ``desired_input_layouts``.



    Keyword Args:

        input_layouts (Union[Placement, Tuple[Optional[Placement]]]):

            The DTensor layouts of input tensors for the nn.Module, this is used to convert the input tensors to

            DTensors. If some inputs are not torch.Tensor or no need to convert to DTensors, ``None`` need to be specified

            as a placeholder. default: None.

        desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]):

            The desired DTensor layout of input tensors for the nn.Module, this is used to ensure the inputs of the nn.Module

            have the desired DTensor layouts. This argument needs to have the same length with ``input_layouts``. default: None.

        input_kwarg_layouts (Dict[str, Placement]):

            The DTensor layouts of input kwargs for the nn.Module, this is used to convert the input kwarg tensors to DTensors.

            default: None

        desired_input_kwarg_layouts: (Dict[str, Placement]):

            The desired DTensor layout of input kwargs for the nn.Module, this is used to ensure the inputs of the nn.Module

            have the desired DTensor layouts. default: None.

        use_local_output (bool, optional):

            Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module inputs, default: False.

    Returns:

        A :class:`ParallelStyle` object that prepares the sharding layouts of the nn.Module's inputs.



    Example::

        >>> # xdoctest: +SKIP(failing)

        >>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput

        >>> from torch.distributed.device_mesh import init_device_mesh

        >>> ...

        >>> block = TransformerBlock(...)  # block is a nn.Module that contains an "attn" Attention submodule

        >>> tp_mesh = init_device_mesh("cuda", (8,))

        >>>

        >>> # According to the style specified below, the first input of attn will be annotated to Sharded DTensor

        >>> # and then redistributed to Replicated DTensor.

        >>> parallelize_module(

        >>>     block, # this can be a submodule or module

        >>>     tp_mesh,

        >>>     parallelize_plan={

        >>>         "attn": PrepareModuleInput(

        >>>             input_layouts=(Shard(0), None, None, ...),

        >>>             desired_input_layouts=(Replicate(), None, None, ...)

        >>>         ),

        >>>     }

        >>> )

    """

    def __init__(

        self,

        *,

        input_layouts: Optional[Union[Placement, tuple[Optional[Placement]]]] = None,

        desired_input_layouts: Optional[

            Union[Placement, tuple[Optional[Placement]]]

        ] = None,

        input_kwarg_layouts: Optional[dict[str, Placement]] = None,

        desired_input_kwarg_layouts: Optional[dict[str, Placement]] = None,

        use_local_output: bool = False,

    ):
        self.input_layouts = (
            (input_layouts,) if isinstance(input_layouts, Placement) else input_layouts
        )
        self.desired_input_layouts = (
            (desired_input_layouts,)
            if isinstance(desired_input_layouts, Placement)
            else desired_input_layouts
        )
        self.use_local_output = use_local_output
        if self.input_layouts is not None:
            assert self.desired_input_layouts is not None, (
                "desired module inputs should not be None!"
            )
            assert len(self.input_layouts) == len(self.desired_input_layouts), (
                "input_layouts and desired_input_layouts should have same length!"
            )
        self.with_kwargs = input_kwarg_layouts is not None
        self.input_kwarg_layouts = input_kwarg_layouts or {}
        self.desired_input_kwarg_layouts = desired_input_kwarg_layouts or {}
        if self.with_kwargs:
            assert len(self.input_kwarg_layouts) == len(
                self.desired_input_kwarg_layouts
            ), (
                "input_kwarg_layouts and desired_input_kwarg_layouts should have same length!"
            )

    def _prepare_input_arg(

        self,

        input: Any,

        mesh: DeviceMesh,

        input_layout: Optional[Placement],

        desired_layout: Optional[Placement],

    ):
        if input_layout is not None:
            if isinstance(input, DTensor):
                # TODO: re-enable the check once we fix the compile path
                # assert inp.placements[0] == input_layout
                dt_inp = input
            else:
                assert isinstance(input, torch.Tensor), (
                    "expecting input to be a torch.Tensor!"
                )
                dt_inp = DTensor.from_local(
                    input, mesh, (input_layout,), run_check=False
                )

            if desired_layout is not None and input_layout != desired_layout:
                dt_inp = dt_inp.redistribute(placements=(desired_layout,))

            return dt_inp.to_local() if self.use_local_output else dt_inp
        else:
            return input

    def _prepare_input_fn(self, inputs, device_mesh):
        if self.input_layouts is None:
            return inputs
        prepared_inputs = []
        if not isinstance(inputs, tuple):
            inputs = (inputs,)
        if len(inputs) != len(self.input_layouts):
            raise ValueError("module inputs and input_layouts should have same length!")

        assert self.desired_input_layouts is not None, (
            "desired module inputs should not be None!"
        )
        for inp, input_layout, desired_layout in zip(
            inputs, self.input_layouts, self.desired_input_layouts
        ):
            prepared_inputs.append(
                self._prepare_input_arg(inp, device_mesh, input_layout, desired_layout)
            )
        return tuple(prepared_inputs)

    def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh):
        prepared_arg_inputs = self._prepare_input_fn(inputs, device_mesh)
        prepared_kwarg_inputs = {}
        for kwarg_key in kwarg_inputs.keys():
            kwarg_val = kwarg_inputs[kwarg_key]
            input_layout = self.input_kwarg_layouts.get(kwarg_key)
            desired_input_layout = self.desired_input_kwarg_layouts.get(kwarg_key)

            prepared_kwarg_inputs[kwarg_key] = self._prepare_input_arg(
                kwarg_val, device_mesh, input_layout, desired_input_layout
            )

        return (prepared_arg_inputs, prepared_kwarg_inputs)

    def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
        if self.with_kwargs:
            module.register_forward_pre_hook(
                lambda _, inputs, kwargs: self._prepare_input_kwarg_fn(
                    inputs, kwargs, device_mesh
                ),
                with_kwargs=True,
            )  # type: ignore[misc]
        else:
            module.register_forward_pre_hook(
                lambda _, inputs: self._prepare_input_fn(inputs, device_mesh)
            )  # type: ignore[misc, call-arg]
        return module

    def __repr__(self) -> str:
        tmpstr = self.__class__.__name__ + "("
        tmpstr += f"input_layouts={self.input_layouts}, "
        tmpstr += f"desired_input_layouts={self.desired_input_layouts}, "
        tmpstr += f"input_kwarg_layouts={self.input_kwarg_layouts}, "
        tmpstr += f"desired_input_kwarg_layouts={self.desired_input_kwarg_layouts}, "
        tmpstr += f"use_local_output={self.use_local_output}"
        tmpstr += ")"
        return tmpstr


class PrepareModuleOutput(ParallelStyle):
    """

    Configure the nn.Module's outputs to convert the output tensors of the nn.Module to DTensors at runtime according to

    ``output_layouts``, and perform layout redistribution according to the ``desired_output_layouts``.



    Keyword Args:

        output_layouts (Union[Placement, Tuple[Placement]]):

            The DTensor layouts of output tensors for the nn.Module, this is used to convert the output tensors to

            DTensors if they are :class:`torch.Tensor`. If some outputs are not torch.Tensor or no need to convert to DTensors,

            ``None`` need to be specified as a placeholder.

        desired_output_layouts (Union[Placement, Tuple[Placement]]):

            The desired DTensor layouts of output tensors for the nn.Module, this is used to ensure the outputs of the nn.Module

            have the desired DTensor layouts.

        use_local_output (bool, optional):

            Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module outputs, default: True.

    Returns:

        A ParallelStyle object that prepares the sharding layouts of the nn.Module's outputs.



    Example::

        >>> # xdoctest: +SKIP(failing)

        >>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleOutput

        >>> from torch.distributed.device_mesh import init_device_mesh

        >>> ...

        >>> block = TransformerBlock(...)  # block is a nn.Module that contains an "attn" Attention submodule

        >>> tp_mesh = init_device_mesh("cuda", (8,))

        >>>

        >>> # According to the style specified below, the output of the TransformerBlock will be converted to Replicated DTensor

        >>> # and then redistributed to Sharded DTensor.

        >>> parallelize_module(

        >>>     block, # this can be a submodule or module

        >>>     tp_mesh,

        >>>     parallelize_plan = PrepareModuleOutput(

        >>>         output_layouts=Replicate(),

        >>>         desired_output_layouts=Shard(0)

        >>>     )

        >>> )

    """

    def __init__(

        self,

        *,

        output_layouts: Union[Placement, tuple[Placement]],

        desired_output_layouts: Union[Placement, tuple[Placement]],

        use_local_output: bool = True,

    ):
        self.output_layouts = (
            (output_layouts,)
            if isinstance(output_layouts, Placement)
            else output_layouts
        )
        self.desired_output_layouts = (
            (desired_output_layouts,)
            if isinstance(desired_output_layouts, Placement)
            else desired_output_layouts
        )
        self.use_local_output = use_local_output
        assert len(self.output_layouts) == len(self.desired_output_layouts), (
            "output_layouts and desired_output_layouts should have same length!"
        )

    def _prepare_out_fn(self, outputs, device_mesh):
        prepared_outputs = []
        if not isinstance(outputs, tuple):
            outputs = (outputs,)
        if len(outputs) != len(self.output_layouts):
            raise ValueError(
                "module outputs and output_layouts should have same length!"
            )
        for out, out_layout, desired_out_layout in zip(
            outputs, self.output_layouts, self.desired_output_layouts
        ):
            if out_layout is not None:
                if isinstance(out, DTensor):
                    # TODO: re-enable the check once we fix the compile path
                    # assert out.placements[0] == out_layout
                    dt_out = out
                else:
                    dt_out = DTensor.from_local(
                        out, device_mesh, (out_layout,), run_check=False
                    )

                if out_layout != desired_out_layout:
                    dt_out = dt_out.redistribute(placements=(desired_out_layout,))
                prepared_outputs.append(
                    dt_out.to_local() if self.use_local_output else dt_out
                )
            else:
                prepared_outputs.append(out)
        if len(prepared_outputs) == 1:
            return prepared_outputs[0]
        else:
            return tuple(prepared_outputs)

    def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
        module.register_forward_hook(
            lambda _, inputs, outputs: self._prepare_out_fn(outputs, device_mesh)
        )  # type: ignore[misc, call-arg]
        return module

    def __repr__(self) -> str:
        tmpstr = self.__class__.__name__ + "("
        tmpstr += f"output_layouts={self.output_layouts}, "
        tmpstr += f"desired_output_layouts={self.desired_output_layouts}, "
        tmpstr += f"use_local_output={self.use_local_output}"
        tmpstr += ")"
        return tmpstr


class PrepareModuleInputOutput(ParallelStyle):
    """

    Configure the nn.Module's inputs (and outputs) to convert the input tensors (and output tensors, respectively) of the nn.Module

    to DTensors at runtime according to ``input_layouts`` (and output_layouts, respectively), and perform layout redistribution

    according to the ``desired_input_layouts`` (and ``desired_output_layouts``, respectively). This is a combination of

    :class:`PrepareModuleInput` and :class:`PrepareModuleOutput`.



    Keyword Args:

        input_layouts (Union[Placement, Tuple[Optional[Placement]]]):

            The DTensor layouts of input tensors for the nn.Module, this is used to convert the input tensors to

            DTensors. If some inputs are not torch.Tensor or no need to convert to DTensors, ``None`` need to be specified

            as a placeholder. default: None.

        desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]):

            The desired DTensor layout of input tensors for the nn.Module, this is used to ensure the inputs of the nn.Module

            have the desired DTensor layouts. This argument needs to have the same length with ``input_layouts``. default: None.

        input_kwarg_layouts (Dict[str, Placement]):

            The DTensor layouts of input kwargs for the nn.Module, this is used to convert the input kwarg tensors to DTensors.

            default: None

        desired_input_kwarg_layouts: (Dict[str, Placement]):

            The desired DTensor layout of input kwargs for the nn.Module, this is used to ensure the inputs of the nn.Module

            have the desired DTensor layouts. default: None.

        use_local_input (bool, optional):

            Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module inputs, default: False.

        output_layouts (Union[Placement, Tuple[Placement]]):

            The DTensor layouts of output tensors for the nn.Module, this is used to convert the output tensors to

            DTensors if they are :class:`torch.Tensor`. If some outputs are not torch.Tensor or no need to convert to DTensors,

            ``None`` need to be specified as a placeholder.

        desired_output_layouts (Union[Placement, Tuple[Placement]]):

            The desired DTensor layouts of output tensors for the nn.Module, this is used to ensure the outputs of the nn.Module

            have the desired DTensor layouts.

        use_local_output (bool, optional):

            Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module outputs, default: True.

    Returns:

        A :class:`ParallelStyle` object that prepares the sharding layouts of the nn.Module's inputs and outputs.



    Example::

        >>> # xdoctest: +SKIP(failing)

        >>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInputOutput

        >>> from torch.distributed.device_mesh import init_device_mesh

        >>> ...

        >>> block = TransformerBlock(...)  # block is a nn.Module that contains an "attn" Attention submodule

        >>> tp_mesh = init_device_mesh("cuda", (8,))

        >>>

        >>> # According to the style specified below, the first input of attn will be annotated as Sharded DTensor

        >>> # and then redistributed to Replicated DTensor, and the output of the TransformerBlock will be annotated

        >>> # as Replicated DTensor and then redistributed to Sharded DTensor.

        >>> parallelize_module(

        >>>     block, # this can be a submodule or module

        >>>     tp_mesh,

        >>>     parallelize_plan={

        >>>         "attn": PrepareModuleInputOutput(

        >>>             input_layouts=(Shard(0), None, None, ...),

        >>>             desired_input_layouts=(Replicate(), None, None, ...),

        >>>             output_layouts=Replicate(),

        >>>             desired_output_layouts=Shard(0),

        >>>         ),

        >>>     }

        >>> )

    """

    def __init__(

        self,

        *,

        input_layouts: Optional[Union[Placement, tuple[Optional[Placement]]]] = None,

        desired_input_layouts: Optional[

            Union[Placement, tuple[Optional[Placement]]]

        ] = None,

        input_kwarg_layouts: Optional[dict[str, Placement]] = None,

        desired_input_kwarg_layouts: Optional[dict[str, Placement]] = None,

        use_local_input: bool = False,

        output_layouts: Union[Placement, tuple[Placement]],

        desired_output_layouts: Union[Placement, tuple[Placement]],

        use_local_output: bool = True,

    ):
        self.prepare_module_input = PrepareModuleInput(
            input_layouts=input_layouts,
            desired_input_layouts=desired_input_layouts,
            input_kwarg_layouts=input_kwarg_layouts,
            desired_input_kwarg_layouts=desired_input_kwarg_layouts,
            use_local_output=use_local_input,
        )
        self.prepare_module_output = PrepareModuleOutput(
            output_layouts=output_layouts,
            desired_output_layouts=desired_output_layouts,
            use_local_output=use_local_output,
        )

    def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
        self.prepare_module_input._apply(module, device_mesh)
        self.prepare_module_output._apply(module, device_mesh)

        return module

    def __repr__(self) -> str:
        tmpstr = self.__class__.__name__ + "("
        tmpstr += f"input_layouts={self.prepare_module_input.input_layouts}, "
        tmpstr += (
            f"desired_input_layouts={self.prepare_module_input.desired_input_layouts}, "
        )
        tmpstr += (
            f"input_kwarg_layouts={self.prepare_module_input.input_kwarg_layouts}, "
        )
        tmpstr += f"desired_input_kwarg_layouts={self.prepare_module_input.desired_input_kwarg_layouts}, "
        tmpstr += f"use_local_input={self.prepare_module_input.use_local_output}, "
        tmpstr += f"output_layouts={self.prepare_module_output.output_layouts}, "
        tmpstr += f"desired_output_layouts={self.prepare_module_output.desired_output_layouts}, "
        tmpstr += f"use_local_output={self.prepare_module_output.use_local_output}"
        tmpstr += ")"
        return tmpstr