File size: 30,157 Bytes
c3d0544
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import math
import random
import warnings
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Union

import torch
from einops import rearrange
from torch import Tensor

"""
This module defines utilities, including classes and functions, for domain
decomposition.
"""


class BasePatching2D(ABC):
    """
    Abstract base class for 2D image patching operations.

    This class provides a foundation for implementing various image patching
    strategies.
    It handles basic parameter validation and provides default methods for
    patching and fusing.

    It is designed to be extensible to support different patching strategies.
    Any new patching strategy for 2D images should inherit from this class and
    implement the abstract methods.

    Parameters
    ----------
    img_shape : Tuple[int, int]
        The height and width of the full input images :math:`(H, W)`.
    patch_shape : Tuple[int, int]
        The height and width of the patches to extract :math:`(H_p, W_p)`.
    """

    def __init__(
        self, img_shape: Tuple[int, int], patch_shape: Tuple[int, int]
    ) -> None:
        # Check that img_shape and patch_shape are 2D
        if len(img_shape) != 2:
            raise ValueError(f"img_shape must be 2D, got {len(img_shape)}D")
        if len(patch_shape) != 2:
            raise ValueError(f"patch_shape must be 2D, got {len(patch_shape)}D")

        # Make sure patches fit within the image
        if any(p > i for p, i in zip(patch_shape, img_shape)):
            warnings.warn(
                f"Patch shape {patch_shape} is larger than "
                f"image shape {img_shape}. "
                f"Patches will be cropped to fit within the image."
            )
        self.img_shape = img_shape
        self.patch_shape = tuple(min(p, i) for p, i in zip(patch_shape, img_shape))

    @abstractmethod
    def apply(self, input: Tensor, **kwargs) -> Tensor:
        """
        Apply the patching operation to a batch of full images.

        Parameters
        ----------
        input : Tensor
            Batch of full input images of shape :math:`(B, C, H, W)`.
        **kwargs : dict
            Additional keyword arguments specific to the patching
            implementation.

        Returns
        -------
        Tensor
            Patched tensor, shape depends on specific implementation.
        """
        pass

    def fuse(self, input: Tensor, **kwargs) -> Tensor:
        """
        Fuse patches back into a complete image.

        Parameters
        ----------
        input : Tensor
            Input tensor containing patches. Shape depends on specific implementation.
        **kwargs : dict
            Additional keyword arguments specific to the fusion implementation.

        Returns
        -------
        Tensor
            Fused tensor. Shape depends on specific implementation.

        Raises
        ------
        NotImplementedError
            If the subclass does not implement this method.
        """
        raise NotImplementedError("'fuse' method must be implemented in subclasses.")

    def global_index(
        self, batch_size: int, device: Union[torch.device, str] = "cpu"
    ) -> Tensor:
        """
        Returns a tensor containing the global indices for each patch.

        Global indices correspond to :math:`(y, x)` global grid coordinates of each
        element within the original image (before patching). It is typically
        used to keep track of the original position of each patch in the
        original image.

        Parameters
        ----------
        batch_size : int
            The size :math:`B` of the batch of images to patch.
        device : Union[torch.device, str], default="cpu"
            Proper device to initialize ``global_index`` on.

        Returns
        -------
        Tensor
            A tensor of shape :math:`(P, 2, H_p, W_p)`, where :math:`P` is the
            number of patches to extract (corresponds to ``self.patch_num`` for
            classes that implement this attribute).
            The y-coordinate is stored in ``global_index[:, 0, :, :]`` and the
            x-coordinate is stored in ``global_index[:, 1, :, :]``.
        """
        Ny = torch.arange(self.img_shape[0], device=device).int()
        Nx = torch.arange(self.img_shape[1], device=device).int()
        grid = torch.stack(torch.meshgrid(Ny, Nx, indexing="ij"), dim=0).unsqueeze(0)
        global_index = self.apply(grid).long()
        return global_index


