File size: 36,740 Bytes
434b0b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
# -*- coding: utf-8 -*-
# @Organization  : Tongyi Lab, Alibaba
# @Author        : Lingteng Qiu
# @Email         : 220019047@link.cuhk.edu.cn
# @Time          : 2025-10-14 19:43:20
# @Function      : GSPlat-based Renderer.

try:
    from gsplat.rendering import rasterization

    gsplat_enable = True
except:
    gsplat_enable = False

from collections import defaultdict

import torch
import torch.nn as nn
from pytorch3d.transforms import matrix_to_quaternion

from core.models.rendering.gs_renderer import GS3DRenderer
from core.models.rendering.utils.typing import *
from core.outputs.output import GaussianAppOutput
from core.structures.camera import Camera
from core.structures.gaussian_model import GaussianModel


def scale_intrs(intrs, ratio_x, ratio_y):
    if len(intrs.shape) >= 3:
        intrs[:, 0] = intrs[:, 0] * ratio_x
        intrs[:, 1] = intrs[:, 1] * ratio_y
    else:
        intrs[0] = intrs[0] * ratio_x
        intrs[1] = intrs[1] * ratio_y
    return intrs


def aabb(xyz):
    return torch.min(xyz, dim=0).values, torch.max(xyz, dim=0).values


