File size: 2,972 Bytes
a521a3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
import torch.nn as nn
from diffusers import UNet2DConditionModel
from diffusers.models.attention import BasicTransformerBlock
from .cubemap_unet_attention_processor import InflatedAttentionProcessor

class CubemapUNet(UNet2DConditionModel):
    
    def __init__(self, pretrained_unet, in_channels=11):
        """
        1. **先加载 `pretrained_unet` 的默认通道数**
        2. **完全加载 UNet 预训练权重**
        3. **扩展 `conv_in` 和 `conv_out`**
        """
        
         # 从 `pretrained_unet.config` 复制所有参数,并修改 `in_channels`
        unet_config = {**pretrained_unet.config}  # 复制字典,防止修改原模型
        
        super().__init__(**unet_config)  # 这里直接传入所有参数
        
        
        # Step 2: 完全加载 `pretrained_unet` 的权重
        print("开始加载预训练权重")
        self.load_state_dict(pretrained_unet.state_dict(), strict=True)
        print("✅ UNet 预训练权重加载成功!")
        
        # Step 3: **扩展 `conv_in`通道为11
        self._expand_conv_in(pretrained_unet, in_channels)
        self.register_to_config(in_channels=in_channels)
        
        cubemap_attn_processor=InflatedAttentionProcessor()
        
        self._modify_attn_processor(cubemap_attn_processor)
        
        
        
    def _expand_conv_in(self, pretrained_unet, new_in_channels):
        """扩展 `conv_in` 以适应新的输入通道"""
        old_conv_in = pretrained_unet.conv_in
        old_weight = old_conv_in.weight  # [out_channels, 4, kernel, kernel]

        # 创建新的 `conv_in`
        new_conv_in = nn.Conv2d(
            new_in_channels,
            old_conv_in.out_channels,
            kernel_size=old_conv_in.kernel_size,
            stride=old_conv_in.stride,
            padding=old_conv_in.padding
        )
        
        old_channels=old_weight.shape[1]

        # 复制前 `4` 个通道的权重
        new_weight = torch.zeros((new_conv_in.out_channels, new_in_channels, *old_conv_in.kernel_size))
        new_weight[:, :old_channels, :, :] = old_weight  # 复制 4 通道

        # 随机初始化新增通道
        new_weight[:, old_channels:, :, :] = torch.randn_like(new_weight[:, old_channels:, :, :]) * 0.01  

        new_conv_in.weight = nn.Parameter(new_weight)
        if old_conv_in.bias is not None:
            new_conv_in.bias = nn.Parameter(old_conv_in.bias.clone())

        self.conv_in = new_conv_in
        print(f"✅ `conv_in` 扩展成功!新输入通道: {new_in_channels}")
        
    def _modify_attn_processor(self,processor):
        
        for name, module in self.named_modules():
            if isinstance(module, BasicTransformerBlock):
                if hasattr(module, 'attn1'):
                    module.attn1.set_processor(processor)

                if hasattr(module, 'attn2'):
                    module.attn2.set_processor(processor)