File size: 29,191 Bytes
c3d0544
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Iterable
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union, cast
from warnings import warn

import torch
import torch.distributed as dist
from torch.distributed.device_mesh import DeviceMesh, _mesh_resources

from physicsnemo.distributed import DistributedManager
from physicsnemo.utils.profiling import annotate, profile
from physicsnemo.utils.version_check import check_module_requirements

# Prevent importing this module if the minimum version of pytorch is not met.
check_module_requirements("physicsnemo.distributed.shard_tensor")

from torch.distributed.tensor import DTensor  # noqa: E402
from torch.distributed.tensor._dtensor_spec import (  # noqa: E402
    TensorMeta,
)
from torch.distributed.tensor.placement_types import (  # noqa: E402
    Placement,
    Replicate,
    Shard,
)

from physicsnemo.distributed._shard_redistribute import (  # noqa: E402
    ShardRedistribute,
)
from physicsnemo.distributed._shard_tensor_spec import (  # noqa: E402
    ShardTensorSpec,
    _infer_shard_tensor_spec_from_local_chunks,
    _stride_from_contiguous_shape_C_style,
)

aten = torch.ops.aten


class _ToTorchTensor(torch.autograd.Function):
    """Autograd function to convert a ShardTensor to a regular PyTorch tensor.

    This class handles the conversion from ShardTensor to torch.Tensor in both forward
    and backward passes, maintaining proper gradient flow.  Slices the ShardTensor
    to the local component only on the current rank.
    """

    @staticmethod
    def forward(
        ctx: torch.autograd.function.FunctionCtx,
        input: "ShardTensor",
        grad_placements: Optional[Sequence[Placement]] = None,
    ) -> torch.Tensor:
        """Convert ShardTensor to torch.Tensor in forward pass.

        Args:
            ctx: Autograd context for saving tensors/variables for backward
            input: ShardTensor to convert
            grad_placements: Optional sequence of placements to use for gradients

        Returns:
            torch.Tensor: Local tensor representation of the ShardTensor
        """
        ctx.shard_tensor_spec = input._spec
        ctx.grad_placements = grad_placements
        local_tensor = input._local_tensor

        # JUST LIKE DTENSOR:
        # We need to return a fresh Tensor object there as autograd metadata
        # will be inplaced into it. So we don't want to pollute the Tensor
        # object stored in the _local_tensor of this ShardTensor.
        return local_tensor.view_as(local_tensor)

    @staticmethod
    def backward(
        ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor
    ) -> Tuple["ShardTensor", None]:
        """Convert gradient torch.Tensor back to ShardTensor in backward pass.

        Args:
            ctx: Autograd context containing saved tensors/variables from forward
            grad_output: Gradient tensor to convert back to ShardTensor

        Returns:
            Tuple containing:
            - ShardTensor gradient
            - None for grad_placements gradient (not needed)
        """
        shard_tensor_spec = ctx.shard_tensor_spec
        mesh = shard_tensor_spec.mesh
        if ctx.grad_placements is not None:
            if ctx.grad_placements != shard_tensor_spec.placements:
                grad_placements = ctx.grad_placements
                grad_sharding_shapes = "infer"
            else:
                # If the placements are the same as the input placements,
                # we reuse the sharding sizes from the input placements.
                grad_placements = ctx.grad_placements
                grad_sharding_shapes = shard_tensor_spec._sharding_shapes
        else:
            grad_placements = shard_tensor_spec.placements
            grad_sharding_shapes = shard_tensor_spec._sharding_shapes
        if grad_sharding_shapes is None:
            grad_sharding_shapes = "infer"
        # Generate a spec based on grad outputs and the expected placements:
        grad_tensor_spec = _infer_shard_tensor_spec_from_local_chunks(
            grad_output, mesh, grad_placements, grad_sharding_shapes
        )

        return (
            ShardTensor(
                grad_output, grad_tensor_spec, requires_grad=grad_output.requires_grad
            ),
            None,
        )