class GSPlatRenderer(GS3DRenderer):

    def __init__(
        self,
        human_model_path,
        subdivide_num,
        smpl_type,
        feat_dim,
        query_dim,
        use_rgb,
        sh_degree,
        xyz_offset_max_step,
        mlp_network_config,
        expr_param_dim,
        shape_param_dim,
        clip_scaling=0.2,
        cano_pose_type=0,
        decoder_mlp=False,
        skip_decoder=False,
        fix_opacity=False,
        fix_rotation=False,
        decode_with_extra_info=None,
        gradient_checkpointing=False,
        apply_pose_blendshape=False,
        dense_sample_pts=40000,  # only use for dense_smaple_smplx
        gs_deform_scale=0.005,
        render_features=False,
    ):
        """
        Initializes the GSPlatRenderer, an extension of GS3DRenderer for Gaussian Splatting rendering.

        Args:
            human_model_path (str): Path to human model files.
            subdivide_num (int): Subdivision number for base mesh.
            smpl_type (str): Type of SMPL/SMPL-X/other model to use.
            feat_dim (int): Dimension of feature embeddings.
            query_dim (int): Dimension of query points/features.
            use_rgb (bool): Whether to use RGB channels.
            sh_degree (int): Spherical harmonics degree for appearance.
            xyz_offset_max_step (float): Max offset per step for position.
            mlp_network_config (dict or None): MLP configuration for feature mapping.
            expr_param_dim (int): Expression parameter dimension.
            shape_param_dim (int): Shape parameter dimension.
            clip_scaling (float, optional): Output scaling for decoder. Default 0.2.
            cano_pose_type (int, optional): Canonical pose type. Default 0.
            decoder_mlp (bool, optional): Use MLP in decoder cross-attention. Default False.
            skip_decoder (bool, optional): Whether to skip decoder and cross-attn layers. Default False.
            fix_opacity (bool, optional): Fix opacity during training. Default False.
            fix_rotation (bool, optional): Fix rotation during training. Default False.
            decode_with_extra_info (dict or None, optional): Provide extra info to decoder. Default None.
            gradient_checkpointing (bool, optional): Enable gradient checkpointing. Default False.
            apply_pose_blendshape (bool, optional): Apply pose blendshape. Default False.
            dense_sample_pts (int, optional): Dense sample points for mesh/voxel. Default 40000.
            gs_deform_scale (float, optional): Deformation scale for Gaussian Splatting. Default 0.005.
            render_features (bool, optional): Output additional features in renderer. Default False.
        """

        if gsplat_enable is False:
            raise ImportError("GSPlat is not installed, please install it first.")
        else:
            super(GSPlatRenderer, self).__init__(
                human_model_path,
                subdivide_num,
                smpl_type,
                feat_dim,
                query_dim,
                use_rgb,
                sh_degree,
                xyz_offset_max_step,
                mlp_network_config,
                expr_param_dim,
                shape_param_dim,
                clip_scaling,
                cano_pose_type,
                decoder_mlp,
                skip_decoder,
                fix_opacity,
                fix_rotation,
                decode_with_extra_info,
                gradient_checkpointing,
                apply_pose_blendshape,
                dense_sample_pts,  # only use for dense_smaple_smplx
                gs_deform_scale,
                render_features,
            )

    def get_gaussians_properties(self, viewpoint_camera, gaussian_model):
        """
        Extracts and returns the 3D Gaussian properties for rendering from a GaussianModel instance in the context
        of the provided viewpoint camera.

        Args:
            viewpoint_camera (Camera): The viewpoint camera for rendering. (Unused in this stub, but kept for interface compatibility.)
            gaussian_model (GaussianModel): The GaussianModel object containing the properties to extract.

        Returns:
            Tuple:
                xyz (Tensor): The 3D coordinates of the Gaussians.
                shs (Tensor or None): The spherical harmonics coefficients or None.
                colors_precomp (Tensor): The precomputed RGB colors (if use_rgb).
                opacity (Tensor): The opacities of the Gaussians.
                scales (Tensor): The scaling factors per Gaussian.
                rotations (Tensor): Quaternion rotations per Gaussian.
                cov3D_precomp (None): Reserved for covariance data (not used here).
        """

        xyz = gaussian_model.xyz
        opacity = gaussian_model.opacity
        scales = gaussian_model.scaling
        rotations = gaussian_model.rotation
        cov3D_precomp = None
        shs = None
        if gaussian_model.use_rgb:
            colors_precomp = gaussian_model.shs
        else:
            raise NotImplementedError
        return xyz, shs, colors_precomp, opacity, scales, rotations, cov3D_precomp

    def forward_single_view(
        self,
        gaussian_model: GaussianModel,
        viewpoint_camera: Camera,
        background_color: Optional[Float[Tensor, "3"]],
        ret_mask: bool = True,
        features=None,
    ):
        """
        Renders a single view using the provided GaussianModel and camera parameters.

        Args:
            gaussian_model (GaussianModel): The 3D Gaussian model to be rendered.
            viewpoint_camera (Camera): The viewpoint camera used for rendering (contains intrinsics, extrinsics, size).
            background_color (Optional[Float[Tensor, "3"]]): Optional background color for the rendered image.
            ret_mask (bool, optional): Whether to return the alpha mask. Default: True.
            features (Optional[Tensor], optional): Optional feature tensor to concatenate with the Gaussian appearance.

        Returns:
            dict:
                {
                    "comp_rgb": Rendered RGB image (H, W, 3),
                    "comp_mask": Rendered alpha mask (H, W),
                    "comp_features": (optional) Rendered additional features (if features is not None)
                }
        """

        xyz, shs, colors_precomp, opacity, scales, rotations, cov3D_precomp = (
            self.get_gaussians_properties(viewpoint_camera, gaussian_model)
        )

        intrinsics = viewpoint_camera.intrinsic
        extrinsics = viewpoint_camera.world_view_transform.transpose(
            0, 1
        ).contiguous()  # c2w -> w2c

        img_height = int(viewpoint_camera.height)
        img_width = int(viewpoint_camera.width)

        colors_precomp = colors_precomp.squeeze(1)
        opacity = opacity.squeeze(1)

        if features is not None:
            colors_precomp = torch.cat([colors_precomp, features], dim=1)
            channel = colors_precomp.shape[1]
            background_color = background_color[0].repeat(channel)

        with torch.autocast(device_type=self.device.type, dtype=torch.float32):
            render_rgbd, render_alphas, meta = rasterization(
                means=xyz.float(),
                quats=rotations.float(),
                scales=scales.float(),
                opacities=opacity.float(),
                colors=colors_precomp.float(),
                viewmats=extrinsics.unsqueeze(0).float(),
                Ks=intrinsics.float().unsqueeze(0)[:, :3, :3],
                width=img_width,
                height=img_height,
                near_plane=viewpoint_camera.znear,
                far_plane=viewpoint_camera.zfar,
                # radius_clip=3.0,
                eps2d=0.3,  # 3 pixel
                render_mode="RGB",
                # render_mode="RGB+D",
                backgrounds=background_color.unsqueeze(0).float(),
                camera_model="pinhole",
            )

        render_rgbd = render_rgbd.squeeze(0)
        render_alphas = render_alphas.squeeze(0)

        rendered_image = render_rgbd[:, :, :3]
        # rendered_depth = render_rgbd[:, :, -1:]

        if features is not None:
            ret = {
                "comp_rgb": rendered_image,  # [H, W, 3]
                "comp_features": render_rgbd[:, :, 3:],
                "comp_rgb_bg": background_color,
                "comp_mask": render_alphas,
                # "comp_depth": rendered_depth,
            }
        else:
            ret = {
                "comp_rgb": rendered_image,  # [H, W, 3]
                "comp_rgb_bg": background_color,
                "comp_mask": render_alphas,
                # "comp_depth": rendered_depth,
            }

        return ret