class RandomPatching2D(BasePatching2D):
    """
    Class for randomly extracting patches from 2D images.

    This class provides utilities to randomly extract patches from a batch of full
    images represented as 4D tensors. It maintains a list of random patch indices
    that can be reset as needed.

    Parameters
    ----------
    img_shape : Tuple[int, int]
        The height and width :math:`(H, W)` of the full input images.
    patch_shape : Tuple[int, int]
        The height and width :math:`(H_p, W_p)` of the patches to extract.
    patch_num : int
        The number of patches :math:`P` to extract.

    Attributes
    ----------
    patch_indices : List[Tuple[int, int]]
        The indices of the patches to extract from the images. These indices
        correspond to the :math:`(y, x)` coordinates of the upper left corner of
        each patch.

    See Also
    --------
    :class:`physicsnemo.utils.patching.BasePatching2D`
        The base class providing the patching interface.
    :class:`physicsnemo.utils.patching.GridPatching2D`
        Alternative patching strategy using deterministic patch locations.
    """

    def __init__(
        self, img_shape: Tuple[int, int], patch_shape: Tuple[int, int], patch_num: int
    ) -> None:
        super().__init__(img_shape, patch_shape)
        self._patch_num = patch_num
        # Generate the indices of the patches to extract
        self.reset_patch_indices()

    @property
    def patch_num(self) -> int:
        """
        Get the number of patches to extract.

        Returns
        -------
        int
            The number of patches :math:`P` to extract.
        """
        return self._patch_num

    def set_patch_num(self, value: int) -> None:
        """
        Set the number of patches to extract and reset patch indices.
        This is the only way to modify the ``patch_num`` attribute.

        Parameters
        ----------
        value : int
            The new number of patches :math:`P` to extract.
        """
        self._patch_num = value
        self.reset_patch_indices()

    def reset_patch_indices(self) -> None:
        """
        Generate new random indices for the patches to extract. These are the
        starting indices of the patches to extract (upper left corner).
        """
        self.patch_indices = [
            (
                random.randint(0, self.img_shape[0] - self.patch_shape[0]),
                random.randint(0, self.img_shape[1] - self.patch_shape[1]),
            )
            for _ in range(self.patch_num)
        ]
        return

    def get_patch_indices(self) -> List[Tuple[int, int]]:
        """
        Get the current list of patch starting indices.

        These are the upper-left coordinates of each extracted patch
        from the full image.

        Returns
        -------
        List[Tuple[int, int]]
            A list of (row, column) tuples representing patch starting positions.
        """
        return self.patch_indices

    def apply(
        self,
        input: Tensor,
        additional_input: Optional[Tensor] = None,
    ) -> Tensor:
        r"""
        Applies the patching operation by extracting patches specified by
        ``self.patch_indices`` from the ``input`` Tensor. Extracted patches are
        batched along the first dimension of the output. The layout of the
        output assumes that for any patch index ``i``, ``out[B * i: B * (i + 1)]``
        corresponds to the *same patch* extracted from each batch element of
        ``input``.

        Parameters
        ----------
        input : Tensor
            The input tensor representing the full image with shape :math:`(B, C, H, W)`.
        additional_input : Optional[Tensor], optional
            Its shape should be :math:`(B, C_{add}, H_{add}, W_{add})`.
            Must have same batch size as ``input``. Bilinear interpolation is
            used to interpolate ``additional_input`` onto a 2D grid of shape
            :math:`(H_p, W_p)`. It is then channel-wise concatenated to the
            extracted patches.
            *Note: ``additional_input`` is not patched or decomposed.*

        Returns
        -------
        Tensor
            A tensor of shape :math:`(P \times B, C [+ C_{add}], H_p, W_p)`.
            If ``additional_input`` is provided, it is channel-wise concatenated
            to the extracted patches.

        See Also
        --------
        :func:`physicsnemo.utils.patching.image_batching`
            The underlying function used to perform the patching operation.
        """
        B = input.shape[0]
        out = torch.zeros(
            B * self.patch_num,
            (
                input.shape[1]
                + (additional_input.shape[1] if additional_input is not None else 0)
            ),
            self.patch_shape[0],
            self.patch_shape[1],
            device=input.device,
        )
        out = out.to(
            memory_format=torch.channels_last
            if input.is_contiguous(memory_format=torch.channels_last)
            else torch.contiguous_format
        )
        if additional_input is not None:
            add_input_interp = torch.nn.functional.interpolate(
                input=additional_input, size=self.patch_shape, mode="bilinear"
            )

        for i, (py, px) in enumerate(self.patch_indices):
            if additional_input is not None:
                out[B * i : B * (i + 1),] = torch.cat(
                    (
                        input[
                            :,
                            :,
                            py : py + self.patch_shape[0],
                            px : px + self.patch_shape[1],
                        ],
                        add_input_interp,
                    ),
                    dim=1,
                )
            else:
                out[B * i : B * (i + 1),] = input[
                    :,
                    :,
                    py : py + self.patch_shape[0],
                    px : px + self.patch_shape[1],
                ]
        return out