class _FromTorchTensor(torch.autograd.Function):
    """Autograd function for converting a torch.Tensor to a ShardTensor.

    This class handles the forward and backward passes for converting between
    torch.Tensor and ShardTensor types, maintaining gradient information.

    Global shape information is inferred using collective communication on
    the specified device mesh.
    """

    @staticmethod
    def forward(
        ctx: torch.autograd.function.FunctionCtx,
        local_input: torch.Tensor,
        device_mesh: DeviceMesh,
        placements: Tuple[Placement, ...],
        sharding_shapes: Union[str, Dict[int, List[Tuple[int, ...]]]] = "chunk",
    ) -> "ShardTensor":
        """Convert a local torch.Tensor to a ShardTensor in forward pass.

        Args:
            ctx: Autograd context for saving tensors/variables for backward
            local_input: Local tensor to convert to ShardTensor
            device_mesh: Device mesh specifying process groups
            placements: Tuple of placement rules for sharding
            sharding_shapes: Controls how shard tensor spec is generated:
                - "chunk": Use torch.chunk shapes to infer shapes from global shape (no communication)
                - "infer": Use collective communication to infer shapes from mesh neighbors.
                - Manual dict mapping mesh dim to list of shard shapes: Use provided shapes.  Must pass on each rank!

        Returns:
            ShardTensor constructed from the local input tensor
        """
        ctx.previous_placement = placements
        ctx.previous_mesh = device_mesh

        # This function is simpler than the corresponding DTensor implementation on the surface
        # because under the hood, we have some logic here to infer the sharding shapes.
        shard_tensor_spec = _infer_shard_tensor_spec_from_local_chunks(
            local_input, device_mesh, placements, sharding_shapes
        )

        shard_tensor = ShardTensor(
            local_input,
            shard_tensor_spec,
            requires_grad=local_input.requires_grad,
        )

        return shard_tensor

    @staticmethod
    def backward(
        ctx: torch.autograd.function.FunctionCtx,
        grad_output: "ShardTensor",
    ) -> Tuple[torch.Tensor, None, None]:
        """Convert gradient ShardTensor back to torch.Tensor in backward pass.

        Args:
            ctx: Autograd context containing saved tensors/variables from forward
            grad_output: Gradient ShardTensor to convert back to torch.Tensor

        Returns:
            Tuple containing:
            - Local tensor gradient
            - None for device_mesh gradient (not needed)
            - None for placements gradient (not needed)

        Raises:
            RuntimeError: If gradient tensor has different placement than original
        """
        previous_placement = ctx.previous_placement
        if grad_output.placements != previous_placement:
            # Automatically redistribute to the previous placement as long as it's not a partial.
            if not any(p.is_partial() for p in previous_placement):
                grad_output = grad_output.redistribute(
                    grad_output._spec.mesh, previous_placement
                )
            else:
                raise RuntimeError(
                    "Resharding gradients with partial placements not implemented"
                )

        return grad_output.to_local(), None, None, None