class GSPlatFeatRenderer(GSPlatRenderer):
    """
    GSPlatFeatRenderer extends GSPlatRenderer to support rendering outputs with additional learned feature channels,
    in addition to RGB images, masks, and optional background handling.

    This class modifies the behavior of `_render_views` and related functions to enable rendering per-pixel features,
    which can be used for downstream tasks like representation learning or multimodal perception.

    Typical usage mirrors GSPlatRenderer but expects `features` to be provided as input for rendering and
    produces feature maps in the output dictionary under the key `"comp_features"`, along with the rendered RGB image
    and mask.

    The rest of the pipeline, including Gaussian attribute animation with SMPL-X or other mesh deformation, is inherited.
    """

    def forward_animate_gs(
        self,
        gs_attr_list: List[GaussianAppOutput],
        query_points: Dict[str, Tensor],
        smplx_data: Dict[str, Tensor],
        c2w: Float[Tensor, "B Nv 4 4"],
        intrinsic: Float[Tensor, "B Nv 4 4"],
        height: int,
        width: int,
        background_color: Optional[Float[Tensor, "B Nv 3"]] = None,
        debug: bool = False,
        df_data: Optional[Dict] = None,
        patch_size=14,
        **kwargs,
    ) -> Dict[str, Tensor]:
        """
        Animate and render Gaussian Splatting (GS) models with feature support for a batch of frames/views.

        Args:
            gs_attr_list (List[GaussianAppOutput]):
                List of Gaussian attribute outputs, one per batch item. Each element contains predicted Gaussian parameters
                such as offset positions, opacity, rotation, scaling, and appearance for canonical points.
            query_points (Dict[str, Tensor]):
                Dictionary containing query information, must include:
                    - 'neutral_coords': Tensor of canonical coordinate positions, shape [B, N, 3].
                    - 'mesh_meta': (Optional) Dictionary with mesh region meta-info as required by the skinning/posing models.
            smplx_data (Dict[str, Tensor]):
                Dictionary containing per-batch SMPL-X (or similar model) data for the current animation/frame. Used for pose and shape transformation.
            c2w (Float[Tensor, "B Nv 4 4"]):
                Camera-to-world matrices for the views to render (B: batch, Nv: number of views).
            intrinsic (Float[Tensor, "B Nv 4 4"]):
                Intrinsic camera matrices, shape matches c2w.
            height (int):
                Height of output render images (in pixels).
            width (int):
                Width of output render images (in pixels).
            background_color (Optional[Float[Tensor, "B Nv 3"]], default=None):
                Optional RGB background color per batch/view.
            debug (bool, optional):
                If True, enables debug behavior (e.g., simplifies opacities, disables poses, saves debug visualizations).
            df_data (Optional[Dict], default=None):
                Optional dictionary of additional deformation/feature data.
            patch_size (int, optional):
                Size of patches to use for rendering (default 14).
            **kwargs:
                Additional keyword arguments. Can optionally contain 'features' key for render feature maps.

        Returns:
            Dict[str, Tensor]:
                Dictionary of rendered outputs, including:
                - Rendered features (under 'comp_features'), images, masks, etc., batched accordingly.
                - '3dgs': List of all canonical-space GaussianModel instances for the batch.
        """

        batch_size = len(gs_attr_list)
        out_list, cano_out_list = [], []
        query_points_pos = query_points["neutral_coords"]
        mesh_meta = query_points["mesh_meta"]

        gs_list = []
        for b in range(batch_size):
            # Animate GS models
            anim_models, cano_models = self.animate_gs_model(
                gs_attr_list[b],
                query_points_pos[b],
                self._get_single_batch_data(smplx_data, b),
                debug=debug,
                mesh_meta=mesh_meta,
            )
            gs_list.extend(cano_models)
            features = (
                kwargs["features"][b] if kwargs.get("features") is not None else None
            )
            # Render animated views
            out_list.append(
                self._render_views(
                    anim_models[: c2w.shape[1]],  # Only keep requested views
                    c2w[b],
                    intrinsic[b],
                    height,
                    width,
                    background_color[b] if background_color is not None else None,
                    debug,
                    patch_size=patch_size,
                    features=features,
                )
            )

        results = self._combine_outputs(out_list, cano_out_list)
        results["3dgs"] = gs_list

        return results

    def animate_gs_model(
        self,
        gs_attr: GaussianAppOutput,
        query_points,
        smplx_data,
        debug=False,
        mesh_meta=None,
    ):
        """
        Animates the Gaussian Splatting (GS) model by transforming canonical (neutral) points and attributes into the posed space using SMPL-X model deformations.

        Args:
            gs_attr (GaussianAppOutput): Gaussian attribute output for canonical points, including offset positions, opacity, rotation, scaling, and appearance.
            query_points (Tensor): Canonical query point coordinates, shape (N, 3).
            smplx_data (dict): SMPL-X input data for the current animation frame, including body pose, shape, etc.
            debug (bool, optional): If True, use debug mode (e.g., force all opacities to 1.0, use identity rotations). Default: False.
            mesh_meta (dict, optional): Additional mesh region meta-information (e.g., for constraints). Default: None.

        Returns:
            Tuple[List[GaussianModel], List[GaussianModel]]:
                - gs_list: List of posed-space GaussianModel instances (one per camera/view except canonical view).
                - cano_gs_list: List of canonical-space GaussianModel instances (last view is canonical).
        """

        device = gs_attr.offset_xyz.device

        if debug:
            N = gs_attr.offset_xyz.shape[0]
            gs_attr.xyz = torch.zeros_like(gs_attr.offset_xyz)
            gs_attr.opacity = torch.ones((N, 1), device=device)
            gs_attr.rotation = matrix_to_quaternion(
                torch.eye(3, device=device).expand(N, 3, 3)
            )

        # build cano_dependent_pose
        merge_smplx_data = self._prepare_smplx_data(smplx_data)

        posed_points = self._transform_points(
            merge_smplx_data, query_points, gs_attr.offset_xyz, device, mesh_meta
        )
        rotation_pose_verts = self._compute_rotations(
            posed_points["transform_mat_posed"],
            gs_attr.rotation,
            device,
            posed_points["mesh_meta"],
        )

        return self._create_gaussian_models(
            posed_points["posed_coords"],
            gs_attr,
            rotation_pose_verts,
            merge_smplx_data["body_pose"].shape[0],
        )

    def _render_views(
        self,
        gs_list: List[GaussianModel],
        c2w: Tensor,
        intrinsic: Tensor,
        height: int,
        width: int,
        bg_color: Optional[Tensor],
        debug: bool,
        patch_size: int,
        **kwargs,
    ) -> Dict[str, Tensor]:
        """
        Renders multiple views for a list of GaussianModel instances.

        Args:
            gs_list (List[GaussianModel]): List of GaussianModel instances to render for each view.
            c2w (Tensor): Camera-to-world matrices for each view, shape [Nv, 4, 4].
            intrinsic (Tensor): Intrinsic camera matrices for each view, shape [Nv, 4, 4].
            height (int): Output image height (in pixels).
            width (int): Output image width (in pixels).
            bg_color (Optional[Tensor]): Optional background color tensor, shape [Nv, 3] or None.
            debug (bool): If True, enable debug visualizations/saving intermediate images.
            patch_size (int): Patch size used in rendering for downsampling/aggregation.
            **kwargs: Additional keyword arguments, can include 'features' per view if rendering extra feature channels.

        Returns:
            Dict[str, Tensor]:
                Dictionary with rendering outputs for all views.
                Main key 'render' is a list of dictionaries per view,
                each with keys such as 'comp_rgb', 'comp_mask', and, if features are provided, 'comp_features'.
        """
        # obtain device
        self.device = gs_list[0].xyz.device
        results = defaultdict(list)

        for v_idx, gs in enumerate(gs_list):

            if self.render_features:
                render_features = kwargs["features"]
            else:
                render_features = None

            camera = Camera.from_c2w(c2w[v_idx], intrinsic[v_idx], height, width)
            results["render"].append(
                self.forward_single_view(
                    gs,
                    camera,
                    bg_color[v_idx],
                    features=render_features,
                    patch_size=patch_size,
                )
            )

            if debug and v_idx == 0:
                self._debug_save_image(results["render"][-1]["comp_rgb"])

        return results

    def forward_single_view(
        self,
        gaussian_model: GaussianModel,
        viewpoint_camera: Camera,
        background_color: Optional[Float[Tensor, "3"]],
        ret_mask: bool = True,
        features=None,
        patch_size=14,
    ):
        """
        Renders a single view for a GaussianModel with optional extra features and positional embedding patching.

        Args:
            gaussian_model (GaussianModel): The 3D Gaussian model to be rendered.
            viewpoint_camera (Camera): The viewpoint camera containing intrinsics, extrinsics, and image size.
            background_color (Optional[Float[Tensor, "3"]]): RGB background color for the rendered image.
            ret_mask (bool, optional): Whether to return the alpha mask in the output. Default: True.
            features (Optional[Tensor], optional): Additional features to concatenate with Gaussian appearance. Default: None.
            patch_size (int, optional): Size of patches for rendering downsampling. Default: 14.

        Returns:
            dict:
                {
                    "comp_rgb": Rendered RGB image (patch_height, patch_width, 3),
                    "comp_mask": Rendered alpha mask (patch_height, patch_width),
                    "comp_features": (optional) Rendered feature image (patch_height, patch_width, channels) if features are provided,
                    ... (other keys as required for downstream use)
                }
        """

        xyz, shs, colors_precomp, opacity, scales, rotations, cov3D_precomp = (
            self.get_gaussians_properties(viewpoint_camera, gaussian_model)
        )

        extrinsics = viewpoint_camera.world_view_transform.transpose(
            0, 1
        ).contiguous()  # c2w -> w2c

        # scale_ratio for patch size
        img_height = int(viewpoint_camera.height)
        img_width = int(viewpoint_camera.width)

        patch_height = img_height // patch_size
        patch_width = img_width // patch_size

        scale_y = img_height / patch_height
        scale_x = img_width / patch_width
        intrinsics = scale_intrs(
            viewpoint_camera.intrinsic.clone(), 1.0 / scale_x, 1.0 / scale_y
        )

        colors_precomp = colors_precomp.squeeze(1)
        opacity = opacity.squeeze(1)

        if features is not None:
            colors_precomp = torch.cat([colors_precomp, features], dim=1)
            channel = colors_precomp.shape[1]
            background_color = background_color[0].repeat(channel)

        with torch.autocast(device_type=self.device.type, dtype=torch.float32):
            render_rgbd, render_alphas, meta = rasterization(
                means=xyz.float(),
                quats=rotations.float(),
                scales=scales.float(),
                opacities=opacity.float(),
                colors=colors_precomp.float(),
                viewmats=extrinsics.unsqueeze(0).float(),
                Ks=intrinsics.float().unsqueeze(0)[:, :3, :3],
                width=patch_width,
                height=patch_height,
                near_plane=viewpoint_camera.znear,
                far_plane=viewpoint_camera.zfar,
                # radius_clip=3.0,
                eps2d=0.3,  # 3 pixel
                render_mode="RGB",
                # render_mode="RGB+D",
                backgrounds=background_color.unsqueeze(0).float(),
                camera_model="pinhole",
            )

        render_rgbd = render_rgbd.squeeze(0)
        render_alphas = render_alphas.squeeze(0)

        rendered_image = render_rgbd[:, :, :3]

        if features is not None:
            ret = {
                "comp_features": render_rgbd[:, :, 3:],
                "comp_mask": render_alphas,
                "comp_rgb": rendered_image,  # [H, W, 3]
                # "comp_depth": rendered_depth,
            }
        else:
            ret = {
                "comp_rgb": rendered_image,  # [H, W, 3]
                "comp_rgb_bg": background_color,
                "comp_mask": render_alphas,
                # "comp_depth": rendered_depth,
            }

        return ret

    def forward(
        self,
        gs_hidden_features: Float[Tensor, "B Np Cp"],
        query_points: dict,
        smplx_data,  # e.g., body_pose:[B, Nv, 21, 3], betas:[B, 100]
        c2w: Float[Tensor, "B Nv 4 4"],
        intrinsic: Float[Tensor, "B Nv 4 4"],
        height,
        width,
        additional_features: Optional[Float[Tensor, "B C H W"]] = None,
        background_color: Optional[Float[Tensor, "B Nv 3"]] = None,
        debug: bool = False,
        **kwargs,
    ):
        """
        Forward rendering pass for the GSPlatFeatPosEmbedRenderer.

        Args:
            gs_hidden_features (Float[Tensor, "B Np Cp"]):
                Gaussian hidden features for the batch (B: batch size, Np: number of points, Cp: feature channels).
            query_points (dict):
                Dictionary containing query data, typically including canonical coordinates and mesh metadata.
            smplx_data:
                SMPL-X (or similar parametric model) data; for example, may contain pose, shape, and other parameters.
            c2w (Float[Tensor, "B Nv 4 4"]):
                Camera-to-world transformation matrices for each batch and view.
            intrinsic (Float[Tensor, "B Nv 4 4"]):
                Intrinsic camera matrices, matching batch and view dimensions.
            height (int):
                Height of the output images.
            width (int):
                Width of the output images.
            additional_features (Optional[Float[Tensor, "B C H W"]], optional):
                Additional feature maps to include in rendering, if provided.
            background_color (Optional[Float[Tensor, "B Nv 3"]], optional):
                Optional background RGB colors per batch/view.
            debug (bool, optional):
                If True, enables debug information or simplified behavior.
            **kwargs:
                Additional keyword arguments. Typically includes:
                    - df_data: Deformation/feature data for advanced driving fields.
                    - patch_size: Patch size for chunked rendering.

        Returns:
            Dict[str, Tensor]: Output dictionary containing rendered images (e.g., 'comp_rgb'),
            optional features (e.g., 'comp_features'), masks, Gaussian attributes, and mesh metadata.
        """

        # need shape_params of smplx_data to get querty points and get "transform_mat_neutral_pose"
        # only forward gs params

        gs_attr_list, query_points, smplx_data = self.forward_gs(
            gs_hidden_features,
            query_points,
            smplx_data=smplx_data,
            additional_features=additional_features,
            debug=debug,
        )

        out = self.forward_animate_gs(
            gs_attr_list,
            query_points,
            smplx_data,
            c2w,
            intrinsic,
            height,
            width,
            background_color,
            debug,
            df_data=kwargs["df_data"],
            patch_size=kwargs.get("patch_size", 14),
            features=gs_hidden_features,
        )

        out["gs_attr"] = gs_attr_list
        out["mesh_meta"] = query_points["mesh_meta"]

        return out

    def _combine_outputs(
        self, out_list: List[Dict], cano_out_list: List[Dict]
    ) -> Dict[str, Tensor]:
        """
        Combines the outputs from multiple rendered views (out_list) and canonical outputs (cano_out_list) into a single batched output dictionary.

        Args:
            out_list (List[Dict]):
                A list containing the output dictionaries for each batch/item, where each dictionary holds the rendered outputs for each viewpoint.
            cano_out_list (List[Dict]):
                A list of output dictionaries for the canonical (neutral pose) view per batch item.

        Returns:
            Dict[str, Tensor]:
                A combined dictionary of batched output tensors (organized by key), typically containing:
                    - Rendered RGB images, masks, and features with shape [batch_size, num_views, ...]
                    - Canonical view outputs under distinct keys, if required.
                    - Any other output keys from the rendering pipeline.
        """

        batch_size = len(out_list)

        combined = defaultdict(list)
        for out in out_list:
            # Collect render outputs
            for render_item in out["render"]:
                for k, v in render_item.items():
                    combined[k].append(v)

        # Reshape and permute tensors
        result = {
            k: torch.stack(v).view(batch_size, -1, *v[0].shape).permute(0, 1, 4, 2, 3)
            for k, v in combined.items()
            if torch.stack(v).dim() >= 4
        }

        return result