class GridPatching2D(BasePatching2D):
    """
    Class for deterministically extracting patches from 2D images in a grid pattern.

    This class provides utilities to extract patches from images in a
    deterministic manner, with configurable overlap and boundary pixels.
    The patches are extracted in a grid-like pattern covering the entire image.

    Parameters
    ----------
    img_shape : Tuple[int, int]
        The height and width of the full input images :math:`(H, W)`.
    patch_shape : Tuple[int, int]
        The height and width of the patches to extract :math:`(H_p, W_p)`.
    overlap_pix : int, optional, default=0
        Number of pixels to overlap between adjacent patches.
    boundary_pix : int, optional, default=0
        Number of pixels to crop as boundary from each patch.

    Attributes
    ----------
    patch_num : int
        Total number of patches :math:`P` that will be extracted from the image,
        calculated as :math:`P = P_x * P_y`.

    See Also
    --------
    :class:`physicsnemo.utils.patching.BasePatching2D`
        The base class providing the patching interface.
    :class:`physicsnemo.utils.patching.RandomPatching2D`
        Alternative patching strategy using random patch locations.
    """

    def __init__(
        self,
        img_shape: Tuple[int, int],
        patch_shape: Tuple[int, int],
        overlap_pix: int = 0,
        boundary_pix: int = 0,
    ):
        super().__init__(img_shape, patch_shape)
        self.overlap_pix = overlap_pix
        self.boundary_pix = boundary_pix
        patch_num_x = math.ceil(
            img_shape[1] / (patch_shape[1] - overlap_pix - boundary_pix)
        )
        patch_num_y = math.ceil(
            img_shape[0] / (patch_shape[0] - overlap_pix - boundary_pix)
        )

        self.patch_num = patch_num_x * patch_num_y
        self._overlap_count = self.get_overlap_count(
            self.patch_shape, self.img_shape, self.overlap_pix, self.boundary_pix
        )

    def apply(
        self,
        input: Tensor,
        additional_input: Optional[Tensor] = None,
    ) -> Tensor:
        r"""
        Apply deterministic patching to the input tensor.

        Splits the input tensor into patches in a grid-like pattern. Can
        optionally concatenate additional interpolated data to each patch.
        Extracted patches are batched along the first dimension of the output.
        The layout of the output assumes that for any patch index ``i``,
        ``out[B * i: B * (i + 1)]`` corresponds to the *same patch* extracted
        from each batch element of ``input``.

        Parameters
        ----------
        input : Tensor
            Batch of full input images of shape :math:`(B, C, H, W)`.
        additional_input : Optional[Tensor], optional, default=None
            Additional data to concatenate to each patch. Shape must be
            :math:`(B, C_{add}, H_{add}, W_{add})`. Will be interpolated
            to match patch dimensions :math:`(H_p, W_p)`
            *Note: ``additional_input`` is not patched or decomposed.*

        Returns
        -------
        Tensor
            Tensor containing patches with shape :math:`(P \times B, C [+ C_{add}], H_p, W_p)`.
            If ``additional_input`` is provided, it is channel-wise concatenated
            to the extracted patches.

        See Also
        --------
        :func:`physicsnemo.utils.patching.image_batching`
            The underlying function used to perform the patching operation.
        """
        if additional_input is not None:
            add_input_interp = torch.nn.functional.interpolate(
                input=additional_input, size=self.patch_shape, mode="bilinear"
            )
        else:
            add_input_interp = None
        out = image_batching(
            input=input,
            patch_shape_y=self.patch_shape[0],
            patch_shape_x=self.patch_shape[1],
            overlap_pix=self.overlap_pix,
            boundary_pix=self.boundary_pix,
            input_interp=add_input_interp,
        )
        return out

    def fuse(self, input: Tensor, batch_size: int) -> Tensor:
        r"""
        Fuse patches back into a complete image.

        Reconstructs the original image by stitching together patches,
        accounting for overlapping regions and boundary pixels. In overlapping
        regions, values are averaged.

        Parameters
        ----------
        input : Tensor
            Input tensor containing patches with shape :math:`(P \times B, C, H_p, W_p)`.
            *Note: the patch layout along the batch dimension should be the same
            as the one returned by the method
            :meth:`~physicsnemo.utils.patching.GridPatching2D.apply`.*
        batch_size : int
            The original batch size :math:`B` before patching.

        Returns
        -------
        Tensor
            Reconstructed image tensor with shape :math:`(B, C, H, W)`.

        See Also
        --------
        :func:`physicsnemo.utils.patching.image_fuse`
            The underlying function used to perform the fusion operation.
        """
        out = image_fuse(
            input=input,
            img_shape_y=self.img_shape[0],
            img_shape_x=self.img_shape[1],
            batch_size=batch_size,
            overlap_pix=self.overlap_pix,
            boundary_pix=self.boundary_pix,
            overlap_count=self._overlap_count,
        )
        return out

    @staticmethod
    def get_overlap_count(
        patch_shape: tuple[int, int],
        img_shape: tuple[int, int],
        overlap_pix: int,
        boundary_pix: int,
    ) -> Tensor:
        r"""
        Compute overlap count map for image patch reconstruction.

        Calculates how many times each pixel in the padded image is covered by
        extracted patches, based on the patch size, overlap size, and boundary
        padding. This is useful for normalizing the reconstructed image after
        folding overlapping patches.

        The overlap count is stored in `self._overlap_count`.

        Parameters
        ----------
        img_shape : Tuple[int, int]
            The height and width of the full input images :math:`(H, W)`.
        patch_shape : Tuple[int, int]
            The height and width of the patches to extract :math:`(H_p, W_p)`.
        overlap_pix : int
            The number of overlapping pixels between adjacent patches.
        boundary_pix : int
            The number of pixels to crop as a boundary from each patch.

        Returns
        -------
        Tensor
            Tensor indicating how many times each pixel in the original input
            is visited (or covered) by patches. Shape is :math:`(1, 1, H_{pad},
            W_{pad})`, where :math:`H_{pad}` and :math:`W_{pad}` are
            the padded image dimensions. Those are computed as :math:`H_{pad} = (H_p -
            \text{overlap_pix} - \text{boundary_pix}) \times (P_H - 1) + H_p +
            \text{boundary_pix}`, where :math:`P_H` is the number of patches
            along the height of the image (and similarly for :math:`W_{pad}`).

        """
        # Infer sizes from input image shape
        patch_shape_y, patch_shape_x = patch_shape
        img_shape_y, img_shape_x = img_shape

        # Calculate the number of patches in each dimension
        patch_num_x = math.ceil(
            img_shape_x / (patch_shape_x - overlap_pix - boundary_pix)
        )
        patch_num_y = math.ceil(
            img_shape_y / (patch_shape_y - overlap_pix - boundary_pix)
        )

        # Calculate the shape of the input after padding
        padded_shape_x = (
            (patch_shape_x - overlap_pix - boundary_pix) * (patch_num_x - 1)
            + patch_shape_x
            + boundary_pix
        )
        padded_shape_y = (
            (patch_shape_y - overlap_pix - boundary_pix) * (patch_num_y - 1)
            + patch_shape_y
            + boundary_pix
        )

        input_ones = torch.ones(
            (1, 1, padded_shape_y, padded_shape_x),
        )
        overlap_count = torch.nn.functional.unfold(
            input=input_ones,
            kernel_size=(patch_shape_y, patch_shape_x),
            stride=(
                patch_shape_y - overlap_pix - boundary_pix,
                patch_shape_x - overlap_pix - boundary_pix,
            ),
        )
        overlap_count = torch.nn.functional.fold(
            input=overlap_count,
            output_size=(padded_shape_y, padded_shape_x),
            kernel_size=(patch_shape_y, patch_shape_x),
            stride=(
                patch_shape_y - overlap_pix - boundary_pix,
                patch_shape_x - overlap_pix - boundary_pix,
            ),
        )
        return overlap_count


