Spaces:
Sleeping
Sleeping
File size: 2,972 Bytes
a521a3f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import torch
import torch.nn as nn
from diffusers import UNet2DConditionModel
from diffusers.models.attention import BasicTransformerBlock
from .cubemap_unet_attention_processor import InflatedAttentionProcessor
class CubemapUNet(UNet2DConditionModel):
def __init__(self, pretrained_unet, in_channels=11):
"""
1. **先加载 `pretrained_unet` 的默认通道数**
2. **完全加载 UNet 预训练权重**
3. **扩展 `conv_in` 和 `conv_out`**
"""
# 从 `pretrained_unet.config` 复制所有参数,并修改 `in_channels`
unet_config = {**pretrained_unet.config} # 复制字典,防止修改原模型
super().__init__(**unet_config) # 这里直接传入所有参数
# Step 2: 完全加载 `pretrained_unet` 的权重
print("开始加载预训练权重")
self.load_state_dict(pretrained_unet.state_dict(), strict=True)
print("✅ UNet 预训练权重加载成功!")
# Step 3: **扩展 `conv_in`通道为11
self._expand_conv_in(pretrained_unet, in_channels)
self.register_to_config(in_channels=in_channels)
cubemap_attn_processor=InflatedAttentionProcessor()
self._modify_attn_processor(cubemap_attn_processor)
def _expand_conv_in(self, pretrained_unet, new_in_channels):
"""扩展 `conv_in` 以适应新的输入通道"""
old_conv_in = pretrained_unet.conv_in
old_weight = old_conv_in.weight # [out_channels, 4, kernel, kernel]
# 创建新的 `conv_in`
new_conv_in = nn.Conv2d(
new_in_channels,
old_conv_in.out_channels,
kernel_size=old_conv_in.kernel_size,
stride=old_conv_in.stride,
padding=old_conv_in.padding
)
old_channels=old_weight.shape[1]
# 复制前 `4` 个通道的权重
new_weight = torch.zeros((new_conv_in.out_channels, new_in_channels, *old_conv_in.kernel_size))
new_weight[:, :old_channels, :, :] = old_weight # 复制 4 通道
# 随机初始化新增通道
new_weight[:, old_channels:, :, :] = torch.randn_like(new_weight[:, old_channels:, :, :]) * 0.01
new_conv_in.weight = nn.Parameter(new_weight)
if old_conv_in.bias is not None:
new_conv_in.bias = nn.Parameter(old_conv_in.bias.clone())
self.conv_in = new_conv_in
print(f"✅ `conv_in` 扩展成功!新输入通道: {new_in_channels}")
def _modify_attn_processor(self,processor):
for name, module in self.named_modules():
if isinstance(module, BasicTransformerBlock):
if hasattr(module, 'attn1'):
module.attn1.set_processor(processor)
if hasattr(module, 'attn2'):
module.attn2.set_processor(processor)
|