File size: 4,977 Bytes
e14f899
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import numpy as np
import torch
import torch.nn as nn
from .diffusion_utils import list2batch


def extract_into_tensor(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1, ) * (len(x_shape) - 1)))

def get_phase_endpoint(index, num_teacher_timesteps=32, multiphase=8):
    interval = num_teacher_timesteps // multiphase
    max_endpoint = num_teacher_timesteps - interval
    
    if index >= max_endpoint:
        return max_endpoint

    else:
        quotient = index // interval
        return quotient * interval

class EulerSolver:
    def __init__(self, sigmas, timesteps=1000, euler_timesteps=50):
        # sigmas: 0.0 -> 1.0, length = 1001
        self.num_timesteps = timesteps

        step_ratio = timesteps / euler_timesteps
        euler_timesteps = np.round(np.arange(timesteps, 0, -step_ratio)).astype(np.int64) - 1   # 999,...,0
        self.euler_timesteps = euler_timesteps[::-1].copy() + 1 # 1,...,1000

        self.sigmas = sigmas[self.euler_timesteps]  # 0.001,...,1.0
        self.sigmas_prev = np.asarray(
            [sigmas[0]] + sigmas[self.euler_timesteps[:-1]].tolist()    # 0.000,...,0.999
        )
        self.sigmas_all = sigmas.copy()

        self.euler_timesteps = torch.from_numpy(self.euler_timesteps).long()
        self.sigmas = torch.from_numpy(self.sigmas)
        self.sigmas_prev = torch.from_numpy(self.sigmas_prev)
        self.sigmas_all = torch.from_numpy(self.sigmas_all)
        

    def to(self, device):
        self.euler_timesteps = self.euler_timesteps.to(device)
        self.sigmas = self.sigmas.to(device)
        self.sigmas_prev = self.sigmas_prev.to(device)
        self.sigmas_all = self.sigmas_all.to(device)
        return self

    def euler_step(self, sample, model_pred, timestep_index):
        sigma = extract_into_tensor(self.sigmas, timestep_index, model_pred.shape)
        sigma_prev = extract_into_tensor(self.sigmas_prev, timestep_index, model_pred.shape)
        x_prev = sample + (sigma_prev - sigma) * model_pred
        return x_prev
    
    def euler_step_to_target(self, sample, model_pred, timestep_index, target_timestep_index):
        sigma = extract_into_tensor(self.sigmas, timestep_index, model_pred.shape)
        sigma_target = extract_into_tensor(self.sigmas_prev, target_timestep_index, model_pred.shape)

        x_target = sample + (sigma_target - sigma) * model_pred
        return x_target


class DiscriminatorHead(nn.Module):
    def __init__(self, in_channels=1280, reduced_channels=512):
        super(DiscriminatorHead, self).__init__()
        
        # Reduce channels using 1x1 convolution
        self.reduce_ch_conv = nn.Conv3d(in_channels, reduced_channels, kernel_size=(1, 1, 1))
        
        # Main convolutional layers
        self.conv_layers = nn.Sequential(
            nn.Conv3d(reduced_channels, reduced_channels * 2, kernel_size=(3, 3, 3), stride=(1, 2, 2)),
            nn.LeakyReLU(0.2),
            nn.Conv3d(reduced_channels * 2, reduced_channels * 4, kernel_size=(3, 3, 3), stride=(1, 2, 2)),
            nn.LeakyReLU(0.2),
            nn.Conv3d(reduced_channels * 4, reduced_channels * 8, kernel_size=(3, 3, 3), stride=(1, 2, 2)),
            nn.LeakyReLU(0.2)
        )
        
        # Global pooling
        self.global_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
        
        # Fully connected layer
        self.fc = nn.Linear(reduced_channels * 8, 1)

    def forward(self, feature):
        # Reduce channels
        reduced_feature = self.reduce_ch_conv(feature)
        
        # Apply main convolutional layers
        x = self.conv_layers(reduced_feature)
        
        # Global pooling
        x = self.global_pool(x)

        # Fully connected layer
        x = x.view(x.size(0), -1)
        out = self.fc(x)
        
        return out
    

class Discriminator(nn.Module):

    def __init__(
        self,
        num_h_per_head=1,
        selected_layers=[20,30,40],
        adapter_channel_dims=[1280],
    ):
        super().__init__()
        if isinstance(adapter_channel_dims, int):
            adapter_channel_dims = [adapter_channel_dims]

        adapter_channel_dims = adapter_channel_dims * len(selected_layers)
        self.num_h_per_head = num_h_per_head
        self.head_num = len(adapter_channel_dims)
        self.heads = nn.ModuleList([
            nn.ModuleList([DiscriminatorHead(adapter_channel) for _ in range(self.num_h_per_head)])
            for adapter_channel in adapter_channel_dims
        ])

    def forward(self, features):
        outputs = []
        assert len(features) == len(self.heads)
        for i in range(0, len(features)):
            for h in self.heads[i]:
                if isinstance(features[i], list):
                    input_features = list2batch(features[i])
                else:
                    input_features = features[i]
                out = h(input_features)
                outputs.append(out)
        return outputs