File size: 3,327 Bytes
c6535db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import json
import torch.nn as nn
from torch.nn import Conv2d
from torch.nn.parameter import Parameter
from diffusers.models.attention_processor import Attention, AttnProcessor
from .replace import custom_prepare_attention_mask, custom_get_attention_scores
import cv2
import torch
import numpy as np


def replace_unet_conv_in(unet, num):
    # replace the first layer to accept 8 in_channels
    _weight = unet.conv_in.weight.clone()  # [320, 4, 3, 3]
    _bias = unet.conv_in.bias.clone()  # [320]
    _weight = _weight.repeat((1, num, 1, 1))  # Keep selected channel(s)
    # half the activation magnitude
    _weight = _weight / num
    # new conv_in channel
    _n_convin_out_channel = unet.conv_in.out_channels
    _new_conv_in = Conv2d(4 * num, _n_convin_out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    _new_conv_in.weight = Parameter(_weight)
    _new_conv_in.bias = Parameter(_bias)
    unet.conv_in = _new_conv_in
    print("Unet conv_in layer is replaced")
    # replace config
    unet.config["in_channels"] = 4 * num
    print("Unet config is updated")
    return unet


def add_aux_conv_in(unet):
    aux_conv_in = nn.Conv2d(in_channels=4, out_channels=1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    aux_conv_in.weight.data[:320, :, :, :] = unet.conv_in.weight.data.clone()
    aux_conv_in.weight.data[320:, :, :, :] = 0.0
    aux_conv_in.bias.data[:320] = unet.conv_in.bias.data.clone()
    aux_conv_in.bias.data[320:] = 0.0
    unet.aux_conv_in = aux_conv_in
    print("add aux_conv_in layer for unet")
    return unet


def replace_attention_mask_method(module, residual_connection):
    if isinstance(module, Attention):
        module.processor = AttnProcessor()
        if hasattr(module, "prepare_attention_mask"):
            module.prepare_attention_mask = custom_prepare_attention_mask.__get__(module)
        if hasattr(module, "cross_attention_dim") and module.cross_attention_dim == 320:
            module.residual_connection = residual_connection
        if hasattr(module, "get_attention_scores"):
            module.get_attention_scores = custom_get_attention_scores.__get__(module)

    # 递归遍历所有子模块
    for child_name, child_module in module.named_children():
        replace_attention_mask_method(child_module, residual_connection)


erosion_kernels = [None] + [cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size)) for size in range(1, 30)]


def get_unknown_tensor_from_pred(pred, rand_width=30, train_mode=True):
    ### pred: N, 1 ,H, W
    N, C, H, W = pred.shape

    pred = pred.data.cpu().numpy()
    uncertain_area = np.ones_like(pred, dtype=np.uint8)
    uncertain_area[pred < 1.0 / 255.0] = 0
    uncertain_area[pred > 1 - 1.0 / 255.0] = 0

    for n in range(N):
        uncertain_area_ = uncertain_area[n, 0, :, :]  # H, W
        if train_mode:
            width = np.random.randint(1, rand_width)
        else:
            width = rand_width // 2
        uncertain_area_ = cv2.dilate(uncertain_area_, erosion_kernels[width])
        uncertain_area[n, 0, :, :] = uncertain_area_

    weight = np.zeros_like(uncertain_area)
    weight[uncertain_area == 1] = 1
    weight = torch.from_numpy(weight).float().cuda()
    return weight