def image_batching(
    input: Tensor,
    patch_shape_y: int,
    patch_shape_x: int,
    overlap_pix: int,
    boundary_pix: int,
    input_interp: Optional[Tensor] = None,
) -> Tensor:
    r"""
    Splits a full image into a batch of patched images.

    This function takes a full image and splits it into patches, adding padding
    where necessary. It can also concatenate additional interpolated data to
    each patch if provided.

    Parameters
    ----------
    input : Tensor
        The input tensor representing a batch of full image with shape :math:`(B, C, H, W)`.
    patch_shape_y : int
        The height :math:`H_p` of each image patch.
    patch_shape_x : int
        The width :math:`W_p` of each image patch.
    overlap_pix : int
        The number of overlapping pixels between adjacent patches.
    boundary_pix : int
        The number of pixels to crop as a boundary from each patch.
    input_interp : Optional[Tensor], optional
        Optional additional data to concatenate to each patch with shape
        :math:`(B, C_{add}, H_{add}, W_{add})`.
        *Note: ``additional_input`` is not patched or decomposed.*

    Returns
    -------
    Tensor
        A tensor containing the image patches, with shape :math:`(P \times B, C [+ C_{add}], H_p, W_p)`.
        If ``additional_input`` is provided, it is channel-wise concatenated
        to the extracted patches.
    """
    # Infer sizes from input image
    batch_size, _, img_shape_y, img_shape_x = input.shape

    # Safety check: make sure patch_shapes are large enough to accommodate
    # overlaps and boundaries pixels
    if (patch_shape_x - overlap_pix - boundary_pix) < 1:
        raise ValueError(
            f"patch_shape_x must verify patch_shape_x ({patch_shape_x}) >= "
            f"1 + overlap_pix ({overlap_pix}) + boundary_pix ({boundary_pix})"
        )
    if (patch_shape_y - overlap_pix - boundary_pix) < 1:
        raise ValueError(
            f"patch_shape_y must verify patch_shape_y ({patch_shape_y}) >= "
            f"1 + overlap_pix ({overlap_pix}) + boundary_pix ({boundary_pix})"
        )
    # Safety check: validate input_interp dimensions if provided
    if input_interp is not None:
        if input_interp.shape[0] != batch_size:
            raise ValueError(
                f"input_interp batch size ({input_interp.shape[0]}) must match "
                f"input batch size ({batch_size})"
            )
        if (input_interp.shape[2] != patch_shape_y) or (
            input_interp.shape[3] != patch_shape_x
        ):
            raise ValueError(
                f"input_interp patch shape ({input_interp.shape[2]}, {input_interp.shape[3]}) "
                f"must match specified patch shape ({patch_shape_y}, {patch_shape_x})"
            )

    # Safety check: make sure patch_shape is large enough in comparison to
    # overlap_pix and boundary_pix. Otherwise, number of patches extracted by
    # unfold differs from the expected number of patches.
    if patch_shape_x <= overlap_pix + 2 * boundary_pix:
        raise ValueError(
            f"patch_shape_x ({patch_shape_x}) must verify "
            f"patch_shape_x ({patch_shape_x}) > "
            f"overlap_pix ({overlap_pix}) + 2 * boundary_pix ({boundary_pix})"
        )
    if patch_shape_y <= overlap_pix + 2 * boundary_pix:
        raise ValueError(
            f"patch_shape_y ({patch_shape_y}) must verify "
            f"patch_shape_y ({patch_shape_y}) > "
            f"overlap_pix ({overlap_pix}) + 2 * boundary_pix ({boundary_pix})"
        )

    patch_num_x = math.ceil(img_shape_x / (patch_shape_x - overlap_pix - boundary_pix))
    patch_num_y = math.ceil(img_shape_y / (patch_shape_y - overlap_pix - boundary_pix))
    padded_shape_x = (
        (patch_shape_x - overlap_pix - boundary_pix) * (patch_num_x - 1)
        + patch_shape_x
        + boundary_pix
    )
    padded_shape_y = (
        (patch_shape_y - overlap_pix - boundary_pix) * (patch_num_y - 1)
        + patch_shape_y
        + boundary_pix
    )
    pad_x_right = padded_shape_x - img_shape_x - boundary_pix
    pad_y_right = padded_shape_y - img_shape_y - boundary_pix
    image_padding = torch.nn.ReflectionPad2d(
        (boundary_pix, pad_x_right, boundary_pix, pad_y_right)
    )  # (padding_left,padding_right,padding_top,padding_bottom)
    input_padded = image_padding(input)
    patch_num = patch_num_x * patch_num_y

    # Cast to float for unfold
    if input.dtype == torch.int32:
        input_padded = input_padded.view(torch.float32)
    elif input.dtype == torch.int64:
        input_padded = input_padded.view(torch.float64)

    x_unfold = torch.nn.functional.unfold(
        input=input_padded,
        kernel_size=(patch_shape_y, patch_shape_x),
        stride=(
            patch_shape_y - overlap_pix - boundary_pix,
            patch_shape_x - overlap_pix - boundary_pix,
        ),
    )

    # Cast back to original dtype
    if input.dtype in [torch.int32, torch.int64]:
        x_unfold = x_unfold.view(input.dtype)

    x_unfold = rearrange(
        x_unfold,
        "b (c p_h p_w) (nb_p_h nb_p_w) -> (nb_p_w nb_p_h b) c p_h p_w",
        p_h=patch_shape_y,
        p_w=patch_shape_x,
        nb_p_h=patch_num_y,
        nb_p_w=patch_num_x,
    )
    if input_interp is not None:
        input_interp_repeated = input_interp.repeat(patch_num, 1, 1, 1)
        return torch.cat((x_unfold, input_interp_repeated), dim=1)
    else:
        return x_unfold