class ShardTensor(DTensor):
    """
    A class similar to pytorch's native DTensor but with more
    flexibility for uneven data sharding.

    Leverages very similar API to DTensor (identical, where possible)
    but deliberately tweaking routines to avoid implicit assumptions
    about tensor sharding.

    The key differences from DTensor are:
    - Supports uneven sharding where different ranks can have different local tensor sizes
    - Tracks and propagates shard size information across operations
    - Handles redistribution of unevenly sharded tensors
    - Provides custom collective operations optimized for uneven sharding

    Like DTensor, operations are dispatched through PyTorch's dispatcher system.
    Most operations work by:
    1. Converting inputs to local tensors
    2. Performing the operation locally
    3. Constructing a new ShardTensor with appropriate sharding spec
    4. Handling any needed communication between ranks

    The class provides methods for:
    - Converting to/from local tensors
    - Redistributing between different sharding schemes
    - Performing collective operations like all_gather and reduce_scatter
    - Basic tensor operations that maintain sharding information
    """

    _local_tensor: torch.Tensor
    _spec: ShardTensorSpec
    __slots__ = ["_local_tensor", "_spec"]

    # For torch.ops.aten operators (low-level dispatch)
    _dispatch_registry: Dict[torch._ops.OpOverload, Callable] = {}

    # For Python-level functions (torch.mean, tensor.mean, etc.)
    _function_registry: Dict[Callable, Callable] = {}

    # For custom functions registered with PyTorch,
    # it is sometimes necessary to match by name.
    # For instance, if you declare an op with
    #
    # @torch.library.custom_op(
    #    "module::function_name", mutates_args=()
    # )
    # def function_external_to_torch(
    #
    # Then, you likely want to register the handler with
    #
    # ShardTensor.register_named_function_handler("module.function_name.default", handler)
    _named_function_registry: dict[str, Callable] = {}

    # Upon construction of any ShardTensor objects, this will be set to true.
    # Wrappers are triggered dynamically, so the wrapping will be pass-through
    # exclusively until true.
    _enable_shard_patches: bool = False

    @classmethod
    def patches_enabled(cls) -> bool:
        """
        Whether to enable patches for this class.

        Default is False, but can be changed by the user.
        """
        return cls._enable_shard_patches

    @classmethod
    def register_dispatch_handler(
        cls, op: torch._ops.OpOverload, handler: Callable
    ) -> None:
        """Register a handler for a specific PyTorch operator in the dispatch system."""
        cls._dispatch_registry[op] = handler

    @classmethod
    def register_function_handler(cls, func: Callable, handler: Callable) -> None:
        """Register a handler for a Python-level function or method."""
        cls._function_registry[func] = handler

    @classmethod
    def register_named_function_handler(cls, func_name: str, handler: Callable) -> None:
        """Register a named function that has been named via torch.library.custom_op"""
        cls._named_function_registry[func_name] = handler

    @staticmethod
    def __new__(
        cls,
        local_tensor: torch.Tensor,
        spec: ShardTensorSpec,
        *,
        requires_grad: bool,
    ) -> "ShardTensor":
        """
        Construct a new Shard Tensor from a local tensor, device mesh, and placement.

        Note that unlike DTensor, ShardTensor will automatically collect the Shard size
        information from all participating devices. This is to enable uneven and
        dynamic sharding.

        Heavily derived from torch DTensor

        Args:
            local_tensor: Local tensor to use as the data
            spec: ShardTensorSpec defining the sharding scheme
            requires_grad: Whether the tensor requires gradients

        Returns:
            A new ShardTensor instance
        """
        if local_tensor.requires_grad and not requires_grad:
            warn(
                "To construct a new ShardTensor from torch.Tensor, "
                "it's recommended to use local_tensor.detach() and "
                "make requires_grad consistent."
            )

        if spec.tensor_meta is None:
            raise ValueError("TensorMeta should not be None!")

        # Check the sharding information is known:
        ret = torch.Tensor._make_wrapper_subclass(
            cls,
            spec.tensor_meta.shape,
            strides=spec.tensor_meta.stride,
            dtype=local_tensor.dtype,
            device=local_tensor.device,
            layout=local_tensor.layout,
            requires_grad=requires_grad,
        )

        ret._spec = spec
        ret._local_tensor = local_tensor

        cls._enable_shard_patches = True

        return ret

    def __repr__(self) -> str:
        return f"ShardTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})"

    @classmethod
    def from_dtensor(cls, dtensor: DTensor) -> "ShardTensor":
        """
        Convert a DTensor to a ShardTensor.  We assume the DTensor is properly constructed.

        Args:
            dtensor: DTensor to convert

        Returns:
            Equivalent ShardTensor
        """
        # Always ensure sharding is turned on:
        cls._enable_shard_patches = True

        # DTensor is locked to sharding a tensor according to chunk format.
        # We can use that to infer sharding sizes with no communication.

        # Create the spec by inferring the sharding sizes from the DTensor:
        spec = _infer_shard_tensor_spec_from_local_chunks(
            dtensor._local_tensor,
            dtensor._spec.mesh,
            dtensor._spec.placements,
            sharding_shapes="chunk",
            global_shape=dtensor.shape,
        )

        return ShardTensor.__new__(
            cls,
            local_tensor=dtensor._local_tensor,
            spec=spec,
            requires_grad=dtensor.requires_grad,
        )

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs={}):
        with annotate(f"__torch_function___{func.__name__}"):
            # Check for overrides:
            if func in cls._function_registry and cls._enable_shard_patches:
                res = cls._function_registry[func](func, types, args, kwargs)
                return res
            elif (
                str(func) in cls._named_function_registry and cls._enable_shard_patches
            ):
                res = cls._named_function_registry[str(func)](func, types, args, kwargs)
                return res
            # Fall back to the default behavior:
            return super().__torch_function__(func, types, args, kwargs)

    @classmethod
    @torch._disable_dynamo
    @profile
    def __torch_dispatch__(
        cls,
        func: torch._ops.OpOverload,
        types: Tuple[type, ...],
        args: Tuple[object, ...] = (),
        kwargs: Optional[Dict[str, object]] = None,
    ) -> Union["ShardTensor", Iterable["ShardTensor"], object]:
        with annotate(f"__torch_dispatch___{func.__name__}"):
            # Leverage DTensor Dispatch as much as possible, but, enable
            # the ability to operate on this output in the future:
            if func in cls._dispatch_registry:
                res = cls._dispatch_registry[func](*args, **kwargs)
                return res

            # We assume that if we reach this point, the operator has not been
            # intercepted by a wrapper or in the registry.  So the DTensor
            # default behavior is likely to be correct.

            if func == aten.view.default:
                # For view, we need input tensors to be contiguous:
                for arg in args:
                    if isinstance(arg, ShardTensor) or isinstance(arg, DTensor):
                        if not arg._local_tensor.is_contiguous():
                            arg._local_tensor = arg._local_tensor.contiguous()

            dispatch_res = DTensor._op_dispatcher.dispatch(func, args, kwargs or {})

            # Return a shard tensor instead of a dtensor.
            def _convert_dtensor_with_input_check(dtensor, input_args):
                """
                This function searches the input for ShardTensors that match output shapes.
                It prevents collectives, since we can copy the sharding shapes for irregular shards.
                The idea here is that, if the global shape is unchanged, and
                the placements are unchanged, the sharding shapes should be unchanged.

                If no matches are found, it falls back to inference based on DTensor.

                This is only used when we already went back through the DTensor dispatch.
                """
                # Check if this matches any input ShardTensor
                for arg in input_args:
                    if (
                        isinstance(arg, ShardTensor)
                        and dtensor._spec.tensor_meta == arg._spec.tensor_meta
                        and dtensor._spec.placements == arg._spec.placements
                    ):
                        return ShardTensor.__new__(
                            ShardTensor,
                            local_tensor=dtensor._local_tensor,
                            spec=arg._spec,
                            requires_grad=dtensor.requires_grad,
                        )

                # Fall back to default conversion
                return ShardTensor.from_dtensor(dtensor)

            if isinstance(dispatch_res, DTensor):
                return _convert_dtensor_with_input_check(dispatch_res, args)

            if isinstance(dispatch_res, Iterable):
                return type(dispatch_res)(
                    _convert_dtensor_with_input_check(d, args)
                    if isinstance(d, DTensor)
                    else d
                    for d in dispatch_res
                )

            return dispatch_res

    @staticmethod
    def from_local(
        local_tensor: torch.Tensor,
        device_mesh: Optional[DeviceMesh] = None,
        placements: Optional[Sequence[Placement]] = None,
        sharding_shapes: Union[str, Dict[int, List[Tuple[int, ...]]]] = "infer",
    ) -> "ShardTensor":
        """
        Generate a new ShardTensor from local torch tensors. Uses
        device mesh and placements to infer global tensor properties.

        No restriction is made on forcing tensors to have equal shapes
        locally. Instead, the requirement is that tensor shapes could
        be concatenated into a single tensor according to the placements.

        Args:
            local_tensor: Local chunk of tensor. All participating tensors must be
                of the same rank and concatable across the mesh dimensions
            device_mesh: Target Device Mesh, if not specified will use the current mesh
            placements: Target placements, must have same number of elements as device_mesh.ndim
            sharding_shapes: Controls how shard tensor spec is generated:
                - "chunk": Use torch.chunk shapes to infer shapes from global shape (no communication)
                - "infer": Use collective communication to infer shapes from mesh neighbors.
                - Manual dict mapping mesh dim to list of shard shapes: Use provided shapes.  Must pass on each rank!
        Returns:
            A new ShardTensor instance
        """

        # this turns on shard patches globally for this process.
        ShardTensor._enable_shard_patches = True

        # This implementation follows the pytorch DTensor Implementation Closely.
        device_mesh = device_mesh or _mesh_resources.get_current_mesh()
        device_type = device_mesh.device_type

        # convert the local tensor to desired device base on device mesh's device_type
        if device_type != local_tensor.device.type and not local_tensor.is_meta:
            local_tensor = local_tensor.to(device_type)

        # set default placements to replicated if not specified
        if placements is None:
            placements = [Replicate() for _ in range(device_mesh.ndim)]
        else:
            placements = list(placements)
            for idx, placement in enumerate(placements):
                # normalize shard dim to be positive
                if placement.is_shard():
                    placement = cast(Shard, placement)
                    if placement.dim < 0:
                        placements[idx] = Shard(placement.dim + local_tensor.ndim)

        # `from_local` is differentiable, and the gradient of the dist tensor this function
        # created should flow back the gradients to the local_tensor, so we call an autograd
        # function to construct the dist tensor instead.
        return _FromTorchTensor.apply(  # pyre-ignore[16]: autograd func
            local_tensor,
            device_mesh,
            tuple(placements),
            sharding_shapes,
        )

    def offsets(self, mesh_dim: Optional[int] = None) -> List[int]:
        """
        Get offsets of shards along a mesh dimension.

        Args:
            mesh_dim: Mesh dimension to get offsets for. If None, returns all offsets.

        Returns:
            List of offsets for shards along specified dimension
        """
        return self._spec.offsets(mesh_dim)

    def redistribute(
        self,
        device_mesh: Optional[DeviceMesh] = None,
        placements: Optional[Sequence[Placement]] = None,
        *,
        async_op: bool = False,
    ) -> "ShardTensor":
        """
        Redistribute tensor across device mesh with new placement scheme.
        Like DTensor redistribute but uses custom layer for shard redistribution.

        Args:
            device_mesh: Target device mesh. Uses current if None.
            placements: Target placement scheme. Required.
            async_op: Whether to run asynchronously

        Returns:
            Redistributed ShardTensor

        Raises:
            RuntimeError: If placements not specified or invalid
        """

        # if device_mesh is not specified, use the current device_mesh
        device_mesh = device_mesh or self.device_mesh
        # raise error if new placements not specified
        if placements is None:
            raise RuntimeError("placements is needed for redistribute!")

        placements = list(placements)
        for i, placement in enumerate(placements):
            if placement.is_partial():
                raise RuntimeError(
                    "Can not redistribute to Partial, redistributing to Partial is for internal use only!"
                )
            elif isinstance(placement, Shard) and placement.dim < 0:
                # normalize shard dim to be positive
                placements[i] = Shard(placement.dim + self.ndim)
        placements = tuple(placements)

        return ShardRedistribute.apply(self, device_mesh, placements, async_op)

    def to_local(
        self, *, grad_placements: Optional[Sequence[Placement]] = None
    ) -> torch.Tensor:
        """
        Get local tensor from this ShardTensor.

        Args:
            grad_placements: Future layout of gradients. Optional.

        Returns:
            Local torch.Tensor. Shape may vary between ranks for sharded tensors.
        """

        if not torch.is_grad_enabled():
            return self._local_tensor

        if grad_placements is not None and not isinstance(grad_placements, tuple):
            grad_placements = tuple(grad_placements)

        return _ToTorchTensor.apply(self, grad_placements)

    def full_tensor(
        self, *, grad_placements: Optional[Sequence[Placement]] = None
    ) -> torch.Tensor:
        """
        Need to re-implement here to ensure a ShardTensor is used as the output
        of redistribute.
        """

        redist_res = self.redistribute(
            placements=[Replicate()] * self.device_mesh.ndim, async_op=False
        )
        return _ToTorchTensor.apply(redist_res, grad_placements)

    def backward(self, *args, **kwargs):
        """
        Backward pass for ShardTensor.

        This method is used to perform the backward pass for a ShardTensor.
        It handles the redistribution of the tensor to the desired placements
        and then calls the backward pass on the local tensor.
        """

        # Before calling backward, we need to resolve any partial placements.
        new_placements = []
        # grad_placements = []
        needs_redistribute = False
        for i, placement in enumerate(self._spec.placements):
            if placement.is_partial():
                new_placements.append(Replicate())
                # grad_placements.append(Shard(i))
                needs_redistribute = True
            else:
                new_placements.append(placement)
                # grad_placements.append(placement)

        if needs_redistribute:
            self = self.redistribute(placements=new_placements)

        return self.to_local().backward(*args, **kwargs)


