SpatialDiffusion / scripts /cubemap_vae.py
zimhe
add scripts and examples
a521a3f
raw
history blame
7.07 kB
import torch
from diffusers.models import AutoencoderKL
from torch import nn
from torchvision.transforms import ToPILImage
from torch import Tensor
class SynchronizedGroupNorm(nn.Module):
def __init__(self, original_group_norm: nn.GroupNorm, num_views: int = 6):
super().__init__()
self.num_views = num_views
# 继承原始分组参数
self.num_groups = original_group_norm.num_groups
self.num_channels = original_group_norm.num_channels
self.eps = original_group_norm.eps
# 按照通道组重构参数
self.group_size = self.num_channels // self.num_groups
# 继承原始参数(保持每个分组的仿射变换)
self.weight = nn.Parameter(original_group_norm.weight.detach().view(self.num_groups, self.group_size))
self.bias = nn.Parameter(original_group_norm.bias.detach().view(self.num_groups, self.group_size))
def forward(self, x: torch.Tensor):
""" 兼容 3D (B, C, D) 和 4D (B, C, H, W) 输入 """
#print(f"Input shape: {x.shape}") # Debugging
# 获取输入的形状信息
BxT, C = x.shape[:2] # 只取前两个维度
B = BxT // self.num_views # 计算 batch 维度
# 处理 3D (B, C, D) 输入
if x.dim() == 3:
D = x.shape[2]
x = x.view(B, self.num_views, self.num_groups, self.group_size, D)
# 计算 GroupNorm
mean = x.mean(dim=(1, 3, 4), keepdim=True)
var = x.var(dim=(1, 3, 4), keepdim=True, unbiased=False)
x = (x - mean) / torch.sqrt(var + self.eps)
# **修正 weight 和 bias 的形状**
weight = self.weight.view(1, self.num_groups, self.group_size, 1)
bias = self.bias.view(1, self.num_groups, self.group_size, 1)
x = x * weight + bias
# 还原形状
return x.view(BxT, C, D)
# 处理 4D (B, C, H, W) 输入
elif x.dim() == 4:
H, W = x.shape[2:]
x = x.view(B, self.num_views, self.num_groups, self.group_size, H, W)
# 计算 GroupNorm
mean = x.mean(dim=(1, 3, 4, 5), keepdim=True)
var = x.var(dim=(1, 3, 4, 5), keepdim=True, unbiased=False)
x = (x - mean) / torch.sqrt(var + self.eps)
# **修正 weight 和 bias 的形状**
weight = self.weight.view(1, self.num_groups, self.group_size, 1, 1)
bias = self.bias.view(1, self.num_groups, self.group_size, 1, 1)
x = x * weight + bias
# 还原形状
return x.view(BxT, C, H, W)
else:
raise ValueError(f"Unsupported input shape: {x.shape}, expected 3D (B, C, D) or 4D (B, C, H, W).")
class CubemapVAE(AutoencoderKL):
def __init__(self, pretrained_vae, num_views=6, in_channels=3,image_size=512):
super().__init__( # 继承自 AutoencoderKL
act_fn="silu",
block_out_channels=[128, 256, 512, 512],
down_block_types=[
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D"
],
up_block_types=[
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D"
],
latent_channels=pretrained_vae.config.latent_channels,
in_channels=in_channels,
out_channels=in_channels
)
self.num_views = num_views
self.in_channels = in_channels
# --- 替换关键模块,适配 Cubemap ---
# 原 AutoencoderKL 的编码器不够灵活,直接覆盖编码器
#self.encoder = CubemapEncoder(pretrained_encoder=pretrained_vae.encoder,num_views=num_views, in_channels=in_channels)
#self.decoder = CubemapDecoder(pretrained_decoder=pretrained_vae.decoder, num_views=num_views, out_channels=in_channels,in_channels=4)
self.encoder=pretrained_vae.encoder
self.decoder=pretrained_vae.decoder
self.quant_conv=pretrained_vae.quant_conv
self.post_quant_conv=pretrained_vae.post_quant_conv
# 将原 GroupNorm 替换为同步 GroupNorm
replace_group_norm_with_sgn(self, num_views=num_views)
def encode(self, images,return_dict:bool=True):
batch_size, num_views, num_channels, height, width = images.shape
images = images.view(batch_size*num_views,num_channels, height, width)
return super().encode(images,return_dict=return_dict)
def decode(self, latents, return_dict=True, **kwargs):
"""
自定义 VAE 解码:
- 去掉 UV 通道 (只保留前 4 个 latent 通道)
- 调用原始 VAE 解码流程
"""
print("Decoder Recieve Latent Shape:", latents.shape)
# 确保 latents 至少有 4 个通道
if latents.shape[1] > 4:
latents = latents[:, :4, :, :] # 只保留前 4 个通道,去掉 UV 通道
return super().decode(latents, return_dict=return_dict, **kwargs)
def decode_to_tensor(self, latents):
decoded = self.decode(latents).sample # (B*6, 3, H, W)
B = latents.shape[0] // 6
images = torch.split(decoded, B, dim=0) # 按 batch 拆分
return images # Tuple of 6 tensors
def decode_to_pil_images(self, latents:Tensor):
images = self.decode_to_tensor(latents) # 获取 6 张图
to_pil = ToPILImage()
return [to_pil(img[0].cpu().detach()) for img in images] # 转换为 PIL
def replace_group_norm_with_sgn(model, num_views):
""" 遍历 model,找到所有 GroupNorm 并替换成 SynchronizedGroupNorm """
replacements = [] # 先收集要替换的 module 名称
for name, module in model.named_modules():
if isinstance(module, nn.GroupNorm):
replacements.append(name)
for name in replacements:
parent_module, attr_name = get_parent_module(model, name)
setattr(parent_module, attr_name, SynchronizedGroupNorm(getattr(parent_module, attr_name), num_views))
def get_parent_module(model, module_name):
""" 获取 `module_name` 所在的上一级 module 和属性名称 """
names = module_name.split(".")
parent_module = model
for name in names[:-1]: # 遍历到倒数第二层
parent_module = getattr(parent_module, name)
return parent_module, names[-1]
def flatten_face_names(face_names):
flat_face_names = []
for item in face_names:
if isinstance(item, str): # 直接是字符串
flat_face_names.append(item)
elif isinstance(item, (list,tuple)): # 是列表,展开其中的字符串
flat_face_names.extend(item)
else:
raise ValueError(f"Unexpected type in face_names: {type(item)}")
return flat_face_names