File size: 4,977 Bytes
e14f899 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import numpy as np
import torch
import torch.nn as nn
from .diffusion_utils import list2batch
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1, ) * (len(x_shape) - 1)))
def get_phase_endpoint(index, num_teacher_timesteps=32, multiphase=8):
interval = num_teacher_timesteps // multiphase
max_endpoint = num_teacher_timesteps - interval
if index >= max_endpoint:
return max_endpoint
else:
quotient = index // interval
return quotient * interval
class EulerSolver:
def __init__(self, sigmas, timesteps=1000, euler_timesteps=50):
# sigmas: 0.0 -> 1.0, length = 1001
self.num_timesteps = timesteps
step_ratio = timesteps / euler_timesteps
euler_timesteps = np.round(np.arange(timesteps, 0, -step_ratio)).astype(np.int64) - 1 # 999,...,0
self.euler_timesteps = euler_timesteps[::-1].copy() + 1 # 1,...,1000
self.sigmas = sigmas[self.euler_timesteps] # 0.001,...,1.0
self.sigmas_prev = np.asarray(
[sigmas[0]] + sigmas[self.euler_timesteps[:-1]].tolist() # 0.000,...,0.999
)
self.sigmas_all = sigmas.copy()
self.euler_timesteps = torch.from_numpy(self.euler_timesteps).long()
self.sigmas = torch.from_numpy(self.sigmas)
self.sigmas_prev = torch.from_numpy(self.sigmas_prev)
self.sigmas_all = torch.from_numpy(self.sigmas_all)
def to(self, device):
self.euler_timesteps = self.euler_timesteps.to(device)
self.sigmas = self.sigmas.to(device)
self.sigmas_prev = self.sigmas_prev.to(device)
self.sigmas_all = self.sigmas_all.to(device)
return self
def euler_step(self, sample, model_pred, timestep_index):
sigma = extract_into_tensor(self.sigmas, timestep_index, model_pred.shape)
sigma_prev = extract_into_tensor(self.sigmas_prev, timestep_index, model_pred.shape)
x_prev = sample + (sigma_prev - sigma) * model_pred
return x_prev
def euler_step_to_target(self, sample, model_pred, timestep_index, target_timestep_index):
sigma = extract_into_tensor(self.sigmas, timestep_index, model_pred.shape)
sigma_target = extract_into_tensor(self.sigmas_prev, target_timestep_index, model_pred.shape)
x_target = sample + (sigma_target - sigma) * model_pred
return x_target
class DiscriminatorHead(nn.Module):
def __init__(self, in_channels=1280, reduced_channels=512):
super(DiscriminatorHead, self).__init__()
# Reduce channels using 1x1 convolution
self.reduce_ch_conv = nn.Conv3d(in_channels, reduced_channels, kernel_size=(1, 1, 1))
# Main convolutional layers
self.conv_layers = nn.Sequential(
nn.Conv3d(reduced_channels, reduced_channels * 2, kernel_size=(3, 3, 3), stride=(1, 2, 2)),
nn.LeakyReLU(0.2),
nn.Conv3d(reduced_channels * 2, reduced_channels * 4, kernel_size=(3, 3, 3), stride=(1, 2, 2)),
nn.LeakyReLU(0.2),
nn.Conv3d(reduced_channels * 4, reduced_channels * 8, kernel_size=(3, 3, 3), stride=(1, 2, 2)),
nn.LeakyReLU(0.2)
)
# Global pooling
self.global_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
# Fully connected layer
self.fc = nn.Linear(reduced_channels * 8, 1)
def forward(self, feature):
# Reduce channels
reduced_feature = self.reduce_ch_conv(feature)
# Apply main convolutional layers
x = self.conv_layers(reduced_feature)
# Global pooling
x = self.global_pool(x)
# Fully connected layer
x = x.view(x.size(0), -1)
out = self.fc(x)
return out
class Discriminator(nn.Module):
def __init__(
self,
num_h_per_head=1,
selected_layers=[20,30,40],
adapter_channel_dims=[1280],
):
super().__init__()
if isinstance(adapter_channel_dims, int):
adapter_channel_dims = [adapter_channel_dims]
adapter_channel_dims = adapter_channel_dims * len(selected_layers)
self.num_h_per_head = num_h_per_head
self.head_num = len(adapter_channel_dims)
self.heads = nn.ModuleList([
nn.ModuleList([DiscriminatorHead(adapter_channel) for _ in range(self.num_h_per_head)])
for adapter_channel in adapter_channel_dims
])
def forward(self, features):
outputs = []
assert len(features) == len(self.heads)
for i in range(0, len(features)):
for h in self.heads[i]:
if isinstance(features[i], list):
input_features = list2batch(features[i])
else:
input_features = features[i]
out = h(input_features)
outputs.append(out)
return outputs |