def scatter_tensor(
    tensor: torch.Tensor,
    global_src: int,
    mesh: DeviceMesh,
    placements: Tuple[Placement, ...],
    global_shape: Optional[torch.Size] = None,
    dtype: Optional[torch.dtype] = None,
    requires_grad: bool = False,
) -> "ShardTensor":
    """
    Take a tensor from source rank and distribute it across devices on the mesh according to placements.

    This function takes a tensor that exists on a single source rank and distributes it across
    a device mesh according to the specified placement scheme. For multi-dimensional meshes,
    it performs a flattened scatter operation before constructing the sharded tensor.

    Args:
        tensor: The tensor to distribute, must exist on source rank
        global_src: Global rank ID of the source process
        mesh: Device mesh defining the process topology
        placements: Tuple of placement specifications defining how to distribute the tensor

    Returns:
        ShardTensor: The distributed tensor with specified placements

    Raises:
        ValueError: If global_src is not an integer or not in the mesh
    """
    dm = DistributedManager()

    if not isinstance(global_src, int):
        raise ValueError("Global source must be an integer rank")
    if global_src not in mesh.mesh:
        raise ValueError("Please specify a tensor source in this mesh")

    is_src = dm.rank == global_src

    # For multi-dimensional meshes, we use a flattened process group
    mesh_group = dm.get_mesh_group(mesh)

    # Broadcast tensor metadata from source
    if global_shape is None or dtype is None:
        if dm.rank == global_src:
            meta = [TensorMeta(tensor.shape, tensor.stride(), tensor.dtype)]
        else:
            meta = [None]

        dist.broadcast_object_list(meta, src=global_src, group=mesh_group)

        local_meta = meta[0]
    else:
        stride = _stride_from_contiguous_shape_C_style(global_shape)
        local_meta = TensorMeta(global_shape, stride, dtype)

    # This needs to be optimized, but I want to get the whole pipeline optimized first.
    # This only gets done when scatter_tensor is called and it should be relatively small
    # in full applications.

    # What isn't optimmized?  Broadcasting the full tensor when placement is likely
    # Shard on at least one mesh dimension.  It would be more efficient to iteratively
    # scatter along Shard dimensions.  BUT, the focus is on performance of full applications
    # and this is a once-per-iteration cost.

    # Broadcast the tensor to all ranks
    if tensor is None and not is_src:
        # Tensor is allowed to be none if not on the root rank
        tensor = torch.empty(local_meta.shape, dtype=local_meta.dtype, device=dm.device)

    dist.broadcast(tensor, src=global_src, group=mesh_group)

    # Create a fully-replicated spec:
    spec = ShardTensorSpec(
        mesh=mesh,
        placements=[Replicate() for _ in range(mesh.ndim)],
        tensor_meta=local_meta,
        _sharding_shapes={},
    )

    # Make a "fully-replicated" tensor on all ranks:
    st = ShardTensor.__new__(
        ShardTensor,
        local_tensor=tensor,
        spec=spec,
        requires_grad=requires_grad,
    )

    # Redistribute the tensor to the desired placements:
    st = st.redistribute(mesh, placements, async_op=False)
    # This is an unoptimal step but is functional:
    if requires_grad:
        st = st.detach()
        st.requires_grad = True

    return st