class GSPlatBackFeatRenderer(GSPlatFeatRenderer):
    """Adding an embedding to model background color"""

    def __init__(
        self,
        human_model_path,
        subdivide_num,
        smpl_type,
        feat_dim,
        query_dim,
        use_rgb,
        sh_degree,
        xyz_offset_max_step,
        mlp_network_config,
        expr_param_dim,
        shape_param_dim,
        clip_scaling=0.2,
        cano_pose_type=0,
        decoder_mlp=False,
        skip_decoder=False,
        fix_opacity=False,
        fix_rotation=False,
        decode_with_extra_info=None,
        gradient_checkpointing=False,
        apply_pose_blendshape=False,
        dense_sample_pts=40000,  # only use for dense_smaple_smplx
        gs_deform_scale=0.005,
        render_features=False,
    ):
        """
        Initializes the GSPlatBackFeatRenderer, an extension of GSPlatFeatRenderer adding a learnable embedding
            to model background color.

        Args:
            human_model_path (str): Path to human model files.
            subdivide_num (int): Subdivision number for base mesh.
            smpl_type (str): Type of SMPL/SMPL-X/other model to use.
            feat_dim (int): Dimension of feature embeddings.
            query_dim (int): Dimension of query points/features.
            use_rgb (bool): Whether to use RGB channels.
            sh_degree (int): Spherical harmonics degree for appearance.
            xyz_offset_max_step (float): Max offset per step for position.
            mlp_network_config (dict or None): MLP configuration for feature mapping.
            expr_param_dim (int): Expression parameter dimension.
            shape_param_dim (int): Shape parameter dimension.
            clip_scaling (float, optional): Output scaling for decoder. Default 0.2.
            cano_pose_type (int, optional): Canonical pose type. Default 0.
            decoder_mlp (bool, optional): Use MLP in decoder cross-attention. Default False.
            skip_decoder (bool, optional): Whether to skip decoder and cross-attn layers. Default False.
            fix_opacity (bool, optional): Fix opacity during training. Default False.
            fix_rotation (bool, optional): Fix rotation during training. Default False.
            decode_with_extra_info (dict or None, optional): Provide extra info to decoder. Default None.
            gradient_checkpointing (bool, optional): Enable gradient checkpointing. Default False.
            apply_pose_blendshape (bool, optional): Apply pose blendshape. Default False.
            dense_sample_pts (int, optional): Dense sample points for mesh/voxel. Default 40000.
            gs_deform_scale (float, optional): Deformation scale for Gaussian Splatting. Default 0.005.
            render_features (bool, optional): Output additional features in renderer. Default False.

        Adds:
            self.background_embedding (nn.Parameter): Learnable background embedding for enhanced modeling.
        """

        super(GSPlatBackFeatRenderer, self).__init__(
            human_model_path,
            subdivide_num,
            smpl_type,
            feat_dim,
            query_dim,
            use_rgb,
            sh_degree,
            xyz_offset_max_step,
            mlp_network_config,
            expr_param_dim,
            shape_param_dim,
            clip_scaling,
            cano_pose_type,
            decoder_mlp,
            skip_decoder,
            fix_opacity,
            fix_rotation,
            decode_with_extra_info,
            gradient_checkpointing,
            apply_pose_blendshape,
            dense_sample_pts,  # only use for dense_smaple_smplx
            gs_deform_scale,
            render_features,
        )

        # learnable positional embedding
        self.background_embedding = nn.Parameter(torch.zeros(3, 128))
        # xavier init
        nn.init.xavier_uniform_(self.background_embedding)

    def forward_single_view(
        self,
        gaussian_model: GaussianModel,
        viewpoint_camera: Camera,
        background_color: Optional[Float[Tensor, "3"]],
        ret_mask: bool = True,
        features=None,
        patch_size=14,
    ):
        """
        Renders a single view using the provided GaussianModel and camera parameters, with support for patch-based rendering
        and learnable background embeddings for enhanced modeling.

        Args:
            gaussian_model (GaussianModel): The 3D Gaussian model to be rendered.
            viewpoint_camera (Camera): The viewpoint camera used for rendering (contains intrinsics, extrinsics, size).
            background_color (Optional[Float[Tensor, "3"]]): Optional background color for the rendered image.
            ret_mask (bool, optional): Whether to return the alpha mask. Default is True.
            features (Optional[Tensor], optional): Optional feature tensor to concatenate with the Gaussian appearance.
            patch_size (int, optional): Size of the rendering patches (default: 14).

        Returns:
            dict:
                {
                    "comp_rgb": Rendered RGB image (H, W, 3),
                    "comp_features": (optional) Rendered additional features (if features is not None),
                    "comp_rgb_bg": Rendered background color,
                    "comp_mask": Rendered alpha mask (H, W),
                    # "comp_depth": Optionally rendered depth map (if enabled)
                }
        """

        xyz, shs, colors_precomp, opacity, scales, rotations, cov3D_precomp = (
            self.get_gaussians_properties(viewpoint_camera, gaussian_model)
        )

        extrinsics = viewpoint_camera.world_view_transform.transpose(
            0, 1
        ).contiguous()  # c2w -> w2c

        # scale_ratio for patch size
        img_height = int(viewpoint_camera.height)
        img_width = int(viewpoint_camera.width)

        patch_height = img_height // patch_size
        patch_width = img_width // patch_size

        scale_y = img_height / patch_height
        scale_x = img_width / patch_width
        intrinsics = scale_intrs(
            viewpoint_camera.intrinsic.clone(), 1.0 / scale_x, 1.0 / scale_y
        )

        colors_precomp = colors_precomp.squeeze(1)
        opacity = opacity.squeeze(1)

        if features is not None:
            colors_precomp = torch.cat([colors_precomp, features], dim=1)
            channel = colors_precomp.shape[1]
            bg_idx = (background_color[0] / 0.5).int().item()
            background_embedding = self.background_embedding[bg_idx]
            background_color = torch.cat(
                [background_color, background_embedding], dim=0
            )

        with torch.autocast(device_type=self.device.type, dtype=torch.float32):
            render_rgbd, render_alphas, meta = rasterization(
                means=xyz.float(),
                quats=rotations.float(),
                scales=scales.float(),
                opacities=opacity.float(),
                colors=colors_precomp.float(),
                viewmats=extrinsics.unsqueeze(0).float(),
                Ks=intrinsics.float().unsqueeze(0)[:, :3, :3],
                width=patch_width,
                height=patch_height,
                near_plane=viewpoint_camera.znear,
                far_plane=viewpoint_camera.zfar,
                # radius_clip=3.0,
                eps2d=0.3,  # 3 pixel
                render_mode="RGB",
                # render_mode="RGB+D",
                backgrounds=background_color.unsqueeze(0).float(),
                camera_model="pinhole",
            )

        render_rgbd = render_rgbd.squeeze(0)
        render_alphas = render_alphas.squeeze(0)

        rendered_image = render_rgbd[:, :, :3]
        rendered_features = render_rgbd[:, :, 3:]

        if features is not None:
            ret = {
                "comp_features": rendered_features,
                "comp_mask": render_alphas,
                "comp_rgb": rendered_image,  # [H, W, 3]
                # "comp_depth": rendered_depth,
            }
        else:
            ret = {
                "comp_rgb": rendered_image,  # [H, W, 3]
                "comp_rgb_bg": background_color,
                "comp_mask": render_alphas,
                # "comp_depth": rendered_depth,
            }

        return ret