import torch import torch.nn as nn from diffusers import UNet2DConditionModel from diffusers.models.attention import BasicTransformerBlock from .cubemap_unet_attention_processor import InflatedAttentionProcessor class CubemapUNet(UNet2DConditionModel): def __init__(self, pretrained_unet, in_channels=11): """ 1. **先加载 `pretrained_unet` 的默认通道数** 2. **完全加载 UNet 预训练权重** 3. **扩展 `conv_in` 和 `conv_out`** """ # 从 `pretrained_unet.config` 复制所有参数,并修改 `in_channels` unet_config = {**pretrained_unet.config} # 复制字典,防止修改原模型 super().__init__(**unet_config) # 这里直接传入所有参数 # Step 2: 完全加载 `pretrained_unet` 的权重 print("开始加载预训练权重") self.load_state_dict(pretrained_unet.state_dict(), strict=True) print("✅ UNet 预训练权重加载成功!") # Step 3: **扩展 `conv_in`通道为11 self._expand_conv_in(pretrained_unet, in_channels) self.register_to_config(in_channels=in_channels) cubemap_attn_processor=InflatedAttentionProcessor() self._modify_attn_processor(cubemap_attn_processor) def _expand_conv_in(self, pretrained_unet, new_in_channels): """扩展 `conv_in` 以适应新的输入通道""" old_conv_in = pretrained_unet.conv_in old_weight = old_conv_in.weight # [out_channels, 4, kernel, kernel] # 创建新的 `conv_in` new_conv_in = nn.Conv2d( new_in_channels, old_conv_in.out_channels, kernel_size=old_conv_in.kernel_size, stride=old_conv_in.stride, padding=old_conv_in.padding ) old_channels=old_weight.shape[1] # 复制前 `4` 个通道的权重 new_weight = torch.zeros((new_conv_in.out_channels, new_in_channels, *old_conv_in.kernel_size)) new_weight[:, :old_channels, :, :] = old_weight # 复制 4 通道 # 随机初始化新增通道 new_weight[:, old_channels:, :, :] = torch.randn_like(new_weight[:, old_channels:, :, :]) * 0.01 new_conv_in.weight = nn.Parameter(new_weight) if old_conv_in.bias is not None: new_conv_in.bias = nn.Parameter(old_conv_in.bias.clone()) self.conv_in = new_conv_in print(f"✅ `conv_in` 扩展成功!新输入通道: {new_in_channels}") def _modify_attn_processor(self,processor): for name, module in self.named_modules(): if isinstance(module, BasicTransformerBlock): if hasattr(module, 'attn1'): module.attn1.set_processor(processor) if hasattr(module, 'attn2'): module.attn2.set_processor(processor)