import torch from diffusers.models import AutoencoderKL from torch import nn from torchvision.transforms import ToPILImage from torch import Tensor class SynchronizedGroupNorm(nn.Module): def __init__(self, original_group_norm: nn.GroupNorm, num_views: int = 6): super().__init__() self.num_views = num_views # 继承原始分组参数 self.num_groups = original_group_norm.num_groups self.num_channels = original_group_norm.num_channels self.eps = original_group_norm.eps # 按照通道组重构参数 self.group_size = self.num_channels // self.num_groups # 继承原始参数(保持每个分组的仿射变换) self.weight = nn.Parameter(original_group_norm.weight.detach().view(self.num_groups, self.group_size)) self.bias = nn.Parameter(original_group_norm.bias.detach().view(self.num_groups, self.group_size)) def forward(self, x: torch.Tensor): """ 兼容 3D (B, C, D) 和 4D (B, C, H, W) 输入 """ #print(f"Input shape: {x.shape}") # Debugging # 获取输入的形状信息 BxT, C = x.shape[:2] # 只取前两个维度 B = BxT // self.num_views # 计算 batch 维度 # 处理 3D (B, C, D) 输入 if x.dim() == 3: D = x.shape[2] x = x.view(B, self.num_views, self.num_groups, self.group_size, D) # 计算 GroupNorm mean = x.mean(dim=(1, 3, 4), keepdim=True) var = x.var(dim=(1, 3, 4), keepdim=True, unbiased=False) x = (x - mean) / torch.sqrt(var + self.eps) # **修正 weight 和 bias 的形状** weight = self.weight.view(1, self.num_groups, self.group_size, 1) bias = self.bias.view(1, self.num_groups, self.group_size, 1) x = x * weight + bias # 还原形状 return x.view(BxT, C, D) # 处理 4D (B, C, H, W) 输入 elif x.dim() == 4: H, W = x.shape[2:] x = x.view(B, self.num_views, self.num_groups, self.group_size, H, W) # 计算 GroupNorm mean = x.mean(dim=(1, 3, 4, 5), keepdim=True) var = x.var(dim=(1, 3, 4, 5), keepdim=True, unbiased=False) x = (x - mean) / torch.sqrt(var + self.eps) # **修正 weight 和 bias 的形状** weight = self.weight.view(1, self.num_groups, self.group_size, 1, 1) bias = self.bias.view(1, self.num_groups, self.group_size, 1, 1) x = x * weight + bias # 还原形状 return x.view(BxT, C, H, W) else: raise ValueError(f"Unsupported input shape: {x.shape}, expected 3D (B, C, D) or 4D (B, C, H, W).") class CubemapVAE(AutoencoderKL): def __init__(self, pretrained_vae, num_views=6, in_channels=3,image_size=512): super().__init__( # 继承自 AutoencoderKL act_fn="silu", block_out_channels=[128, 256, 512, 512], down_block_types=[ "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D" ], up_block_types=[ "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D" ], latent_channels=pretrained_vae.config.latent_channels, in_channels=in_channels, out_channels=in_channels ) self.num_views = num_views self.in_channels = in_channels # --- 替换关键模块,适配 Cubemap --- # 原 AutoencoderKL 的编码器不够灵活,直接覆盖编码器 #self.encoder = CubemapEncoder(pretrained_encoder=pretrained_vae.encoder,num_views=num_views, in_channels=in_channels) #self.decoder = CubemapDecoder(pretrained_decoder=pretrained_vae.decoder, num_views=num_views, out_channels=in_channels,in_channels=4) self.encoder=pretrained_vae.encoder self.decoder=pretrained_vae.decoder self.quant_conv=pretrained_vae.quant_conv self.post_quant_conv=pretrained_vae.post_quant_conv # 将原 GroupNorm 替换为同步 GroupNorm replace_group_norm_with_sgn(self, num_views=num_views) def encode(self, images,return_dict:bool=True): batch_size, num_views, num_channels, height, width = images.shape images = images.view(batch_size*num_views,num_channels, height, width) return super().encode(images,return_dict=return_dict) def decode(self, latents, return_dict=True, **kwargs): """ 自定义 VAE 解码: - 去掉 UV 通道 (只保留前 4 个 latent 通道) - 调用原始 VAE 解码流程 """ print("Decoder Recieve Latent Shape:", latents.shape) # 确保 latents 至少有 4 个通道 if latents.shape[1] > 4: latents = latents[:, :4, :, :] # 只保留前 4 个通道,去掉 UV 通道 return super().decode(latents, return_dict=return_dict, **kwargs) def decode_to_tensor(self, latents): decoded = self.decode(latents).sample # (B*6, 3, H, W) B = latents.shape[0] // 6 images = torch.split(decoded, B, dim=0) # 按 batch 拆分 return images # Tuple of 6 tensors def decode_to_pil_images(self, latents:Tensor): images = self.decode_to_tensor(latents) # 获取 6 张图 to_pil = ToPILImage() return [to_pil(img[0].cpu().detach()) for img in images] # 转换为 PIL def replace_group_norm_with_sgn(model, num_views): """ 遍历 model,找到所有 GroupNorm 并替换成 SynchronizedGroupNorm """ replacements = [] # 先收集要替换的 module 名称 for name, module in model.named_modules(): if isinstance(module, nn.GroupNorm): replacements.append(name) for name in replacements: parent_module, attr_name = get_parent_module(model, name) setattr(parent_module, attr_name, SynchronizedGroupNorm(getattr(parent_module, attr_name), num_views)) def get_parent_module(model, module_name): """ 获取 `module_name` 所在的上一级 module 和属性名称 """ names = module_name.split(".") parent_module = model for name in names[:-1]: # 遍历到倒数第二层 parent_module = getattr(parent_module, name) return parent_module, names[-1] def flatten_face_names(face_names): flat_face_names = [] for item in face_names: if isinstance(item, str): # 直接是字符串 flat_face_names.append(item) elif isinstance(item, (list,tuple)): # 是列表,展开其中的字符串 flat_face_names.extend(item) else: raise ValueError(f"Unexpected type in face_names: {type(item)}") return flat_face_names