Camellia997's picture
Upload folder using huggingface_hub
e14f899 verified
import numpy as np
import torch
import torch.nn as nn
from .diffusion_utils import list2batch
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1, ) * (len(x_shape) - 1)))
def get_phase_endpoint(index, num_teacher_timesteps=32, multiphase=8):
interval = num_teacher_timesteps // multiphase
max_endpoint = num_teacher_timesteps - interval
if index >= max_endpoint:
return max_endpoint
else:
quotient = index // interval
return quotient * interval
class EulerSolver:
def __init__(self, sigmas, timesteps=1000, euler_timesteps=50):
# sigmas: 0.0 -> 1.0, length = 1001
self.num_timesteps = timesteps
step_ratio = timesteps / euler_timesteps
euler_timesteps = np.round(np.arange(timesteps, 0, -step_ratio)).astype(np.int64) - 1 # 999,...,0
self.euler_timesteps = euler_timesteps[::-1].copy() + 1 # 1,...,1000
self.sigmas = sigmas[self.euler_timesteps] # 0.001,...,1.0
self.sigmas_prev = np.asarray(
[sigmas[0]] + sigmas[self.euler_timesteps[:-1]].tolist() # 0.000,...,0.999
)
self.sigmas_all = sigmas.copy()
self.euler_timesteps = torch.from_numpy(self.euler_timesteps).long()
self.sigmas = torch.from_numpy(self.sigmas)
self.sigmas_prev = torch.from_numpy(self.sigmas_prev)
self.sigmas_all = torch.from_numpy(self.sigmas_all)
def to(self, device):
self.euler_timesteps = self.euler_timesteps.to(device)
self.sigmas = self.sigmas.to(device)
self.sigmas_prev = self.sigmas_prev.to(device)
self.sigmas_all = self.sigmas_all.to(device)
return self
def euler_step(self, sample, model_pred, timestep_index):
sigma = extract_into_tensor(self.sigmas, timestep_index, model_pred.shape)
sigma_prev = extract_into_tensor(self.sigmas_prev, timestep_index, model_pred.shape)
x_prev = sample + (sigma_prev - sigma) * model_pred
return x_prev
def euler_step_to_target(self, sample, model_pred, timestep_index, target_timestep_index):
sigma = extract_into_tensor(self.sigmas, timestep_index, model_pred.shape)
sigma_target = extract_into_tensor(self.sigmas_prev, target_timestep_index, model_pred.shape)
x_target = sample + (sigma_target - sigma) * model_pred
return x_target
class DiscriminatorHead(nn.Module):
def __init__(self, in_channels=1280, reduced_channels=512):
super(DiscriminatorHead, self).__init__()
# Reduce channels using 1x1 convolution
self.reduce_ch_conv = nn.Conv3d(in_channels, reduced_channels, kernel_size=(1, 1, 1))
# Main convolutional layers
self.conv_layers = nn.Sequential(
nn.Conv3d(reduced_channels, reduced_channels * 2, kernel_size=(3, 3, 3), stride=(1, 2, 2)),
nn.LeakyReLU(0.2),
nn.Conv3d(reduced_channels * 2, reduced_channels * 4, kernel_size=(3, 3, 3), stride=(1, 2, 2)),
nn.LeakyReLU(0.2),
nn.Conv3d(reduced_channels * 4, reduced_channels * 8, kernel_size=(3, 3, 3), stride=(1, 2, 2)),
nn.LeakyReLU(0.2)
)
# Global pooling
self.global_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
# Fully connected layer
self.fc = nn.Linear(reduced_channels * 8, 1)
def forward(self, feature):
# Reduce channels
reduced_feature = self.reduce_ch_conv(feature)
# Apply main convolutional layers
x = self.conv_layers(reduced_feature)
# Global pooling
x = self.global_pool(x)
# Fully connected layer
x = x.view(x.size(0), -1)
out = self.fc(x)
return out
class Discriminator(nn.Module):
def __init__(
self,
num_h_per_head=1,
selected_layers=[20,30,40],
adapter_channel_dims=[1280],
):
super().__init__()
if isinstance(adapter_channel_dims, int):
adapter_channel_dims = [adapter_channel_dims]
adapter_channel_dims = adapter_channel_dims * len(selected_layers)
self.num_h_per_head = num_h_per_head
self.head_num = len(adapter_channel_dims)
self.heads = nn.ModuleList([
nn.ModuleList([DiscriminatorHead(adapter_channel) for _ in range(self.num_h_per_head)])
for adapter_channel in adapter_channel_dims
])
def forward(self, features):
outputs = []
assert len(features) == len(self.heads)
for i in range(0, len(features)):
for h in self.heads[i]:
if isinstance(features[i], list):
input_features = list2batch(features[i])
else:
input_features = features[i]
out = h(input_features)
outputs.append(out)
return outputs