def image_fuse(
    input: Tensor,
    img_shape_y: int,
    img_shape_x: int,
    batch_size: int,
    overlap_pix: int,
    boundary_pix: int,
    overlap_count: Optional[Tensor] = None,
) -> Tensor:
    r"""
    Reconstructs a full image from a batch of patched images. Reverts the patching
    operation performed by :func:`~physicsnemo.utils.patching.image_batching`.

    It assumes that the patches are extracted in a grid-like pattern, and that
    their layout along the batch dimension is the same as the one returned by
    :func:`~physicsnemo.utils.patching.image_batching`.

    This function takes a batch of image patches and reconstructs the full
    image by stitching the patches together. The function accounts for
    overlapping and boundary pixels, ensuring that overlapping areas are
    averaged.
    *Note: a simple unweighted average between overlapping patches is used to
    fuse the patches.*

    Parameters
    ----------
    input : Tensor
        The input tensor containing the image patches with shape :math:`(P \times B, C, H_p, W_p)`.
    img_shape_y : int
        The height :math:`H` of the original full image.
    img_shape_x : int
        The width :math:`W` of the original full image.
    batch_size : int
        The original batch size :math:`B` before patching.
    overlap_pix : int
        The number of overlapping pixels between adjacent patches.
    boundary_pix : int
        The number of pixels to crop as a boundary from each patch.
    overlap_count : Tensor, optional, default=None
        A tensor of shape :math:`(1, 1, H, W)` containing the number of
        overlaps for each pixel (i.e. the number of patches that cover each pixel).
        This is typically computed by
        :meth:`~physicsnemo.utils.patching.GridPatching2D.get_overlap_count`.
        If not provided, it will be computed internally.

    Returns
    -------
    Tensor
        The reconstructed full image tensor with shape :math:`(B, C, H, W)`.
    """

    # Infer sizes from input image shape
    patch_shape_y, patch_shape_x = input.shape[2], input.shape[3]

    # Calculate the number of patches in each dimension
    patch_num_x = math.ceil(img_shape_x / (patch_shape_x - overlap_pix - boundary_pix))
    patch_num_y = math.ceil(img_shape_y / (patch_shape_y - overlap_pix - boundary_pix))

    # Calculate the shape of the input after padding
    padded_shape_x = (
        (patch_shape_x - overlap_pix - boundary_pix) * (patch_num_x - 1)
        + patch_shape_x
        + boundary_pix
    )
    padded_shape_y = (
        (patch_shape_y - overlap_pix - boundary_pix) * (patch_num_y - 1)
        + patch_shape_y
        + boundary_pix
    )
    # Calculate the shape of the padding to add to input
    pad_x_right = padded_shape_x - img_shape_x - boundary_pix
    pad_y_right = padded_shape_y - img_shape_y - boundary_pix
    pad = (boundary_pix, pad_x_right, boundary_pix, pad_y_right)

    # Count local overlaps between patches
    if overlap_count is None:
        overlap_count = GridPatching2D.get_overlap_count(
            (patch_shape_y, patch_shape_x),
            (img_shape_y, img_shape_x),
            overlap_pix,
            boundary_pix,
        )

    if overlap_count.device != input.device:
        overlap_count = overlap_count.to(input.device)

    # Reshape input to make it 3D to apply fold
    x = rearrange(
        input,
        "(nb_p_w nb_p_h b) c p_h p_w -> b (c p_h p_w) (nb_p_h nb_p_w)",
        p_h=patch_shape_y,
        p_w=patch_shape_x,
        nb_p_h=patch_num_y,
        nb_p_w=patch_num_x,
    )

    # Cast to float for fold
    if input.dtype == torch.int32:
        x = x.view(torch.float32)
    elif input.dtype == torch.int64:
        x = x.view(torch.float64)

    # Stitch patches together (by summing over overlapping patches)
    x_folded = torch.nn.functional.fold(
        input=x,
        output_size=(padded_shape_y, padded_shape_x),
        kernel_size=(patch_shape_y, patch_shape_x),
        stride=(
            patch_shape_y - overlap_pix - boundary_pix,
            patch_shape_x - overlap_pix - boundary_pix,
        ),
    )

    # Cast back to original dtype
    if input.dtype in [torch.int32, torch.int64]:
        x_folded = x_folded.view(input.dtype)

    # Remove padding
    x_no_padding = x_folded[
        ..., pad[2] : pad[2] + img_shape_y, pad[0] : pad[0] + img_shape_x
    ]
    overlap_count_no_padding = overlap_count[
        ..., pad[2] : pad[2] + img_shape_y, pad[0] : pad[0] + img_shape_x
    ]

    # Normalize by overlap count
    return x_no_padding / overlap_count_no_padding