SpatialDiffusion / scripts /cubemap_unet.py
zimhe
add scripts and examples
a521a3f
import torch
import torch.nn as nn
from diffusers import UNet2DConditionModel
from diffusers.models.attention import BasicTransformerBlock
from .cubemap_unet_attention_processor import InflatedAttentionProcessor
class CubemapUNet(UNet2DConditionModel):
def __init__(self, pretrained_unet, in_channels=11):
"""
1. **先加载 `pretrained_unet` 的默认通道数**
2. **完全加载 UNet 预训练权重**
3. **扩展 `conv_in` 和 `conv_out`**
"""
# 从 `pretrained_unet.config` 复制所有参数,并修改 `in_channels`
unet_config = {**pretrained_unet.config} # 复制字典,防止修改原模型
super().__init__(**unet_config) # 这里直接传入所有参数
# Step 2: 完全加载 `pretrained_unet` 的权重
print("开始加载预训练权重")
self.load_state_dict(pretrained_unet.state_dict(), strict=True)
print("✅ UNet 预训练权重加载成功!")
# Step 3: **扩展 `conv_in`通道为11
self._expand_conv_in(pretrained_unet, in_channels)
self.register_to_config(in_channels=in_channels)
cubemap_attn_processor=InflatedAttentionProcessor()
self._modify_attn_processor(cubemap_attn_processor)
def _expand_conv_in(self, pretrained_unet, new_in_channels):
"""扩展 `conv_in` 以适应新的输入通道"""
old_conv_in = pretrained_unet.conv_in
old_weight = old_conv_in.weight # [out_channels, 4, kernel, kernel]
# 创建新的 `conv_in`
new_conv_in = nn.Conv2d(
new_in_channels,
old_conv_in.out_channels,
kernel_size=old_conv_in.kernel_size,
stride=old_conv_in.stride,
padding=old_conv_in.padding
)
old_channels=old_weight.shape[1]
# 复制前 `4` 个通道的权重
new_weight = torch.zeros((new_conv_in.out_channels, new_in_channels, *old_conv_in.kernel_size))
new_weight[:, :old_channels, :, :] = old_weight # 复制 4 通道
# 随机初始化新增通道
new_weight[:, old_channels:, :, :] = torch.randn_like(new_weight[:, old_channels:, :, :]) * 0.01
new_conv_in.weight = nn.Parameter(new_weight)
if old_conv_in.bias is not None:
new_conv_in.bias = nn.Parameter(old_conv_in.bias.clone())
self.conv_in = new_conv_in
print(f"✅ `conv_in` 扩展成功!新输入通道: {new_in_channels}")
def _modify_attn_processor(self,processor):
for name, module in self.named_modules():
if isinstance(module, BasicTransformerBlock):
if hasattr(module, 'attn1'):
module.attn1.set_processor(processor)
if hasattr(module, 'attn2'):
module.attn2.set_processor(processor)