Spaces:
Sleeping
Sleeping
File size: 7,070 Bytes
a521a3f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import torch
from diffusers.models import AutoencoderKL
from torch import nn
from torchvision.transforms import ToPILImage
from torch import Tensor
class SynchronizedGroupNorm(nn.Module):
def __init__(self, original_group_norm: nn.GroupNorm, num_views: int = 6):
super().__init__()
self.num_views = num_views
# 继承原始分组参数
self.num_groups = original_group_norm.num_groups
self.num_channels = original_group_norm.num_channels
self.eps = original_group_norm.eps
# 按照通道组重构参数
self.group_size = self.num_channels // self.num_groups
# 继承原始参数(保持每个分组的仿射变换)
self.weight = nn.Parameter(original_group_norm.weight.detach().view(self.num_groups, self.group_size))
self.bias = nn.Parameter(original_group_norm.bias.detach().view(self.num_groups, self.group_size))
def forward(self, x: torch.Tensor):
""" 兼容 3D (B, C, D) 和 4D (B, C, H, W) 输入 """
#print(f"Input shape: {x.shape}") # Debugging
# 获取输入的形状信息
BxT, C = x.shape[:2] # 只取前两个维度
B = BxT // self.num_views # 计算 batch 维度
# 处理 3D (B, C, D) 输入
if x.dim() == 3:
D = x.shape[2]
x = x.view(B, self.num_views, self.num_groups, self.group_size, D)
# 计算 GroupNorm
mean = x.mean(dim=(1, 3, 4), keepdim=True)
var = x.var(dim=(1, 3, 4), keepdim=True, unbiased=False)
x = (x - mean) / torch.sqrt(var + self.eps)
# **修正 weight 和 bias 的形状**
weight = self.weight.view(1, self.num_groups, self.group_size, 1)
bias = self.bias.view(1, self.num_groups, self.group_size, 1)
x = x * weight + bias
# 还原形状
return x.view(BxT, C, D)
# 处理 4D (B, C, H, W) 输入
elif x.dim() == 4:
H, W = x.shape[2:]
x = x.view(B, self.num_views, self.num_groups, self.group_size, H, W)
# 计算 GroupNorm
mean = x.mean(dim=(1, 3, 4, 5), keepdim=True)
var = x.var(dim=(1, 3, 4, 5), keepdim=True, unbiased=False)
x = (x - mean) / torch.sqrt(var + self.eps)
# **修正 weight 和 bias 的形状**
weight = self.weight.view(1, self.num_groups, self.group_size, 1, 1)
bias = self.bias.view(1, self.num_groups, self.group_size, 1, 1)
x = x * weight + bias
# 还原形状
return x.view(BxT, C, H, W)
else:
raise ValueError(f"Unsupported input shape: {x.shape}, expected 3D (B, C, D) or 4D (B, C, H, W).")
class CubemapVAE(AutoencoderKL):
def __init__(self, pretrained_vae, num_views=6, in_channels=3,image_size=512):
super().__init__( # 继承自 AutoencoderKL
act_fn="silu",
block_out_channels=[128, 256, 512, 512],
down_block_types=[
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D"
],
up_block_types=[
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D"
],
latent_channels=pretrained_vae.config.latent_channels,
in_channels=in_channels,
out_channels=in_channels
)
self.num_views = num_views
self.in_channels = in_channels
# --- 替换关键模块,适配 Cubemap ---
# 原 AutoencoderKL 的编码器不够灵活,直接覆盖编码器
#self.encoder = CubemapEncoder(pretrained_encoder=pretrained_vae.encoder,num_views=num_views, in_channels=in_channels)
#self.decoder = CubemapDecoder(pretrained_decoder=pretrained_vae.decoder, num_views=num_views, out_channels=in_channels,in_channels=4)
self.encoder=pretrained_vae.encoder
self.decoder=pretrained_vae.decoder
self.quant_conv=pretrained_vae.quant_conv
self.post_quant_conv=pretrained_vae.post_quant_conv
# 将原 GroupNorm 替换为同步 GroupNorm
replace_group_norm_with_sgn(self, num_views=num_views)
def encode(self, images,return_dict:bool=True):
batch_size, num_views, num_channels, height, width = images.shape
images = images.view(batch_size*num_views,num_channels, height, width)
return super().encode(images,return_dict=return_dict)
def decode(self, latents, return_dict=True, **kwargs):
"""
自定义 VAE 解码:
- 去掉 UV 通道 (只保留前 4 个 latent 通道)
- 调用原始 VAE 解码流程
"""
print("Decoder Recieve Latent Shape:", latents.shape)
# 确保 latents 至少有 4 个通道
if latents.shape[1] > 4:
latents = latents[:, :4, :, :] # 只保留前 4 个通道,去掉 UV 通道
return super().decode(latents, return_dict=return_dict, **kwargs)
def decode_to_tensor(self, latents):
decoded = self.decode(latents).sample # (B*6, 3, H, W)
B = latents.shape[0] // 6
images = torch.split(decoded, B, dim=0) # 按 batch 拆分
return images # Tuple of 6 tensors
def decode_to_pil_images(self, latents:Tensor):
images = self.decode_to_tensor(latents) # 获取 6 张图
to_pil = ToPILImage()
return [to_pil(img[0].cpu().detach()) for img in images] # 转换为 PIL
def replace_group_norm_with_sgn(model, num_views):
""" 遍历 model,找到所有 GroupNorm 并替换成 SynchronizedGroupNorm """
replacements = [] # 先收集要替换的 module 名称
for name, module in model.named_modules():
if isinstance(module, nn.GroupNorm):
replacements.append(name)
for name in replacements:
parent_module, attr_name = get_parent_module(model, name)
setattr(parent_module, attr_name, SynchronizedGroupNorm(getattr(parent_module, attr_name), num_views))
def get_parent_module(model, module_name):
""" 获取 `module_name` 所在的上一级 module 和属性名称 """
names = module_name.split(".")
parent_module = model
for name in names[:-1]: # 遍历到倒数第二层
parent_module = getattr(parent_module, name)
return parent_module, names[-1]
def flatten_face_names(face_names):
flat_face_names = []
for item in face_names:
if isinstance(item, str): # 直接是字符串
flat_face_names.append(item)
elif isinstance(item, (list,tuple)): # 是列表,展开其中的字符串
flat_face_names.extend(item)
else:
raise ValueError(f"Unexpected type in face_names: {type(item)}")
return flat_face_names
|