Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from diffusers import UNet2DConditionModel | |
| from diffusers.models.attention import BasicTransformerBlock | |
| from .cubemap_unet_attention_processor import InflatedAttentionProcessor | |
| class CubemapUNet(UNet2DConditionModel): | |
| def __init__(self, pretrained_unet, in_channels=11): | |
| """ | |
| 1. **先加载 `pretrained_unet` 的默认通道数** | |
| 2. **完全加载 UNet 预训练权重** | |
| 3. **扩展 `conv_in` 和 `conv_out`** | |
| """ | |
| # 从 `pretrained_unet.config` 复制所有参数,并修改 `in_channels` | |
| unet_config = {**pretrained_unet.config} # 复制字典,防止修改原模型 | |
| super().__init__(**unet_config) # 这里直接传入所有参数 | |
| # Step 2: 完全加载 `pretrained_unet` 的权重 | |
| print("开始加载预训练权重") | |
| self.load_state_dict(pretrained_unet.state_dict(), strict=True) | |
| print("✅ UNet 预训练权重加载成功!") | |
| # Step 3: **扩展 `conv_in`通道为11 | |
| self._expand_conv_in(pretrained_unet, in_channels) | |
| self.register_to_config(in_channels=in_channels) | |
| cubemap_attn_processor=InflatedAttentionProcessor() | |
| self._modify_attn_processor(cubemap_attn_processor) | |
| def _expand_conv_in(self, pretrained_unet, new_in_channels): | |
| """扩展 `conv_in` 以适应新的输入通道""" | |
| old_conv_in = pretrained_unet.conv_in | |
| old_weight = old_conv_in.weight # [out_channels, 4, kernel, kernel] | |
| # 创建新的 `conv_in` | |
| new_conv_in = nn.Conv2d( | |
| new_in_channels, | |
| old_conv_in.out_channels, | |
| kernel_size=old_conv_in.kernel_size, | |
| stride=old_conv_in.stride, | |
| padding=old_conv_in.padding | |
| ) | |
| old_channels=old_weight.shape[1] | |
| # 复制前 `4` 个通道的权重 | |
| new_weight = torch.zeros((new_conv_in.out_channels, new_in_channels, *old_conv_in.kernel_size)) | |
| new_weight[:, :old_channels, :, :] = old_weight # 复制 4 通道 | |
| # 随机初始化新增通道 | |
| new_weight[:, old_channels:, :, :] = torch.randn_like(new_weight[:, old_channels:, :, :]) * 0.01 | |
| new_conv_in.weight = nn.Parameter(new_weight) | |
| if old_conv_in.bias is not None: | |
| new_conv_in.bias = nn.Parameter(old_conv_in.bias.clone()) | |
| self.conv_in = new_conv_in | |
| print(f"✅ `conv_in` 扩展成功!新输入通道: {new_in_channels}") | |
| def _modify_attn_processor(self,processor): | |
| for name, module in self.named_modules(): | |
| if isinstance(module, BasicTransformerBlock): | |
| if hasattr(module, 'attn1'): | |
| module.attn1.set_processor(processor) | |
| if hasattr(module, 'attn2'): | |
| module.attn2.set_processor(processor) | |