File size: 7,070 Bytes
a521a3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import torch
from diffusers.models import AutoencoderKL
from torch import nn
from torchvision.transforms import ToPILImage
from torch import Tensor

class SynchronizedGroupNorm(nn.Module):
    def __init__(self, original_group_norm: nn.GroupNorm, num_views: int = 6):
        super().__init__()
        self.num_views = num_views
        
        # 继承原始分组参数
        self.num_groups = original_group_norm.num_groups
        self.num_channels = original_group_norm.num_channels
        self.eps = original_group_norm.eps
        
        # 按照通道组重构参数
        self.group_size = self.num_channels // self.num_groups
        
        # 继承原始参数(保持每个分组的仿射变换)
        self.weight = nn.Parameter(original_group_norm.weight.detach().view(self.num_groups, self.group_size))
        self.bias = nn.Parameter(original_group_norm.bias.detach().view(self.num_groups, self.group_size))
    
    def forward(self, x: torch.Tensor):
        """ 兼容 3D (B, C, D) 和 4D (B, C, H, W) 输入 """
        
        #print(f"Input shape: {x.shape}")  # Debugging

        # 获取输入的形状信息
        BxT, C = x.shape[:2]  # 只取前两个维度
        B = BxT // self.num_views  # 计算 batch 维度

        # 处理 3D (B, C, D) 输入
        if x.dim() == 3:
            D = x.shape[2]
            x = x.view(B, self.num_views, self.num_groups, self.group_size, D)

            # 计算 GroupNorm
            mean = x.mean(dim=(1, 3, 4), keepdim=True)
            var = x.var(dim=(1, 3, 4), keepdim=True, unbiased=False)
            x = (x - mean) / torch.sqrt(var + self.eps)
            # **修正 weight 和 bias 的形状**
            weight = self.weight.view(1, self.num_groups, self.group_size, 1)
            bias = self.bias.view(1, self.num_groups, self.group_size, 1)
            x = x * weight + bias

            # 还原形状
            return x.view(BxT, C, D)

        # 处理 4D (B, C, H, W) 输入
        elif x.dim() == 4:
            H, W = x.shape[2:]
            x = x.view(B, self.num_views, self.num_groups, self.group_size, H, W)

            # 计算 GroupNorm
            mean = x.mean(dim=(1, 3, 4, 5), keepdim=True)
            var = x.var(dim=(1, 3, 4, 5), keepdim=True, unbiased=False)
            x = (x - mean) / torch.sqrt(var + self.eps)
            # **修正 weight 和 bias 的形状**
            weight = self.weight.view(1, self.num_groups, self.group_size, 1, 1)
            bias = self.bias.view(1, self.num_groups, self.group_size, 1, 1)
            x = x * weight + bias

            # 还原形状
            return x.view(BxT, C, H, W)

        else:
            raise ValueError(f"Unsupported input shape: {x.shape}, expected 3D (B, C, D) or 4D (B, C, H, W).")
        

class CubemapVAE(AutoencoderKL):
    def __init__(self, pretrained_vae, num_views=6, in_channels=3,image_size=512):
        super().__init__(  # 继承自 AutoencoderKL
            act_fn="silu",
            block_out_channels=[128, 256, 512, 512],
            down_block_types=[
                "DownEncoderBlock2D",
                "DownEncoderBlock2D",
                "DownEncoderBlock2D",
                "DownEncoderBlock2D"
            ],
            up_block_types=[
                "UpDecoderBlock2D",
                "UpDecoderBlock2D",
                "UpDecoderBlock2D",
                "UpDecoderBlock2D"
            ],
            latent_channels=pretrained_vae.config.latent_channels,
            in_channels=in_channels,  
            out_channels=in_channels  
        )
        self.num_views = num_views
        self.in_channels = in_channels
        
        
        # --- 替换关键模块,适配 Cubemap ---
        # 原 AutoencoderKL 的编码器不够灵活,直接覆盖编码器
        #self.encoder = CubemapEncoder(pretrained_encoder=pretrained_vae.encoder,num_views=num_views, in_channels=in_channels)
        #self.decoder = CubemapDecoder(pretrained_decoder=pretrained_vae.decoder, num_views=num_views, out_channels=in_channels,in_channels=4)
        self.encoder=pretrained_vae.encoder
        self.decoder=pretrained_vae.decoder
        self.quant_conv=pretrained_vae.quant_conv
        self.post_quant_conv=pretrained_vae.post_quant_conv
        # 将原 GroupNorm 替换为同步 GroupNorm
        replace_group_norm_with_sgn(self, num_views=num_views)
        
    def encode(self, images,return_dict:bool=True):
        batch_size, num_views, num_channels, height, width = images.shape
        images = images.view(batch_size*num_views,num_channels, height, width)
        return super().encode(images,return_dict=return_dict)
        
        
    def decode(self, latents, return_dict=True, **kwargs):
        """
        自定义 VAE 解码:
        - 去掉 UV 通道 (只保留前 4 个 latent 通道)
        - 调用原始 VAE 解码流程
        """
        
        print("Decoder Recieve Latent Shape:", latents.shape)
        # 确保 latents 至少有 4 个通道
        if latents.shape[1] > 4:
            latents = latents[:, :4, :, :]  # 只保留前 4 个通道,去掉 UV 通道
            
        
        
        return super().decode(latents, return_dict=return_dict, **kwargs)
    
    def decode_to_tensor(self, latents):
        decoded = self.decode(latents).sample  # (B*6, 3, H, W)
        
        B = latents.shape[0] // 6
        images = torch.split(decoded, B, dim=0)  # 按 batch 拆分
        
        return images  # Tuple of 6 tensors
    
    def decode_to_pil_images(self, latents:Tensor):
        images = self.decode_to_tensor(latents)  # 获取 6 张图
        to_pil = ToPILImage()
        
        return [to_pil(img[0].cpu().detach()) for img in images]  # 转换为 PIL



def replace_group_norm_with_sgn(model, num_views):
    """ 遍历 model,找到所有 GroupNorm 并替换成 SynchronizedGroupNorm """
    replacements = []  # 先收集要替换的 module 名称
    for name, module in model.named_modules():
        if isinstance(module, nn.GroupNorm):
            replacements.append(name)

    for name in replacements:
        parent_module, attr_name = get_parent_module(model, name)
        setattr(parent_module, attr_name, SynchronizedGroupNorm(getattr(parent_module, attr_name), num_views))

def get_parent_module(model, module_name):
    """ 获取 `module_name` 所在的上一级 module 和属性名称 """
    names = module_name.split(".")
    parent_module = model
    for name in names[:-1]:  # 遍历到倒数第二层
        parent_module = getattr(parent_module, name)
    return parent_module, names[-1]

            

def flatten_face_names(face_names):
    flat_face_names = []
    for item in face_names:
        if isinstance(item, str):  # 直接是字符串
            flat_face_names.append(item)
        elif isinstance(item, (list,tuple)):  # 是列表,展开其中的字符串
            flat_face_names.extend(item)
        else:
            raise ValueError(f"Unexpected type in face_names: {type(item)}")
    return flat_face_names